mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Replace Pydantic Settings with TypedDict + load_settings() (#3843)
* Replace Pydantic Settings with TypedDict + load_settings() - Remove pydantic-settings dependency, add python-dotenv - Delete _pydantic.py (AFBaseSettings, HTTPsUrl) - Add _settings.py with generic load_settings() function, SecretString, type coercion, and Required field validation (SettingNotFoundError) - Convert all 13 settings classes from AFBaseSettings subclasses to TypedDict definitions with load_settings() calls - Update all consumers from attribute access to dict access - Add 20 unit tests for load_settings() covering basic loading, dotenv, SecretString, type coercion, and required field validation - Update all existing tests for new settings patterns * Fix mypy type errors from settings conversion - Fix str | None attribute access in responses_client (walrus operator) - Fix SecretString | None narrowing in bedrock (type: ignore after guard) - Convert _context_provider.py attribute access to dict access (missed file) - Fix endpoint type narrowing in search_provider and context_provider - Fix purview: str | None .rstrip(), int | None defaults, urlparse bytes * Address PR review: required_fields param, type validation, fixes - Move required field validation from TypedDict annotations (Required) to a required_fields parameter on load_settings(), enabling runtime decisions about which fields are required - Remove Required imports and restore from __future__ import annotations in ollama and foundry_local - Add _check_override_type() for deterministic ServiceInitializationError on invalid override types (e.g. dict passed for str field) - Fix all multi-exception test catches back to single exception type - Fix Ollama host=None: use .get() so None is passed through to SDK default - Fix Purview processor: use explicit is-None checks instead of or operator - Remove unused BaseModel import from openai/_shared.py - Add 4 new tests (24 total): required_fields param, type validation * Fix type validation: allow int for float fields _check_override_type now permits int values for float-typed fields, matching Python's standard numeric promotion behavior. * fix: wrap urlparse arg with str() to fix mypy bytes endswith error
This commit is contained in:
committed by
GitHub
Unverified
parent
b488158abe
commit
8457533c69
@@ -27,7 +27,7 @@ from agent_framework import (
|
||||
get_logger,
|
||||
prepare_function_call_results,
|
||||
)
|
||||
from agent_framework._pydantic import AFBaseSettings
|
||||
from agent_framework._settings import SecretString, load_settings
|
||||
from agent_framework._types import _get_data_bytes_as_str # type: ignore
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
from agent_framework.observability import ChatTelemetryLayer
|
||||
@@ -47,7 +47,7 @@ from anthropic.types.beta.beta_bash_code_execution_tool_result_error import (
|
||||
from anthropic.types.beta.beta_code_execution_tool_result_error import (
|
||||
BetaCodeExecutionToolResultError,
|
||||
)
|
||||
from pydantic import BaseModel, SecretStr, ValidationError
|
||||
from pydantic import BaseModel
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import TypedDict # type: ignore # pragma: no cover
|
||||
@@ -192,40 +192,20 @@ FINISH_REASON_MAP: dict[str, FinishReasonLiteral] = {
|
||||
}
|
||||
|
||||
|
||||
class AnthropicSettings(AFBaseSettings):
|
||||
class AnthropicSettings(TypedDict, total=False):
|
||||
"""Anthropic Project settings.
|
||||
|
||||
The settings are first loaded from environment variables with the prefix 'ANTHROPIC_'.
|
||||
If the environment variables are not found, the settings can be loaded from a .env file
|
||||
with the encoding 'utf-8'. If the settings are not found in the .env file, the settings
|
||||
are ignored; however, validation will fail alerting that the settings are missing.
|
||||
with the encoding 'utf-8'.
|
||||
|
||||
Keyword Args:
|
||||
Keys:
|
||||
api_key: The Anthropic API key.
|
||||
chat_model_id: The Anthropic chat model ID.
|
||||
env_file_path: If provided, the .env settings are read from this file path location.
|
||||
env_file_encoding: The encoding of the .env file, defaults to 'utf-8'.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
from agent_framework.anthropic import AnthropicSettings
|
||||
|
||||
# Using environment variables
|
||||
# Set ANTHROPIC_API_KEY=your_anthropic_api_key
|
||||
# ANTHROPIC_CHAT_MODEL_ID=claude-sonnet-4-5-20250929
|
||||
|
||||
# Or passing parameters directly
|
||||
settings = AnthropicSettings(chat_model_id="claude-sonnet-4-5-20250929")
|
||||
|
||||
# Or loading from a .env file
|
||||
settings = AnthropicSettings(env_file_path="path/to/.env")
|
||||
"""
|
||||
|
||||
env_prefix: ClassVar[str] = "ANTHROPIC_"
|
||||
|
||||
api_key: SecretStr | None = None
|
||||
chat_model_id: str | None = None
|
||||
api_key: SecretString | None
|
||||
chat_model_id: str | None
|
||||
|
||||
|
||||
class AnthropicClient(
|
||||
@@ -311,25 +291,24 @@ class AnthropicClient(
|
||||
response = await client.get_response("Hello", options={"my_custom_option": "value"})
|
||||
|
||||
"""
|
||||
try:
|
||||
anthropic_settings = AnthropicSettings(
|
||||
api_key=api_key, # type: ignore[arg-type]
|
||||
chat_model_id=model_id,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
except ValidationError as ex:
|
||||
raise ServiceInitializationError("Failed to create Anthropic settings.", ex) from ex
|
||||
anthropic_settings = load_settings(
|
||||
AnthropicSettings,
|
||||
env_prefix="ANTHROPIC_",
|
||||
api_key=api_key,
|
||||
chat_model_id=model_id,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
|
||||
if anthropic_client is None:
|
||||
if not anthropic_settings.api_key:
|
||||
if not anthropic_settings["api_key"]:
|
||||
raise ServiceInitializationError(
|
||||
"Anthropic API key is required. Set via 'api_key' parameter "
|
||||
"or 'ANTHROPIC_API_KEY' environment variable."
|
||||
)
|
||||
|
||||
anthropic_client = AsyncAnthropic(
|
||||
api_key=anthropic_settings.api_key.get_secret_value(),
|
||||
api_key=anthropic_settings["api_key"].get_secret_value(),
|
||||
default_headers={"User-Agent": AGENT_FRAMEWORK_USER_AGENT},
|
||||
)
|
||||
|
||||
@@ -343,7 +322,7 @@ class AnthropicClient(
|
||||
# Initialize instance variables
|
||||
self.anthropic_client = anthropic_client
|
||||
self.additional_beta_flags = additional_beta_flags or []
|
||||
self.model_id = anthropic_settings.chat_model_id
|
||||
self.model_id = anthropic_settings["chat_model_id"]
|
||||
# streaming requires tracking the last function call ID and name
|
||||
self._last_call_id_name: tuple[str, str] | None = None
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ from agent_framework import (
|
||||
SupportsChatGetResponse,
|
||||
tool,
|
||||
)
|
||||
from agent_framework._settings import load_settings
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
from anthropic.types.beta import (
|
||||
BetaMessage,
|
||||
@@ -20,7 +21,7 @@ from anthropic.types.beta import (
|
||||
BetaToolUseBlock,
|
||||
BetaUsage,
|
||||
)
|
||||
from pydantic import Field, ValidationError
|
||||
from pydantic import Field
|
||||
|
||||
from agent_framework_anthropic import AnthropicClient
|
||||
from agent_framework_anthropic._chat_client import AnthropicSettings
|
||||
@@ -41,8 +42,12 @@ def create_test_anthropic_client(
|
||||
) -> AnthropicClient:
|
||||
"""Helper function to create AnthropicClient instances for testing, bypassing normal validation."""
|
||||
if anthropic_settings is None:
|
||||
anthropic_settings = AnthropicSettings(
|
||||
api_key="test-api-key-12345", chat_model_id="claude-3-5-sonnet-20241022", env_file_path="test.env"
|
||||
anthropic_settings = load_settings(
|
||||
AnthropicSettings,
|
||||
env_prefix="ANTHROPIC_",
|
||||
api_key="test-api-key-12345",
|
||||
chat_model_id="claude-3-5-sonnet-20241022",
|
||||
env_file_path="test.env",
|
||||
)
|
||||
|
||||
# Create client instance directly
|
||||
@@ -50,7 +55,7 @@ def create_test_anthropic_client(
|
||||
|
||||
# Set attributes directly
|
||||
client.anthropic_client = mock_anthropic_client
|
||||
client.model_id = model_id or anthropic_settings.chat_model_id
|
||||
client.model_id = model_id or anthropic_settings["chat_model_id"]
|
||||
client._last_call_id_name = None
|
||||
client.additional_properties = {}
|
||||
client.middleware = None
|
||||
@@ -64,30 +69,34 @@ def create_test_anthropic_client(
|
||||
|
||||
def test_anthropic_settings_init(anthropic_unit_test_env: dict[str, str]) -> None:
|
||||
"""Test AnthropicSettings initialization."""
|
||||
settings = AnthropicSettings(env_file_path="test.env")
|
||||
settings = load_settings(AnthropicSettings, env_prefix="ANTHROPIC_", env_file_path="test.env")
|
||||
|
||||
assert settings.api_key is not None
|
||||
assert settings.api_key.get_secret_value() == anthropic_unit_test_env["ANTHROPIC_API_KEY"]
|
||||
assert settings.chat_model_id == anthropic_unit_test_env["ANTHROPIC_CHAT_MODEL_ID"]
|
||||
assert settings["api_key"] is not None
|
||||
assert settings["api_key"].get_secret_value() == anthropic_unit_test_env["ANTHROPIC_API_KEY"]
|
||||
assert settings["chat_model_id"] == anthropic_unit_test_env["ANTHROPIC_CHAT_MODEL_ID"]
|
||||
|
||||
|
||||
def test_anthropic_settings_init_with_explicit_values() -> None:
|
||||
"""Test AnthropicSettings initialization with explicit values."""
|
||||
settings = AnthropicSettings(
|
||||
api_key="custom-api-key", chat_model_id="claude-3-opus-20240229", env_file_path="test.env"
|
||||
settings = load_settings(
|
||||
AnthropicSettings,
|
||||
env_prefix="ANTHROPIC_",
|
||||
api_key="custom-api-key",
|
||||
chat_model_id="claude-3-opus-20240229",
|
||||
env_file_path="test.env",
|
||||
)
|
||||
|
||||
assert settings.api_key is not None
|
||||
assert settings.api_key.get_secret_value() == "custom-api-key"
|
||||
assert settings.chat_model_id == "claude-3-opus-20240229"
|
||||
assert settings["api_key"] is not None
|
||||
assert settings["api_key"].get_secret_value() == "custom-api-key"
|
||||
assert settings["chat_model_id"] == "claude-3-opus-20240229"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("exclude_list", [["ANTHROPIC_API_KEY"]], indirect=True)
|
||||
def test_anthropic_settings_missing_api_key(anthropic_unit_test_env: dict[str, str]) -> None:
|
||||
"""Test AnthropicSettings when API key is missing."""
|
||||
settings = AnthropicSettings(env_file_path="test.env")
|
||||
assert settings.api_key is None
|
||||
assert settings.chat_model_id == anthropic_unit_test_env["ANTHROPIC_CHAT_MODEL_ID"]
|
||||
settings = load_settings(AnthropicSettings, env_prefix="ANTHROPIC_", env_file_path="test.env")
|
||||
assert settings["api_key"] is None
|
||||
assert settings["chat_model_id"] == anthropic_unit_test_env["ANTHROPIC_CHAT_MODEL_ID"]
|
||||
|
||||
|
||||
# Client Initialization Tests
|
||||
@@ -116,23 +125,13 @@ def test_anthropic_client_init_auto_create_client(anthropic_unit_test_env: dict[
|
||||
|
||||
def test_anthropic_client_init_missing_api_key() -> None:
|
||||
"""Test AnthropicClient initialization when API key is missing."""
|
||||
with patch("agent_framework_anthropic._chat_client.AnthropicSettings") as mock_settings:
|
||||
mock_settings.return_value.api_key = None
|
||||
mock_settings.return_value.chat_model_id = "claude-3-5-sonnet-20241022"
|
||||
with patch("agent_framework_anthropic._chat_client.load_settings") as mock_load:
|
||||
mock_load.return_value = {"api_key": None, "chat_model_id": "claude-3-5-sonnet-20241022"}
|
||||
|
||||
with pytest.raises(ServiceInitializationError, match="Anthropic API key is required"):
|
||||
AnthropicClient()
|
||||
|
||||
|
||||
def test_anthropic_client_init_validation_error() -> None:
|
||||
"""Test that ValidationError in AnthropicSettings is properly handled."""
|
||||
with patch("agent_framework_anthropic._chat_client.AnthropicSettings") as mock_settings:
|
||||
mock_settings.side_effect = ValidationError.from_exception_data("test", [])
|
||||
|
||||
with pytest.raises(ServiceInitializationError, match="Failed to create Anthropic settings"):
|
||||
AnthropicClient()
|
||||
|
||||
|
||||
def test_anthropic_client_service_url(mock_anthropic_client: MagicMock) -> None:
|
||||
"""Test service_url method."""
|
||||
client = create_test_anthropic_client(mock_anthropic_client)
|
||||
|
||||
+23
-24
@@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Literal
|
||||
from agent_framework import AGENT_FRAMEWORK_USER_AGENT, Message
|
||||
from agent_framework._logging import get_logger
|
||||
from agent_framework._sessions import AgentSession, BaseContextProvider, SessionContext
|
||||
from agent_framework._settings import load_settings
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
from azure.core.credentials_async import AsyncTokenCredential
|
||||
@@ -41,7 +42,6 @@ from azure.search.documents.models import (
|
||||
VectorizableTextQuery,
|
||||
VectorizedQuery,
|
||||
)
|
||||
from pydantic import ValidationError
|
||||
|
||||
from ._search_provider import AzureAISearchSettings
|
||||
|
||||
@@ -180,40 +180,39 @@ class _AzureAISearchContextProvider(BaseContextProvider):
|
||||
super().__init__(source_id)
|
||||
|
||||
# Load settings from environment/file
|
||||
try:
|
||||
settings = AzureAISearchSettings(
|
||||
endpoint=endpoint,
|
||||
index_name=index_name,
|
||||
knowledge_base_name=knowledge_base_name,
|
||||
api_key=api_key if isinstance(api_key, str) else None,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
except ValidationError as ex:
|
||||
raise ServiceInitializationError("Failed to create Azure AI Search settings.", ex) from ex
|
||||
settings = load_settings(
|
||||
AzureAISearchSettings,
|
||||
env_prefix="AZURE_SEARCH_",
|
||||
endpoint=endpoint,
|
||||
index_name=index_name,
|
||||
knowledge_base_name=knowledge_base_name,
|
||||
api_key=api_key if isinstance(api_key, str) else None,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
|
||||
if not settings.endpoint:
|
||||
if not settings.get("endpoint"):
|
||||
raise ServiceInitializationError(
|
||||
"Azure AI Search endpoint is required. Set via 'endpoint' parameter "
|
||||
"or 'AZURE_SEARCH_ENDPOINT' environment variable."
|
||||
)
|
||||
|
||||
if mode == "semantic":
|
||||
if not settings.index_name:
|
||||
if not settings.get("index_name"):
|
||||
raise ServiceInitializationError(
|
||||
"Azure AI Search index name is required for semantic mode. "
|
||||
"Set via 'index_name' parameter or 'AZURE_SEARCH_INDEX_NAME' environment variable."
|
||||
)
|
||||
elif mode == "agentic":
|
||||
if settings.index_name and settings.knowledge_base_name:
|
||||
if settings.get("index_name") and settings.get("knowledge_base_name"):
|
||||
raise ServiceInitializationError(
|
||||
"For agentic mode, provide either 'index_name' OR 'knowledge_base_name', not both."
|
||||
)
|
||||
if not settings.index_name and not settings.knowledge_base_name:
|
||||
if not settings.get("index_name") and not settings.get("knowledge_base_name"):
|
||||
raise ServiceInitializationError(
|
||||
"For agentic mode, provide either 'index_name' or 'knowledge_base_name'."
|
||||
)
|
||||
if settings.index_name and not model_deployment_name:
|
||||
if settings.get("index_name") and not model_deployment_name:
|
||||
raise ServiceInitializationError(
|
||||
"model_deployment_name is required for agentic mode when creating Knowledge Base from index."
|
||||
)
|
||||
@@ -223,16 +222,16 @@ class _AzureAISearchContextProvider(BaseContextProvider):
|
||||
resolved_credential = credential
|
||||
elif isinstance(api_key, AzureKeyCredential):
|
||||
resolved_credential = api_key
|
||||
elif settings.api_key:
|
||||
resolved_credential = AzureKeyCredential(settings.api_key.get_secret_value())
|
||||
elif settings.get("api_key"):
|
||||
resolved_credential = AzureKeyCredential(settings["api_key"].get_secret_value()) # type: ignore[union-attr]
|
||||
else:
|
||||
raise ServiceInitializationError(
|
||||
"Azure credential is required. Provide 'api_key' or 'credential' parameter "
|
||||
"or set 'AZURE_SEARCH_API_KEY' environment variable."
|
||||
)
|
||||
|
||||
self.endpoint = settings.endpoint
|
||||
self.index_name = settings.index_name
|
||||
self.endpoint: str = settings["endpoint"] # type: ignore[assignment] # validated above
|
||||
self.index_name = settings.get("index_name")
|
||||
self.credential = resolved_credential
|
||||
self.mode = mode
|
||||
self.top_k = top_k
|
||||
@@ -244,7 +243,7 @@ class _AzureAISearchContextProvider(BaseContextProvider):
|
||||
self.azure_openai_resource_url = azure_openai_resource_url
|
||||
self.azure_openai_deployment_name = model_deployment_name
|
||||
self.model_name = model_name or model_deployment_name
|
||||
self.knowledge_base_name = settings.knowledge_base_name
|
||||
self.knowledge_base_name = settings.get("knowledge_base_name")
|
||||
self.retrieval_instructions = retrieval_instructions
|
||||
self.azure_openai_api_key = azure_openai_api_key
|
||||
self.knowledge_base_output_mode = knowledge_base_output_mode
|
||||
@@ -253,10 +252,10 @@ class _AzureAISearchContextProvider(BaseContextProvider):
|
||||
|
||||
self._use_existing_knowledge_base = False
|
||||
if mode == "agentic":
|
||||
if settings.knowledge_base_name:
|
||||
if settings.get("knowledge_base_name"):
|
||||
self._use_existing_knowledge_base = True
|
||||
else:
|
||||
self.knowledge_base_name = f"{settings.index_name}-kb"
|
||||
self.knowledge_base_name = f"{settings.get('index_name', '')}-kb"
|
||||
|
||||
self._auto_discovered_vector_field = False
|
||||
self._use_vectorizable_query = False
|
||||
|
||||
+32
-36
@@ -5,11 +5,11 @@ from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from collections.abc import Awaitable, Callable, MutableSequence
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Literal
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from agent_framework import AGENT_FRAMEWORK_USER_AGENT, Context, ContextProvider, Message
|
||||
from agent_framework._logging import get_logger
|
||||
from agent_framework._pydantic import AFBaseSettings
|
||||
from agent_framework._settings import SecretString, load_settings
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
from azure.core.credentials_async import AsyncTokenCredential
|
||||
@@ -35,7 +35,6 @@ from azure.search.documents.models import (
|
||||
VectorizableTextQuery,
|
||||
VectorizedQuery,
|
||||
)
|
||||
from pydantic import SecretStr, ValidationError
|
||||
|
||||
# Type checking imports for optional agentic mode dependencies
|
||||
if TYPE_CHECKING:
|
||||
@@ -99,9 +98,9 @@ else:
|
||||
from typing_extensions import override # type: ignore[import] # pragma: no cover
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import Self # pragma: no cover
|
||||
from typing import Self, TypedDict # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import Self # pragma: no cover
|
||||
from typing_extensions import Self, TypedDict # pragma: no cover
|
||||
|
||||
"""Azure AI Search Context Provider for Agent Framework.
|
||||
|
||||
@@ -120,7 +119,7 @@ logger = get_logger("agent_framework.azure")
|
||||
_DEFAULT_AGENTIC_MESSAGE_HISTORY_COUNT = 10
|
||||
|
||||
|
||||
class AzureAISearchSettings(AFBaseSettings):
|
||||
class AzureAISearchSettings(TypedDict, total=False):
|
||||
"""Settings for Azure AI Search Context Provider with auto-loading from environment.
|
||||
|
||||
The settings are first loaded from environment variables with the prefix 'AZURE_SEARCH_'.
|
||||
@@ -158,12 +157,10 @@ class AzureAISearchSettings(AFBaseSettings):
|
||||
settings = AzureAISearchSettings(env_file_path="path/to/.env")
|
||||
"""
|
||||
|
||||
env_prefix: ClassVar[str] = "AZURE_SEARCH_"
|
||||
|
||||
endpoint: str | None = None
|
||||
index_name: str | None = None
|
||||
knowledge_base_name: str | None = None
|
||||
api_key: SecretStr | None = None
|
||||
endpoint: str | None
|
||||
index_name: str | None
|
||||
knowledge_base_name: str | None
|
||||
api_key: SecretString | None
|
||||
|
||||
|
||||
class AzureAISearchContextProvider(ContextProvider):
|
||||
@@ -336,42 +333,41 @@ class AzureAISearchContextProvider(ContextProvider):
|
||||
provider = AzureAISearchContextProvider(credential=credential, env_file_path="path/to/.env")
|
||||
"""
|
||||
# Load settings from environment/file
|
||||
try:
|
||||
settings = AzureAISearchSettings(
|
||||
endpoint=endpoint,
|
||||
index_name=index_name,
|
||||
knowledge_base_name=knowledge_base_name,
|
||||
api_key=api_key if isinstance(api_key, str) else None,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
except ValidationError as ex:
|
||||
raise ServiceInitializationError("Failed to create Azure AI Search settings.", ex) from ex
|
||||
settings = load_settings(
|
||||
AzureAISearchSettings,
|
||||
env_prefix="AZURE_SEARCH_",
|
||||
endpoint=endpoint,
|
||||
index_name=index_name,
|
||||
knowledge_base_name=knowledge_base_name,
|
||||
api_key=api_key if isinstance(api_key, str) else None,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
|
||||
# Validate required parameters
|
||||
if not settings.endpoint:
|
||||
if not settings.get("endpoint"):
|
||||
raise ServiceInitializationError(
|
||||
"Azure AI Search endpoint is required. Set via 'endpoint' parameter "
|
||||
"or 'AZURE_SEARCH_ENDPOINT' environment variable."
|
||||
)
|
||||
|
||||
# Validate index_name and knowledge_base_name based on mode
|
||||
# Note: settings.* contains the resolved value (explicit param OR env var)
|
||||
# Note: settings["field"] / settings.get("field") contains the resolved value (explicit param OR env var)
|
||||
if mode == "semantic":
|
||||
# Semantic mode: always requires index_name
|
||||
if not settings.index_name:
|
||||
if not settings.get("index_name"):
|
||||
raise ServiceInitializationError(
|
||||
"Azure AI Search index name is required for semantic mode. "
|
||||
"Set via 'index_name' parameter or 'AZURE_SEARCH_INDEX_NAME' environment variable."
|
||||
)
|
||||
elif mode == "agentic":
|
||||
# Agentic mode: requires exactly ONE of index_name or knowledge_base_name
|
||||
if settings.index_name and settings.knowledge_base_name:
|
||||
if settings.get("index_name") and settings.get("knowledge_base_name"):
|
||||
raise ServiceInitializationError(
|
||||
"For agentic mode, provide either 'index_name' OR 'knowledge_base_name', not both. "
|
||||
"Use 'index_name' to auto-create a Knowledge Base, or 'knowledge_base_name' to use an existing one."
|
||||
)
|
||||
if not settings.index_name and not settings.knowledge_base_name:
|
||||
if not settings.get("index_name") and not settings.get("knowledge_base_name"):
|
||||
raise ServiceInitializationError(
|
||||
"For agentic mode, provide either 'index_name' (to auto-create Knowledge Base) "
|
||||
"or 'knowledge_base_name' (to use existing Knowledge Base). "
|
||||
@@ -379,7 +375,7 @@ class AzureAISearchContextProvider(ContextProvider):
|
||||
"AZURE_SEARCH_INDEX_NAME / AZURE_SEARCH_KNOWLEDGE_BASE_NAME."
|
||||
)
|
||||
# If using index_name to create KB, model config is required
|
||||
if settings.index_name and not model_deployment_name:
|
||||
if settings.get("index_name") and not model_deployment_name:
|
||||
raise ServiceInitializationError(
|
||||
"model_deployment_name is required for agentic mode when creating Knowledge Base from index. "
|
||||
"This is the Azure OpenAI deployment used by the Knowledge Base for query planning."
|
||||
@@ -392,16 +388,16 @@ class AzureAISearchContextProvider(ContextProvider):
|
||||
resolved_credential = credential
|
||||
elif isinstance(api_key, AzureKeyCredential):
|
||||
resolved_credential = api_key
|
||||
elif settings.api_key:
|
||||
resolved_credential = AzureKeyCredential(settings.api_key.get_secret_value())
|
||||
elif resolved_api_key := settings.get("api_key"):
|
||||
resolved_credential = AzureKeyCredential(resolved_api_key.get_secret_value())
|
||||
else:
|
||||
raise ServiceInitializationError(
|
||||
"Azure credential is required. Provide 'api_key' or 'credential' parameter "
|
||||
"or set 'AZURE_SEARCH_API_KEY' environment variable."
|
||||
)
|
||||
|
||||
self.endpoint = settings.endpoint
|
||||
self.index_name = settings.index_name
|
||||
self.endpoint: str = settings["endpoint"] # type: ignore[assignment] # validated above
|
||||
self.index_name = settings.get("index_name")
|
||||
self.credential = resolved_credential
|
||||
self.mode = mode
|
||||
self.top_k = top_k
|
||||
@@ -416,7 +412,7 @@ class AzureAISearchContextProvider(ContextProvider):
|
||||
# If model_name not provided, default to deployment name
|
||||
self.model_name = model_name or model_deployment_name
|
||||
# Use resolved KB name (from explicit param or env var)
|
||||
self.knowledge_base_name = settings.knowledge_base_name
|
||||
self.knowledge_base_name = settings.get("knowledge_base_name")
|
||||
self.retrieval_instructions = retrieval_instructions
|
||||
self.azure_openai_api_key = azure_openai_api_key
|
||||
self.knowledge_base_output_mode = knowledge_base_output_mode
|
||||
@@ -429,12 +425,12 @@ class AzureAISearchContextProvider(ContextProvider):
|
||||
# - index_name provided: auto-create KB from index
|
||||
self._use_existing_knowledge_base = False
|
||||
if mode == "agentic":
|
||||
if settings.knowledge_base_name:
|
||||
if settings.get("knowledge_base_name"):
|
||||
# Use existing KB directly (supports any source type: web, blob, index, etc.)
|
||||
self._use_existing_knowledge_base = True
|
||||
else:
|
||||
# Auto-generate KB name from index name
|
||||
self.knowledge_base_name = f"{settings.index_name}-kb"
|
||||
self.knowledge_base_name = f"{settings.get('index_name', '')}-kb"
|
||||
|
||||
# Auto-discover vector field if not specified
|
||||
self._auto_discovered_vector_field = False
|
||||
|
||||
@@ -6,6 +6,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from agent_framework import Context, Message
|
||||
from agent_framework._settings import load_settings
|
||||
from agent_framework.azure import AzureAISearchContextProvider, AzureAISearchSettings
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
@@ -48,25 +49,27 @@ class TestAzureAISearchSettings:
|
||||
|
||||
def test_settings_with_direct_values(self) -> None:
|
||||
"""Test settings with direct values."""
|
||||
settings = AzureAISearchSettings(
|
||||
settings = load_settings(
|
||||
AzureAISearchSettings,
|
||||
env_prefix="AZURE_SEARCH_",
|
||||
endpoint="https://test.search.windows.net",
|
||||
index_name="test-index",
|
||||
api_key="test-key",
|
||||
)
|
||||
assert settings.endpoint == "https://test.search.windows.net"
|
||||
assert settings.index_name == "test-index"
|
||||
# api_key is now SecretStr
|
||||
assert settings.api_key.get_secret_value() == "test-key"
|
||||
assert settings["endpoint"] == "https://test.search.windows.net"
|
||||
assert settings["index_name"] == "test-index"
|
||||
assert settings["api_key"] == "test-key"
|
||||
|
||||
def test_settings_with_env_file_path(self) -> None:
|
||||
"""Test settings with env_file_path parameter."""
|
||||
settings = AzureAISearchSettings(
|
||||
settings = load_settings(
|
||||
AzureAISearchSettings,
|
||||
env_prefix="AZURE_SEARCH_",
|
||||
endpoint="https://test.search.windows.net",
|
||||
index_name="test-index",
|
||||
env_file_path="test.env",
|
||||
)
|
||||
assert settings.endpoint == "https://test.search.windows.net"
|
||||
assert settings.index_name == "test-index"
|
||||
assert settings["endpoint"] == "https://test.search.windows.net"
|
||||
assert settings["index_name"] == "test-index"
|
||||
|
||||
def test_provider_uses_settings_from_env(self) -> None:
|
||||
"""Test that provider creates settings internally from env."""
|
||||
|
||||
@@ -15,12 +15,13 @@ from agent_framework import (
|
||||
normalize_tools,
|
||||
)
|
||||
from agent_framework._mcp import MCPTool
|
||||
from agent_framework._settings import load_settings
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
from azure.ai.agents.aio import AgentsClient
|
||||
from azure.ai.agents.models import Agent as AzureAgent
|
||||
from azure.ai.agents.models import ResponseFormatJsonSchema, ResponseFormatJsonSchemaType
|
||||
from azure.core.credentials_async import AsyncTokenCredential
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._chat_client import AzureAIAgentClient, AzureAIAgentOptions
|
||||
from ._shared import AzureAISettings, from_azure_ai_agent_tools, to_azure_ai_agent_tools
|
||||
@@ -112,21 +113,21 @@ class AzureAIAgentsProvider(Generic[OptionsCoT]):
|
||||
Raises:
|
||||
ServiceInitializationError: If required parameters are missing or invalid.
|
||||
"""
|
||||
try:
|
||||
self._settings = AzureAISettings(
|
||||
project_endpoint=project_endpoint,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
except ValidationError as ex:
|
||||
raise ServiceInitializationError("Failed to create Azure AI settings.", ex) from ex
|
||||
self._settings = load_settings(
|
||||
AzureAISettings,
|
||||
env_prefix="AZURE_AI_",
|
||||
project_endpoint=project_endpoint,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
|
||||
self._should_close_client = False
|
||||
|
||||
if agents_client is not None:
|
||||
self._agents_client = agents_client
|
||||
else:
|
||||
if not self._settings.project_endpoint:
|
||||
resolved_endpoint = self._settings.get("project_endpoint")
|
||||
if not resolved_endpoint:
|
||||
raise ServiceInitializationError(
|
||||
"Azure AI project endpoint is required. Provide 'project_endpoint' parameter "
|
||||
"or set 'AZURE_AI_PROJECT_ENDPOINT' environment variable."
|
||||
@@ -134,7 +135,7 @@ class AzureAIAgentsProvider(Generic[OptionsCoT]):
|
||||
if not credential:
|
||||
raise ServiceInitializationError("Azure credential is required when agents_client is not provided.")
|
||||
self._agents_client = AgentsClient(
|
||||
endpoint=self._settings.project_endpoint,
|
||||
endpoint=resolved_endpoint,
|
||||
credential=credential,
|
||||
user_agent=AGENT_FRAMEWORK_USER_AGENT,
|
||||
)
|
||||
@@ -211,7 +212,7 @@ class AzureAIAgentsProvider(Generic[OptionsCoT]):
|
||||
tools=get_weather,
|
||||
)
|
||||
"""
|
||||
resolved_model = model or self._settings.model_deployment_name
|
||||
resolved_model = model or self._settings.get("model_deployment_name")
|
||||
if not resolved_model:
|
||||
raise ServiceInitializationError(
|
||||
"Model deployment name is required. Provide 'model' parameter "
|
||||
|
||||
@@ -35,6 +35,7 @@ from agent_framework import (
|
||||
get_logger,
|
||||
prepare_function_call_results,
|
||||
)
|
||||
from agent_framework._settings import load_settings
|
||||
from agent_framework.exceptions import ServiceInitializationError, ServiceInvalidRequestError, ServiceResponseException
|
||||
from agent_framework.observability import ChatTelemetryLayer
|
||||
from azure.ai.agents.aio import AgentsClient
|
||||
@@ -85,7 +86,7 @@ from azure.ai.agents.models import (
|
||||
ToolOutput,
|
||||
)
|
||||
from azure.core.credentials_async import AsyncTokenCredential
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._shared import AzureAISettings, to_azure_ai_agent_tools
|
||||
|
||||
@@ -482,26 +483,26 @@ class AzureAIAgentClient(
|
||||
client: AzureAIAgentClient[MyOptions] = AzureAIAgentClient(credential=credential)
|
||||
response = await client.get_response("Hello", options={"my_custom_option": "value"})
|
||||
"""
|
||||
try:
|
||||
azure_ai_settings = AzureAISettings(
|
||||
project_endpoint=project_endpoint,
|
||||
model_deployment_name=model_deployment_name,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
except ValidationError as ex:
|
||||
raise ServiceInitializationError("Failed to create Azure AI settings.", ex) from ex
|
||||
azure_ai_settings = load_settings(
|
||||
AzureAISettings,
|
||||
env_prefix="AZURE_AI_",
|
||||
project_endpoint=project_endpoint,
|
||||
model_deployment_name=model_deployment_name,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
|
||||
# If no agents_client is provided, create one
|
||||
should_close_client = False
|
||||
if agents_client is None:
|
||||
if not azure_ai_settings.project_endpoint:
|
||||
resolved_endpoint = azure_ai_settings.get("project_endpoint")
|
||||
if not resolved_endpoint:
|
||||
raise ServiceInitializationError(
|
||||
"Azure AI project endpoint is required. Set via 'project_endpoint' parameter "
|
||||
"or 'AZURE_AI_PROJECT_ENDPOINT' environment variable."
|
||||
)
|
||||
|
||||
if agent_id is None and not azure_ai_settings.model_deployment_name:
|
||||
if agent_id is None and not azure_ai_settings.get("model_deployment_name"):
|
||||
raise ServiceInitializationError(
|
||||
"Azure AI model deployment name is required. Set via 'model_deployment_name' parameter "
|
||||
"or 'AZURE_AI_MODEL_DEPLOYMENT_NAME' environment variable."
|
||||
@@ -511,7 +512,7 @@ class AzureAIAgentClient(
|
||||
if not credential:
|
||||
raise ServiceInitializationError("Azure credential is required when agents_client is not provided.")
|
||||
agents_client = AgentsClient(
|
||||
endpoint=azure_ai_settings.project_endpoint,
|
||||
endpoint=resolved_endpoint,
|
||||
credential=credential,
|
||||
user_agent=AGENT_FRAMEWORK_USER_AGENT,
|
||||
)
|
||||
@@ -530,7 +531,7 @@ class AzureAIAgentClient(
|
||||
self.agent_id = agent_id
|
||||
self.agent_name = agent_name
|
||||
self.agent_description = agent_description
|
||||
self.model_id = azure_ai_settings.model_deployment_name
|
||||
self.model_id = azure_ai_settings.get("model_deployment_name")
|
||||
self.thread_id = thread_id
|
||||
self.should_cleanup_agent = should_cleanup_agent # Track whether we should delete the agent
|
||||
self._agent_created = False # Track whether agent was created inside this class
|
||||
|
||||
@@ -20,6 +20,7 @@ from agent_framework import (
|
||||
MiddlewareTypes,
|
||||
get_logger,
|
||||
)
|
||||
from agent_framework._settings import load_settings
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
from agent_framework.observability import ChatTelemetryLayer
|
||||
from agent_framework.openai import OpenAIResponsesOptions
|
||||
@@ -40,7 +41,6 @@ from azure.ai.projects.models import (
|
||||
from azure.ai.projects.models import FileSearchTool as ProjectsFileSearchTool
|
||||
from azure.core.credentials_async import AsyncTokenCredential
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
from pydantic import ValidationError
|
||||
|
||||
from ._shared import AzureAISettings, create_text_format_config
|
||||
|
||||
@@ -171,20 +171,20 @@ class RawAzureAIClient(RawOpenAIResponsesClient[AzureAIClientOptionsT], Generic[
|
||||
client: AzureAIClient[MyOptions] = AzureAIClient(credential=credential)
|
||||
response = await client.get_response("Hello", options={"my_custom_option": "value"})
|
||||
"""
|
||||
try:
|
||||
azure_ai_settings = AzureAISettings(
|
||||
project_endpoint=project_endpoint,
|
||||
model_deployment_name=model_deployment_name,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
except ValidationError as ex:
|
||||
raise ServiceInitializationError("Failed to create Azure AI settings.", ex) from ex
|
||||
azure_ai_settings = load_settings(
|
||||
AzureAISettings,
|
||||
env_prefix="AZURE_AI_",
|
||||
project_endpoint=project_endpoint,
|
||||
model_deployment_name=model_deployment_name,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
|
||||
# If no project_client is provided, create one
|
||||
should_close_client = False
|
||||
if project_client is None:
|
||||
if not azure_ai_settings.project_endpoint:
|
||||
resolved_endpoint = azure_ai_settings.get("project_endpoint")
|
||||
if not resolved_endpoint:
|
||||
raise ServiceInitializationError(
|
||||
"Azure AI project endpoint is required. Set via 'project_endpoint' parameter "
|
||||
"or 'AZURE_AI_PROJECT_ENDPOINT' environment variable."
|
||||
@@ -194,7 +194,7 @@ class RawAzureAIClient(RawOpenAIResponsesClient[AzureAIClientOptionsT], Generic[
|
||||
if not credential:
|
||||
raise ServiceInitializationError("Azure credential is required when project_client is not provided.")
|
||||
project_client = AIProjectClient(
|
||||
endpoint=azure_ai_settings.project_endpoint,
|
||||
endpoint=resolved_endpoint,
|
||||
credential=credential,
|
||||
user_agent=AGENT_FRAMEWORK_USER_AGENT,
|
||||
)
|
||||
@@ -212,7 +212,7 @@ class RawAzureAIClient(RawOpenAIResponsesClient[AzureAIClientOptionsT], Generic[
|
||||
self.use_latest_version = use_latest_version
|
||||
self.project_client = project_client
|
||||
self.credential = credential
|
||||
self.model_id = azure_ai_settings.model_deployment_name
|
||||
self.model_id = azure_ai_settings.get("model_deployment_name")
|
||||
self.conversation_id = conversation_id
|
||||
|
||||
# Track whether the application endpoint is used
|
||||
|
||||
@@ -16,6 +16,7 @@ from agent_framework import (
|
||||
normalize_tools,
|
||||
)
|
||||
from agent_framework._mcp import MCPTool
|
||||
from agent_framework._settings import load_settings
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
from azure.ai.projects.aio import AIProjectClient
|
||||
from azure.ai.projects.models import (
|
||||
@@ -28,7 +29,6 @@ from azure.ai.projects.models import (
|
||||
FunctionTool as AzureFunctionTool,
|
||||
)
|
||||
from azure.core.credentials_async import AsyncTokenCredential
|
||||
from pydantic import ValidationError
|
||||
|
||||
from ._client import AzureAIClient, AzureAIProjectAgentOptions
|
||||
from ._shared import AzureAISettings, create_text_format_config, from_azure_ai_tools, to_azure_ai_tools
|
||||
@@ -123,21 +123,21 @@ class AzureAIProjectAgentProvider(Generic[OptionsCoT]):
|
||||
Raises:
|
||||
ServiceInitializationError: If required parameters are missing or invalid.
|
||||
"""
|
||||
try:
|
||||
self._settings = AzureAISettings(
|
||||
project_endpoint=project_endpoint,
|
||||
model_deployment_name=model,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
except ValidationError as ex:
|
||||
raise ServiceInitializationError("Failed to create Azure AI settings.", ex) from ex
|
||||
self._settings = load_settings(
|
||||
AzureAISettings,
|
||||
env_prefix="AZURE_AI_",
|
||||
project_endpoint=project_endpoint,
|
||||
model_deployment_name=model,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
|
||||
# Track whether we should close client connection
|
||||
self._should_close_client = False
|
||||
|
||||
if project_client is None:
|
||||
if not self._settings.project_endpoint:
|
||||
resolved_endpoint = self._settings.get("project_endpoint")
|
||||
if not resolved_endpoint:
|
||||
raise ServiceInitializationError(
|
||||
"Azure AI project endpoint is required. Set via 'project_endpoint' parameter "
|
||||
"or 'AZURE_AI_PROJECT_ENDPOINT' environment variable."
|
||||
@@ -147,7 +147,7 @@ class AzureAIProjectAgentProvider(Generic[OptionsCoT]):
|
||||
raise ServiceInitializationError("Azure credential is required when project_client is not provided.")
|
||||
|
||||
project_client = AIProjectClient(
|
||||
endpoint=self._settings.project_endpoint,
|
||||
endpoint=resolved_endpoint,
|
||||
credential=credential,
|
||||
user_agent=AGENT_FRAMEWORK_USER_AGENT,
|
||||
)
|
||||
@@ -191,7 +191,7 @@ class AzureAIProjectAgentProvider(Generic[OptionsCoT]):
|
||||
ServiceInitializationError: If required parameters are missing.
|
||||
"""
|
||||
# Resolve model from parameter or environment variable
|
||||
resolved_model = model or self._settings.model_deployment_name
|
||||
resolved_model = model or self._settings.get("model_deployment_name")
|
||||
if not resolved_model:
|
||||
raise ServiceInitializationError(
|
||||
"Model deployment name is required. Provide 'model' parameter "
|
||||
|
||||
@@ -2,14 +2,14 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from collections.abc import Mapping, MutableMapping, Sequence
|
||||
from typing import Any, ClassVar, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from agent_framework import (
|
||||
FunctionTool,
|
||||
get_logger,
|
||||
)
|
||||
from agent_framework._pydantic import AFBaseSettings
|
||||
from agent_framework.exceptions import ServiceInvalidRequestError
|
||||
from azure.ai.agents.models import (
|
||||
CodeInterpreterToolDefinition,
|
||||
@@ -32,10 +32,15 @@ from azure.ai.projects.models import (
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import TypedDict # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypedDict # type: ignore # pragma: no cover
|
||||
|
||||
logger = get_logger("agent_framework.azure")
|
||||
|
||||
|
||||
class AzureAISettings(AFBaseSettings):
|
||||
class AzureAISettings(TypedDict, total=False):
|
||||
"""Azure AI Project settings.
|
||||
|
||||
The settings are first loaded from environment variables with the prefix 'AZURE_AI_'.
|
||||
@@ -70,10 +75,8 @@ class AzureAISettings(AFBaseSettings):
|
||||
settings = AzureAISettings(env_file_path="path/to/.env")
|
||||
"""
|
||||
|
||||
env_prefix: ClassVar[str] = "AZURE_AI_"
|
||||
|
||||
project_endpoint: str | None = None
|
||||
model_deployment_name: str | None = None
|
||||
project_endpoint: str | None
|
||||
model_deployment_name: str | None
|
||||
|
||||
|
||||
def _extract_project_connection_id(additional_properties: dict[str, Any] | None) -> str | None:
|
||||
|
||||
@@ -86,12 +86,9 @@ def test_provider_init_missing_endpoint_raises(
|
||||
mock_azure_credential: MagicMock,
|
||||
) -> None:
|
||||
"""Test AzureAIAgentsProvider raises error when endpoint is missing."""
|
||||
# Mock AzureAISettings to return None for project_endpoint
|
||||
with patch("agent_framework_azure_ai._agent_provider.AzureAISettings") as mock_settings_class:
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.project_endpoint = None
|
||||
mock_settings.model_deployment_name = "test-model"
|
||||
mock_settings_class.return_value = mock_settings
|
||||
# Mock load_settings to return a dict with None for project_endpoint
|
||||
with patch("agent_framework_azure_ai._agent_provider.load_settings") as mock_load_settings:
|
||||
mock_load_settings.return_value = {"project_endpoint": None, "model_deployment_name": "test-model"}
|
||||
|
||||
with pytest.raises(ServiceInitializationError) as exc_info:
|
||||
AzureAIAgentsProvider(credential=mock_azure_credential)
|
||||
@@ -270,11 +267,8 @@ async def test_create_agent_missing_model_raises(
|
||||
) -> None:
|
||||
"""Test that create_agent raises error when model is not specified."""
|
||||
# Create provider with mocked settings that has no model
|
||||
with patch("agent_framework_azure_ai._agent_provider.AzureAISettings") as mock_settings_class:
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.project_endpoint = "https://test.com"
|
||||
mock_settings.model_deployment_name = None # No model configured
|
||||
mock_settings_class.return_value = mock_settings
|
||||
with patch("agent_framework_azure_ai._agent_provider.load_settings") as mock_load_settings:
|
||||
mock_load_settings.return_value = {"project_endpoint": "https://test.com", "model_deployment_name": None}
|
||||
|
||||
provider = AzureAIAgentsProvider(agents_client=mock_agents_client)
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from agent_framework import (
|
||||
tool,
|
||||
)
|
||||
from agent_framework._serialization import SerializationMixin
|
||||
from agent_framework._settings import load_settings
|
||||
from agent_framework.exceptions import ServiceInitializationError, ServiceInvalidRequestError
|
||||
from azure.ai.agents.models import (
|
||||
AgentsNamedToolChoice,
|
||||
@@ -44,7 +45,7 @@ from azure.ai.agents.models import (
|
||||
)
|
||||
from azure.core.credentials_async import AsyncTokenCredential
|
||||
from azure.identity.aio import AzureCliCredential
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from agent_framework_azure_ai import AzureAIAgentClient, AzureAISettings
|
||||
|
||||
@@ -67,7 +68,7 @@ def create_test_azure_ai_chat_client(
|
||||
) -> AzureAIAgentClient:
|
||||
"""Helper function to create AzureAIAgentClient instances for testing, bypassing normal validation."""
|
||||
if azure_ai_settings is None:
|
||||
azure_ai_settings = AzureAISettings(env_file_path="test.env")
|
||||
azure_ai_settings = load_settings(AzureAISettings, env_prefix="AZURE_AI_", env_file_path="test.env")
|
||||
|
||||
# Create client instance directly
|
||||
client = object.__new__(AzureAIAgentClient)
|
||||
@@ -78,7 +79,7 @@ def create_test_azure_ai_chat_client(
|
||||
client.agent_id = agent_id
|
||||
client.agent_name = agent_name
|
||||
client.agent_description = None
|
||||
client.model_id = azure_ai_settings.model_deployment_name
|
||||
client.model_id = azure_ai_settings.get("model_deployment_name")
|
||||
client.thread_id = thread_id
|
||||
client.should_cleanup_agent = should_cleanup_agent
|
||||
client._agent_created = False
|
||||
@@ -104,21 +105,23 @@ def create_test_azure_ai_chat_client(
|
||||
|
||||
def test_azure_ai_settings_init(azure_ai_unit_test_env: dict[str, str]) -> None:
|
||||
"""Test AzureAISettings initialization."""
|
||||
settings = AzureAISettings()
|
||||
settings = load_settings(AzureAISettings, env_prefix="AZURE_AI_")
|
||||
|
||||
assert settings.project_endpoint == azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"]
|
||||
assert settings.model_deployment_name == azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"]
|
||||
assert settings["project_endpoint"] == azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"]
|
||||
assert settings["model_deployment_name"] == azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"]
|
||||
|
||||
|
||||
def test_azure_ai_settings_init_with_explicit_values() -> None:
|
||||
"""Test AzureAISettings initialization with explicit values."""
|
||||
settings = AzureAISettings(
|
||||
settings = load_settings(
|
||||
AzureAISettings,
|
||||
env_prefix="AZURE_AI_",
|
||||
project_endpoint="https://custom-endpoint.com/",
|
||||
model_deployment_name="custom-model",
|
||||
)
|
||||
|
||||
assert settings.project_endpoint == "https://custom-endpoint.com/"
|
||||
assert settings.model_deployment_name == "custom-model"
|
||||
assert settings["project_endpoint"] == "https://custom-endpoint.com/"
|
||||
assert settings["model_deployment_name"] == "custom-model"
|
||||
|
||||
|
||||
def test_azure_ai_chat_client_init_with_client(mock_agents_client: MagicMock) -> None:
|
||||
@@ -138,33 +141,29 @@ def test_azure_ai_chat_client_init_auto_create_client(
|
||||
mock_agents_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test AzureAIAgentClient initialization with auto-created agents_client."""
|
||||
azure_ai_settings = AzureAISettings(**azure_ai_unit_test_env) # type: ignore
|
||||
azure_ai_settings = load_settings(AzureAISettings, env_prefix="AZURE_AI_", **azure_ai_unit_test_env) # type: ignore
|
||||
|
||||
# Create client instance directly
|
||||
client = object.__new__(AzureAIAgentClient)
|
||||
client.agents_client = mock_agents_client
|
||||
client.agent_id = None
|
||||
client.thread_id = None
|
||||
client._should_close_client = False # type: ignore
|
||||
client.credential = None
|
||||
client.model_id = azure_ai_settings.model_deployment_name
|
||||
client.agent_name = None
|
||||
client.additional_properties = {}
|
||||
client.middleware = None
|
||||
chat_client = object.__new__(AzureAIAgentClient)
|
||||
chat_client.agents_client = mock_agents_client
|
||||
chat_client.agent_id = None
|
||||
chat_client.thread_id = None
|
||||
chat_client._should_close_client = False # type: ignore
|
||||
chat_client.credential = None
|
||||
chat_client.model_id = azure_ai_settings.get("model_deployment_name")
|
||||
chat_client.agent_name = None
|
||||
chat_client.additional_properties = {}
|
||||
chat_client.middleware = None
|
||||
|
||||
assert client.agents_client is mock_agents_client
|
||||
assert client.agent_id is None
|
||||
assert chat_client.agents_client is mock_agents_client
|
||||
assert chat_client.agent_id is None
|
||||
|
||||
|
||||
def test_azure_ai_chat_client_init_missing_project_endpoint() -> None:
|
||||
"""Test AzureAIAgentClient initialization when project_endpoint is missing and no agents_client provided."""
|
||||
# Mock AzureAISettings to return settings with None project_endpoint
|
||||
with patch("agent_framework_azure_ai._chat_client.AzureAISettings") as mock_settings:
|
||||
mock_settings_instance = MagicMock()
|
||||
mock_settings_instance.project_endpoint = None # This should trigger the error
|
||||
mock_settings_instance.model_deployment_name = "test-model"
|
||||
mock_settings_instance.agent_name = "test-agent"
|
||||
mock_settings.return_value = mock_settings_instance
|
||||
with patch("agent_framework_azure_ai._chat_client.load_settings") as mock_load_settings:
|
||||
mock_load_settings.return_value = {"project_endpoint": None, "model_deployment_name": "test-model"}
|
||||
|
||||
with pytest.raises(ServiceInitializationError, match="project endpoint is required"):
|
||||
AzureAIAgentClient(
|
||||
@@ -179,12 +178,8 @@ def test_azure_ai_chat_client_init_missing_project_endpoint() -> None:
|
||||
def test_azure_ai_chat_client_init_missing_model_deployment_for_agent_creation() -> None:
|
||||
"""Test AzureAIAgentClient initialization when model deployment is missing for agent creation."""
|
||||
# Mock AzureAISettings to return settings with None model_deployment_name
|
||||
with patch("agent_framework_azure_ai._chat_client.AzureAISettings") as mock_settings:
|
||||
mock_settings_instance = MagicMock()
|
||||
mock_settings_instance.project_endpoint = "https://test.com"
|
||||
mock_settings_instance.model_deployment_name = None # This should trigger the error
|
||||
mock_settings_instance.agent_name = "test-agent"
|
||||
mock_settings.return_value = mock_settings_instance
|
||||
with patch("agent_framework_azure_ai._chat_client.load_settings") as mock_load_settings:
|
||||
mock_load_settings.return_value = {"project_endpoint": "https://test.com", "model_deployment_name": None}
|
||||
|
||||
with pytest.raises(ServiceInitializationError, match="model deployment name is required"):
|
||||
AzureAIAgentClient(
|
||||
@@ -210,20 +205,6 @@ def test_azure_ai_chat_client_init_missing_credential(azure_ai_unit_test_env: di
|
||||
)
|
||||
|
||||
|
||||
def test_azure_ai_chat_client_init_validation_error(mock_azure_credential: MagicMock) -> None:
|
||||
"""Test that ValidationError in AzureAISettings is properly handled."""
|
||||
with patch("agent_framework_azure_ai._chat_client.AzureAISettings") as mock_settings:
|
||||
# Create a proper ValidationError with empty errors list and model dict
|
||||
mock_settings.side_effect = ValidationError.from_exception_data("AzureAISettings", [])
|
||||
|
||||
with pytest.raises(ServiceInitializationError, match="Failed to create Azure AI settings."):
|
||||
AzureAIAgentClient(
|
||||
project_endpoint="https://test.com",
|
||||
model_deployment_name="test-model",
|
||||
credential=mock_azure_credential,
|
||||
)
|
||||
|
||||
|
||||
def test_azure_ai_chat_client_from_dict() -> None:
|
||||
"""Test from_settings class method."""
|
||||
mock_agents_client = MagicMock()
|
||||
@@ -248,11 +229,15 @@ async def test_azure_ai_chat_client_get_agent_id_or_create_with_temperature_and_
|
||||
mock_agents_client: MagicMock, azure_ai_unit_test_env: dict[str, str]
|
||||
) -> None:
|
||||
"""Test _get_agent_id_or_create with temperature and top_p in run_options."""
|
||||
azure_ai_settings = AzureAISettings(model_deployment_name=azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"])
|
||||
azure_ai_settings = load_settings(
|
||||
AzureAISettings,
|
||||
env_prefix="AZURE_AI_",
|
||||
model_deployment_name=azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"],
|
||||
)
|
||||
client = create_test_azure_ai_chat_client(mock_agents_client, azure_ai_settings=azure_ai_settings)
|
||||
|
||||
run_options = {
|
||||
"model": azure_ai_settings.model_deployment_name,
|
||||
"model": azure_ai_settings.get("model_deployment_name"),
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
}
|
||||
@@ -284,13 +269,19 @@ async def test_azure_ai_chat_client_get_agent_id_or_create_create_new(
|
||||
azure_ai_unit_test_env: dict[str, str],
|
||||
) -> None:
|
||||
"""Test _get_agent_id_or_create when creating a new agent."""
|
||||
azure_ai_settings = AzureAISettings(model_deployment_name=azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"])
|
||||
client = create_test_azure_ai_chat_client(mock_agents_client, azure_ai_settings=azure_ai_settings)
|
||||
azure_ai_settings = load_settings(
|
||||
AzureAISettings,
|
||||
env_prefix="AZURE_AI_",
|
||||
model_deployment_name=azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"],
|
||||
)
|
||||
chat_client = create_test_azure_ai_chat_client(mock_agents_client, azure_ai_settings=azure_ai_settings)
|
||||
|
||||
agent_id = await client._get_agent_id_or_create(run_options={"model": azure_ai_settings.model_deployment_name}) # type: ignore
|
||||
agent_id = await chat_client._get_agent_id_or_create(
|
||||
run_options={"model": azure_ai_settings.get("model_deployment_name")}
|
||||
) # type: ignore
|
||||
|
||||
assert agent_id == "test-agent-id"
|
||||
assert client._agent_created
|
||||
assert chat_client._agent_created
|
||||
|
||||
|
||||
async def test_azure_ai_chat_client_thread_management_through_public_api(mock_agents_client: MagicMock) -> None:
|
||||
@@ -547,14 +538,18 @@ async def test_azure_ai_chat_client_get_agent_id_or_create_with_run_options(
|
||||
mock_agents_client: MagicMock, azure_ai_unit_test_env: dict[str, str]
|
||||
) -> None:
|
||||
"""Test _get_agent_id_or_create with run_options containing tools and instructions."""
|
||||
azure_ai_settings = AzureAISettings(model_deployment_name=azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"])
|
||||
azure_ai_settings = load_settings(
|
||||
AzureAISettings,
|
||||
env_prefix="AZURE_AI_",
|
||||
model_deployment_name=azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"],
|
||||
)
|
||||
client = create_test_azure_ai_chat_client(mock_agents_client, azure_ai_settings=azure_ai_settings)
|
||||
|
||||
run_options = {
|
||||
"tools": [{"type": "function", "function": {"name": "test_tool"}}],
|
||||
"instructions": "Test instructions",
|
||||
"response_format": {"type": "json_object"},
|
||||
"model": azure_ai_settings.model_deployment_name,
|
||||
"model": azure_ai_settings.get("model_deployment_name"),
|
||||
}
|
||||
|
||||
agent_id = await client._get_agent_id_or_create(run_options) # type: ignore
|
||||
@@ -1134,13 +1129,19 @@ async def test_azure_ai_chat_client_get_agent_id_or_create_with_agent_name(
|
||||
mock_agents_client: MagicMock, azure_ai_unit_test_env: dict[str, str]
|
||||
) -> None:
|
||||
"""Test _get_agent_id_or_create uses default name when no agent_name set."""
|
||||
azure_ai_settings = AzureAISettings(model_deployment_name=azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"])
|
||||
azure_ai_settings = load_settings(
|
||||
AzureAISettings,
|
||||
env_prefix="AZURE_AI_",
|
||||
model_deployment_name=azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"],
|
||||
)
|
||||
client = create_test_azure_ai_chat_client(mock_agents_client, azure_ai_settings=azure_ai_settings)
|
||||
|
||||
# Ensure agent_name is None to test the default
|
||||
client.agent_name = None # type: ignore
|
||||
|
||||
agent_id = await client._get_agent_id_or_create(run_options={"model": azure_ai_settings.model_deployment_name}) # type: ignore
|
||||
agent_id = await client._get_agent_id_or_create(
|
||||
run_options={"model": azure_ai_settings.get("model_deployment_name")}
|
||||
) # type: ignore
|
||||
|
||||
assert agent_id == "test-agent-id"
|
||||
# Verify create_agent was called with default "UnnamedAgent"
|
||||
@@ -1153,11 +1154,15 @@ async def test_azure_ai_chat_client_get_agent_id_or_create_with_response_format(
|
||||
mock_agents_client: MagicMock, azure_ai_unit_test_env: dict[str, str]
|
||||
) -> None:
|
||||
"""Test _get_agent_id_or_create with response_format in run_options."""
|
||||
azure_ai_settings = AzureAISettings(model_deployment_name=azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"])
|
||||
azure_ai_settings = load_settings(
|
||||
AzureAISettings,
|
||||
env_prefix="AZURE_AI_",
|
||||
model_deployment_name=azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"],
|
||||
)
|
||||
client = create_test_azure_ai_chat_client(mock_agents_client, azure_ai_settings=azure_ai_settings)
|
||||
|
||||
# Test with response_format in run_options
|
||||
run_options = {"response_format": {"type": "json_object"}, "model": azure_ai_settings.model_deployment_name}
|
||||
run_options = {"response_format": {"type": "json_object"}, "model": azure_ai_settings.get("model_deployment_name")}
|
||||
|
||||
agent_id = await client._get_agent_id_or_create(run_options) # type: ignore
|
||||
|
||||
@@ -1172,13 +1177,17 @@ async def test_azure_ai_chat_client_get_agent_id_or_create_with_tool_resources(
|
||||
mock_agents_client: MagicMock, azure_ai_unit_test_env: dict[str, str]
|
||||
) -> None:
|
||||
"""Test _get_agent_id_or_create with tool_resources in run_options."""
|
||||
azure_ai_settings = AzureAISettings(model_deployment_name=azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"])
|
||||
azure_ai_settings = load_settings(
|
||||
AzureAISettings,
|
||||
env_prefix="AZURE_AI_",
|
||||
model_deployment_name=azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"],
|
||||
)
|
||||
client = create_test_azure_ai_chat_client(mock_agents_client, azure_ai_settings=azure_ai_settings)
|
||||
|
||||
# Test with tool_resources in run_options
|
||||
run_options = {
|
||||
"tool_resources": {"vector_store_ids": ["vs-123"]},
|
||||
"model": azure_ai_settings.model_deployment_name,
|
||||
"model": azure_ai_settings.get("model_deployment_name"),
|
||||
}
|
||||
|
||||
agent_id = await client._get_agent_id_or_create(run_options) # type: ignore
|
||||
|
||||
@@ -20,6 +20,7 @@ from agent_framework import (
|
||||
SupportsChatGetResponse,
|
||||
tool,
|
||||
)
|
||||
from agent_framework._settings import load_settings
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
from azure.ai.projects.aio import AIProjectClient
|
||||
from azure.ai.projects.models import (
|
||||
@@ -36,7 +37,7 @@ from azure.core.exceptions import ResourceNotFoundError
|
||||
from azure.identity.aio import AzureCliCredential
|
||||
from openai.types.responses.parsed_response import ParsedResponse
|
||||
from openai.types.responses.response import Response as OpenAIResponse
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pytest import fixture, param
|
||||
|
||||
from agent_framework_azure_ai import AzureAIClient, AzureAISettings
|
||||
@@ -113,7 +114,7 @@ def create_test_azure_ai_client(
|
||||
) -> AzureAIClient:
|
||||
"""Helper function to create AzureAIClient instances for testing, bypassing normal validation."""
|
||||
if azure_ai_settings is None:
|
||||
azure_ai_settings = AzureAISettings(env_file_path="test.env")
|
||||
azure_ai_settings = load_settings(AzureAISettings, env_prefix="AZURE_AI_", env_file_path="test.env")
|
||||
|
||||
# Create client instance directly
|
||||
client = object.__new__(AzureAIClient)
|
||||
@@ -125,7 +126,7 @@ def create_test_azure_ai_client(
|
||||
client.agent_version = agent_version
|
||||
client.agent_description = None
|
||||
client.use_latest_version = use_latest_version
|
||||
client.model_id = azure_ai_settings.model_deployment_name
|
||||
client.model_id = azure_ai_settings.get("model_deployment_name")
|
||||
client.conversation_id = conversation_id
|
||||
client._is_application_endpoint = False # type: ignore
|
||||
client._should_close_client = should_close_client # type: ignore
|
||||
@@ -143,28 +144,29 @@ def create_test_azure_ai_client(
|
||||
|
||||
def test_azure_ai_settings_init(azure_ai_unit_test_env: dict[str, str]) -> None:
|
||||
"""Test AzureAISettings initialization."""
|
||||
settings = AzureAISettings()
|
||||
settings = load_settings(AzureAISettings, env_prefix="AZURE_AI_")
|
||||
|
||||
assert settings.project_endpoint == azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"]
|
||||
assert settings.model_deployment_name == azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"]
|
||||
assert settings["project_endpoint"] == azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"]
|
||||
assert settings["model_deployment_name"] == azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"]
|
||||
|
||||
|
||||
def test_azure_ai_settings_init_with_explicit_values() -> None:
|
||||
"""Test AzureAISettings initialization with explicit values."""
|
||||
settings = AzureAISettings(
|
||||
settings = load_settings(
|
||||
AzureAISettings,
|
||||
env_prefix="AZURE_AI_",
|
||||
project_endpoint="https://custom-endpoint.com/",
|
||||
model_deployment_name="custom-model",
|
||||
)
|
||||
|
||||
assert settings.project_endpoint == "https://custom-endpoint.com/"
|
||||
assert settings.model_deployment_name == "custom-model"
|
||||
assert settings["project_endpoint"] == "https://custom-endpoint.com/"
|
||||
assert settings["model_deployment_name"] == "custom-model"
|
||||
|
||||
|
||||
def test_init_with_project_client(mock_project_client: MagicMock) -> None:
|
||||
"""Test AzureAIClient initialization with existing project_client."""
|
||||
with patch("agent_framework_azure_ai._client.AzureAISettings") as mock_settings:
|
||||
mock_settings.return_value.project_endpoint = None
|
||||
mock_settings.return_value.model_deployment_name = "test-model"
|
||||
with patch("agent_framework_azure_ai._client.load_settings") as mock_load_settings:
|
||||
mock_load_settings.return_value = {"project_endpoint": None, "model_deployment_name": "test-model"}
|
||||
|
||||
client = AzureAIClient(
|
||||
project_client=mock_project_client,
|
||||
@@ -205,9 +207,8 @@ def test_init_auto_create_client(
|
||||
|
||||
def test_init_missing_project_endpoint() -> None:
|
||||
"""Test AzureAIClient initialization when project_endpoint is missing and no project_client provided."""
|
||||
with patch("agent_framework_azure_ai._client.AzureAISettings") as mock_settings:
|
||||
mock_settings.return_value.project_endpoint = None
|
||||
mock_settings.return_value.model_deployment_name = "test-model"
|
||||
with patch("agent_framework_azure_ai._client.load_settings") as mock_load_settings:
|
||||
mock_load_settings.return_value = {"project_endpoint": None, "model_deployment_name": "test-model"}
|
||||
|
||||
with pytest.raises(ServiceInitializationError, match="Azure AI project endpoint is required"):
|
||||
AzureAIClient(credential=MagicMock())
|
||||
@@ -224,15 +225,6 @@ def test_init_missing_credential(azure_ai_unit_test_env: dict[str, str]) -> None
|
||||
)
|
||||
|
||||
|
||||
def test_init_validation_error(mock_azure_credential: MagicMock) -> None:
|
||||
"""Test that ValidationError in AzureAISettings is properly handled."""
|
||||
with patch("agent_framework_azure_ai._client.AzureAISettings") as mock_settings:
|
||||
mock_settings.side_effect = ValidationError.from_exception_data("test", [])
|
||||
|
||||
with pytest.raises(ServiceInitializationError, match="Failed to create Azure AI settings"):
|
||||
AzureAIClient(credential=mock_azure_credential)
|
||||
|
||||
|
||||
async def test_get_agent_reference_or_create_existing_version(
|
||||
mock_project_client: MagicMock,
|
||||
) -> None:
|
||||
@@ -259,7 +251,11 @@ async def test_get_agent_reference_or_create_new_agent(
|
||||
azure_ai_unit_test_env: dict[str, str],
|
||||
) -> None:
|
||||
"""Test _get_agent_reference_or_create when creating a new agent."""
|
||||
azure_ai_settings = AzureAISettings(model_deployment_name=azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"])
|
||||
azure_ai_settings = load_settings(
|
||||
AzureAISettings,
|
||||
env_prefix="AZURE_AI_",
|
||||
model_deployment_name=azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"],
|
||||
)
|
||||
client = create_test_azure_ai_client(
|
||||
mock_project_client, agent_name="new-agent", azure_ai_settings=azure_ai_settings
|
||||
)
|
||||
@@ -270,7 +266,7 @@ async def test_get_agent_reference_or_create_new_agent(
|
||||
mock_agent.version = "1.0"
|
||||
mock_project_client.agents.create_version = AsyncMock(return_value=mock_agent)
|
||||
|
||||
run_options = {"model": azure_ai_settings.model_deployment_name}
|
||||
run_options = {"model": azure_ai_settings.get("model_deployment_name")}
|
||||
agent_ref = await client._get_agent_reference_or_create(run_options, None) # type: ignore
|
||||
|
||||
assert agent_ref == {"name": "new-agent", "version": "1.0", "type": "agent_reference"}
|
||||
|
||||
@@ -107,9 +107,8 @@ def test_provider_init_with_credential_and_endpoint(
|
||||
|
||||
def test_provider_init_missing_endpoint() -> None:
|
||||
"""Test AzureAIProjectAgentProvider initialization when endpoint is missing."""
|
||||
with patch("agent_framework_azure_ai._project_provider.AzureAISettings") as mock_settings:
|
||||
mock_settings.return_value.project_endpoint = None
|
||||
mock_settings.return_value.model_deployment_name = "test-model"
|
||||
with patch("agent_framework_azure_ai._project_provider.load_settings") as mock_load_settings:
|
||||
mock_load_settings.return_value = {"project_endpoint": None, "model_deployment_name": "test-model"}
|
||||
|
||||
with pytest.raises(ServiceInitializationError, match="Azure AI project endpoint is required"):
|
||||
AzureAIProjectAgentProvider(credential=MagicMock())
|
||||
@@ -130,9 +129,11 @@ async def test_provider_create_agent(
|
||||
azure_ai_unit_test_env: dict[str, str],
|
||||
) -> None:
|
||||
"""Test AzureAIProjectAgentProvider.create_agent method."""
|
||||
with patch("agent_framework_azure_ai._project_provider.AzureAISettings") as mock_settings:
|
||||
mock_settings.return_value.project_endpoint = azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"]
|
||||
mock_settings.return_value.model_deployment_name = azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"]
|
||||
with patch("agent_framework_azure_ai._project_provider.load_settings") as mock_load_settings:
|
||||
mock_load_settings.return_value = {
|
||||
"project_endpoint": azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"],
|
||||
"model_deployment_name": azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"],
|
||||
}
|
||||
|
||||
provider = AzureAIProjectAgentProvider(project_client=mock_project_client)
|
||||
|
||||
@@ -168,9 +169,11 @@ async def test_provider_create_agent_with_env_model(
|
||||
azure_ai_unit_test_env: dict[str, str],
|
||||
) -> None:
|
||||
"""Test AzureAIProjectAgentProvider.create_agent uses model from env var."""
|
||||
with patch("agent_framework_azure_ai._project_provider.AzureAISettings") as mock_settings:
|
||||
mock_settings.return_value.project_endpoint = azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"]
|
||||
mock_settings.return_value.model_deployment_name = azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"]
|
||||
with patch("agent_framework_azure_ai._project_provider.load_settings") as mock_load_settings:
|
||||
mock_load_settings.return_value = {
|
||||
"project_endpoint": azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"],
|
||||
"model_deployment_name": azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"],
|
||||
}
|
||||
|
||||
provider = AzureAIProjectAgentProvider(project_client=mock_project_client)
|
||||
|
||||
@@ -200,9 +203,8 @@ async def test_provider_create_agent_with_env_model(
|
||||
|
||||
async def test_provider_create_agent_missing_model(mock_project_client: MagicMock) -> None:
|
||||
"""Test AzureAIProjectAgentProvider.create_agent raises when model is missing."""
|
||||
with patch("agent_framework_azure_ai._project_provider.AzureAISettings") as mock_settings:
|
||||
mock_settings.return_value.project_endpoint = "https://test.com"
|
||||
mock_settings.return_value.model_deployment_name = None
|
||||
with patch("agent_framework_azure_ai._project_provider.load_settings") as mock_load_settings:
|
||||
mock_load_settings.return_value = {"project_endpoint": "https://test.com", "model_deployment_name": None}
|
||||
|
||||
provider = AzureAIProjectAgentProvider(project_client=mock_project_client)
|
||||
|
||||
@@ -215,9 +217,11 @@ async def test_provider_create_agent_with_rai_config(
|
||||
azure_ai_unit_test_env: dict[str, str],
|
||||
) -> None:
|
||||
"""Test AzureAIProjectAgentProvider.create_agent passes rai_config from default_options."""
|
||||
with patch("agent_framework_azure_ai._project_provider.AzureAISettings") as mock_settings:
|
||||
mock_settings.return_value.project_endpoint = azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"]
|
||||
mock_settings.return_value.model_deployment_name = azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"]
|
||||
with patch("agent_framework_azure_ai._project_provider.load_settings") as mock_load_settings:
|
||||
mock_load_settings.return_value = {
|
||||
"project_endpoint": azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"],
|
||||
"model_deployment_name": azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"],
|
||||
}
|
||||
|
||||
provider = AzureAIProjectAgentProvider(project_client=mock_project_client)
|
||||
|
||||
@@ -258,9 +262,11 @@ async def test_provider_create_agent_with_reasoning(
|
||||
azure_ai_unit_test_env: dict[str, str],
|
||||
) -> None:
|
||||
"""Test AzureAIProjectAgentProvider.create_agent passes reasoning from default_options."""
|
||||
with patch("agent_framework_azure_ai._project_provider.AzureAISettings") as mock_settings:
|
||||
mock_settings.return_value.project_endpoint = azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"]
|
||||
mock_settings.return_value.model_deployment_name = azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"]
|
||||
with patch("agent_framework_azure_ai._project_provider.load_settings") as mock_load_settings:
|
||||
mock_load_settings.return_value = {
|
||||
"project_endpoint": azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"],
|
||||
"model_deployment_name": azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"],
|
||||
}
|
||||
|
||||
provider = AzureAIProjectAgentProvider(project_client=mock_project_client)
|
||||
|
||||
@@ -465,9 +471,11 @@ async def test_provider_context_manager(mock_project_client: MagicMock) -> None:
|
||||
mock_client.close = AsyncMock()
|
||||
mock_ai_project_client.return_value = mock_client
|
||||
|
||||
with patch("agent_framework_azure_ai._project_provider.AzureAISettings") as mock_settings:
|
||||
mock_settings.return_value.project_endpoint = "https://test.com"
|
||||
mock_settings.return_value.model_deployment_name = "test-model"
|
||||
with patch("agent_framework_azure_ai._project_provider.load_settings") as mock_load_settings:
|
||||
mock_load_settings.return_value = {
|
||||
"project_endpoint": "https://test.com",
|
||||
"model_deployment_name": "test-model",
|
||||
}
|
||||
|
||||
async with AzureAIProjectAgentProvider(credential=MagicMock()) as provider:
|
||||
assert provider._project_client is mock_client # type: ignore
|
||||
@@ -494,9 +502,11 @@ async def test_provider_close_method(mock_project_client: MagicMock) -> None:
|
||||
mock_client.close = AsyncMock()
|
||||
mock_ai_project_client.return_value = mock_client
|
||||
|
||||
with patch("agent_framework_azure_ai._project_provider.AzureAISettings") as mock_settings:
|
||||
mock_settings.return_value.project_endpoint = "https://test.com"
|
||||
mock_settings.return_value.model_deployment_name = "test-model"
|
||||
with patch("agent_framework_azure_ai._project_provider.load_settings") as mock_load_settings:
|
||||
mock_load_settings.return_value = {
|
||||
"project_endpoint": "https://test.com",
|
||||
"model_deployment_name": "test-model",
|
||||
}
|
||||
|
||||
provider = AzureAIProjectAgentProvider(credential=MagicMock())
|
||||
await provider.close()
|
||||
@@ -581,12 +591,14 @@ async def test_provider_create_agent_with_mcp_tool(
|
||||
return [tools]
|
||||
|
||||
with (
|
||||
patch("agent_framework_azure_ai._project_provider.AzureAISettings") as mock_settings,
|
||||
patch("agent_framework_azure_ai._project_provider.load_settings") as mock_load_settings,
|
||||
patch("agent_framework_azure_ai._project_provider.to_azure_ai_tools") as mock_to_azure_tools,
|
||||
patch("agent_framework_azure_ai._project_provider.normalize_tools", side_effect=mock_normalize_tools),
|
||||
):
|
||||
mock_settings.return_value.project_endpoint = azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"]
|
||||
mock_settings.return_value.model_deployment_name = azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"]
|
||||
mock_load_settings.return_value = {
|
||||
"project_endpoint": azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"],
|
||||
"model_deployment_name": azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"],
|
||||
}
|
||||
mock_to_azure_tools.return_value = [{"type": "function", "name": "mcp_function_1"}]
|
||||
|
||||
provider = AzureAIProjectAgentProvider(project_client=mock_project_client)
|
||||
@@ -642,12 +654,14 @@ async def test_provider_create_agent_with_mcp_and_regular_tools(
|
||||
return [tools]
|
||||
|
||||
with (
|
||||
patch("agent_framework_azure_ai._project_provider.AzureAISettings") as mock_settings,
|
||||
patch("agent_framework_azure_ai._project_provider.load_settings") as mock_load_settings,
|
||||
patch("agent_framework_azure_ai._project_provider.to_azure_ai_tools") as mock_to_azure_tools,
|
||||
patch("agent_framework_azure_ai._project_provider.normalize_tools", side_effect=mock_normalize_tools),
|
||||
):
|
||||
mock_settings.return_value.project_endpoint = azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"]
|
||||
mock_settings.return_value.model_deployment_name = azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"]
|
||||
mock_load_settings.return_value = {
|
||||
"project_endpoint": azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"],
|
||||
"model_deployment_name": azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"],
|
||||
}
|
||||
mock_to_azure_tools.return_value = []
|
||||
|
||||
provider = AzureAIProjectAgentProvider(project_client=mock_project_client)
|
||||
|
||||
@@ -30,13 +30,13 @@ from agent_framework import (
|
||||
prepare_function_call_results,
|
||||
validate_tool_mode,
|
||||
)
|
||||
from agent_framework._pydantic import AFBaseSettings
|
||||
from agent_framework._settings import SecretString, load_settings
|
||||
from agent_framework.exceptions import ServiceInitializationError, ServiceInvalidResponseError
|
||||
from agent_framework.observability import ChatTelemetryLayer
|
||||
from boto3.session import Session as Boto3Session
|
||||
from botocore.client import BaseClient
|
||||
from botocore.config import Config as BotoConfig
|
||||
from pydantic import BaseModel, SecretStr, ValidationError
|
||||
from pydantic import BaseModel
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
from typing import TypeVar # type: ignore # pragma: no cover
|
||||
@@ -205,16 +205,14 @@ FINISH_REASON_MAP: dict[str, FinishReasonLiteral] = {
|
||||
}
|
||||
|
||||
|
||||
class BedrockSettings(AFBaseSettings):
|
||||
class BedrockSettings(TypedDict, total=False):
|
||||
"""Bedrock configuration settings pulled from environment variables or .env files."""
|
||||
|
||||
env_prefix: ClassVar[str] = "BEDROCK_"
|
||||
|
||||
region: str = DEFAULT_REGION
|
||||
chat_model_id: str | None = None
|
||||
access_key: SecretStr | None = None
|
||||
secret_key: SecretStr | None = None
|
||||
session_token: SecretStr | None = None
|
||||
region: str | None
|
||||
chat_model_id: str | None
|
||||
access_key: SecretString | None
|
||||
secret_key: SecretString | None
|
||||
session_token: SecretString | None
|
||||
|
||||
|
||||
class BedrockChatClient(
|
||||
@@ -280,24 +278,25 @@ class BedrockChatClient(
|
||||
client = BedrockChatClient[MyOptions](model_id="<model name>")
|
||||
response = await client.get_response("Hello", options={"my_custom_option": "value"})
|
||||
"""
|
||||
try:
|
||||
settings = BedrockSettings(
|
||||
region=region,
|
||||
chat_model_id=model_id,
|
||||
access_key=access_key, # type: ignore[arg-type]
|
||||
secret_key=secret_key, # type: ignore[arg-type]
|
||||
session_token=session_token, # type: ignore[arg-type]
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
except ValidationError as ex:
|
||||
raise ServiceInitializationError("Failed to initialize Bedrock settings.", ex) from ex
|
||||
settings = load_settings(
|
||||
BedrockSettings,
|
||||
env_prefix="BEDROCK_",
|
||||
region=region,
|
||||
chat_model_id=model_id,
|
||||
access_key=access_key,
|
||||
secret_key=secret_key,
|
||||
session_token=session_token,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
if not settings.get("region"):
|
||||
settings["region"] = DEFAULT_REGION
|
||||
|
||||
if client is None:
|
||||
session = boto3_session or self._create_session(settings)
|
||||
client = session.client(
|
||||
"bedrock-runtime",
|
||||
region_name=settings.region,
|
||||
region_name=settings["region"],
|
||||
config=BotoConfig(user_agent_extra=AGENT_FRAMEWORK_USER_AGENT),
|
||||
)
|
||||
|
||||
@@ -307,17 +306,17 @@ class BedrockChatClient(
|
||||
**kwargs,
|
||||
)
|
||||
self._bedrock_client = client
|
||||
self.model_id = settings.chat_model_id
|
||||
self.region = settings.region
|
||||
self.model_id = settings["chat_model_id"]
|
||||
self.region = settings["region"]
|
||||
|
||||
@staticmethod
|
||||
def _create_session(settings: BedrockSettings) -> Boto3Session:
|
||||
session_kwargs: dict[str, Any] = {"region_name": settings.region or DEFAULT_REGION}
|
||||
if settings.access_key and settings.secret_key:
|
||||
session_kwargs["aws_access_key_id"] = settings.access_key.get_secret_value()
|
||||
session_kwargs["aws_secret_access_key"] = settings.secret_key.get_secret_value()
|
||||
if settings.session_token:
|
||||
session_kwargs["aws_session_token"] = settings.session_token.get_secret_value()
|
||||
session_kwargs: dict[str, Any] = {"region_name": settings.get("region") or DEFAULT_REGION}
|
||||
if settings.get("access_key") and settings.get("secret_key"):
|
||||
session_kwargs["aws_access_key_id"] = settings["access_key"].get_secret_value() # type: ignore[union-attr]
|
||||
session_kwargs["aws_secret_access_key"] = settings["secret_key"].get_secret_value() # type: ignore[union-attr]
|
||||
if settings.get("session_token"):
|
||||
session_kwargs["aws_session_token"] = settings["session_token"].get_secret_value() # type: ignore[union-attr]
|
||||
return Boto3Session(**session_kwargs)
|
||||
|
||||
@override
|
||||
|
||||
@@ -11,6 +11,7 @@ from agent_framework import (
|
||||
FunctionTool,
|
||||
Message,
|
||||
)
|
||||
from agent_framework._settings import load_settings
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agent_framework_bedrock._chat_client import BedrockChatClient, BedrockSettings
|
||||
@@ -33,9 +34,9 @@ def _dummy_weather(location: str) -> str: # pragma: no cover - helper
|
||||
def test_settings_load_from_environment(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("BEDROCK_REGION", "us-west-2")
|
||||
monkeypatch.setenv("BEDROCK_CHAT_MODEL_ID", "anthropic.claude-v2")
|
||||
settings = BedrockSettings()
|
||||
assert settings.region == "us-west-2"
|
||||
assert settings.chat_model_id == "anthropic.claude-v2"
|
||||
settings = load_settings(BedrockSettings, env_prefix="BEDROCK_")
|
||||
assert settings["region"] == "us-west-2"
|
||||
assert settings["chat_model_id"] == "anthropic.claude-v2"
|
||||
|
||||
|
||||
def test_build_request_includes_tool_config() -> None:
|
||||
|
||||
@@ -21,8 +21,9 @@ from agent_framework import (
|
||||
get_logger,
|
||||
normalize_messages,
|
||||
)
|
||||
from agent_framework._settings import load_settings
|
||||
from agent_framework._types import normalize_tools
|
||||
from agent_framework.exceptions import ServiceException, ServiceInitializationError
|
||||
from agent_framework.exceptions import ServiceException
|
||||
from claude_agent_sdk import (
|
||||
AssistantMessage,
|
||||
ClaudeSDKClient,
|
||||
@@ -34,7 +35,6 @@ from claude_agent_sdk import (
|
||||
ClaudeAgentOptions as SDKOptions,
|
||||
)
|
||||
from claude_agent_sdk.types import StreamEvent, TextBlock
|
||||
from pydantic import ValidationError
|
||||
|
||||
from ._settings import ClaudeAgentSettings
|
||||
|
||||
@@ -273,19 +273,18 @@ class ClaudeAgent(BaseAgent, Generic[OptionsT]):
|
||||
self._mcp_servers: dict[str, Any] = opts.pop("mcp_servers", None) or {}
|
||||
|
||||
# Load settings from environment and options
|
||||
try:
|
||||
self._settings = ClaudeAgentSettings(
|
||||
cli_path=cli_path,
|
||||
model=model,
|
||||
cwd=cwd,
|
||||
permission_mode=permission_mode,
|
||||
max_turns=max_turns,
|
||||
max_budget_usd=max_budget_usd,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
except ValidationError as ex:
|
||||
raise ServiceInitializationError("Failed to create Claude Agent settings.", ex) from ex
|
||||
self._settings = load_settings(
|
||||
ClaudeAgentSettings,
|
||||
env_prefix="CLAUDE_AGENT_",
|
||||
cli_path=cli_path,
|
||||
model=model,
|
||||
cwd=cwd,
|
||||
permission_mode=permission_mode,
|
||||
max_turns=max_turns,
|
||||
max_budget_usd=max_budget_usd,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
|
||||
# Separate built-in tools (strings) from custom tools (callables/FunctionTool)
|
||||
self._builtin_tools: list[str] = []
|
||||
@@ -411,18 +410,18 @@ class ClaudeAgent(BaseAgent, Generic[OptionsT]):
|
||||
opts["resume"] = resume_session_id
|
||||
|
||||
# Apply settings from environment
|
||||
if self._settings.cli_path:
|
||||
opts["cli_path"] = self._settings.cli_path
|
||||
if self._settings.model:
|
||||
opts["model"] = self._settings.model
|
||||
if self._settings.cwd:
|
||||
opts["cwd"] = self._settings.cwd
|
||||
if self._settings.permission_mode:
|
||||
opts["permission_mode"] = self._settings.permission_mode
|
||||
if self._settings.max_turns:
|
||||
opts["max_turns"] = self._settings.max_turns
|
||||
if self._settings.max_budget_usd:
|
||||
opts["max_budget_usd"] = self._settings.max_budget_usd
|
||||
if self._settings["cli_path"]:
|
||||
opts["cli_path"] = self._settings["cli_path"]
|
||||
if self._settings["model"]:
|
||||
opts["model"] = self._settings["model"]
|
||||
if self._settings["cwd"]:
|
||||
opts["cwd"] = self._settings["cwd"]
|
||||
if self._settings["permission_mode"]:
|
||||
opts["permission_mode"] = self._settings["permission_mode"]
|
||||
if self._settings["max_turns"]:
|
||||
opts["max_turns"] = self._settings["max_turns"]
|
||||
if self._settings["max_budget_usd"]:
|
||||
opts["max_budget_usd"] = self._settings["max_budget_usd"]
|
||||
|
||||
# Apply default options
|
||||
for key, value in self._default_options.items():
|
||||
|
||||
@@ -1,51 +1,29 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
from agent_framework._pydantic import AFBaseSettings
|
||||
from typing import TypedDict
|
||||
|
||||
__all__ = ["ClaudeAgentSettings"]
|
||||
|
||||
|
||||
class ClaudeAgentSettings(AFBaseSettings):
|
||||
class ClaudeAgentSettings(TypedDict, total=False):
|
||||
"""Claude Agent settings.
|
||||
|
||||
The settings are first loaded from environment variables with the prefix 'CLAUDE_AGENT_'.
|
||||
If the environment variables are not found, the settings can be loaded from a .env file
|
||||
with the encoding 'utf-8'. If the settings are not found in the .env file, the settings
|
||||
are ignored; however, validation will fail alerting that the settings are missing.
|
||||
with the encoding 'utf-8'.
|
||||
|
||||
Keyword Args:
|
||||
Keys:
|
||||
cli_path: The path to Claude CLI executable.
|
||||
model: The model to use (sonnet, opus, haiku).
|
||||
cwd: The working directory for Claude CLI.
|
||||
permission_mode: Permission mode (default, acceptEdits, plan, bypassPermissions).
|
||||
max_turns: Maximum number of conversation turns.
|
||||
max_budget_usd: Maximum budget in USD.
|
||||
env_file_path: If provided, the .env settings are read from this file path location.
|
||||
env_file_encoding: The encoding of the .env file, defaults to 'utf-8'.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
from agent_framework.anthropic import ClaudeAgentSettings
|
||||
|
||||
# Using environment variables
|
||||
# Set CLAUDE_AGENT_MODEL=sonnet
|
||||
# CLAUDE_AGENT_PERMISSION_MODE=default
|
||||
|
||||
# Or passing parameters directly
|
||||
settings = ClaudeAgentSettings(model="sonnet")
|
||||
|
||||
# Or loading from a .env file
|
||||
settings = ClaudeAgentSettings(env_file_path="path/to/.env")
|
||||
"""
|
||||
|
||||
env_prefix: ClassVar[str] = "CLAUDE_AGENT_"
|
||||
|
||||
cli_path: str | None = None
|
||||
model: str | None = None
|
||||
cwd: str | None = None
|
||||
permission_mode: str | None = None
|
||||
max_turns: int | None = None
|
||||
max_budget_usd: float | None = None
|
||||
cli_path: str | None
|
||||
model: str | None
|
||||
cwd: str | None
|
||||
permission_mode: str | None
|
||||
max_turns: int | None
|
||||
max_budget_usd: float | None
|
||||
|
||||
@@ -5,6 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from agent_framework import AgentResponseUpdate, AgentThread, Content, Message, tool
|
||||
from agent_framework._settings import load_settings
|
||||
|
||||
from agent_framework_claude import ClaudeAgent, ClaudeAgentOptions, ClaudeAgentSettings
|
||||
from agent_framework_claude._agent import TOOLS_MCP_SERVER_NAME
|
||||
@@ -15,23 +16,21 @@ from agent_framework_claude._agent import TOOLS_MCP_SERVER_NAME
|
||||
class TestClaudeAgentSettings:
|
||||
"""Tests for ClaudeAgentSettings."""
|
||||
|
||||
def test_env_prefix(self) -> None:
|
||||
"""Test that env_prefix is correctly set."""
|
||||
assert ClaudeAgentSettings.env_prefix == "CLAUDE_AGENT_"
|
||||
|
||||
def test_default_values(self) -> None:
|
||||
"""Test default values are None."""
|
||||
settings = ClaudeAgentSettings()
|
||||
assert settings.cli_path is None
|
||||
assert settings.model is None
|
||||
assert settings.cwd is None
|
||||
assert settings.permission_mode is None
|
||||
assert settings.max_turns is None
|
||||
assert settings.max_budget_usd is None
|
||||
settings = load_settings(ClaudeAgentSettings, env_prefix="CLAUDE_AGENT_")
|
||||
assert settings["cli_path"] is None
|
||||
assert settings["model"] is None
|
||||
assert settings["cwd"] is None
|
||||
assert settings["permission_mode"] is None
|
||||
assert settings["max_turns"] is None
|
||||
assert settings["max_budget_usd"] is None
|
||||
|
||||
def test_explicit_values(self) -> None:
|
||||
"""Test explicit values override defaults."""
|
||||
settings = ClaudeAgentSettings(
|
||||
settings = load_settings(
|
||||
ClaudeAgentSettings,
|
||||
env_prefix="CLAUDE_AGENT_",
|
||||
cli_path="/usr/local/bin/claude",
|
||||
model="sonnet",
|
||||
cwd="/home/user/project",
|
||||
@@ -39,20 +38,20 @@ class TestClaudeAgentSettings:
|
||||
max_turns=10,
|
||||
max_budget_usd=5.0,
|
||||
)
|
||||
assert settings.cli_path == "/usr/local/bin/claude"
|
||||
assert settings.model == "sonnet"
|
||||
assert settings.cwd == "/home/user/project"
|
||||
assert settings.permission_mode == "default"
|
||||
assert settings.max_turns == 10
|
||||
assert settings.max_budget_usd == 5.0
|
||||
assert settings["cli_path"] == "/usr/local/bin/claude"
|
||||
assert settings["model"] == "sonnet"
|
||||
assert settings["cwd"] == "/home/user/project"
|
||||
assert settings["permission_mode"] == "default"
|
||||
assert settings["max_turns"] == 10
|
||||
assert settings["max_budget_usd"] == 5.0
|
||||
|
||||
def test_env_variable_loading(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test loading from environment variables."""
|
||||
monkeypatch.setenv("CLAUDE_AGENT_MODEL", "opus")
|
||||
monkeypatch.setenv("CLAUDE_AGENT_MAX_TURNS", "20")
|
||||
settings = ClaudeAgentSettings()
|
||||
assert settings.model == "opus"
|
||||
assert settings.max_turns == 20
|
||||
settings = load_settings(ClaudeAgentSettings, env_prefix="CLAUDE_AGENT_")
|
||||
assert settings["model"] == "opus"
|
||||
assert settings["max_turns"] == 20
|
||||
|
||||
|
||||
# region Test ClaudeAgent Initialization
|
||||
@@ -95,9 +94,9 @@ class TestClaudeAgentInit:
|
||||
"max_turns": 10,
|
||||
}
|
||||
agent = ClaudeAgent(default_options=options)
|
||||
assert agent._settings.model == "sonnet" # type: ignore[reportPrivateUsage]
|
||||
assert agent._settings.permission_mode == "default" # type: ignore[reportPrivateUsage]
|
||||
assert agent._settings.max_turns == 10 # type: ignore[reportPrivateUsage]
|
||||
assert agent._settings["model"] == "sonnet" # type: ignore[reportPrivateUsage]
|
||||
assert agent._settings["permission_mode"] == "default" # type: ignore[reportPrivateUsage]
|
||||
assert agent._settings["max_turns"] == 10 # type: ignore[reportPrivateUsage]
|
||||
|
||||
def test_with_function_tool(self) -> None:
|
||||
"""Test agent with function tool."""
|
||||
@@ -620,13 +619,13 @@ class TestClaudeAgentPermissions:
|
||||
def test_default_permission_mode(self) -> None:
|
||||
"""Test default permission mode."""
|
||||
agent = ClaudeAgent()
|
||||
assert agent._settings.permission_mode is None # type: ignore[reportPrivateUsage]
|
||||
assert agent._settings["permission_mode"] is None # type: ignore[reportPrivateUsage]
|
||||
|
||||
def test_permission_mode_from_settings(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test permission mode from environment settings."""
|
||||
monkeypatch.setenv("CLAUDE_AGENT_PERMISSION_MODE", "acceptEdits")
|
||||
settings = ClaudeAgentSettings()
|
||||
assert settings.permission_mode == "acceptEdits"
|
||||
settings = load_settings(ClaudeAgentSettings, env_prefix="CLAUDE_AGENT_")
|
||||
assert settings["permission_mode"] == "acceptEdits"
|
||||
|
||||
def test_permission_mode_in_options(self) -> None:
|
||||
"""Test permission mode in options."""
|
||||
@@ -634,7 +633,7 @@ class TestClaudeAgentPermissions:
|
||||
"permission_mode": "bypassPermissions",
|
||||
}
|
||||
agent = ClaudeAgent(default_options=options)
|
||||
assert agent._settings.permission_mode == "bypassPermissions" # type: ignore[reportPrivateUsage]
|
||||
assert agent._settings["permission_mode"] == "bypassPermissions" # type: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
# region Test ClaudeAgent Error Handling
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterable, Awaitable, Sequence
|
||||
from typing import Any, ClassVar, Literal, overload
|
||||
from typing import Any, Literal, TypedDict, overload
|
||||
|
||||
from agent_framework import (
|
||||
AgentMiddlewareTypes,
|
||||
@@ -17,23 +17,21 @@ from agent_framework import (
|
||||
ResponseStream,
|
||||
normalize_messages,
|
||||
)
|
||||
from agent_framework._pydantic import AFBaseSettings
|
||||
from agent_framework._settings import load_settings
|
||||
from agent_framework.exceptions import ServiceException, ServiceInitializationError
|
||||
from microsoft_agents.copilotstudio.client import AgentType, ConnectionSettings, CopilotClient, PowerPlatformCloud
|
||||
from pydantic import ValidationError
|
||||
|
||||
from ._acquire_token import acquire_token
|
||||
|
||||
|
||||
class CopilotStudioSettings(AFBaseSettings):
|
||||
class CopilotStudioSettings(TypedDict, total=False):
|
||||
"""Copilot Studio model settings.
|
||||
|
||||
The settings are first loaded from environment variables with the prefix 'COPILOTSTUDIOAGENT__'.
|
||||
If the environment variables are not found, the settings can be loaded from a .env file
|
||||
with the encoding 'utf-8'. If the settings are not found in the .env file, the settings
|
||||
are ignored; however, validation will fail alerting that the settings are missing.
|
||||
with the encoding 'utf-8'.
|
||||
|
||||
Keyword Args:
|
||||
Keys:
|
||||
environmentid: Environment ID of environment with the Copilot Studio App.
|
||||
Can be set via environment variable COPILOTSTUDIOAGENT__ENVIRONMENTID.
|
||||
schemaname: The agent identifier or schema name of the Copilot to use.
|
||||
@@ -42,32 +40,12 @@ class CopilotStudioSettings(AFBaseSettings):
|
||||
Can be set via environment variable COPILOTSTUDIOAGENT__AGENTAPPID.
|
||||
tenantid: The tenant ID of the App Registration used to login.
|
||||
Can be set via environment variable COPILOTSTUDIOAGENT__TENANTID.
|
||||
env_file_path: If provided, the .env settings are read from this file path location.
|
||||
env_file_encoding: The encoding of the .env file, defaults to 'utf-8'.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
from agent_framework_copilotstudio import CopilotStudioSettings
|
||||
|
||||
# Using environment variables
|
||||
# Set COPILOTSTUDIOAGENT__ENVIRONMENTID=env-123
|
||||
# Set COPILOTSTUDIOAGENT__SCHEMANAME=my-agent
|
||||
settings = CopilotStudioSettings()
|
||||
|
||||
# Or passing parameters directly
|
||||
settings = CopilotStudioSettings(environmentid="env-123", schemaname="my-agent")
|
||||
|
||||
# Or loading from a .env file
|
||||
settings = CopilotStudioSettings(env_file_path="path/to/.env")
|
||||
"""
|
||||
|
||||
env_prefix: ClassVar[str] = "COPILOTSTUDIOAGENT__"
|
||||
|
||||
environmentid: str | None = None
|
||||
schemaname: str | None = None
|
||||
agentappid: str | None = None
|
||||
tenantid: str | None = None
|
||||
environmentid: str | None
|
||||
schemaname: str | None
|
||||
agentappid: str | None
|
||||
tenantid: str | None
|
||||
|
||||
|
||||
class CopilotStudioAgent(BaseAgent):
|
||||
@@ -144,54 +122,53 @@ class CopilotStudioAgent(BaseAgent):
|
||||
middleware=middleware,
|
||||
)
|
||||
if not client:
|
||||
try:
|
||||
copilot_studio_settings = CopilotStudioSettings(
|
||||
environmentid=environment_id,
|
||||
schemaname=agent_identifier,
|
||||
agentappid=client_id,
|
||||
tenantid=tenant_id,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
except ValidationError as ex:
|
||||
raise ServiceInitializationError("Failed to create Copilot Studio settings.", ex) from ex
|
||||
copilot_studio_settings = load_settings(
|
||||
CopilotStudioSettings,
|
||||
env_prefix="COPILOTSTUDIOAGENT__",
|
||||
environmentid=environment_id,
|
||||
schemaname=agent_identifier,
|
||||
agentappid=client_id,
|
||||
tenantid=tenant_id,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
|
||||
if not settings:
|
||||
if not copilot_studio_settings.environmentid:
|
||||
if not copilot_studio_settings["environmentid"]:
|
||||
raise ServiceInitializationError(
|
||||
"Copilot Studio environment ID is required. Set via 'environment_id' parameter "
|
||||
"or 'COPILOTSTUDIOAGENT__ENVIRONMENTID' environment variable."
|
||||
)
|
||||
if not copilot_studio_settings.schemaname:
|
||||
if not copilot_studio_settings["schemaname"]:
|
||||
raise ServiceInitializationError(
|
||||
"Copilot Studio agent identifier/schema name is required. Set via 'agent_identifier' parameter "
|
||||
"or 'COPILOTSTUDIOAGENT__SCHEMANAME' environment variable."
|
||||
)
|
||||
|
||||
settings = ConnectionSettings(
|
||||
environment_id=copilot_studio_settings.environmentid,
|
||||
agent_identifier=copilot_studio_settings.schemaname,
|
||||
environment_id=copilot_studio_settings["environmentid"],
|
||||
agent_identifier=copilot_studio_settings["schemaname"],
|
||||
cloud=cloud,
|
||||
copilot_agent_type=agent_type,
|
||||
custom_power_platform_cloud=custom_power_platform_cloud,
|
||||
)
|
||||
|
||||
if not token:
|
||||
if not copilot_studio_settings.agentappid:
|
||||
if not copilot_studio_settings["agentappid"]:
|
||||
raise ServiceInitializationError(
|
||||
"Copilot Studio client ID is required. Set via 'client_id' parameter "
|
||||
"or 'COPILOTSTUDIOAGENT__AGENTAPPID' environment variable."
|
||||
)
|
||||
|
||||
if not copilot_studio_settings.tenantid:
|
||||
if not copilot_studio_settings["tenantid"]:
|
||||
raise ServiceInitializationError(
|
||||
"Copilot Studio tenant ID is required. Set via 'tenant_id' parameter "
|
||||
"or 'COPILOTSTUDIOAGENT__TENANTID' environment variable."
|
||||
)
|
||||
|
||||
token = acquire_token(
|
||||
client_id=copilot_studio_settings.agentappid,
|
||||
tenant_id=copilot_studio_settings.tenantid,
|
||||
client_id=copilot_studio_settings["agentappid"],
|
||||
tenant_id=copilot_studio_settings["tenantid"],
|
||||
username=username,
|
||||
token_cache=token_cache,
|
||||
scopes=scopes,
|
||||
|
||||
@@ -38,49 +38,57 @@ class TestCopilotStudioAgent:
|
||||
return MagicMock(spec=CopilotClient)
|
||||
|
||||
@patch("agent_framework_copilotstudio._acquire_token.acquire_token")
|
||||
@patch("agent_framework_copilotstudio._agent.CopilotStudioSettings")
|
||||
def test_init_missing_environment_id(self, mock_settings: MagicMock, mock_acquire_token: MagicMock) -> None:
|
||||
@patch("agent_framework_copilotstudio._agent.load_settings")
|
||||
def test_init_missing_environment_id(self, mock_load_settings: MagicMock, mock_acquire_token: MagicMock) -> None:
|
||||
mock_acquire_token.return_value = "fake-token"
|
||||
mock_settings.return_value.environmentid = None
|
||||
mock_settings.return_value.schemaname = "test-bot"
|
||||
mock_settings.return_value.tenantid = "test-tenant"
|
||||
mock_settings.return_value.agentappid = "test-client"
|
||||
mock_load_settings.return_value = {
|
||||
"environmentid": None,
|
||||
"schemaname": "test-bot",
|
||||
"tenantid": "test-tenant",
|
||||
"agentappid": "test-client",
|
||||
}
|
||||
|
||||
with pytest.raises(ServiceInitializationError, match="environment ID is required"):
|
||||
CopilotStudioAgent()
|
||||
|
||||
@patch("agent_framework_copilotstudio._acquire_token.acquire_token")
|
||||
@patch("agent_framework_copilotstudio._agent.CopilotStudioSettings")
|
||||
def test_init_missing_bot_id(self, mock_settings: MagicMock, mock_acquire_token: MagicMock) -> None:
|
||||
@patch("agent_framework_copilotstudio._agent.load_settings")
|
||||
def test_init_missing_bot_id(self, mock_load_settings: MagicMock, mock_acquire_token: MagicMock) -> None:
|
||||
mock_acquire_token.return_value = "fake-token"
|
||||
mock_settings.return_value.environmentid = "test-env"
|
||||
mock_settings.return_value.schemaname = None
|
||||
mock_settings.return_value.tenantid = "test-tenant"
|
||||
mock_settings.return_value.agentappid = "test-client"
|
||||
mock_load_settings.return_value = {
|
||||
"environmentid": "test-env",
|
||||
"schemaname": None,
|
||||
"tenantid": "test-tenant",
|
||||
"agentappid": "test-client",
|
||||
}
|
||||
|
||||
with pytest.raises(ServiceInitializationError, match="agent identifier"):
|
||||
CopilotStudioAgent()
|
||||
|
||||
@patch("agent_framework_copilotstudio._acquire_token.acquire_token")
|
||||
@patch("agent_framework_copilotstudio._agent.CopilotStudioSettings")
|
||||
def test_init_missing_tenant_id(self, mock_settings: MagicMock, mock_acquire_token: MagicMock) -> None:
|
||||
@patch("agent_framework_copilotstudio._agent.load_settings")
|
||||
def test_init_missing_tenant_id(self, mock_load_settings: MagicMock, mock_acquire_token: MagicMock) -> None:
|
||||
mock_acquire_token.return_value = "fake-token"
|
||||
mock_settings.return_value.environmentid = "test-env"
|
||||
mock_settings.return_value.schemaname = "test-bot"
|
||||
mock_settings.return_value.tenantid = None
|
||||
mock_settings.return_value.agentappid = "test-client"
|
||||
mock_load_settings.return_value = {
|
||||
"environmentid": "test-env",
|
||||
"schemaname": "test-bot",
|
||||
"tenantid": None,
|
||||
"agentappid": "test-client",
|
||||
}
|
||||
|
||||
with pytest.raises(ServiceInitializationError, match="tenant ID is required"):
|
||||
CopilotStudioAgent()
|
||||
|
||||
@patch("agent_framework_copilotstudio._acquire_token.acquire_token")
|
||||
@patch("agent_framework_copilotstudio._agent.CopilotStudioSettings")
|
||||
def test_init_missing_client_id(self, mock_settings: MagicMock, mock_acquire_token: MagicMock) -> None:
|
||||
@patch("agent_framework_copilotstudio._agent.load_settings")
|
||||
def test_init_missing_client_id(self, mock_load_settings: MagicMock, mock_acquire_token: MagicMock) -> None:
|
||||
mock_acquire_token.return_value = "fake-token"
|
||||
mock_settings.return_value.environmentid = "test-env"
|
||||
mock_settings.return_value.schemaname = "test-bot"
|
||||
mock_settings.return_value.tenantid = "test-tenant"
|
||||
mock_settings.return_value.agentappid = None
|
||||
mock_load_settings.return_value = {
|
||||
"environmentid": "test-env",
|
||||
"schemaname": "test-bot",
|
||||
"tenantid": "test-tenant",
|
||||
"agentappid": None,
|
||||
}
|
||||
|
||||
with pytest.raises(ServiceInitializationError, match="client ID is required"):
|
||||
CopilotStudioAgent()
|
||||
@@ -93,11 +101,13 @@ class TestCopilotStudioAgent:
|
||||
@patch("agent_framework_copilotstudio._acquire_token.acquire_token")
|
||||
def test_init_empty_environment_id(self, mock_acquire_token: MagicMock) -> None:
|
||||
mock_acquire_token.return_value = "fake-token"
|
||||
with patch("agent_framework_copilotstudio._agent.CopilotStudioSettings") as mock_settings:
|
||||
mock_settings.return_value.environmentid = ""
|
||||
mock_settings.return_value.schemaname = "test-bot"
|
||||
mock_settings.return_value.tenantid = "test-tenant"
|
||||
mock_settings.return_value.agentappid = "test-client"
|
||||
with patch("agent_framework_copilotstudio._agent.load_settings") as mock_load_settings:
|
||||
mock_load_settings.return_value = {
|
||||
"environmentid": "",
|
||||
"schemaname": "test-bot",
|
||||
"tenantid": "test-tenant",
|
||||
"agentappid": "test-client",
|
||||
}
|
||||
|
||||
with pytest.raises(ServiceInitializationError, match="environment ID is required"):
|
||||
CopilotStudioAgent()
|
||||
@@ -105,11 +115,13 @@ class TestCopilotStudioAgent:
|
||||
@patch("agent_framework_copilotstudio._acquire_token.acquire_token")
|
||||
def test_init_empty_schema_name(self, mock_acquire_token: MagicMock) -> None:
|
||||
mock_acquire_token.return_value = "fake-token"
|
||||
with patch("agent_framework_copilotstudio._agent.CopilotStudioSettings") as mock_settings:
|
||||
mock_settings.return_value.environmentid = "test-env"
|
||||
mock_settings.return_value.schemaname = ""
|
||||
mock_settings.return_value.tenantid = "test-tenant"
|
||||
mock_settings.return_value.agentappid = "test-client"
|
||||
with patch("agent_framework_copilotstudio._agent.load_settings") as mock_load_settings:
|
||||
mock_load_settings.return_value = {
|
||||
"environmentid": "test-env",
|
||||
"schemaname": "",
|
||||
"tenantid": "test-tenant",
|
||||
"agentappid": "test-client",
|
||||
}
|
||||
|
||||
with pytest.raises(ServiceInitializationError, match="agent identifier"):
|
||||
CopilotStudioAgent()
|
||||
|
||||
@@ -1,70 +0,0 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated, Any, ClassVar, TypeVar
|
||||
|
||||
from pydantic import Field, UrlConstraints
|
||||
from pydantic.networks import AnyUrl
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
HTTPsUrl = Annotated[AnyUrl, UrlConstraints(max_length=2083, allowed_schemes=["https"])]
|
||||
|
||||
__all__ = ["AFBaseSettings", "HTTPsUrl"]
|
||||
|
||||
|
||||
SettingsT = TypeVar("SettingsT", bound="AFBaseSettings")
|
||||
|
||||
|
||||
class AFBaseSettings(BaseSettings):
|
||||
"""Base class for all settings classes in the Agent Framework.
|
||||
|
||||
A subclass creates it's fields and overrides the env_prefix class variable
|
||||
with the prefix for the environment variables.
|
||||
|
||||
In the case where a value is specified for the same Settings field in multiple ways,
|
||||
the selected value is determined as follows (in descending order of priority):
|
||||
- Arguments passed to the Settings class initializer.
|
||||
- Environment variables, e.g. my_prefix_special_function as described above.
|
||||
- Variables loaded from a dotenv (.env) file.
|
||||
- Variables loaded from the secrets directory.
|
||||
- The default field values for the Settings model.
|
||||
"""
|
||||
|
||||
env_prefix: ClassVar[str] = ""
|
||||
env_file_path: str | None = Field(default=None, exclude=True)
|
||||
env_file_encoding: str | None = Field(default="utf-8", exclude=True)
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
extra="ignore",
|
||||
case_sensitive=False,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the settings class."""
|
||||
# Remove any None values from the kwargs so that defaults are used.
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __new__(cls: type[SettingsT], *args: Any, **kwargs: Any) -> SettingsT:
|
||||
"""Override the __new__ method to set the env_prefix."""
|
||||
# for both, if supplied but None, set to default
|
||||
if "env_file_encoding" in kwargs and kwargs["env_file_encoding"] is not None:
|
||||
env_file_encoding = kwargs["env_file_encoding"]
|
||||
else:
|
||||
env_file_encoding = "utf-8"
|
||||
if "env_file_path" in kwargs and kwargs["env_file_path"] is not None:
|
||||
env_file_path = kwargs["env_file_path"]
|
||||
else:
|
||||
env_file_path = ".env"
|
||||
cls.model_config.update( # type: ignore
|
||||
env_prefix=cls.env_prefix,
|
||||
env_file=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
cls.model_rebuild()
|
||||
return super().__new__(cls) # type: ignore[return-value]
|
||||
@@ -0,0 +1,262 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Generic settings loader with environment variable resolution.
|
||||
|
||||
This module provides a ``load_settings()`` function that populates a ``TypedDict``
|
||||
from environment variables, ``.env`` files, and explicit overrides. It replaces
|
||||
the previous pydantic-settings-based ``AFBaseSettings`` with a lighter-weight,
|
||||
function-based approach that has no pydantic-settings dependency.
|
||||
|
||||
Usage::
|
||||
|
||||
class MySettings(TypedDict, total=False):
|
||||
api_key: str | None # optional — resolves to None if not set
|
||||
model_id: str | None # optional by default
|
||||
|
||||
|
||||
# Make model_id required at call time:
|
||||
settings = load_settings(
|
||||
MySettings,
|
||||
env_prefix="MY_APP_",
|
||||
required_fields=["model_id"],
|
||||
model_id="gpt-4",
|
||||
)
|
||||
settings["api_key"] # type-checked dict access
|
||||
settings["model_id"] # str | None per type, but guaranteed not None at runtime
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import Callable, Sequence
|
||||
from contextlib import suppress
|
||||
from typing import Any, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from .exceptions import SettingNotFoundError
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
from typing import TypeVar # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypeVar # type: ignore # pragma: no cover
|
||||
|
||||
__all__ = ["SecretString", "load_settings"]
|
||||
|
||||
SettingsT = TypeVar("SettingsT", default=dict[str, Any])
|
||||
|
||||
|
||||
class SecretString(str):
|
||||
"""A string subclass that masks its value in repr() to prevent accidental exposure.
|
||||
|
||||
SecretString behaves exactly like a regular string in all operations,
|
||||
but its repr() shows '**********' instead of the actual value.
|
||||
This helps prevent secrets from being accidentally logged or displayed.
|
||||
|
||||
It also provides a ``get_secret_value()`` method for backward compatibility
|
||||
with code that previously used ``pydantic.SecretStr``.
|
||||
|
||||
Example:
|
||||
```python
|
||||
api_key = SecretString("sk-secret-key")
|
||||
print(api_key) # sk-secret-key (normal string behavior)
|
||||
print(repr(api_key)) # SecretString('**********')
|
||||
print(f"Key: {api_key}") # Key: sk-secret-key
|
||||
print(api_key.get_secret_value()) # sk-secret-key
|
||||
```
|
||||
"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return a masked representation to prevent secret exposure."""
|
||||
return "SecretString('**********')"
|
||||
|
||||
def get_secret_value(self) -> str:
|
||||
"""Return the underlying string value.
|
||||
|
||||
Provided for backward compatibility with ``pydantic.SecretStr``.
|
||||
Since SecretString *is* a str, this simply returns ``str(self)``.
|
||||
"""
|
||||
return str(self)
|
||||
|
||||
|
||||
def _coerce_value(value: str, target_type: type) -> Any:
|
||||
"""Coerce a string value to the target type."""
|
||||
origin = get_origin(target_type)
|
||||
args = get_args(target_type)
|
||||
|
||||
# Handle Union types (e.g., str | None) — try each non-None arm
|
||||
if origin is type(None):
|
||||
return None
|
||||
|
||||
if args and type(None) in args:
|
||||
for arg in args:
|
||||
if arg is not type(None):
|
||||
with suppress(ValueError, TypeError):
|
||||
return _coerce_value(value, arg)
|
||||
return value
|
||||
|
||||
# Handle SecretString
|
||||
if target_type is SecretString or (isinstance(target_type, type) and issubclass(target_type, SecretString)):
|
||||
return SecretString(value)
|
||||
|
||||
# Handle basic types
|
||||
if target_type is str:
|
||||
return value
|
||||
if target_type is int:
|
||||
return int(value)
|
||||
if target_type is float:
|
||||
return float(value)
|
||||
if target_type is bool:
|
||||
return value.lower() in ("true", "1", "yes", "on")
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def _check_override_type(value: Any, field_type: type, field_name: str) -> None:
|
||||
"""Validate that *value* is compatible with *field_type*.
|
||||
|
||||
Raises ``ServiceInitializationError`` when the override is clearly
|
||||
incompatible (e.g. a ``dict`` passed where ``str`` is expected).
|
||||
Callable values and ``None`` are always accepted.
|
||||
"""
|
||||
if value is None:
|
||||
return
|
||||
|
||||
# Callables are always allowed (e.g. lazy token providers)
|
||||
if callable(value) and not isinstance(value, (str, bytes)):
|
||||
return
|
||||
|
||||
# Collect the concrete types that *field_type* allows
|
||||
origin = get_origin(field_type)
|
||||
args = get_args(field_type)
|
||||
|
||||
allowed: tuple[type, ...]
|
||||
if origin is Union or origin is type(int | str):
|
||||
allowed = tuple(a for a in args if isinstance(a, type) and a is not type(None))
|
||||
# If any arm is a Callable, allow anything callable
|
||||
if any(get_origin(a) is Callable or a is Callable for a in args):
|
||||
return
|
||||
elif isinstance(field_type, type):
|
||||
allowed = (field_type,)
|
||||
else:
|
||||
return # complex / unknown annotation — skip check
|
||||
|
||||
if not allowed:
|
||||
return
|
||||
|
||||
if not isinstance(value, allowed):
|
||||
# Allow str for SecretString fields (will be coerced)
|
||||
if isinstance(value, str) and any(isinstance(a, type) and issubclass(a, str) for a in allowed):
|
||||
return
|
||||
# Allow int for float fields (standard numeric promotion)
|
||||
if isinstance(value, int) and float in allowed:
|
||||
return
|
||||
|
||||
from .exceptions import ServiceInitializationError
|
||||
|
||||
allowed_names = ", ".join(t.__name__ for t in allowed)
|
||||
raise ServiceInitializationError(
|
||||
f"Invalid type for setting '{field_name}': expected {allowed_names}, got {type(value).__name__}."
|
||||
)
|
||||
|
||||
|
||||
def load_settings(
|
||||
settings_type: type[SettingsT],
|
||||
*,
|
||||
env_prefix: str = "",
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
required_fields: Sequence[str] | None = None,
|
||||
**overrides: Any,
|
||||
) -> SettingsT:
|
||||
"""Load settings from environment variables, a ``.env`` file, and explicit overrides.
|
||||
|
||||
The *settings_type* must be a ``TypedDict`` subclass. Values are resolved in
|
||||
this order (highest priority first):
|
||||
|
||||
1. Explicit keyword *overrides* (``None`` values are filtered out).
|
||||
2. Environment variables (``<env_prefix><FIELD_NAME>``).
|
||||
3. A ``.env`` file (loaded via ``python-dotenv``; existing env vars take precedence).
|
||||
4. Default values — fields with class-level defaults on the TypedDict, or
|
||||
``None`` for optional fields.
|
||||
|
||||
Fields listed in *required_fields* are validated after resolution. If any
|
||||
required field resolves to ``None``, a ``SettingNotFoundError`` is raised.
|
||||
This allows callers to decide which fields are required based on runtime
|
||||
context (e.g. ``endpoint`` is only required when no pre-built client is
|
||||
provided).
|
||||
|
||||
Args:
|
||||
settings_type: A ``TypedDict`` class describing the settings schema.
|
||||
env_prefix: Prefix for environment variable lookup (e.g. ``"OPENAI_"``).
|
||||
env_file_path: Path to ``.env`` file. Defaults to ``".env"`` when omitted.
|
||||
env_file_encoding: Encoding for reading the ``.env`` file. Defaults to ``"utf-8"``.
|
||||
required_fields: Field names that must resolve to a non-``None`` value.
|
||||
**overrides: Field values. ``None`` values are ignored so that callers can
|
||||
forward optional parameters without masking env-var / default resolution.
|
||||
|
||||
Returns:
|
||||
A populated dict matching *settings_type*.
|
||||
|
||||
Raises:
|
||||
SettingNotFoundError: If a required field could not be resolved from any source.
|
||||
ServiceInitializationError: If an override value has an incompatible type.
|
||||
"""
|
||||
encoding = env_file_encoding or "utf-8"
|
||||
|
||||
# Load .env file if it exists (existing env vars take precedence by default)
|
||||
env_path = env_file_path or ".env"
|
||||
if os.path.isfile(env_path):
|
||||
load_dotenv(dotenv_path=env_path, encoding=encoding)
|
||||
|
||||
# Filter out None overrides so defaults / env vars are preserved
|
||||
overrides = {k: v for k, v in overrides.items() if v is not None}
|
||||
|
||||
# Get field type hints from the TypedDict
|
||||
hints = get_type_hints(settings_type)
|
||||
required: set[str] = set(required_fields) if required_fields else set()
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
for field_name, field_type in hints.items():
|
||||
# 1. Explicit override wins
|
||||
if field_name in overrides:
|
||||
override_value = overrides[field_name]
|
||||
_check_override_type(override_value, field_type, field_name)
|
||||
# Coerce plain str → SecretString if the annotation expects it
|
||||
if isinstance(override_value, str) and not isinstance(override_value, SecretString):
|
||||
with suppress(ValueError, TypeError):
|
||||
coerced = _coerce_value(override_value, field_type)
|
||||
if isinstance(coerced, SecretString):
|
||||
override_value = coerced
|
||||
result[field_name] = override_value
|
||||
continue
|
||||
|
||||
# 2. Environment variable
|
||||
env_var_name = f"{env_prefix}{field_name.upper()}"
|
||||
env_value = os.getenv(env_var_name)
|
||||
if env_value is not None:
|
||||
try:
|
||||
result[field_name] = _coerce_value(env_value, field_type)
|
||||
except (ValueError, TypeError):
|
||||
result[field_name] = env_value
|
||||
continue
|
||||
|
||||
# 3. Default from TypedDict class-level defaults, or None for optional fields
|
||||
if hasattr(settings_type, field_name):
|
||||
result[field_name] = getattr(settings_type, field_name)
|
||||
else:
|
||||
result[field_name] = None
|
||||
|
||||
# Validate required fields after all resolution
|
||||
if required:
|
||||
for field_name in required:
|
||||
if result.get(field_name) is None:
|
||||
env_var_name = f"{env_prefix}{field_name.upper()}"
|
||||
raise SettingNotFoundError(
|
||||
f"Required setting '{field_name}' was not provided. "
|
||||
f"Set it via the '{field_name}' parameter or the "
|
||||
f"'{env_var_name}' environment variable."
|
||||
)
|
||||
|
||||
return result # type: ignore[return-value]
|
||||
@@ -7,12 +7,13 @@ from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Generic
|
||||
|
||||
from openai.lib.azure import AsyncAzureADTokenProvider, AsyncAzureOpenAI
|
||||
from pydantic import ValidationError
|
||||
|
||||
from .._settings import load_settings
|
||||
from ..exceptions import ServiceInitializationError
|
||||
from ..openai import OpenAIAssistantsClient
|
||||
from ..openai._assistants_client import OpenAIAssistantsOptions
|
||||
from ._shared import AzureOpenAISettings
|
||||
from ._entra_id_authentication import get_entra_auth_token
|
||||
from ._shared import DEFAULT_AZURE_TOKEN_ENDPOINT, AzureOpenAISettings, _apply_azure_defaults
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from azure.core.credentials import TokenCredential
|
||||
@@ -137,23 +138,21 @@ class AzureOpenAIAssistantsClient(
|
||||
client: AzureOpenAIAssistantsClient[MyOptions] = AzureOpenAIAssistantsClient()
|
||||
response = await client.get_response("Hello", options={"my_custom_option": "value"})
|
||||
"""
|
||||
try:
|
||||
azure_openai_settings = AzureOpenAISettings(
|
||||
# pydantic settings will see if there is a value, if not, will try the env var or .env file
|
||||
api_key=api_key, # type: ignore
|
||||
base_url=base_url, # type: ignore
|
||||
endpoint=endpoint, # type: ignore
|
||||
chat_deployment_name=deployment_name,
|
||||
api_version=api_version,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
token_endpoint=token_endpoint,
|
||||
default_api_version=self.DEFAULT_AZURE_API_VERSION,
|
||||
)
|
||||
except ValidationError as ex:
|
||||
raise ServiceInitializationError("Failed to create Azure OpenAI settings.", ex) from ex
|
||||
azure_openai_settings = load_settings(
|
||||
AzureOpenAISettings,
|
||||
env_prefix="AZURE_OPENAI_",
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
endpoint=endpoint,
|
||||
chat_deployment_name=deployment_name,
|
||||
api_version=api_version,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
token_endpoint=token_endpoint,
|
||||
)
|
||||
_apply_azure_defaults(azure_openai_settings, default_api_version=self.DEFAULT_AZURE_API_VERSION)
|
||||
|
||||
if not azure_openai_settings.chat_deployment_name:
|
||||
if not azure_openai_settings["chat_deployment_name"]:
|
||||
raise ServiceInitializationError(
|
||||
"Azure OpenAI deployment name is required. Set via 'deployment_name' parameter "
|
||||
"or 'AZURE_OPENAI_CHAT_DEPLOYMENT_NAME' environment variable."
|
||||
@@ -162,40 +161,41 @@ class AzureOpenAIAssistantsClient(
|
||||
# Handle authentication: try API key first, then AD token, then Entra ID
|
||||
if (
|
||||
not async_client
|
||||
and not azure_openai_settings.api_key
|
||||
and not azure_openai_settings["api_key"]
|
||||
and not ad_token
|
||||
and not ad_token_provider
|
||||
and azure_openai_settings.token_endpoint
|
||||
and azure_openai_settings["token_endpoint"]
|
||||
and credential
|
||||
):
|
||||
ad_token = azure_openai_settings.get_azure_auth_token(credential)
|
||||
token_ep = azure_openai_settings["token_endpoint"] or DEFAULT_AZURE_TOKEN_ENDPOINT
|
||||
ad_token = get_entra_auth_token(credential, token_ep)
|
||||
|
||||
if not async_client and not azure_openai_settings.api_key and not ad_token and not ad_token_provider:
|
||||
if not async_client and not azure_openai_settings["api_key"] and not ad_token and not ad_token_provider:
|
||||
raise ServiceInitializationError("The Azure OpenAI API key, ad_token, or ad_token_provider is required.")
|
||||
|
||||
# Create Azure client if not provided
|
||||
if not async_client:
|
||||
client_params: dict[str, Any] = {
|
||||
"api_version": azure_openai_settings.api_version,
|
||||
"api_version": azure_openai_settings["api_version"],
|
||||
"default_headers": default_headers,
|
||||
}
|
||||
|
||||
if azure_openai_settings.api_key:
|
||||
client_params["api_key"] = azure_openai_settings.api_key.get_secret_value()
|
||||
if azure_openai_settings["api_key"]:
|
||||
client_params["api_key"] = azure_openai_settings["api_key"].get_secret_value()
|
||||
elif ad_token:
|
||||
client_params["azure_ad_token"] = ad_token
|
||||
elif ad_token_provider:
|
||||
client_params["azure_ad_token_provider"] = ad_token_provider
|
||||
|
||||
if azure_openai_settings.base_url:
|
||||
client_params["base_url"] = str(azure_openai_settings.base_url)
|
||||
elif azure_openai_settings.endpoint:
|
||||
client_params["azure_endpoint"] = str(azure_openai_settings.endpoint)
|
||||
if azure_openai_settings["base_url"]:
|
||||
client_params["base_url"] = str(azure_openai_settings["base_url"])
|
||||
elif azure_openai_settings["endpoint"]:
|
||||
client_params["azure_endpoint"] = str(azure_openai_settings["endpoint"])
|
||||
|
||||
async_client = AsyncAzureOpenAI(**client_params)
|
||||
|
||||
super().__init__(
|
||||
model_id=azure_openai_settings.chat_deployment_name,
|
||||
model_id=azure_openai_settings["chat_deployment_name"],
|
||||
assistant_id=assistant_id,
|
||||
assistant_name=assistant_name,
|
||||
assistant_description=assistant_description,
|
||||
|
||||
@@ -12,7 +12,7 @@ from azure.core.credentials import TokenCredential
|
||||
from openai.lib.azure import AsyncAzureADTokenProvider, AsyncAzureOpenAI
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agent_framework import (
|
||||
Annotation,
|
||||
@@ -28,9 +28,11 @@ from agent_framework.observability import ChatTelemetryLayer
|
||||
from agent_framework.openai import OpenAIChatOptions
|
||||
from agent_framework.openai._chat_client import RawOpenAIChatClient
|
||||
|
||||
from .._settings import load_settings
|
||||
from ._shared import (
|
||||
AzureOpenAIConfigMixin,
|
||||
AzureOpenAISettings,
|
||||
_apply_azure_defaults,
|
||||
)
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
@@ -247,37 +249,35 @@ class AzureOpenAIChatClient( # type: ignore[misc]
|
||||
client: AzureOpenAIChatClient[MyOptions] = AzureOpenAIChatClient()
|
||||
response = await client.get_response("Hello", options={"my_custom_option": "value"})
|
||||
"""
|
||||
try:
|
||||
# Filter out any None values from the arguments
|
||||
azure_openai_settings = AzureOpenAISettings(
|
||||
# pydantic settings will see if there is a value, if not, will try the env var or .env file
|
||||
api_key=api_key, # type: ignore
|
||||
base_url=base_url, # type: ignore
|
||||
endpoint=endpoint, # type: ignore
|
||||
chat_deployment_name=deployment_name,
|
||||
api_version=api_version,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
token_endpoint=token_endpoint,
|
||||
)
|
||||
except ValidationError as exc:
|
||||
raise ServiceInitializationError(f"Failed to validate settings: {exc}") from exc
|
||||
azure_openai_settings = load_settings(
|
||||
AzureOpenAISettings,
|
||||
env_prefix="AZURE_OPENAI_",
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
endpoint=endpoint,
|
||||
chat_deployment_name=deployment_name,
|
||||
api_version=api_version,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
token_endpoint=token_endpoint,
|
||||
)
|
||||
_apply_azure_defaults(azure_openai_settings)
|
||||
|
||||
if not azure_openai_settings.chat_deployment_name:
|
||||
if not azure_openai_settings["chat_deployment_name"]:
|
||||
raise ServiceInitializationError(
|
||||
"Azure OpenAI deployment name is required. Set via 'deployment_name' parameter "
|
||||
"or 'AZURE_OPENAI_CHAT_DEPLOYMENT_NAME' environment variable."
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
deployment_name=azure_openai_settings.chat_deployment_name,
|
||||
endpoint=azure_openai_settings.endpoint,
|
||||
base_url=azure_openai_settings.base_url,
|
||||
api_version=azure_openai_settings.api_version, # type: ignore
|
||||
api_key=azure_openai_settings.api_key.get_secret_value() if azure_openai_settings.api_key else None,
|
||||
deployment_name=azure_openai_settings["chat_deployment_name"],
|
||||
endpoint=azure_openai_settings["endpoint"],
|
||||
base_url=azure_openai_settings["base_url"],
|
||||
api_version=azure_openai_settings["api_version"], # type: ignore
|
||||
api_key=azure_openai_settings["api_key"].get_secret_value() if azure_openai_settings["api_key"] else None,
|
||||
ad_token=ad_token,
|
||||
ad_token_provider=ad_token_provider,
|
||||
token_endpoint=azure_openai_settings.token_endpoint,
|
||||
token_endpoint=azure_openai_settings["token_endpoint"],
|
||||
credential=credential,
|
||||
default_headers=default_headers,
|
||||
client=async_client,
|
||||
|
||||
@@ -5,15 +5,15 @@ from __future__ import annotations
|
||||
import sys
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Generic
|
||||
from urllib.parse import urljoin
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
from azure.ai.projects.aio import AIProjectClient
|
||||
from azure.core.credentials import TokenCredential
|
||||
from openai import AsyncOpenAI
|
||||
from openai.lib.azure import AsyncAzureADTokenProvider
|
||||
from pydantic import ValidationError
|
||||
|
||||
from .._middleware import ChatMiddlewareLayer
|
||||
from .._settings import load_settings
|
||||
from .._telemetry import AGENT_FRAMEWORK_USER_AGENT
|
||||
from .._tools import FunctionInvocationConfiguration, FunctionInvocationLayer
|
||||
from ..exceptions import ServiceInitializationError
|
||||
@@ -22,6 +22,7 @@ from ..openai._responses_client import RawOpenAIResponsesClient
|
||||
from ._shared import (
|
||||
AzureOpenAIConfigMixin,
|
||||
AzureOpenAISettings,
|
||||
_apply_azure_defaults,
|
||||
)
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
@@ -82,7 +83,8 @@ class AzureOpenAIResponsesClient( # type: ignore[misc]
|
||||
env_file_encoding: str | None = None,
|
||||
instruction_role: str | None = None,
|
||||
middleware: Sequence[MiddlewareTypes] | None = None,
|
||||
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
|
||||
function_invocation_configuration: FunctionInvocationConfiguration
|
||||
| None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize an Azure OpenAI Responses client.
|
||||
@@ -188,54 +190,58 @@ class AzureOpenAIResponsesClient( # type: ignore[misc]
|
||||
deployment_name = str(model_id)
|
||||
|
||||
# Project client path: create OpenAI client from an Azure AI Foundry project
|
||||
if async_client is None and (project_client is not None or project_endpoint is not None):
|
||||
if async_client is None and (
|
||||
project_client is not None or project_endpoint is not None
|
||||
):
|
||||
async_client = self._create_client_from_project(
|
||||
project_client=project_client,
|
||||
project_endpoint=project_endpoint,
|
||||
credential=credential,
|
||||
)
|
||||
|
||||
try:
|
||||
azure_openai_settings = AzureOpenAISettings(
|
||||
# pydantic settings will see if there is a value, if not, will try the env var or .env file
|
||||
api_key=api_key, # type: ignore
|
||||
base_url=base_url, # type: ignore
|
||||
endpoint=endpoint, # type: ignore
|
||||
responses_deployment_name=deployment_name,
|
||||
api_version=api_version,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
token_endpoint=token_endpoint,
|
||||
default_api_version="preview",
|
||||
azure_openai_settings = load_settings(
|
||||
AzureOpenAISettings,
|
||||
env_prefix="AZURE_OPENAI_",
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
endpoint=endpoint,
|
||||
responses_deployment_name=deployment_name,
|
||||
api_version=api_version,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
token_endpoint=token_endpoint,
|
||||
)
|
||||
_apply_azure_defaults(azure_openai_settings, default_api_version="preview")
|
||||
# TODO(peterychang): This is a temporary hack to ensure that the base_url is set correctly
|
||||
# while this feature is in preview.
|
||||
# But we should only do this if we're on azure. Private deployments may not need this.
|
||||
if (
|
||||
not azure_openai_settings.get("base_url")
|
||||
and azure_openai_settings.get("endpoint")
|
||||
and (hostname := urlparse(str(azure_openai_settings["endpoint"])).hostname)
|
||||
and hostname.endswith(".openai.azure.com")
|
||||
):
|
||||
azure_openai_settings["base_url"] = urljoin(
|
||||
str(azure_openai_settings["endpoint"]), "/openai/v1/"
|
||||
)
|
||||
# TODO(peterychang): This is a temporary hack to ensure that the base_url is set correctly
|
||||
# while this feature is in preview.
|
||||
# But we should only do this if we're on azure. Private deployments may not need this.
|
||||
if (
|
||||
not azure_openai_settings.base_url
|
||||
and azure_openai_settings.endpoint
|
||||
and azure_openai_settings.endpoint.host
|
||||
and azure_openai_settings.endpoint.host.endswith(".openai.azure.com")
|
||||
):
|
||||
azure_openai_settings.base_url = urljoin(str(azure_openai_settings.endpoint), "/openai/v1/") # type: ignore
|
||||
except ValidationError as exc:
|
||||
raise ServiceInitializationError(f"Failed to validate settings: {exc}") from exc
|
||||
|
||||
if not azure_openai_settings.responses_deployment_name:
|
||||
if not azure_openai_settings["responses_deployment_name"]:
|
||||
raise ServiceInitializationError(
|
||||
"Azure OpenAI deployment name is required. Set via 'deployment_name' parameter "
|
||||
"or 'AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME' environment variable."
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
deployment_name=azure_openai_settings.responses_deployment_name,
|
||||
endpoint=azure_openai_settings.endpoint,
|
||||
base_url=azure_openai_settings.base_url,
|
||||
api_version=azure_openai_settings.api_version, # type: ignore
|
||||
api_key=azure_openai_settings.api_key.get_secret_value() if azure_openai_settings.api_key else None,
|
||||
deployment_name=azure_openai_settings["responses_deployment_name"],
|
||||
endpoint=azure_openai_settings["endpoint"],
|
||||
base_url=azure_openai_settings["base_url"],
|
||||
api_version=azure_openai_settings["api_version"], # type: ignore
|
||||
api_key=azure_openai_settings["api_key"].get_secret_value()
|
||||
if azure_openai_settings["api_key"]
|
||||
else None,
|
||||
ad_token=ad_token,
|
||||
ad_token_provider=ad_token_provider,
|
||||
token_endpoint=azure_openai_settings.token_endpoint,
|
||||
token_endpoint=azure_openai_settings["token_endpoint"],
|
||||
credential=credential,
|
||||
default_headers=default_headers,
|
||||
client=async_client,
|
||||
|
||||
@@ -11,28 +11,26 @@ from typing import Any, ClassVar, Final
|
||||
from azure.core.credentials import TokenCredential
|
||||
from openai import AsyncOpenAI
|
||||
from openai.lib.azure import AsyncAzureOpenAI
|
||||
from pydantic import SecretStr, model_validator
|
||||
|
||||
from .._pydantic import AFBaseSettings, HTTPsUrl
|
||||
from .._settings import SecretString
|
||||
from .._telemetry import APP_INFO, prepend_agent_framework_to_user_agent
|
||||
from ..exceptions import ServiceInitializationError
|
||||
from ..openai._shared import OpenAIBase
|
||||
from ._entra_id_authentication import get_entra_auth_token
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import Self # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import Self # pragma: no cover
|
||||
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import TypedDict # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypedDict # type: ignore # pragma: no cover
|
||||
|
||||
|
||||
DEFAULT_AZURE_API_VERSION: Final[str] = "2024-10-21"
|
||||
DEFAULT_AZURE_TOKEN_ENDPOINT: Final[str] = "https://cognitiveservices.azure.com/.default" # noqa: S105
|
||||
|
||||
|
||||
class AzureOpenAISettings(AFBaseSettings):
|
||||
class AzureOpenAISettings(TypedDict, total=False):
|
||||
"""AzureOpenAI model settings.
|
||||
|
||||
The settings are first loaded from environment variables with the prefix 'AZURE_OPENAI_'.
|
||||
@@ -62,7 +60,7 @@ class AzureOpenAISettings(AFBaseSettings):
|
||||
found in the Keys & Endpoint section when examining your resource in
|
||||
the Azure portal. You can use either KEY1 or KEY2.
|
||||
Can be set via environment variable AZURE_OPENAI_API_KEY.
|
||||
api_version: The API version to use. The default value is `default_api_version`.
|
||||
api_version: The API version to use. The default value is `DEFAULT_AZURE_API_VERSION`.
|
||||
Can be set via environment variable AZURE_OPENAI_API_VERSION.
|
||||
base_url: The url of the Azure deployment. This value
|
||||
can be found in the Keys & Endpoint section when examining
|
||||
@@ -71,14 +69,8 @@ class AzureOpenAISettings(AFBaseSettings):
|
||||
use endpoint if you only want to supply the endpoint.
|
||||
Can be set via environment variable AZURE_OPENAI_BASE_URL.
|
||||
token_endpoint: The token endpoint to use to retrieve the authentication token.
|
||||
The default value is `default_token_endpoint`.
|
||||
The default value is `DEFAULT_AZURE_TOKEN_ENDPOINT`.
|
||||
Can be set via environment variable AZURE_OPENAI_TOKEN_ENDPOINT.
|
||||
default_api_version: The default API version to use if not specified.
|
||||
The default value is "2024-10-21".
|
||||
default_token_endpoint: The default token endpoint to use if not specified.
|
||||
The default value is "https://cognitiveservices.azure.com/.default".
|
||||
env_file_path: The path to the .env file to load settings from.
|
||||
env_file_encoding: The encoding of the .env file, defaults to 'utf-8'.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
@@ -89,60 +81,46 @@ class AzureOpenAISettings(AFBaseSettings):
|
||||
# Set AZURE_OPENAI_ENDPOINT=https://your-endpoint.openai.azure.com
|
||||
# Set AZURE_OPENAI_CHAT_DEPLOYMENT_NAME=gpt-4
|
||||
# Set AZURE_OPENAI_API_KEY=your-key
|
||||
settings = AzureOpenAISettings()
|
||||
settings = load_settings(AzureOpenAISettings, env_prefix="AZURE_OPENAI_")
|
||||
|
||||
# Or passing parameters directly
|
||||
settings = AzureOpenAISettings(
|
||||
endpoint="https://your-endpoint.openai.azure.com", chat_deployment_name="gpt-4", api_key="your-key"
|
||||
settings = load_settings(
|
||||
AzureOpenAISettings,
|
||||
env_prefix="AZURE_OPENAI_",
|
||||
endpoint="https://your-endpoint.openai.azure.com",
|
||||
chat_deployment_name="gpt-4",
|
||||
api_key="your-key",
|
||||
)
|
||||
|
||||
# Or loading from a .env file
|
||||
settings = AzureOpenAISettings(env_file_path="path/to/.env")
|
||||
settings = load_settings(AzureOpenAISettings, env_prefix="AZURE_OPENAI_", env_file_path="path/to/.env")
|
||||
"""
|
||||
|
||||
env_prefix: ClassVar[str] = "AZURE_OPENAI_"
|
||||
chat_deployment_name: str | None
|
||||
responses_deployment_name: str | None
|
||||
endpoint: str | None
|
||||
base_url: str | None
|
||||
api_key: SecretString | None
|
||||
api_version: str | None
|
||||
token_endpoint: str | None
|
||||
|
||||
chat_deployment_name: str | None = None
|
||||
responses_deployment_name: str | None = None
|
||||
endpoint: HTTPsUrl | None = None
|
||||
base_url: HTTPsUrl | None = None
|
||||
api_key: SecretStr | None = None
|
||||
api_version: str | None = None
|
||||
token_endpoint: str | None = None
|
||||
default_api_version: str = DEFAULT_AZURE_API_VERSION
|
||||
default_token_endpoint: str = DEFAULT_AZURE_TOKEN_ENDPOINT
|
||||
|
||||
def get_azure_auth_token(
|
||||
self, credential: TokenCredential, token_endpoint: str | None = None, **kwargs: Any
|
||||
) -> str | None:
|
||||
"""Retrieve a Microsoft Entra Auth Token for a given token endpoint for the use with Azure OpenAI.
|
||||
def _apply_azure_defaults(
|
||||
settings: AzureOpenAISettings,
|
||||
default_api_version: str = DEFAULT_AZURE_API_VERSION,
|
||||
default_token_endpoint: str = DEFAULT_AZURE_TOKEN_ENDPOINT,
|
||||
) -> None:
|
||||
"""Apply default values for api_version and token_endpoint after loading settings.
|
||||
|
||||
The required role for the token is `Cognitive Services OpenAI Contributor`.
|
||||
The token endpoint may be specified as an environment variable, via the .env
|
||||
file or as an argument. If the token endpoint is not provided, the default is None.
|
||||
The `token_endpoint` argument takes precedence over the `token_endpoint` attribute.
|
||||
|
||||
Args:
|
||||
credential: The Azure AD credential to use.
|
||||
token_endpoint: The token endpoint to use. Defaults to `https://cognitiveservices.azure.com/.default`.
|
||||
|
||||
Keyword Args:
|
||||
**kwargs: Additional keyword arguments to pass to the token retrieval method.
|
||||
|
||||
Returns:
|
||||
The Azure token or None if the token could not be retrieved.
|
||||
|
||||
Raises:
|
||||
ServiceInitializationError: If the token endpoint is not provided.
|
||||
"""
|
||||
endpoint_to_use = token_endpoint or self.token_endpoint or self.default_token_endpoint
|
||||
return get_entra_auth_token(credential, endpoint_to_use, **kwargs)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_fields(self) -> Self:
|
||||
self.api_version = self.api_version or self.default_api_version
|
||||
self.token_endpoint = self.token_endpoint or self.default_token_endpoint
|
||||
return self
|
||||
Args:
|
||||
settings: The loaded Azure OpenAI settings dict.
|
||||
default_api_version: The default API version to use if not set.
|
||||
default_token_endpoint: The default token endpoint to use if not set.
|
||||
"""
|
||||
if not settings.get("api_version"):
|
||||
settings["api_version"] = default_api_version
|
||||
if not settings.get("token_endpoint"):
|
||||
settings["token_endpoint"] = default_token_endpoint
|
||||
|
||||
|
||||
class AzureOpenAIConfigMixin(OpenAIBase):
|
||||
@@ -154,8 +132,8 @@ class AzureOpenAIConfigMixin(OpenAIBase):
|
||||
def __init__(
|
||||
self,
|
||||
deployment_name: str,
|
||||
endpoint: HTTPsUrl | None = None,
|
||||
base_url: HTTPsUrl | None = None,
|
||||
endpoint: str | None = None,
|
||||
base_url: str | None = None,
|
||||
api_version: str = DEFAULT_AZURE_API_VERSION,
|
||||
api_key: str | None = None,
|
||||
ad_token: str | None = None,
|
||||
@@ -170,7 +148,7 @@ class AzureOpenAIConfigMixin(OpenAIBase):
|
||||
"""Internal class for configuring a connection to an Azure OpenAI service.
|
||||
|
||||
The `validate_call` decorator is used with a configuration that allows arbitrary types.
|
||||
This is necessary for types like `HTTPsUrl` and `OpenAIModelTypes`.
|
||||
This is necessary for types like `str` and `OpenAIModelTypes`.
|
||||
|
||||
Args:
|
||||
deployment_name: Name of the deployment.
|
||||
|
||||
@@ -146,3 +146,9 @@ class ContentError(AgentFrameworkException):
|
||||
"""An error occurred while processing content."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SettingNotFoundError(AgentFrameworkException):
|
||||
"""A required setting could not be resolved from any source."""
|
||||
|
||||
pass
|
||||
|
||||
@@ -18,11 +18,10 @@ from opentelemetry import metrics, trace
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.semconv.attributes import service_attributes
|
||||
from opentelemetry.semconv_ai import Meters, SpanAttributes
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from . import __version__ as version_info
|
||||
from ._logging import get_logger
|
||||
from ._pydantic import AFBaseSettings
|
||||
from ._settings import load_settings
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
from typing import TypeVar # type: ignore # pragma: no cover
|
||||
@@ -566,7 +565,16 @@ def create_metric_views() -> list[View]:
|
||||
]
|
||||
|
||||
|
||||
class ObservabilitySettings(AFBaseSettings):
|
||||
class _ObservabilitySettingsData(TypedDict, total=False):
|
||||
"""TypedDict schema for observability settings fields."""
|
||||
|
||||
enable_instrumentation: bool | None
|
||||
enable_sensitive_data: bool | None
|
||||
enable_console_exporters: bool | None
|
||||
vs_code_extension_port: int | None
|
||||
|
||||
|
||||
class ObservabilitySettings:
|
||||
"""Settings for Agent Framework Observability.
|
||||
|
||||
If the environment variables are not found, the settings can
|
||||
@@ -603,23 +611,27 @@ class ObservabilitySettings(AFBaseSettings):
|
||||
settings = ObservabilitySettings(enable_instrumentation=True, enable_console_exporters=True)
|
||||
"""
|
||||
|
||||
env_prefix: ClassVar[str] = ""
|
||||
|
||||
enable_instrumentation: bool = False
|
||||
enable_sensitive_data: bool = False
|
||||
enable_console_exporters: bool = False
|
||||
vs_code_extension_port: int | None = None
|
||||
_resource: Resource = PrivateAttr()
|
||||
_executed_setup: bool = PrivateAttr(default=False)
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the settings and create the resource."""
|
||||
super().__init__(**kwargs)
|
||||
# Create resource with env file settings
|
||||
self._resource = create_resource(
|
||||
env_file_path=self.env_file_path,
|
||||
env_file_encoding=self.env_file_encoding,
|
||||
env_file_path = kwargs.pop("env_file_path", None)
|
||||
env_file_encoding = kwargs.pop("env_file_encoding", None)
|
||||
data = load_settings(
|
||||
_ObservabilitySettingsData,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
**kwargs,
|
||||
)
|
||||
self.enable_instrumentation: bool = data.get("enable_instrumentation") or False
|
||||
self.enable_sensitive_data: bool = data.get("enable_sensitive_data") or False
|
||||
self.enable_console_exporters: bool = data.get("enable_console_exporters") or False
|
||||
self.vs_code_extension_port: int | None = data.get("vs_code_extension_port")
|
||||
self.env_file_path = env_file_path
|
||||
self.env_file_encoding = env_file_encoding
|
||||
self._resource = create_resource(
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
self._executed_setup = False
|
||||
|
||||
@property
|
||||
def ENABLED(self) -> bool:
|
||||
|
||||
@@ -8,7 +8,9 @@ from typing import TYPE_CHECKING, Any, Generic, cast
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.beta.assistant import Assistant
|
||||
from pydantic import BaseModel, SecretStr, ValidationError
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agent_framework._settings import SecretString, load_settings
|
||||
|
||||
from .._agents import Agent
|
||||
from .._memory import ContextProvider
|
||||
@@ -107,7 +109,7 @@ class OpenAIAssistantProvider(Generic[OptionsCoT]):
|
||||
self,
|
||||
client: AsyncOpenAI | None = None,
|
||||
*,
|
||||
api_key: str | SecretStr | Callable[[], str | Awaitable[str]] | None = None,
|
||||
api_key: str | SecretString | Callable[[], str | Awaitable[str]] | None = None,
|
||||
org_id: str | None = None,
|
||||
base_url: str | None = None,
|
||||
env_file_path: str | None = None,
|
||||
@@ -147,35 +149,34 @@ class OpenAIAssistantProvider(Generic[OptionsCoT]):
|
||||
|
||||
if client is None:
|
||||
# Load settings and create client
|
||||
try:
|
||||
settings = OpenAISettings(
|
||||
api_key=api_key, # type: ignore[reportArgumentType]
|
||||
org_id=org_id,
|
||||
base_url=base_url,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
except ValidationError as ex:
|
||||
raise ServiceInitializationError("Failed to create OpenAI settings.", ex) from ex
|
||||
settings = load_settings(
|
||||
OpenAISettings,
|
||||
env_prefix="OPENAI_",
|
||||
api_key=api_key,
|
||||
org_id=org_id,
|
||||
base_url=base_url,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
|
||||
if not settings.api_key:
|
||||
if not settings["api_key"]:
|
||||
raise ServiceInitializationError(
|
||||
"OpenAI API key is required. Set via 'api_key' parameter or 'OPENAI_API_KEY' environment variable."
|
||||
)
|
||||
|
||||
# Get API key value
|
||||
api_key_value: str | Callable[[], str | Awaitable[str]] | None
|
||||
if isinstance(settings.api_key, SecretStr):
|
||||
api_key_value = settings.api_key.get_secret_value()
|
||||
if isinstance(settings["api_key"], SecretString):
|
||||
api_key_value = settings["api_key"].get_secret_value()
|
||||
else:
|
||||
api_key_value = settings.api_key
|
||||
api_key_value = settings["api_key"]
|
||||
|
||||
# Create client
|
||||
client_args: dict[str, Any] = {"api_key": api_key_value}
|
||||
if settings.org_id:
|
||||
client_args["organization"] = settings.org_id
|
||||
if settings.base_url:
|
||||
client_args["base_url"] = settings.base_url
|
||||
if settings["org_id"]:
|
||||
client_args["organization"] = settings["org_id"]
|
||||
if settings["base_url"]:
|
||||
client_args["base_url"] = settings["base_url"]
|
||||
|
||||
self._client = AsyncOpenAI(**client_args)
|
||||
|
||||
|
||||
@@ -27,10 +27,11 @@ from openai.types.beta.threads import (
|
||||
from openai.types.beta.threads.run_create_params import AdditionalMessage
|
||||
from openai.types.beta.threads.run_submit_tool_outputs_params import ToolOutput
|
||||
from openai.types.beta.threads.runs import RunStep
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .._clients import BaseChatClient
|
||||
from .._middleware import ChatMiddlewareLayer
|
||||
from .._settings import load_settings
|
||||
from .._tools import (
|
||||
FunctionInvocationConfiguration,
|
||||
FunctionInvocationLayer,
|
||||
@@ -343,35 +344,34 @@ class OpenAIAssistantsClient( # type: ignore[misc]
|
||||
client: OpenAIAssistantsClient[MyOptions] = OpenAIAssistantsClient(model_id="gpt-4")
|
||||
response = await client.get_response("Hello", options={"my_custom_option": "value"})
|
||||
"""
|
||||
try:
|
||||
openai_settings = OpenAISettings(
|
||||
api_key=api_key, # type: ignore[reportArgumentType]
|
||||
base_url=base_url,
|
||||
org_id=org_id,
|
||||
chat_model_id=model_id,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
except ValidationError as ex:
|
||||
raise ServiceInitializationError("Failed to create OpenAI settings.", ex) from ex
|
||||
openai_settings = load_settings(
|
||||
OpenAISettings,
|
||||
env_prefix="OPENAI_",
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
org_id=org_id,
|
||||
chat_model_id=model_id,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
|
||||
if not async_client and not openai_settings.api_key:
|
||||
if not async_client and not openai_settings["api_key"]:
|
||||
raise ServiceInitializationError(
|
||||
"OpenAI API key is required. Set via 'api_key' parameter or 'OPENAI_API_KEY' environment variable."
|
||||
)
|
||||
if not openai_settings.chat_model_id:
|
||||
if not openai_settings["chat_model_id"]:
|
||||
raise ServiceInitializationError(
|
||||
"OpenAI model ID is required. "
|
||||
"Set via 'model_id' parameter or 'OPENAI_CHAT_MODEL_ID' environment variable."
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
model_id=openai_settings.chat_model_id,
|
||||
api_key=self._get_api_key(openai_settings.api_key),
|
||||
org_id=openai_settings.org_id,
|
||||
model_id=openai_settings["chat_model_id"],
|
||||
api_key=self._get_api_key(openai_settings["api_key"]),
|
||||
org_id=openai_settings["org_id"],
|
||||
default_headers=default_headers,
|
||||
client=async_client,
|
||||
base_url=openai_settings.base_url,
|
||||
base_url=openai_settings["base_url"],
|
||||
middleware=middleware,
|
||||
function_invocation_configuration=function_invocation_configuration,
|
||||
)
|
||||
|
||||
@@ -17,11 +17,12 @@ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
|
||||
from openai.types.chat.chat_completion_message_custom_tool_call import ChatCompletionMessageCustomToolCall
|
||||
from openai.types.chat.completion_create_params import WebSearchOptions
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .._clients import BaseChatClient
|
||||
from .._logging import get_logger
|
||||
from .._middleware import ChatAndFunctionMiddlewareTypes, ChatMiddlewareLayer
|
||||
from .._settings import load_settings
|
||||
from .._tools import (
|
||||
FunctionInvocationConfiguration,
|
||||
FunctionInvocationLayer,
|
||||
@@ -718,33 +719,32 @@ class OpenAIChatClient( # type: ignore[misc]
|
||||
client: OpenAIChatClient[MyOptions] = OpenAIChatClient(model_id="<model name>")
|
||||
response = await client.get_response("Hello", options={"my_custom_option": "value"})
|
||||
"""
|
||||
try:
|
||||
openai_settings = OpenAISettings(
|
||||
api_key=api_key, # type: ignore[reportArgumentType]
|
||||
base_url=base_url,
|
||||
org_id=org_id,
|
||||
chat_model_id=model_id,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
except ValidationError as ex:
|
||||
raise ServiceInitializationError("Failed to create OpenAI settings.", ex) from ex
|
||||
openai_settings = load_settings(
|
||||
OpenAISettings,
|
||||
env_prefix="OPENAI_",
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
org_id=org_id,
|
||||
chat_model_id=model_id,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
|
||||
if not async_client and not openai_settings.api_key:
|
||||
if not async_client and not openai_settings["api_key"]:
|
||||
raise ServiceInitializationError(
|
||||
"OpenAI API key is required. Set via 'api_key' parameter or 'OPENAI_API_KEY' environment variable."
|
||||
)
|
||||
if not openai_settings.chat_model_id:
|
||||
if not openai_settings["chat_model_id"]:
|
||||
raise ServiceInitializationError(
|
||||
"OpenAI model ID is required. "
|
||||
"Set via 'model_id' parameter or 'OPENAI_CHAT_MODEL_ID' environment variable."
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
model_id=openai_settings.chat_model_id,
|
||||
api_key=self._get_api_key(openai_settings.api_key),
|
||||
base_url=openai_settings.base_url if openai_settings.base_url else None,
|
||||
org_id=openai_settings.org_id,
|
||||
model_id=openai_settings["chat_model_id"],
|
||||
api_key=self._get_api_key(openai_settings["api_key"]),
|
||||
base_url=openai_settings["base_url"] if openai_settings["base_url"] else None,
|
||||
org_id=openai_settings["org_id"],
|
||||
default_headers=default_headers,
|
||||
client=async_client,
|
||||
instruction_role=instruction_role,
|
||||
|
||||
@@ -33,11 +33,12 @@ from openai.types.responses.tool_param import (
|
||||
Mcp,
|
||||
)
|
||||
from openai.types.responses.web_search_tool_param import WebSearchToolParam
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .._clients import BaseChatClient
|
||||
from .._logging import get_logger
|
||||
from .._middleware import ChatMiddlewareLayer
|
||||
from .._settings import load_settings
|
||||
from .._tools import (
|
||||
FunctionInvocationConfiguration,
|
||||
FunctionInvocationLayer,
|
||||
@@ -1810,36 +1811,35 @@ class OpenAIResponsesClient( # type: ignore[misc]
|
||||
client: OpenAIResponsesClient[MyOptions] = OpenAIResponsesClient(model_id="gpt-4o")
|
||||
response = await client.get_response("Hello", options={"my_custom_option": "value"})
|
||||
"""
|
||||
try:
|
||||
openai_settings = OpenAISettings(
|
||||
api_key=api_key, # type: ignore[reportArgumentType]
|
||||
org_id=org_id,
|
||||
base_url=base_url,
|
||||
responses_model_id=model_id,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
except ValidationError as ex:
|
||||
raise ServiceInitializationError("Failed to create OpenAI settings.", ex) from ex
|
||||
openai_settings = load_settings(
|
||||
OpenAISettings,
|
||||
env_prefix="OPENAI_",
|
||||
api_key=api_key,
|
||||
org_id=org_id,
|
||||
base_url=base_url,
|
||||
responses_model_id=model_id,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
|
||||
if not async_client and not openai_settings.api_key:
|
||||
if not async_client and not openai_settings["api_key"]:
|
||||
raise ServiceInitializationError(
|
||||
"OpenAI API key is required. Set via 'api_key' parameter or 'OPENAI_API_KEY' environment variable."
|
||||
)
|
||||
if not openai_settings.responses_model_id:
|
||||
if not openai_settings["responses_model_id"]:
|
||||
raise ServiceInitializationError(
|
||||
"OpenAI model ID is required. "
|
||||
"Set via 'model_id' parameter or 'OPENAI_RESPONSES_MODEL_ID' environment variable."
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
model_id=openai_settings.responses_model_id,
|
||||
api_key=self._get_api_key(openai_settings.api_key),
|
||||
org_id=openai_settings.org_id,
|
||||
model_id=openai_settings["responses_model_id"],
|
||||
api_key=self._get_api_key(openai_settings["api_key"]),
|
||||
org_id=openai_settings["org_id"],
|
||||
default_headers=default_headers,
|
||||
client=async_client,
|
||||
instruction_role=instruction_role,
|
||||
base_url=openai_settings.base_url,
|
||||
base_url=openai_settings["base_url"],
|
||||
middleware=middleware,
|
||||
function_invocation_configuration=function_invocation_configuration,
|
||||
**kwargs,
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from collections.abc import Awaitable, Callable, Mapping, MutableMapping, Sequence
|
||||
from copy import copy
|
||||
from typing import Any, ClassVar, Union
|
||||
@@ -20,11 +21,10 @@ from openai.types.images_response import ImagesResponse
|
||||
from openai.types.responses.response import Response
|
||||
from openai.types.responses.response_stream_event import ResponseStreamEvent
|
||||
from packaging.version import parse
|
||||
from pydantic import SecretStr
|
||||
|
||||
from .._logging import get_logger
|
||||
from .._pydantic import AFBaseSettings
|
||||
from .._serialization import SerializationMixin
|
||||
from .._settings import SecretString
|
||||
from .._telemetry import APP_INFO, USER_AGENT_KEY, prepend_agent_framework_to_user_agent
|
||||
from .._tools import FunctionTool
|
||||
from ..exceptions import ServiceInitializationError
|
||||
@@ -47,6 +47,11 @@ RESPONSE_TYPE = Union[
|
||||
|
||||
OPTION_TYPE = dict[str, Any]
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import TypedDict # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypedDict # type: ignore # pragma: no cover
|
||||
|
||||
|
||||
__all__ = ["OpenAISettings"]
|
||||
|
||||
@@ -74,7 +79,7 @@ def _check_openai_version_for_callable_api_key() -> None:
|
||||
logger.warning(f"Could not check OpenAI version for callable API key support: {e}")
|
||||
|
||||
|
||||
class OpenAISettings(AFBaseSettings):
|
||||
class OpenAISettings(TypedDict, total=False):
|
||||
"""OpenAI environment settings.
|
||||
|
||||
The settings are first loaded from environment variables with the prefix 'OPENAI_'.
|
||||
@@ -93,8 +98,6 @@ class OpenAISettings(AFBaseSettings):
|
||||
Can be set via environment variable OPENAI_CHAT_MODEL_ID.
|
||||
responses_model_id: The OpenAI responses model ID to use, for example, gpt-4o or o1.
|
||||
Can be set via environment variable OPENAI_RESPONSES_MODEL_ID.
|
||||
env_file_path: The path to the .env file to load settings from.
|
||||
env_file_encoding: The encoding of the .env file, defaults to 'utf-8'.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
@@ -104,22 +107,20 @@ class OpenAISettings(AFBaseSettings):
|
||||
# Using environment variables
|
||||
# Set OPENAI_API_KEY=sk-...
|
||||
# Set OPENAI_CHAT_MODEL_ID=gpt-4
|
||||
settings = OpenAISettings()
|
||||
settings = load_settings(OpenAISettings, env_prefix="OPENAI_")
|
||||
|
||||
# Or passing parameters directly
|
||||
settings = OpenAISettings(api_key="sk-...", chat_model_id="gpt-4")
|
||||
settings = load_settings(OpenAISettings, env_prefix="OPENAI_", api_key="sk-...", chat_model_id="gpt-4")
|
||||
|
||||
# Or loading from a .env file
|
||||
settings = OpenAISettings(env_file_path="path/to/.env")
|
||||
settings = load_settings(OpenAISettings, env_prefix="OPENAI_", env_file_path="path/to/.env")
|
||||
"""
|
||||
|
||||
env_prefix: ClassVar[str] = "OPENAI_"
|
||||
|
||||
api_key: SecretStr | Callable[[], str | Awaitable[str]] | None = None
|
||||
base_url: str | None = None
|
||||
org_id: str | None = None
|
||||
chat_model_id: str | None = None
|
||||
responses_model_id: str | None = None
|
||||
api_key: SecretString | Callable[[], str | Awaitable[str]] | None
|
||||
base_url: str | None
|
||||
org_id: str | None
|
||||
chat_model_id: str | None
|
||||
responses_model_id: str | None
|
||||
|
||||
|
||||
class OpenAIBase(SerializationMixin):
|
||||
@@ -181,19 +182,18 @@ class OpenAIBase(SerializationMixin):
|
||||
return self.client
|
||||
|
||||
def _get_api_key(
|
||||
self, api_key: str | SecretStr | Callable[[], str | Awaitable[str]] | None
|
||||
self, api_key: str | SecretString | Callable[[], str | Awaitable[str]] | None
|
||||
) -> str | Callable[[], str | Awaitable[str]] | None:
|
||||
"""Get the appropriate API key value for client initialization.
|
||||
|
||||
Args:
|
||||
api_key: The API key parameter which can be a string, SecretStr, callable, or None.
|
||||
api_key: The API key parameter which can be a string, SecretString, callable, or None.
|
||||
|
||||
Returns:
|
||||
For callable API keys: returns the callable directly.
|
||||
For SecretStr API keys: returns the string value.
|
||||
For string/None API keys: returns as-is.
|
||||
For SecretString/string/None API keys: returns as-is (SecretString is a str subclass).
|
||||
"""
|
||||
if isinstance(api_key, SecretStr):
|
||||
if isinstance(api_key, SecretString):
|
||||
return api_key.get_secret_value()
|
||||
|
||||
# Check version compatibility for callable API keys
|
||||
|
||||
@@ -26,7 +26,7 @@ dependencies = [
|
||||
# utilities
|
||||
"typing-extensions",
|
||||
"pydantic>=2,<3",
|
||||
"pydantic-settings>=2,<3",
|
||||
"python-dotenv>=1,<2",
|
||||
# telemetry
|
||||
"opentelemetry-api>=1.39.0",
|
||||
"opentelemetry-sdk>=1.39.0",
|
||||
|
||||
@@ -19,6 +19,7 @@ from agent_framework import (
|
||||
SupportsChatGetResponse,
|
||||
tool,
|
||||
)
|
||||
from agent_framework._settings import SecretString
|
||||
from agent_framework.azure import AzureOpenAIAssistantsClient
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
|
||||
@@ -556,19 +557,21 @@ def test_azure_assistants_client_entra_id_authentication() -> None:
|
||||
mock_credential = MagicMock()
|
||||
|
||||
with (
|
||||
patch("agent_framework.azure._assistants_client.AzureOpenAISettings") as mock_settings_class,
|
||||
patch("agent_framework.azure._assistants_client.load_settings") as mock_load_settings,
|
||||
patch("agent_framework.azure._assistants_client.get_entra_auth_token") as mock_get_token,
|
||||
patch("agent_framework.azure._assistants_client.AsyncAzureOpenAI") as mock_azure_client,
|
||||
patch("agent_framework.openai.OpenAIAssistantsClient.__init__", return_value=None),
|
||||
):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.chat_deployment_name = "test-deployment"
|
||||
mock_settings.api_key = None # No API key to trigger Entra ID path
|
||||
mock_settings.token_endpoint = "https://login.microsoftonline.com/test"
|
||||
mock_settings.get_azure_auth_token.return_value = "entra-token-12345"
|
||||
mock_settings.api_version = "2024-05-01-preview"
|
||||
mock_settings.endpoint = "https://test-endpoint.openai.azure.com"
|
||||
mock_settings.base_url = None
|
||||
mock_settings_class.return_value = mock_settings
|
||||
mock_load_settings.return_value = {
|
||||
"chat_deployment_name": "test-deployment",
|
||||
"responses_deployment_name": None,
|
||||
"api_key": None,
|
||||
"token_endpoint": "https://login.microsoftonline.com/test",
|
||||
"api_version": "2024-05-01-preview",
|
||||
"endpoint": "https://test-endpoint.openai.azure.com",
|
||||
"base_url": None,
|
||||
}
|
||||
mock_get_token.return_value = "entra-token-12345"
|
||||
|
||||
client = AzureOpenAIAssistantsClient(
|
||||
deployment_name="test-deployment",
|
||||
@@ -579,7 +582,7 @@ def test_azure_assistants_client_entra_id_authentication() -> None:
|
||||
)
|
||||
|
||||
# Verify Entra ID token was requested
|
||||
mock_settings.get_azure_auth_token.assert_called_once_with(mock_credential)
|
||||
mock_get_token.assert_called_once_with(mock_credential, "https://login.microsoftonline.com/test")
|
||||
|
||||
# Verify client was created with the token
|
||||
mock_azure_client.assert_called_once()
|
||||
@@ -592,12 +595,16 @@ def test_azure_assistants_client_entra_id_authentication() -> None:
|
||||
|
||||
def test_azure_assistants_client_no_authentication_error() -> None:
|
||||
"""Test authentication validation error when no auth provided."""
|
||||
with patch("agent_framework.azure._assistants_client.AzureOpenAISettings") as mock_settings_class:
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.chat_deployment_name = "test-deployment"
|
||||
mock_settings.api_key = None # No API key
|
||||
mock_settings.token_endpoint = None # No token endpoint
|
||||
mock_settings_class.return_value = mock_settings
|
||||
with patch("agent_framework.azure._assistants_client.load_settings") as mock_load_settings:
|
||||
mock_load_settings.return_value = {
|
||||
"chat_deployment_name": "test-deployment",
|
||||
"responses_deployment_name": None,
|
||||
"api_key": None,
|
||||
"token_endpoint": None,
|
||||
"api_version": "2024-05-01-preview",
|
||||
"endpoint": "https://test-endpoint.openai.azure.com",
|
||||
"base_url": None,
|
||||
}
|
||||
|
||||
# Test missing authentication raises error
|
||||
with pytest.raises(ServiceInitializationError, match="API key, ad_token, or ad_token_provider is required"):
|
||||
@@ -611,17 +618,19 @@ def test_azure_assistants_client_no_authentication_error() -> None:
|
||||
def test_azure_assistants_client_ad_token_authentication() -> None:
|
||||
"""Test ad_token authentication client parameter path."""
|
||||
with (
|
||||
patch("agent_framework.azure._assistants_client.AzureOpenAISettings") as mock_settings_class,
|
||||
patch("agent_framework.azure._assistants_client.load_settings") as mock_load_settings,
|
||||
patch("agent_framework.azure._assistants_client.AsyncAzureOpenAI") as mock_azure_client,
|
||||
patch("agent_framework.openai.OpenAIAssistantsClient.__init__", return_value=None),
|
||||
):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.chat_deployment_name = "test-deployment"
|
||||
mock_settings.api_key = None # No API key
|
||||
mock_settings.api_version = "2024-05-01-preview"
|
||||
mock_settings.endpoint = "https://test-endpoint.openai.azure.com"
|
||||
mock_settings.base_url = None
|
||||
mock_settings_class.return_value = mock_settings
|
||||
mock_load_settings.return_value = {
|
||||
"chat_deployment_name": "test-deployment",
|
||||
"responses_deployment_name": None,
|
||||
"api_key": None,
|
||||
"token_endpoint": None,
|
||||
"api_version": "2024-05-01-preview",
|
||||
"endpoint": "https://test-endpoint.openai.azure.com",
|
||||
"base_url": None,
|
||||
}
|
||||
|
||||
client = AzureOpenAIAssistantsClient(
|
||||
deployment_name="test-deployment",
|
||||
@@ -645,17 +654,19 @@ def test_azure_assistants_client_ad_token_provider_authentication() -> None:
|
||||
mock_token_provider = MagicMock(spec=AsyncAzureADTokenProvider)
|
||||
|
||||
with (
|
||||
patch("agent_framework.azure._assistants_client.AzureOpenAISettings") as mock_settings_class,
|
||||
patch("agent_framework.azure._assistants_client.load_settings") as mock_load_settings,
|
||||
patch("agent_framework.azure._assistants_client.AsyncAzureOpenAI") as mock_azure_client,
|
||||
patch("agent_framework.openai.OpenAIAssistantsClient.__init__", return_value=None),
|
||||
):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.chat_deployment_name = "test-deployment"
|
||||
mock_settings.api_key = None # No API key
|
||||
mock_settings.api_version = "2024-05-01-preview"
|
||||
mock_settings.endpoint = "https://test-endpoint.openai.azure.com"
|
||||
mock_settings.base_url = None
|
||||
mock_settings_class.return_value = mock_settings
|
||||
mock_load_settings.return_value = {
|
||||
"chat_deployment_name": "test-deployment",
|
||||
"responses_deployment_name": None,
|
||||
"api_key": None,
|
||||
"token_endpoint": None,
|
||||
"api_version": "2024-05-01-preview",
|
||||
"endpoint": "https://test-endpoint.openai.azure.com",
|
||||
"base_url": None,
|
||||
}
|
||||
|
||||
client = AzureOpenAIAssistantsClient(
|
||||
deployment_name="test-deployment",
|
||||
@@ -675,17 +686,19 @@ def test_azure_assistants_client_ad_token_provider_authentication() -> None:
|
||||
def test_azure_assistants_client_base_url_configuration() -> None:
|
||||
"""Test base_url client parameter path."""
|
||||
with (
|
||||
patch("agent_framework.azure._assistants_client.AzureOpenAISettings") as mock_settings_class,
|
||||
patch("agent_framework.azure._assistants_client.load_settings") as mock_load_settings,
|
||||
patch("agent_framework.azure._assistants_client.AsyncAzureOpenAI") as mock_azure_client,
|
||||
patch("agent_framework.openai.OpenAIAssistantsClient.__init__", return_value=None),
|
||||
):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.chat_deployment_name = "test-deployment"
|
||||
mock_settings.api_key.get_secret_value.return_value = "test-api-key"
|
||||
mock_settings.base_url = "https://custom-base-url.com"
|
||||
mock_settings.endpoint = None # No endpoint, should use base_url
|
||||
mock_settings.api_version = "2024-05-01-preview"
|
||||
mock_settings_class.return_value = mock_settings
|
||||
mock_load_settings.return_value = {
|
||||
"chat_deployment_name": "test-deployment",
|
||||
"responses_deployment_name": None,
|
||||
"api_key": SecretString("test-api-key"),
|
||||
"token_endpoint": None,
|
||||
"api_version": "2024-05-01-preview",
|
||||
"endpoint": None,
|
||||
"base_url": "https://custom-base-url.com",
|
||||
}
|
||||
|
||||
client = AzureOpenAIAssistantsClient(
|
||||
deployment_name="test-deployment", api_key="test-api-key", base_url="https://custom-base-url.com"
|
||||
@@ -704,17 +717,19 @@ def test_azure_assistants_client_base_url_configuration() -> None:
|
||||
def test_azure_assistants_client_azure_endpoint_configuration() -> None:
|
||||
"""Test azure_endpoint client parameter path."""
|
||||
with (
|
||||
patch("agent_framework.azure._assistants_client.AzureOpenAISettings") as mock_settings_class,
|
||||
patch("agent_framework.azure._assistants_client.load_settings") as mock_load_settings,
|
||||
patch("agent_framework.azure._assistants_client.AsyncAzureOpenAI") as mock_azure_client,
|
||||
patch("agent_framework.openai.OpenAIAssistantsClient.__init__", return_value=None),
|
||||
):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.chat_deployment_name = "test-deployment"
|
||||
mock_settings.api_key.get_secret_value.return_value = "test-api-key"
|
||||
mock_settings.base_url = None # No base_url
|
||||
mock_settings.endpoint = "https://test-endpoint.openai.azure.com"
|
||||
mock_settings.api_version = "2024-05-01-preview"
|
||||
mock_settings_class.return_value = mock_settings
|
||||
mock_load_settings.return_value = {
|
||||
"chat_deployment_name": "test-deployment",
|
||||
"responses_deployment_name": None,
|
||||
"api_key": SecretString("test-api-key"),
|
||||
"token_endpoint": None,
|
||||
"api_version": "2024-05-01-preview",
|
||||
"endpoint": "https://test-endpoint.openai.azure.com",
|
||||
"base_url": None,
|
||||
}
|
||||
|
||||
client = AzureOpenAIAssistantsClient(
|
||||
deployment_name="test-deployment",
|
||||
|
||||
@@ -109,8 +109,11 @@ def test_init_with_empty_endpoint_and_base_url(azure_openai_unit_test_env: dict[
|
||||
|
||||
@pytest.mark.parametrize("override_env_param_dict", [{"AZURE_OPENAI_ENDPOINT": "http://test.com"}], indirect=True)
|
||||
def test_init_with_invalid_endpoint(azure_openai_unit_test_env: dict[str, str]) -> None:
|
||||
with pytest.raises(ServiceInitializationError):
|
||||
AzureOpenAIChatClient()
|
||||
# Note: URL scheme validation was previously handled by pydantic's HTTPsUrl type.
|
||||
# After migrating to load_settings with TypedDict, endpoint is a plain string and no longer
|
||||
# validated at the settings level. The Azure OpenAI SDK may reject invalid URLs at runtime.
|
||||
client = AzureOpenAIChatClient()
|
||||
assert client is not None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("exclude_list", [["AZURE_OPENAI_BASE_URL"]], indirect=True)
|
||||
|
||||
@@ -0,0 +1,238 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Tests for load_settings() function."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from typing import TypedDict
|
||||
|
||||
import pytest
|
||||
|
||||
from agent_framework._settings import SecretString, load_settings
|
||||
|
||||
|
||||
class SimpleSettings(TypedDict, total=False):
|
||||
api_key: str | None
|
||||
timeout: int | None
|
||||
enabled: bool | None
|
||||
rate_limit: float | None
|
||||
|
||||
|
||||
class RequiredFieldSettings(TypedDict, total=False):
|
||||
name: str | None
|
||||
optional_field: str | None
|
||||
|
||||
|
||||
class SecretSettings(TypedDict, total=False):
|
||||
api_key: SecretString | None
|
||||
username: str | None
|
||||
|
||||
|
||||
class TestLoadSettingsBasic:
|
||||
"""Test basic load_settings functionality."""
|
||||
|
||||
def test_fields_are_none_when_unset(self) -> None:
|
||||
settings = load_settings(SimpleSettings, env_prefix="TEST_APP_")
|
||||
|
||||
assert settings["api_key"] is None
|
||||
assert settings["timeout"] is None
|
||||
assert settings["enabled"] is None
|
||||
assert settings["rate_limit"] is None
|
||||
|
||||
def test_overrides(self) -> None:
|
||||
settings = load_settings(SimpleSettings, env_prefix="TEST_APP_", timeout=60, enabled=False)
|
||||
|
||||
assert settings["timeout"] == 60
|
||||
assert settings["enabled"] is False
|
||||
|
||||
def test_none_overrides_are_filtered(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("TEST_APP_TIMEOUT", "120")
|
||||
|
||||
settings = load_settings(SimpleSettings, env_prefix="TEST_APP_", timeout=None)
|
||||
|
||||
# timeout=None is filtered, so env var wins
|
||||
assert settings["timeout"] == 120
|
||||
|
||||
def test_env_vars(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("TEST_APP_API_KEY", "test-key-123")
|
||||
monkeypatch.setenv("TEST_APP_TIMEOUT", "120")
|
||||
monkeypatch.setenv("TEST_APP_ENABLED", "false")
|
||||
|
||||
settings = load_settings(SimpleSettings, env_prefix="TEST_APP_")
|
||||
|
||||
assert settings["api_key"] == "test-key-123"
|
||||
assert settings["timeout"] == 120
|
||||
assert settings["enabled"] is False
|
||||
|
||||
def test_overrides_beat_env_vars(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("TEST_APP_TIMEOUT", "120")
|
||||
|
||||
settings = load_settings(SimpleSettings, env_prefix="TEST_APP_", timeout=60)
|
||||
|
||||
assert settings["timeout"] == 60
|
||||
|
||||
def test_no_prefix(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("API_KEY", "no-prefix-key")
|
||||
|
||||
settings = load_settings(SimpleSettings, api_key=None)
|
||||
|
||||
assert settings["api_key"] == "no-prefix-key"
|
||||
|
||||
|
||||
class TestDotenvFile:
|
||||
"""Test .env file loading."""
|
||||
|
||||
def test_load_from_dotenv(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.delenv("TEST_APP_API_KEY", raising=False)
|
||||
monkeypatch.delenv("TEST_APP_TIMEOUT", raising=False)
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".env", delete=False) as f:
|
||||
f.write("TEST_APP_API_KEY=dotenv-key\n")
|
||||
f.write("TEST_APP_TIMEOUT=90\n")
|
||||
f.flush()
|
||||
env_path = f.name
|
||||
|
||||
try:
|
||||
settings = load_settings(SimpleSettings, env_prefix="TEST_APP_", env_file_path=env_path)
|
||||
|
||||
assert settings["api_key"] == "dotenv-key"
|
||||
assert settings["timeout"] == 90
|
||||
finally:
|
||||
os.unlink(env_path)
|
||||
|
||||
def test_env_vars_override_dotenv(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("TEST_APP_API_KEY", "real-env-key")
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".env", delete=False) as f:
|
||||
f.write("TEST_APP_API_KEY=dotenv-key\n")
|
||||
f.flush()
|
||||
env_path = f.name
|
||||
|
||||
try:
|
||||
settings = load_settings(SimpleSettings, env_prefix="TEST_APP_", env_file_path=env_path)
|
||||
|
||||
assert settings["api_key"] == "real-env-key"
|
||||
finally:
|
||||
os.unlink(env_path)
|
||||
|
||||
def test_missing_dotenv_file(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.delenv("TEST_APP_API_KEY", raising=False)
|
||||
settings = load_settings(SimpleSettings, env_prefix="TEST_APP_", env_file_path="/nonexistent/.env")
|
||||
|
||||
assert settings["api_key"] is None
|
||||
|
||||
|
||||
class TestSecretString:
|
||||
"""Test SecretString type handling."""
|
||||
|
||||
def test_secretstring_from_env(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("SECRET_API_KEY", "secret-value")
|
||||
|
||||
settings = load_settings(SecretSettings, env_prefix="SECRET_")
|
||||
|
||||
assert isinstance(settings["api_key"], SecretString)
|
||||
assert settings["api_key"] == "secret-value"
|
||||
|
||||
def test_secretstring_from_override(self) -> None:
|
||||
settings = load_settings(SecretSettings, env_prefix="SECRET_", api_key="kwarg-secret")
|
||||
|
||||
assert isinstance(settings["api_key"], SecretString)
|
||||
assert settings["api_key"] == "kwarg-secret"
|
||||
|
||||
def test_secretstring_masked_in_repr(self) -> None:
|
||||
s = SecretString("my-secret")
|
||||
assert "my-secret" not in repr(s)
|
||||
assert "**********" in repr(s)
|
||||
|
||||
def test_get_secret_value_compat(self) -> None:
|
||||
s = SecretString("my-secret")
|
||||
|
||||
assert s.get_secret_value() == "my-secret"
|
||||
assert isinstance(s.get_secret_value(), str)
|
||||
|
||||
|
||||
class TestTypeCoercion:
|
||||
"""Test type coercion from string values."""
|
||||
|
||||
def test_int_coercion(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("TEST_APP_TIMEOUT", "42")
|
||||
|
||||
settings = load_settings(SimpleSettings, env_prefix="TEST_APP_")
|
||||
|
||||
assert settings["timeout"] == 42
|
||||
assert isinstance(settings["timeout"], int)
|
||||
|
||||
def test_float_coercion(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("TEST_APP_RATE_LIMIT", "2.5")
|
||||
|
||||
settings = load_settings(SimpleSettings, env_prefix="TEST_APP_")
|
||||
|
||||
assert settings["rate_limit"] == 2.5
|
||||
assert isinstance(settings["rate_limit"], float)
|
||||
|
||||
def test_bool_coercion_true_values(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
for true_val in ["true", "True", "TRUE", "1", "yes", "on"]:
|
||||
monkeypatch.setenv("TEST_APP_ENABLED", true_val)
|
||||
settings = load_settings(SimpleSettings, env_prefix="TEST_APP_")
|
||||
assert settings["enabled"] is True, f"Failed for {true_val}"
|
||||
|
||||
def test_bool_coercion_false_values(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
for false_val in ["false", "False", "FALSE", "0", "no", "off"]:
|
||||
monkeypatch.setenv("TEST_APP_ENABLED", false_val)
|
||||
settings = load_settings(SimpleSettings, env_prefix="TEST_APP_")
|
||||
assert settings["enabled"] is False, f"Failed for {false_val}"
|
||||
|
||||
|
||||
class TestRequiredFields:
|
||||
"""Test required field validation."""
|
||||
|
||||
def test_required_field_provided(self) -> None:
|
||||
settings = load_settings(
|
||||
RequiredFieldSettings,
|
||||
env_prefix="TEST_",
|
||||
required_fields=["name"],
|
||||
name="my-app",
|
||||
)
|
||||
|
||||
assert settings["name"] == "my-app"
|
||||
assert settings["optional_field"] is None
|
||||
|
||||
def test_required_field_from_env(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("TEST_NAME", "env-app")
|
||||
|
||||
settings = load_settings(RequiredFieldSettings, env_prefix="TEST_", required_fields=["name"])
|
||||
|
||||
assert settings["name"] == "env-app"
|
||||
|
||||
def test_required_field_missing_raises(self) -> None:
|
||||
from agent_framework.exceptions import SettingNotFoundError
|
||||
|
||||
with pytest.raises(SettingNotFoundError, match="Required setting 'name'"):
|
||||
load_settings(RequiredFieldSettings, env_prefix="TEST_", required_fields=["name"])
|
||||
|
||||
def test_without_required_fields_param_allows_none(self) -> None:
|
||||
settings = load_settings(RequiredFieldSettings, env_prefix="TEST_")
|
||||
|
||||
assert settings["name"] is None
|
||||
|
||||
|
||||
class TestOverrideTypeValidation:
|
||||
"""Test override type validation."""
|
||||
|
||||
def test_invalid_type_raises(self) -> None:
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
|
||||
with pytest.raises(ServiceInitializationError, match="Invalid type for setting 'api_key'"):
|
||||
load_settings(SimpleSettings, env_prefix="TEST_", api_key={"bad": "type"})
|
||||
|
||||
def test_valid_types_accepted(self) -> None:
|
||||
settings = load_settings(SimpleSettings, env_prefix="TEST_", timeout=42, enabled=True)
|
||||
|
||||
assert settings["timeout"] == 42
|
||||
assert settings["enabled"] is True
|
||||
|
||||
def test_str_accepted_for_secretstring(self) -> None:
|
||||
settings = load_settings(SecretSettings, env_prefix="TEST_", api_key="plain-string")
|
||||
|
||||
assert isinstance(settings["api_key"], SecretString)
|
||||
assert settings["api_key"] == "plain-string"
|
||||
@@ -131,9 +131,15 @@ class TestOpenAIAssistantProviderInit:
|
||||
"""Test initialization fails without API key when settings return None."""
|
||||
from unittest.mock import patch
|
||||
|
||||
# Mock OpenAISettings to return None for api_key
|
||||
with patch("agent_framework.openai._assistant_provider.OpenAISettings") as mock_settings:
|
||||
mock_settings.return_value.api_key = None
|
||||
# Mock load_settings to return a dict with None for api_key
|
||||
with patch("agent_framework.openai._assistant_provider.load_settings") as mock_load:
|
||||
mock_load.return_value = {
|
||||
"api_key": None,
|
||||
"org_id": None,
|
||||
"base_url": None,
|
||||
"chat_model_id": None,
|
||||
"responses_model_id": None,
|
||||
}
|
||||
|
||||
with pytest.raises(ServiceInitializationError) as exc_info:
|
||||
OpenAIAssistantProvider()
|
||||
|
||||
@@ -146,7 +146,7 @@ def test_init_auto_create_client(
|
||||
def test_init_validation_fail() -> None:
|
||||
"""Test OpenAIAssistantsClient initialization with validation failure."""
|
||||
with pytest.raises(ServiceInitializationError):
|
||||
# Force failure by providing invalid model ID type - this should cause validation to fail
|
||||
# Force failure by providing invalid model ID type
|
||||
OpenAIAssistantsClient(model_id=123, api_key="valid-key") # type: ignore
|
||||
|
||||
|
||||
|
||||
+17
-17
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, ClassVar, Generic
|
||||
from typing import Any, Generic
|
||||
|
||||
from agent_framework import (
|
||||
ChatAndFunctionMiddlewareTypes,
|
||||
@@ -13,7 +13,7 @@ from agent_framework import (
|
||||
FunctionInvocationConfiguration,
|
||||
FunctionInvocationLayer,
|
||||
)
|
||||
from agent_framework._pydantic import AFBaseSettings
|
||||
from agent_framework._settings import load_settings
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
from agent_framework.observability import ChatTelemetryLayer
|
||||
from agent_framework.openai._chat_client import RawOpenAIChatClient
|
||||
@@ -115,25 +115,19 @@ FoundryLocalChatOptionsT = TypeVar(
|
||||
# endregion
|
||||
|
||||
|
||||
class FoundryLocalSettings(AFBaseSettings):
|
||||
class FoundryLocalSettings(TypedDict, total=False):
|
||||
"""Foundry local model settings.
|
||||
|
||||
The settings are first loaded from environment variables with the prefix 'FOUNDRY_LOCAL_'.
|
||||
If the environment variables are not found, the settings can be loaded from a .env file
|
||||
with the encoding 'utf-8'. If the settings are not found in the .env file, the settings
|
||||
are ignored; however, validation will fail alerting that the settings are missing.
|
||||
with the encoding 'utf-8'.
|
||||
|
||||
Attributes:
|
||||
Keys:
|
||||
model_id: The name of the model deployment to use.
|
||||
(Env var FOUNDRY_LOCAL_MODEL_ID)
|
||||
Parameters:
|
||||
env_file_path: If provided, the .env settings are read from this file path location.
|
||||
env_file_encoding: The encoding of the .env file, defaults to 'utf-8'.
|
||||
"""
|
||||
|
||||
env_prefix: ClassVar[str] = "FOUNDRY_LOCAL_"
|
||||
|
||||
model_id: str
|
||||
model_id: str | None
|
||||
|
||||
|
||||
class FoundryLocalClient(
|
||||
@@ -247,21 +241,27 @@ class FoundryLocalClient(
|
||||
type that is not supported by the model, it will not be found.
|
||||
|
||||
"""
|
||||
settings = FoundryLocalSettings(
|
||||
model_id=model_id, # type: ignore
|
||||
settings = load_settings(
|
||||
FoundryLocalSettings,
|
||||
env_prefix="FOUNDRY_LOCAL_",
|
||||
required_fields=["model_id"],
|
||||
model_id=model_id,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
manager = FoundryLocalManager(bootstrap=bootstrap, timeout=timeout)
|
||||
model_info = manager.get_model_info(
|
||||
alias_or_model_id=settings.model_id,
|
||||
alias_or_model_id=settings["model_id"],
|
||||
device=device,
|
||||
)
|
||||
if model_info is None:
|
||||
message = (
|
||||
f"Model with ID or alias '{settings.model_id}:{device.value}' not found in Foundry Local."
|
||||
f"Model with ID or alias '{settings['model_id']}:{device.value}' not found in Foundry Local."
|
||||
if device
|
||||
else f"Model with ID or alias '{settings.model_id}' for your current device not found in Foundry Local."
|
||||
else (
|
||||
f"Model with ID or alias '{settings['model_id']}' for your current device "
|
||||
"not found in Foundry Local."
|
||||
)
|
||||
)
|
||||
raise ServiceInitializationError(message)
|
||||
if prepare_model:
|
||||
|
||||
@@ -4,8 +4,8 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from agent_framework import SupportsChatGetResponse
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
from pydantic import ValidationError
|
||||
from agent_framework._settings import load_settings
|
||||
from agent_framework.exceptions import ServiceInitializationError, SettingNotFoundError
|
||||
|
||||
from agent_framework_foundry_local import FoundryLocalClient
|
||||
from agent_framework_foundry_local._foundry_local_client import FoundryLocalSettings
|
||||
@@ -15,31 +15,43 @@ from agent_framework_foundry_local._foundry_local_client import FoundryLocalSett
|
||||
|
||||
def test_foundry_local_settings_init_from_env(foundry_local_unit_test_env: dict[str, str]) -> None:
|
||||
"""Test FoundryLocalSettings initialization from environment variables."""
|
||||
settings = FoundryLocalSettings(env_file_path="test.env")
|
||||
settings = load_settings(FoundryLocalSettings, env_prefix="FOUNDRY_LOCAL_", env_file_path="test.env")
|
||||
|
||||
assert settings.model_id == foundry_local_unit_test_env["FOUNDRY_LOCAL_MODEL_ID"]
|
||||
assert settings["model_id"] == foundry_local_unit_test_env["FOUNDRY_LOCAL_MODEL_ID"]
|
||||
|
||||
|
||||
def test_foundry_local_settings_init_with_explicit_values() -> None:
|
||||
"""Test FoundryLocalSettings initialization with explicit values."""
|
||||
settings = FoundryLocalSettings(model_id="custom-model-id", env_file_path="test.env")
|
||||
settings = load_settings(
|
||||
FoundryLocalSettings,
|
||||
env_prefix="FOUNDRY_LOCAL_",
|
||||
model_id="custom-model-id",
|
||||
env_file_path="test.env",
|
||||
)
|
||||
|
||||
assert settings.model_id == "custom-model-id"
|
||||
assert settings["model_id"] == "custom-model-id"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("exclude_list", [["FOUNDRY_LOCAL_MODEL_ID"]], indirect=True)
|
||||
def test_foundry_local_settings_missing_model_id(foundry_local_unit_test_env: dict[str, str]) -> None:
|
||||
"""Test FoundryLocalSettings when model_id is missing raises ValidationError."""
|
||||
with pytest.raises(ValidationError):
|
||||
FoundryLocalSettings(env_file_path="test.env")
|
||||
"""Test FoundryLocalSettings when model_id is missing raises error."""
|
||||
with pytest.raises(SettingNotFoundError, match="Required setting 'model_id'"):
|
||||
load_settings(
|
||||
FoundryLocalSettings,
|
||||
env_prefix="FOUNDRY_LOCAL_",
|
||||
required_fields=["model_id"],
|
||||
env_file_path="test.env",
|
||||
)
|
||||
|
||||
|
||||
def test_foundry_local_settings_explicit_overrides_env(foundry_local_unit_test_env: dict[str, str]) -> None:
|
||||
"""Test that explicit values override environment variables."""
|
||||
settings = FoundryLocalSettings(model_id="override-model-id", env_file_path="test.env")
|
||||
settings = load_settings(
|
||||
FoundryLocalSettings, env_prefix="FOUNDRY_LOCAL_", model_id="override-model-id", env_file_path="test.env"
|
||||
)
|
||||
|
||||
assert settings.model_id == "override-model-id"
|
||||
assert settings.model_id != foundry_local_unit_test_env["FOUNDRY_LOCAL_MODEL_ID"]
|
||||
assert settings["model_id"] == "override-model-id"
|
||||
assert settings["model_id"] != foundry_local_unit_test_env["FOUNDRY_LOCAL_MODEL_ID"]
|
||||
|
||||
|
||||
# Client Initialization Tests
|
||||
|
||||
@@ -21,9 +21,10 @@ from agent_framework import (
|
||||
ResponseStream,
|
||||
normalize_messages,
|
||||
)
|
||||
from agent_framework._settings import load_settings
|
||||
from agent_framework._tools import FunctionTool
|
||||
from agent_framework._types import normalize_tools
|
||||
from agent_framework.exceptions import ServiceException, ServiceInitializationError
|
||||
from agent_framework.exceptions import ServiceException
|
||||
from copilot import CopilotClient, CopilotSession
|
||||
from copilot.generated.session_events import SessionEvent, SessionEventType
|
||||
from copilot.types import (
|
||||
@@ -38,7 +39,6 @@ from copilot.types import (
|
||||
ToolResult,
|
||||
)
|
||||
from copilot.types import Tool as CopilotTool
|
||||
from pydantic import ValidationError
|
||||
|
||||
from ._settings import GitHubCopilotSettings
|
||||
|
||||
@@ -207,17 +207,16 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
on_permission_request: PermissionHandlerType | None = opts.pop("on_permission_request", None)
|
||||
mcp_servers: dict[str, MCPServerConfig] | None = opts.pop("mcp_servers", None)
|
||||
|
||||
try:
|
||||
self._settings = GitHubCopilotSettings(
|
||||
cli_path=cli_path,
|
||||
model=model,
|
||||
timeout=timeout,
|
||||
log_level=log_level,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
except ValidationError as ex:
|
||||
raise ServiceInitializationError("Failed to create GitHub Copilot settings.", ex) from ex
|
||||
self._settings = load_settings(
|
||||
GitHubCopilotSettings,
|
||||
env_prefix="GITHUB_COPILOT_",
|
||||
cli_path=cli_path,
|
||||
model=model,
|
||||
timeout=timeout,
|
||||
log_level=log_level,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
|
||||
self._tools = normalize_tools(tools)
|
||||
self._permission_handler = on_permission_request
|
||||
@@ -249,10 +248,10 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
|
||||
if self._client is None:
|
||||
client_options: CopilotClientOptions = {}
|
||||
if self._settings.cli_path:
|
||||
client_options["cli_path"] = self._settings.cli_path
|
||||
if self._settings.log_level:
|
||||
client_options["log_level"] = self._settings.log_level # type: ignore[typeddict-item]
|
||||
if self._settings["cli_path"]:
|
||||
client_options["cli_path"] = self._settings["cli_path"]
|
||||
if self._settings["log_level"]:
|
||||
client_options["log_level"] = self._settings["log_level"] # type: ignore[typeddict-item]
|
||||
|
||||
self._client = CopilotClient(client_options if client_options else None)
|
||||
|
||||
@@ -355,7 +354,7 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
thread = self.get_new_thread()
|
||||
|
||||
opts: dict[str, Any] = dict(options) if options else {}
|
||||
timeout = opts.pop("timeout", None) or self._settings.timeout or DEFAULT_TIMEOUT_SECONDS
|
||||
timeout = opts.pop("timeout", None) or self._settings["timeout"] or DEFAULT_TIMEOUT_SECONDS
|
||||
|
||||
session = await self._get_or_create_session(thread, streaming=False, runtime_options=opts)
|
||||
input_messages = normalize_messages(messages)
|
||||
@@ -578,7 +577,7 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
opts = runtime_options or {}
|
||||
config: SessionConfig = {"streaming": streaming}
|
||||
|
||||
model = opts.get("model") or self._settings.model
|
||||
model = opts.get("model") or self._settings["model"]
|
||||
if model:
|
||||
config["model"] = model # type: ignore[typeddict-item]
|
||||
|
||||
|
||||
@@ -1,19 +1,16 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
from agent_framework._pydantic import AFBaseSettings
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
class GitHubCopilotSettings(AFBaseSettings):
|
||||
class GitHubCopilotSettings(TypedDict, total=False):
|
||||
"""GitHub Copilot model settings.
|
||||
|
||||
The settings are first loaded from environment variables with the prefix 'GITHUB_COPILOT_'.
|
||||
If the environment variables are not found, the settings can be loaded from a .env file
|
||||
with the encoding 'utf-8'. If the settings are not found in the .env file, the settings
|
||||
are ignored; however, validation will fail alerting that the settings are missing.
|
||||
with the encoding 'utf-8'.
|
||||
|
||||
Keyword Args:
|
||||
Keys:
|
||||
cli_path: Path to the Copilot CLI executable.
|
||||
Can be set via environment variable GITHUB_COPILOT_CLI_PATH.
|
||||
model: Model to use (e.g., "gpt-5", "claude-sonnet-4").
|
||||
@@ -22,28 +19,9 @@ class GitHubCopilotSettings(AFBaseSettings):
|
||||
Can be set via environment variable GITHUB_COPILOT_TIMEOUT.
|
||||
log_level: CLI log level.
|
||||
Can be set via environment variable GITHUB_COPILOT_LOG_LEVEL.
|
||||
env_file_path: If provided, the .env settings are read from this file path location.
|
||||
env_file_encoding: The encoding of the .env file, defaults to 'utf-8'.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
from agent_framework_github_copilot import GitHubCopilotSettings
|
||||
|
||||
# Using environment variables
|
||||
# Set GITHUB_COPILOT_MODEL=gpt-5
|
||||
settings = GitHubCopilotSettings()
|
||||
|
||||
# Or passing parameters directly
|
||||
settings = GitHubCopilotSettings(model="claude-sonnet-4", timeout=120)
|
||||
|
||||
# Or loading from a .env file
|
||||
settings = GitHubCopilotSettings(env_file_path="path/to/.env")
|
||||
"""
|
||||
|
||||
env_prefix: ClassVar[str] = "GITHUB_COPILOT_"
|
||||
|
||||
cli_path: str | None = None
|
||||
model: str | None = None
|
||||
timeout: float | None = None
|
||||
log_level: str | None = None
|
||||
cli_path: str | None
|
||||
model: str | None
|
||||
timeout: float | None
|
||||
log_level: str | None
|
||||
|
||||
@@ -122,8 +122,8 @@ class TestGitHubCopilotAgentInit:
|
||||
agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent(
|
||||
default_options={"model": "claude-sonnet-4", "timeout": 120}
|
||||
)
|
||||
assert agent._settings.model == "claude-sonnet-4" # type: ignore
|
||||
assert agent._settings.timeout == 120 # type: ignore
|
||||
assert agent._settings["model"] == "claude-sonnet-4" # type: ignore
|
||||
assert agent._settings["timeout"] == 120 # type: ignore
|
||||
|
||||
def test_init_with_tools(self) -> None:
|
||||
"""Test initialization with function tools."""
|
||||
|
||||
@@ -30,9 +30,8 @@ from agent_framework import (
|
||||
UsageDetails,
|
||||
get_logger,
|
||||
)
|
||||
from agent_framework._pydantic import AFBaseSettings
|
||||
from agent_framework._settings import load_settings
|
||||
from agent_framework.exceptions import (
|
||||
ServiceInitializationError,
|
||||
ServiceInvalidRequestError,
|
||||
ServiceResponseException,
|
||||
)
|
||||
@@ -42,7 +41,7 @@ from ollama import AsyncClient
|
||||
# Rename imported types to avoid naming conflicts with Agent Framework types
|
||||
from ollama._types import ChatResponse as OllamaChatResponse
|
||||
from ollama._types import Message as OllamaMessage
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pydantic import BaseModel
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
from typing import TypeVar # type: ignore # pragma: no cover
|
||||
@@ -275,13 +274,11 @@ OllamaChatOptionsT = TypeVar("OllamaChatOptionsT", bound=TypedDict, default="Oll
|
||||
# endregion
|
||||
|
||||
|
||||
class OllamaSettings(AFBaseSettings):
|
||||
class OllamaSettings(TypedDict, total=False):
|
||||
"""Ollama settings."""
|
||||
|
||||
env_prefix: ClassVar[str] = "OLLAMA_"
|
||||
|
||||
host: str | None = None
|
||||
model_id: str | None = None
|
||||
host: str | None
|
||||
model_id: str | None
|
||||
|
||||
|
||||
logger = get_logger("agent_framework.ollama")
|
||||
@@ -322,23 +319,19 @@ class OllamaChatClient(
|
||||
env_file_encoding: The encoding to use when reading the dotenv (.env) file. Defaults to 'utf-8'.
|
||||
**kwargs: Additional keyword arguments passed to BaseChatClient.
|
||||
"""
|
||||
try:
|
||||
ollama_settings = OllamaSettings(
|
||||
host=host,
|
||||
model_id=model_id,
|
||||
env_file_encoding=env_file_encoding,
|
||||
env_file_path=env_file_path,
|
||||
)
|
||||
except ValidationError as ex:
|
||||
raise ServiceInitializationError("Failed to create Ollama settings.", ex) from ex
|
||||
ollama_settings = load_settings(
|
||||
OllamaSettings,
|
||||
env_prefix="OLLAMA_",
|
||||
required_fields=["model_id"],
|
||||
host=host,
|
||||
model_id=model_id,
|
||||
env_file_encoding=env_file_encoding,
|
||||
env_file_path=env_file_path,
|
||||
)
|
||||
|
||||
if ollama_settings.model_id is None:
|
||||
raise ServiceInitializationError(
|
||||
"Ollama chat model ID must be provided via model_id or OLLAMA_MODEL_ID environment variable."
|
||||
)
|
||||
|
||||
self.model_id = ollama_settings.model_id
|
||||
self.client = client or AsyncClient(host=ollama_settings.host)
|
||||
self.model_id = ollama_settings["model_id"]
|
||||
# we can just pass in None for the host, the default is set by the Ollama package.
|
||||
self.client = client or AsyncClient(host=ollama_settings.get("host"))
|
||||
# Save Host URL for serialization with to_dict()
|
||||
self.host = str(self.client._client.base_url) # pyright: ignore[reportUnknownMemberType,reportPrivateUsage,reportUnknownArgumentType]
|
||||
|
||||
|
||||
@@ -15,9 +15,9 @@ from agent_framework import (
|
||||
tool,
|
||||
)
|
||||
from agent_framework.exceptions import (
|
||||
ServiceInitializationError,
|
||||
ServiceInvalidRequestError,
|
||||
ServiceResponseException,
|
||||
SettingNotFoundError,
|
||||
)
|
||||
from ollama import AsyncClient
|
||||
from ollama._types import ChatResponse as OllamaChatResponse
|
||||
@@ -182,7 +182,7 @@ def test_init_client(ollama_unit_test_env: dict[str, str]) -> None:
|
||||
|
||||
@pytest.mark.parametrize("exclude_list", [["OLLAMA_MODEL_ID"]], indirect=True)
|
||||
def test_with_invalid_settings(ollama_unit_test_env: dict[str, str]) -> None:
|
||||
with pytest.raises(ServiceInitializationError):
|
||||
with pytest.raises(SettingNotFoundError, match="Required setting 'model_id'"):
|
||||
OllamaChatClient(
|
||||
host="http://localhost:12345",
|
||||
model_id=None,
|
||||
|
||||
@@ -9,7 +9,7 @@ from ._exceptions import (
|
||||
PurviewServiceError,
|
||||
)
|
||||
from ._middleware import PurviewChatPolicyMiddleware, PurviewPolicyMiddleware
|
||||
from ._settings import PurviewAppLocation, PurviewLocationType, PurviewSettings
|
||||
from ._settings import PurviewAppLocation, PurviewLocationType, PurviewSettings, get_purview_scopes
|
||||
|
||||
__all__ = [
|
||||
"CacheProvider",
|
||||
@@ -23,4 +23,5 @@ __all__ = [
|
||||
"PurviewRequestError",
|
||||
"PurviewServiceError",
|
||||
"PurviewSettings",
|
||||
"get_purview_scopes",
|
||||
]
|
||||
|
||||
@@ -31,7 +31,7 @@ from ._models import (
|
||||
ProtectionScopesRequest,
|
||||
ProtectionScopesResponse,
|
||||
)
|
||||
from ._settings import PurviewSettings
|
||||
from ._settings import PurviewSettings, get_purview_scopes
|
||||
|
||||
logger = get_logger("agent_framework.purview")
|
||||
|
||||
@@ -52,7 +52,7 @@ class PurviewClient:
|
||||
):
|
||||
self._credential: TokenCredential | AsyncTokenCredential = credential
|
||||
self._settings = settings
|
||||
self._graph_uri = settings.graph_base_uri.rstrip("/")
|
||||
self._graph_uri = (settings.get("graph_base_uri") or "https://graph.microsoft.com/v1.0/").rstrip("/")
|
||||
self._timeout = timeout
|
||||
self._client = httpx.AsyncClient(timeout=timeout)
|
||||
|
||||
@@ -61,7 +61,7 @@ class PurviewClient:
|
||||
|
||||
async def _get_token(self, *, tenant_id: str | None = None) -> str:
|
||||
"""Acquire an access token using either async or sync credential."""
|
||||
scopes = self._settings.get_scopes()
|
||||
scopes = get_purview_scopes(self._settings)
|
||||
cred = self._credential
|
||||
token = cred.get_token(*scopes, tenant_id=tenant_id)
|
||||
token = await token if inspect.isawaitable(token) else token
|
||||
@@ -167,7 +167,7 @@ class PurviewClient:
|
||||
if resp.status_code in (401, 403):
|
||||
raise PurviewAuthenticationError(f"Auth failure {resp.status_code}: {resp.text}")
|
||||
if resp.status_code == 402:
|
||||
if self._settings.ignore_payment_required:
|
||||
if self._settings.get("ignore_payment_required", False):
|
||||
return response_type() # type: ignore[call-arg, no-any-return]
|
||||
raise PurviewPaymentRequiredError(f"Payment required {resp.status_code}: {resp.text}")
|
||||
if resp.status_code == 429:
|
||||
|
||||
@@ -78,18 +78,22 @@ class PurviewPolicyMiddleware(AgentMiddleware):
|
||||
from agent_framework import AgentResponse, Message
|
||||
|
||||
context.result = AgentResponse(
|
||||
messages=[Message(role="system", text=self._settings.blocked_prompt_message)]
|
||||
messages=[
|
||||
Message(
|
||||
role="system", text=self._settings.get("blocked_prompt_message", "Prompt blocked by policy")
|
||||
)
|
||||
]
|
||||
)
|
||||
raise MiddlewareTermination
|
||||
except MiddlewareTermination:
|
||||
raise
|
||||
except PurviewPaymentRequiredError as ex:
|
||||
logger.error(f"Purview payment required error in policy pre-check: {ex}")
|
||||
if not self._settings.ignore_payment_required:
|
||||
if not self._settings.get("ignore_payment_required", False):
|
||||
raise
|
||||
except Exception as ex:
|
||||
logger.error(f"Error in Purview policy pre-check: {ex}")
|
||||
if not self._settings.ignore_exceptions:
|
||||
if not self._settings.get("ignore_exceptions", False):
|
||||
raise
|
||||
|
||||
await call_next()
|
||||
@@ -111,18 +115,23 @@ class PurviewPolicyMiddleware(AgentMiddleware):
|
||||
from agent_framework import AgentResponse, Message
|
||||
|
||||
context.result = AgentResponse(
|
||||
messages=[Message(role="system", text=self._settings.blocked_response_message)]
|
||||
messages=[
|
||||
Message(
|
||||
role="system",
|
||||
text=self._settings.get("blocked_response_message", "Response blocked by policy"),
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
# Streaming responses are not supported for post-checks
|
||||
logger.debug("Streaming responses are not supported for Purview policy post-checks")
|
||||
except PurviewPaymentRequiredError as ex:
|
||||
logger.error(f"Purview payment required error in policy post-check: {ex}")
|
||||
if not self._settings.ignore_payment_required:
|
||||
if not self._settings.get("ignore_payment_required", False):
|
||||
raise
|
||||
except Exception as ex:
|
||||
logger.error(f"Error in Purview policy post-check: {ex}")
|
||||
if not self._settings.ignore_exceptions:
|
||||
if not self._settings.get("ignore_exceptions", False):
|
||||
raise
|
||||
|
||||
|
||||
@@ -173,18 +182,20 @@ class PurviewChatPolicyMiddleware(ChatMiddleware):
|
||||
if should_block_prompt:
|
||||
from agent_framework import ChatResponse, Message
|
||||
|
||||
blocked_message = Message(role="system", text=self._settings.blocked_prompt_message)
|
||||
blocked_message = Message(
|
||||
role="system", text=self._settings.get("blocked_prompt_message", "Prompt blocked by policy")
|
||||
)
|
||||
context.result = ChatResponse(messages=[blocked_message])
|
||||
raise MiddlewareTermination
|
||||
except MiddlewareTermination:
|
||||
raise
|
||||
except PurviewPaymentRequiredError as ex:
|
||||
logger.error(f"Purview payment required error in policy pre-check: {ex}")
|
||||
if not self._settings.ignore_payment_required:
|
||||
if not self._settings.get("ignore_payment_required", False):
|
||||
raise
|
||||
except Exception as ex:
|
||||
logger.error(f"Error in Purview policy pre-check: {ex}")
|
||||
if not self._settings.ignore_exceptions:
|
||||
if not self._settings.get("ignore_exceptions", False):
|
||||
raise
|
||||
|
||||
await call_next()
|
||||
@@ -205,15 +216,18 @@ class PurviewChatPolicyMiddleware(ChatMiddleware):
|
||||
if should_block_response:
|
||||
from agent_framework import ChatResponse, Message
|
||||
|
||||
blocked_message = Message(role="system", text=self._settings.blocked_response_message)
|
||||
blocked_message = Message(
|
||||
role="system",
|
||||
text=self._settings.get("blocked_response_message", "Response blocked by policy"),
|
||||
)
|
||||
context.result = ChatResponse(messages=[blocked_message])
|
||||
else:
|
||||
logger.debug("Streaming responses are not supported for Purview policy post-checks")
|
||||
except PurviewPaymentRequiredError as ex:
|
||||
logger.error(f"Purview payment required error in policy post-check: {ex}")
|
||||
if not self._settings.ignore_payment_required:
|
||||
if not self._settings.get("ignore_payment_required", False):
|
||||
raise
|
||||
except Exception as ex:
|
||||
logger.error(f"Error in Purview policy post-check: {ex}")
|
||||
if not self._settings.ignore_exceptions:
|
||||
if not self._settings.get("ignore_exceptions", False):
|
||||
raise
|
||||
|
||||
@@ -57,8 +57,11 @@ class ScopedContentProcessor:
|
||||
def __init__(self, client: PurviewClient, settings: PurviewSettings, cache_provider: CacheProvider | None = None):
|
||||
self._client = client
|
||||
self._settings = settings
|
||||
cache_ttl = settings.get("cache_ttl_seconds")
|
||||
max_cache = settings.get("max_cache_size_bytes")
|
||||
self._cache: CacheProvider = cache_provider or InMemoryCacheProvider(
|
||||
default_ttl_seconds=settings.cache_ttl_seconds, max_size_bytes=settings.max_cache_size_bytes
|
||||
default_ttl_seconds=cache_ttl if cache_ttl is not None else 14400,
|
||||
max_size_bytes=max_cache if max_cache is not None else 200 * 1024 * 1024,
|
||||
)
|
||||
self._background_tasks: set[asyncio.Task[Any]] = set()
|
||||
|
||||
@@ -116,10 +119,10 @@ class ScopedContentProcessor:
|
||||
results: list[ProcessContentRequest] = []
|
||||
token_info = None
|
||||
|
||||
if not (self._settings.tenant_id and self._settings.purview_app_location):
|
||||
token_info = await self._client.get_user_info_from_token(tenant_id=self._settings.tenant_id)
|
||||
if not (self._settings.get("tenant_id") and self._settings.get("purview_app_location")):
|
||||
token_info = await self._client.get_user_info_from_token(tenant_id=self._settings.get("tenant_id"))
|
||||
|
||||
tenant_id = (token_info or {}).get("tenant_id") or self._settings.tenant_id
|
||||
tenant_id = (token_info or {}).get("tenant_id") or self._settings.get("tenant_id")
|
||||
if not tenant_id or not _is_valid_guid(tenant_id):
|
||||
raise ValueError("Tenant id required or must be inferable from credential")
|
||||
|
||||
@@ -159,10 +162,11 @@ class ScopedContentProcessor:
|
||||
)
|
||||
activity_meta = ActivityMetadata(activity=activity)
|
||||
|
||||
if self._settings.purview_app_location:
|
||||
purview_app_location = self._settings.get("purview_app_location")
|
||||
if purview_app_location:
|
||||
policy_location = PolicyLocation(
|
||||
data_type=self._settings.purview_app_location.get_policy_location()["@odata.type"],
|
||||
value=self._settings.purview_app_location.location_value,
|
||||
data_type=purview_app_location.get_policy_location()["@odata.type"],
|
||||
value=purview_app_location.location_value,
|
||||
)
|
||||
elif token_info and token_info.get("client_id"):
|
||||
policy_location = PolicyLocation(
|
||||
@@ -172,13 +176,14 @@ class ScopedContentProcessor:
|
||||
else:
|
||||
raise ValueError("App location not provided or inferable")
|
||||
|
||||
app_version = self._settings.app_version or "Unknown"
|
||||
protected_app = ProtectedAppMetadata(
|
||||
name=self._settings.app_name,
|
||||
version=app_version,
|
||||
name=self._settings["app_name"],
|
||||
version=self._settings.get("app_version", "Unknown"),
|
||||
application_location=policy_location,
|
||||
)
|
||||
integrated_app = IntegratedAppMetadata(name=self._settings.app_name, version=app_version)
|
||||
integrated_app = IntegratedAppMetadata(
|
||||
name=self._settings["app_name"], version=self._settings.get("app_version", "Unknown")
|
||||
)
|
||||
device_meta = DeviceMetadata(
|
||||
operating_system_specifications=OperatingSystemSpecifications(
|
||||
operating_system_platform="Unknown", operating_system_version="Unknown"
|
||||
@@ -229,11 +234,13 @@ class ScopedContentProcessor:
|
||||
ps_resp = cached_ps_resp
|
||||
else:
|
||||
try:
|
||||
ttl = self._settings.get("cache_ttl_seconds")
|
||||
ttl_seconds = ttl if ttl is not None else 14400
|
||||
ps_resp = await self._client.get_protection_scopes(ps_req)
|
||||
await self._cache.set(cache_key, ps_resp, ttl_seconds=self._settings.cache_ttl_seconds)
|
||||
await self._cache.set(cache_key, ps_resp, ttl_seconds=ttl_seconds)
|
||||
except PurviewPaymentRequiredError as ex:
|
||||
# Cache the exception at tenant level so all subsequent requests for this tenant fail fast
|
||||
await self._cache.set(tenant_payment_cache_key, ex, ttl_seconds=self._settings.cache_ttl_seconds)
|
||||
await self._cache.set(tenant_payment_cache_key, ex, ttl_seconds=ttl_seconds)
|
||||
raise
|
||||
|
||||
if ps_resp.scope_identifier:
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import sys
|
||||
from enum import Enum
|
||||
|
||||
from agent_framework._pydantic import AFBaseSettings
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_settings import SettingsConfigDict
|
||||
from pydantic import BaseModel
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import TypedDict # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypedDict # type: ignore # pragma: no cover
|
||||
|
||||
|
||||
class PurviewLocationType(str, Enum):
|
||||
@@ -18,8 +22,8 @@ class PurviewLocationType(str, Enum):
|
||||
class PurviewAppLocation(BaseModel):
|
||||
"""Identifier representing the app's location for Purview policy evaluation."""
|
||||
|
||||
location_type: PurviewLocationType = Field(..., description="The location type.")
|
||||
location_value: str = Field(..., description="The location value.")
|
||||
location_type: PurviewLocationType
|
||||
location_value: str
|
||||
|
||||
def get_policy_location(self) -> dict[str, str]:
|
||||
ns = "microsoft.graph"
|
||||
@@ -34,8 +38,8 @@ class PurviewAppLocation(BaseModel):
|
||||
return {"@odata.type": dt, "value": self.location_value}
|
||||
|
||||
|
||||
class PurviewSettings(AFBaseSettings):
|
||||
"""Settings for Purview integration.
|
||||
class PurviewSettings(TypedDict, total=False):
|
||||
"""Settings for Purview integration mirroring .NET PurviewSettings.
|
||||
|
||||
Attributes:
|
||||
app_name: Public app name.
|
||||
@@ -51,40 +55,30 @@ class PurviewSettings(AFBaseSettings):
|
||||
max_cache_size_bytes: Maximum cache size in bytes (default 200MB).
|
||||
"""
|
||||
|
||||
app_name: str = Field(...)
|
||||
app_version: str | None = Field(default=None)
|
||||
tenant_id: str | None = Field(default=None)
|
||||
purview_app_location: PurviewAppLocation | None = Field(default=None)
|
||||
graph_base_uri: str = Field(default="https://graph.microsoft.com/v1.0/")
|
||||
blocked_prompt_message: str = Field(
|
||||
default="Prompt blocked by policy",
|
||||
description="Message to return when a prompt is blocked by policy.",
|
||||
)
|
||||
blocked_response_message: str = Field(
|
||||
default="Response blocked by policy",
|
||||
description="Message to return when a response is blocked by policy.",
|
||||
)
|
||||
ignore_exceptions: bool = Field(
|
||||
default=False,
|
||||
description="If True, all Purview exceptions will be logged but not thrown in middleware.",
|
||||
)
|
||||
ignore_payment_required: bool = Field(
|
||||
default=False,
|
||||
description="If True, 402 payment required errors will be logged but not thrown.",
|
||||
)
|
||||
cache_ttl_seconds: int = Field(
|
||||
default=14400,
|
||||
description="Time to live for cache entries in seconds (default 14400 = 4 hours).",
|
||||
)
|
||||
max_cache_size_bytes: int = Field(
|
||||
default=200 * 1024 * 1024,
|
||||
description="Maximum cache size in bytes (default 200MB).",
|
||||
)
|
||||
app_name: str | None
|
||||
app_version: str | None
|
||||
tenant_id: str | None
|
||||
purview_app_location: PurviewAppLocation | None
|
||||
graph_base_uri: str | None
|
||||
blocked_prompt_message: str | None
|
||||
blocked_response_message: str | None
|
||||
ignore_exceptions: bool | None
|
||||
ignore_payment_required: bool | None
|
||||
cache_ttl_seconds: int | None
|
||||
max_cache_size_bytes: int | None
|
||||
|
||||
model_config = SettingsConfigDict(populate_by_name=True, validate_assignment=True)
|
||||
|
||||
def get_scopes(self) -> list[str]:
|
||||
from urllib.parse import urlparse
|
||||
def get_purview_scopes(settings: PurviewSettings) -> list[str]:
|
||||
"""Get the OAuth scopes for the Purview Graph API.
|
||||
|
||||
host = urlparse(self.graph_base_uri).hostname or "graph.microsoft.com"
|
||||
return [f"https://{host}/.default"]
|
||||
Args:
|
||||
settings: The Purview settings containing graph_base_uri.
|
||||
|
||||
Returns:
|
||||
A list of OAuth scope strings.
|
||||
"""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
graph_base_uri = settings.get("graph_base_uri", "https://graph.microsoft.com/v1.0/")
|
||||
host = urlparse(str(graph_base_uri)).hostname or "graph.microsoft.com"
|
||||
return [f"https://{host}/.default"]
|
||||
|
||||
@@ -125,7 +125,7 @@ class TestPurviewChatPolicyMiddleware:
|
||||
) -> None:
|
||||
"""Test that exceptions in post-check are logged but don't affect result when ignore_exceptions=True."""
|
||||
# Set ignore_exceptions to True to test exception suppression
|
||||
middleware._settings.ignore_exceptions = True
|
||||
middleware._settings["ignore_exceptions"] = True
|
||||
|
||||
call_count = 0
|
||||
|
||||
|
||||
@@ -119,7 +119,7 @@ class TestPurviewPolicyMiddleware:
|
||||
) -> None:
|
||||
"""Test middleware handles result that doesn't have messages attribute."""
|
||||
# Set ignore_exceptions to True so AttributeError is caught and logged
|
||||
middleware._settings.ignore_exceptions = True
|
||||
middleware._settings["ignore_exceptions"] = True
|
||||
|
||||
context = AgentContext(agent=mock_agent, messages=[Message(role="user", text="Hello")])
|
||||
|
||||
@@ -216,7 +216,7 @@ class TestPurviewPolicyMiddleware:
|
||||
self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock
|
||||
) -> None:
|
||||
"""Test that post-check exceptions are propagated when ignore_exceptions=False."""
|
||||
middleware._settings.ignore_exceptions = False
|
||||
middleware._settings["ignore_exceptions"] = False
|
||||
|
||||
context = AgentContext(agent=mock_agent, messages=[Message(role="user", text="Hello")])
|
||||
|
||||
@@ -242,7 +242,7 @@ class TestPurviewPolicyMiddleware:
|
||||
) -> None:
|
||||
"""Test that exceptions in pre-check are logged but don't stop processing when ignore_exceptions=True."""
|
||||
# Set ignore_exceptions to True
|
||||
middleware._settings.ignore_exceptions = True
|
||||
middleware._settings["ignore_exceptions"] = True
|
||||
|
||||
context = AgentContext(agent=mock_agent, messages=[Message(role="user", text="Test")])
|
||||
|
||||
@@ -265,7 +265,7 @@ class TestPurviewPolicyMiddleware:
|
||||
) -> None:
|
||||
"""Test that exceptions in post-check are logged but don't affect result when ignore_exceptions=True."""
|
||||
# Set ignore_exceptions to True
|
||||
middleware._settings.ignore_exceptions = True
|
||||
middleware._settings["ignore_exceptions"] = True
|
||||
|
||||
context = AgentContext(agent=mock_agent, messages=[Message(role="user", text="Test")])
|
||||
|
||||
|
||||
@@ -636,7 +636,6 @@ class TestScopedContentProcessorCaching:
|
||||
return PurviewSettings(
|
||||
app_name="Test App",
|
||||
tenant_id="12345678-1234-1234-1234-123456789012",
|
||||
default_user_id="12345678-1234-1234-1234-123456789012",
|
||||
purview_app_location=location,
|
||||
)
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ class TestPurviewClient:
|
||||
@pytest.fixture
|
||||
def settings(self) -> PurviewSettings:
|
||||
"""Create test settings."""
|
||||
return PurviewSettings(app_name="Test App", tenant_id="test-tenant", default_user_id="test-user")
|
||||
return PurviewSettings(app_name="Test App", tenant_id="test-tenant")
|
||||
|
||||
@pytest.fixture
|
||||
async def client(
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from agent_framework_purview import PurviewAppLocation, PurviewLocationType, PurviewSettings
|
||||
from agent_framework_purview import PurviewAppLocation, PurviewLocationType, PurviewSettings, get_purview_scopes
|
||||
|
||||
|
||||
class TestPurviewSettings:
|
||||
@@ -14,10 +14,10 @@ class TestPurviewSettings:
|
||||
"""Test PurviewSettings with default values."""
|
||||
settings = PurviewSettings(app_name="Test App")
|
||||
|
||||
assert settings.app_name == "Test App"
|
||||
assert settings.graph_base_uri == "https://graph.microsoft.com/v1.0/"
|
||||
assert settings.tenant_id is None
|
||||
assert settings.purview_app_location is None
|
||||
assert settings["app_name"] == "Test App"
|
||||
assert settings.get("graph_base_uri") is None
|
||||
assert settings.get("tenant_id") is None
|
||||
assert settings.get("purview_app_location") is None
|
||||
|
||||
def test_settings_with_custom_values(self) -> None:
|
||||
"""Test PurviewSettings with custom values."""
|
||||
@@ -30,9 +30,9 @@ class TestPurviewSettings:
|
||||
purview_app_location=app_location,
|
||||
)
|
||||
|
||||
assert settings.graph_base_uri == "https://graph.microsoft-ppe.com"
|
||||
assert settings.tenant_id == "test-tenant-id"
|
||||
assert settings.purview_app_location.location_value == "app-123"
|
||||
assert settings["graph_base_uri"] == "https://graph.microsoft-ppe.com"
|
||||
assert settings["tenant_id"] == "test-tenant-id"
|
||||
assert settings["purview_app_location"].location_value == "app-123"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"graph_uri,expected_scope",
|
||||
@@ -44,7 +44,7 @@ class TestPurviewSettings:
|
||||
def test_get_scopes(self, graph_uri: str, expected_scope: str) -> None:
|
||||
"""Test get_scopes returns correct scope for different URIs."""
|
||||
settings = PurviewSettings(app_name="Test App", graph_base_uri=graph_uri)
|
||||
scopes = settings.get_scopes()
|
||||
scopes = get_purview_scopes(settings)
|
||||
|
||||
assert len(scopes) == 1
|
||||
assert expected_scope in scopes
|
||||
|
||||
Generated
+2
-2
@@ -331,7 +331,7 @@ dependencies = [
|
||||
{ name = "opentelemetry-semantic-conventions-ai", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "packaging", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "pydantic-settings", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "python-dotenv", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
{ name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
|
||||
]
|
||||
|
||||
@@ -386,7 +386,7 @@ requires-dist = [
|
||||
{ name = "opentelemetry-semantic-conventions-ai", specifier = ">=0.4.13" },
|
||||
{ name = "packaging", specifier = ">=24.1" },
|
||||
{ name = "pydantic", specifier = ">=2,<3" },
|
||||
{ name = "pydantic-settings", specifier = ">=2,<3" },
|
||||
{ name = "python-dotenv", specifier = ">=1,<2" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
provides-extras = ["all"]
|
||||
|
||||
Reference in New Issue
Block a user