mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: [BREAKING] Redesign Python exception hierarchy (#4082)
* [BREAKING] Redesign Python exception hierarchy Replace the flat ServiceException family with domain-scoped branches: - AgentException (with InvalidAuth, InvalidRequest, InvalidResponse, ContentFilter) - ChatClientException (same consistent suberrors) - IntegrationException (same + InitializationError) - WorkflowException (Runner, Convergence, Checkpoint, Validation, Action, Declarative) - ContentError (AdditionItemMismatch) - ToolException / ToolExecutionException (unchanged) - MiddlewareException / MiddlewareTermination (unchanged) Key changes: - All Service* exceptions removed (ServiceException, ServiceInitializationError, etc.) - AgentExecutionException split into AgentInvalidRequest/ResponseException - AgentInvocationError removed, split into AgentInvalidRequest/ResponseException - Workflow exceptions moved from _workflows/_exceptions.py into main exceptions.py - _workflows/__init__.py emptied; main __init__.py imports directly from submodules - Purview exceptions re-parented under IntegrationException hierarchy - Init validation errors use built-in ValueError/TypeError instead of custom exceptions - CODING_STANDARD.md updated with hierarchy design and rationale Fixes microsoft/agent-framework#3410 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Clarify ToolException vs ToolExecutionException docstrings ToolException: base class for all tool-related exceptions (preconditions, connection/init failures). ToolExecutionException: runtime call failures (tool call failed, reconnect failed, MCP errors). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Fix remaining stale imports from agent_framework._workflows - azurefunctions: _context.py, _app.py, _serialization.py, test_func_utils.py used 'from agent_framework._workflows import X' which broke after emptying _workflows/__init__.py; changed to direct submodule imports - azure-ai-search: test still referenced ServiceInitializationError; updated to ValueError to match production code Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
7f606a2e3a
commit
5ee06853a1
+100
-1
@@ -165,6 +165,105 @@ The package follows a flat import structure:
|
||||
from agent_framework.azure import AzureOpenAIChatClient
|
||||
```
|
||||
|
||||
## Exception Hierarchy
|
||||
|
||||
The Agent Framework defines a structured exception hierarchy rooted at `AgentFrameworkException`. Every AF-specific
|
||||
exception inherits from this base, so callers can catch `AgentFrameworkException` as a broad fallback. The hierarchy
|
||||
is organized into domain-specific L1 branches, each with a consistent set of leaf exceptions where applicable.
|
||||
|
||||
### Design Principles
|
||||
|
||||
- **Domain-scoped branches**: Exceptions are grouped by the subsystem that raises them (agent, chat client,
|
||||
integration, workflow, content, tool, middleware), not by HTTP status code or generic error category.
|
||||
- **Consistent suberror pattern**: The `AgentException`, `ChatClientException`, and `IntegrationException` branches
|
||||
share a parallel set of leaf exceptions (`InvalidAuth`, `InvalidRequest`, `InvalidResponse`, `ContentFilter`) so
|
||||
that callers can handle the same failure mode uniformly across domains.
|
||||
- **Built-ins for validation**: Configuration/parameter validation errors use Python built-in exceptions
|
||||
(`ValueError`, `TypeError`, `RuntimeError`) rather than AF-specific classes. AF exceptions are reserved for
|
||||
domain-level failures that callers may want to catch and handle distinctly from programming errors.
|
||||
- **No compatibility aliases**: When exceptions are renamed or removed, the old names are not kept as aliases.
|
||||
This is a deliberate trade-off for hierarchy clarity over backward compatibility.
|
||||
- **Suffix convention**: L1 branch classes use `...Exception` (e.g., `AgentException`). Leaf classes may use
|
||||
either `...Exception` or `...Error` depending on the domain convention (e.g., `ContentError`,
|
||||
`WorkflowValidationError`). Within a branch, the suffix is consistent.
|
||||
|
||||
### Full Hierarchy
|
||||
|
||||
```
|
||||
AgentFrameworkException # Base for all AF exceptions
|
||||
├── AgentException # Agent-scoped failures
|
||||
│ ├── AgentInvalidAuthException # Agent auth failures
|
||||
│ ├── AgentInvalidRequestException # Invalid request to agent (e.g., agent not found, bad input)
|
||||
│ ├── AgentInvalidResponseException # Invalid/unexpected response from agent
|
||||
│ └── AgentContentFilterException # Agent content filter triggered
|
||||
│
|
||||
├── ChatClientException # Chat client lifecycle and communication failures
|
||||
│ ├── ChatClientInvalidAuthException # Chat client auth failures
|
||||
│ ├── ChatClientInvalidRequestException # Invalid request to chat client
|
||||
│ ├── ChatClientInvalidResponseException # Invalid/unexpected response from chat client
|
||||
│ └── ChatClientContentFilterException # Chat client content filter triggered
|
||||
│
|
||||
├── IntegrationException # External service/dependency integration failures
|
||||
│ ├── IntegrationInitializationError # Wrapped dependency lifecycle failure during setup
|
||||
│ ├── IntegrationInvalidAuthException # Integration auth failures (e.g., 401/403)
|
||||
│ ├── IntegrationInvalidRequestException # Invalid request to integration
|
||||
│ ├── IntegrationInvalidResponseException # Invalid/unexpected response from integration
|
||||
│ └── IntegrationContentFilterException # Integration content filter triggered
|
||||
│
|
||||
├── ContentError # Content processing/validation failures
|
||||
│ └── AdditionItemMismatch # Type mismatch when merging content items
|
||||
│
|
||||
├── WorkflowException # Workflow engine failures
|
||||
│ ├── WorkflowRunnerException # Runtime execution failures
|
||||
│ │ ├── WorkflowConvergenceException # Runner exceeded max iterations
|
||||
│ │ └── WorkflowCheckpointException # Checkpoint save/restore/decode failures
|
||||
│ ├── WorkflowValidationError # Graph validation errors
|
||||
│ │ ├── EdgeDuplicationError # Duplicate edge in workflow graph
|
||||
│ │ ├── TypeCompatibilityError # Type mismatch between connected executors
|
||||
│ │ └── GraphConnectivityError # Graph connectivity issues
|
||||
│ ├── WorkflowActionError # User-level error from declarative ThrowException action
|
||||
│ └── DeclarativeWorkflowError # Declarative workflow definition/YAML errors
|
||||
│
|
||||
├── ToolException # Tool-related failures
|
||||
│ └── ToolExecutionException # Failure during tool execution
|
||||
│
|
||||
├── MiddlewareException # Middleware failures
|
||||
│ └── MiddlewareTermination # Control-flow: early middleware termination
|
||||
│
|
||||
└── SettingNotFoundError # Required setting not resolved from any source
|
||||
```
|
||||
|
||||
### When to Use AF Exceptions vs Built-ins
|
||||
|
||||
| Scenario | Exception to use |
|
||||
|---|---|
|
||||
| Missing or invalid constructor argument (e.g., `api_key` is `None`) | `ValueError` or `TypeError` |
|
||||
| Object in wrong state (e.g., client not initialized) | `RuntimeError` |
|
||||
| External service returns 401/403 | `IntegrationInvalidAuthException` (or `ChatClient`/`Agent` variant) |
|
||||
| External service returns unexpected response | `IntegrationInvalidResponseException` (or variant) |
|
||||
| Content filter blocks a request | `IntegrationContentFilterException` (or variant) |
|
||||
| Request validation fails before sending to service | `IntegrationInvalidRequestException` (or variant) |
|
||||
| Agent not found in registry | `AgentInvalidRequestException` |
|
||||
| Agent returned no/bad response | `AgentInvalidResponseException` |
|
||||
| Workflow runner exceeds max iterations | `WorkflowConvergenceException` |
|
||||
| Checkpoint serialization/deserialization failure | `WorkflowCheckpointException` |
|
||||
| Workflow graph has invalid structure | `WorkflowValidationError` (or specific subclass) |
|
||||
| Declarative YAML definition error | `DeclarativeWorkflowError` |
|
||||
| Tool execution failure | `ToolExecutionException` |
|
||||
| Content merge type mismatch | `AdditionItemMismatch` |
|
||||
|
||||
### Choosing Between Agent, ChatClient, and Integration Branches
|
||||
|
||||
- **`AgentException`**: The failure is scoped to agent-level logic — agent lookup, agent response handling,
|
||||
agent content filtering. Use when the agent itself is the source of the problem.
|
||||
- **`ChatClientException`**: The failure is scoped to the chat client (the LLM provider connection) — auth with
|
||||
the LLM provider, request/response format issues specific to the chat protocol, chat-level content filtering.
|
||||
- **`IntegrationException`**: The failure is in a non-chat external dependency — search services, vector stores,
|
||||
Purview, custom APIs, or any service that is not the primary LLM chat provider.
|
||||
|
||||
When in doubt: if the code is in a chat client constructor or method, use `ChatClient*`. If it's in an agent
|
||||
method, use `Agent*`. If it's talking to an external service that isn't the chat LLM, use `Integration*`.
|
||||
|
||||
## Package Structure
|
||||
|
||||
The project uses a monorepo structure with separate packages for each connector/extension:
|
||||
@@ -299,7 +398,7 @@ They should contain:
|
||||
- Returns are specified after a header called `Returns:` or `Yields:`, with the return type and explanation of the return value.
|
||||
- Keyword arguments are specified after a header called `Keyword Args:`, with each argument being specified in the same format as `Args:`.
|
||||
- A header for exceptions can be added, called `Raises:`, following these guidelines:
|
||||
- **Always document** Agent Framework specific exceptions (e.g., `AgentInitializationError`, `AgentExecutionException`)
|
||||
- **Always document** Agent Framework specific exceptions (e.g., `AgentInvalidRequestException`, `IntegrationInvalidAuthException`)
|
||||
- **Only document** standard Python exceptions (TypeError, ValueError, KeyError, etc.) when the condition is non-obvious or provides value to API users
|
||||
- Format: `ExceptionType`: Explanation of the exception.
|
||||
- If a longer explanation is needed, it should be placed on the next line, indented by 4 spaces.
|
||||
|
||||
@@ -40,7 +40,7 @@ from agent_framework._tools import (
|
||||
normalize_function_invocation_configuration,
|
||||
)
|
||||
from agent_framework._types import ResponseStream
|
||||
from agent_framework.exceptions import AgentExecutionException
|
||||
from agent_framework.exceptions import AgentInvalidResponseException
|
||||
|
||||
from ._message_adapters import normalize_agui_input_messages
|
||||
from ._orchestration._predictive_state import PredictiveStateHandler
|
||||
@@ -207,7 +207,7 @@ async def _normalize_response_stream(response_stream: Any) -> AsyncIterable[Any]
|
||||
if isinstance(resolved_stream, AsyncIterable):
|
||||
return cast(AsyncIterable[Any], resolved_stream)
|
||||
resolved_type = f"{type(resolved_stream).__module__}.{type(resolved_stream).__name__}"
|
||||
raise AgentExecutionException(
|
||||
raise AgentInvalidResponseException(
|
||||
"Agent did not return a streaming AsyncIterable response. "
|
||||
f"Awaitable resolved to unsupported type: {resolved_type}."
|
||||
)
|
||||
@@ -220,7 +220,7 @@ async def _normalize_response_stream(response_stream: Any) -> AsyncIterable[Any]
|
||||
return cast(AsyncIterable[Any], response_stream)
|
||||
|
||||
stream_type = f"{type(response_stream).__module__}.{type(response_stream).__name__}"
|
||||
raise AgentExecutionException(
|
||||
raise AgentInvalidResponseException(
|
||||
f"Agent did not return a streaming AsyncIterable response. Received unsupported type: {stream_type}."
|
||||
)
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from ag_ui.core import (
|
||||
TextMessageStartEvent,
|
||||
)
|
||||
from agent_framework import AgentResponseUpdate, Content, Message, ResponseStream
|
||||
from agent_framework.exceptions import AgentExecutionException
|
||||
from agent_framework.exceptions import AgentInvalidResponseException
|
||||
|
||||
from agent_framework_ag_ui._run import (
|
||||
FlowState,
|
||||
@@ -226,7 +226,7 @@ class TestNormalizeResponseStream:
|
||||
|
||||
async def test_rejects_non_stream_values(self):
|
||||
"""Reject unsupported stream return values."""
|
||||
with pytest.raises(AgentExecutionException):
|
||||
with pytest.raises(AgentInvalidResponseException):
|
||||
await _normalize_response_stream("not-a-stream")
|
||||
|
||||
|
||||
|
||||
@@ -28,7 +28,6 @@ from agent_framework import (
|
||||
)
|
||||
from agent_framework._settings import SecretString, load_settings
|
||||
from agent_framework._types import _get_data_bytes_as_str # type: ignore
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
from agent_framework.observability import ChatTelemetryLayer
|
||||
from anthropic import AsyncAnthropic
|
||||
from anthropic.types.beta import (
|
||||
@@ -303,7 +302,7 @@ class AnthropicClient(
|
||||
|
||||
if anthropic_client is None:
|
||||
if not anthropic_settings["api_key"]:
|
||||
raise ServiceInitializationError(
|
||||
raise ValueError(
|
||||
"Anthropic API key is required. Set via 'api_key' parameter "
|
||||
"or 'ANTHROPIC_API_KEY' environment variable."
|
||||
)
|
||||
|
||||
@@ -14,7 +14,6 @@ from agent_framework import (
|
||||
tool,
|
||||
)
|
||||
from agent_framework._settings import load_settings
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
from anthropic.types.beta import (
|
||||
BetaMessage,
|
||||
BetaTextBlock,
|
||||
@@ -128,7 +127,7 @@ def test_anthropic_client_init_missing_api_key() -> None:
|
||||
with patch("agent_framework_anthropic._chat_client.load_settings") as mock_load:
|
||||
mock_load.return_value = {"api_key": None, "chat_model_id": "claude-3-5-sonnet-20241022"}
|
||||
|
||||
with pytest.raises(ServiceInitializationError, match="Anthropic API key is required"):
|
||||
with pytest.raises(ValueError, match="Anthropic API key is required"):
|
||||
AnthropicClient()
|
||||
|
||||
|
||||
|
||||
+2
-3
@@ -17,7 +17,6 @@ from agent_framework import AGENT_FRAMEWORK_USER_AGENT, Message
|
||||
from agent_framework._sessions import AgentSession, BaseContextProvider, SessionContext
|
||||
from agent_framework._settings import SecretString, load_settings
|
||||
from agent_framework.azure._entra_id_authentication import AzureCredentialTypes
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
from azure.core.credentials_async import AsyncTokenCredential
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
@@ -219,7 +218,7 @@ class AzureAISearchContextProvider(BaseContextProvider):
|
||||
)
|
||||
|
||||
if mode == "agentic" and settings.get("index_name") and not model_deployment_name:
|
||||
raise ServiceInitializationError(
|
||||
raise ValueError(
|
||||
"model_deployment_name is required for agentic mode when creating Knowledge Base from index."
|
||||
)
|
||||
|
||||
@@ -231,7 +230,7 @@ class AzureAISearchContextProvider(BaseContextProvider):
|
||||
elif settings.get("api_key"):
|
||||
resolved_credential = AzureKeyCredential(settings["api_key"].get_secret_value()) # type: ignore[union-attr]
|
||||
else:
|
||||
raise ServiceInitializationError(
|
||||
raise ValueError(
|
||||
"Azure credential is required. Provide 'api_key' or 'credential' parameter "
|
||||
"or set 'AZURE_SEARCH_API_KEY' environment variable."
|
||||
)
|
||||
|
||||
@@ -8,7 +8,7 @@ from unittest.mock import AsyncMock, Mock, patch
|
||||
import pytest
|
||||
from agent_framework import Message
|
||||
from agent_framework._sessions import AgentSession, SessionContext
|
||||
from agent_framework.exceptions import ServiceInitializationError, SettingNotFoundError
|
||||
from agent_framework.exceptions import SettingNotFoundError
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
|
||||
from agent_framework_azure_ai_search._context_provider import AzureAISearchContextProvider
|
||||
@@ -180,7 +180,7 @@ class TestInitCredentialResolution:
|
||||
assert provider.credential is akc
|
||||
|
||||
def test_no_credential_raises(self) -> None:
|
||||
with pytest.raises(ServiceInitializationError, match="Azure credential is required"):
|
||||
with pytest.raises(ValueError, match="Azure credential is required"):
|
||||
AzureAISearchContextProvider(
|
||||
endpoint="https://test.search.windows.net",
|
||||
index_name="idx",
|
||||
@@ -216,7 +216,7 @@ class TestInitAgenticValidation:
|
||||
)
|
||||
|
||||
def test_missing_model_deployment_name_raises(self) -> None:
|
||||
with pytest.raises(ServiceInitializationError, match="model_deployment_name"):
|
||||
with pytest.raises(ValueError, match="model_deployment_name"):
|
||||
AzureAISearchContextProvider(
|
||||
source_id="s",
|
||||
endpoint="https://test.search.windows.net",
|
||||
|
||||
@@ -18,7 +18,6 @@ from agent_framework._mcp import MCPTool
|
||||
from agent_framework._settings import load_settings
|
||||
from agent_framework._tools import ToolTypes
|
||||
from agent_framework.azure._entra_id_authentication import AzureCredentialTypes
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
from azure.ai.agents.aio import AgentsClient
|
||||
from azure.ai.agents.models import Agent as AzureAgent
|
||||
from azure.ai.agents.models import ResponseFormatJsonSchema, ResponseFormatJsonSchemaType
|
||||
@@ -113,7 +112,7 @@ class AzureAIAgentsProvider(Generic[OptionsCoT]):
|
||||
env_file_encoding: Encoding of the .env file.
|
||||
|
||||
Raises:
|
||||
ServiceInitializationError: If required parameters are missing or invalid.
|
||||
ValueError: If required parameters are missing or invalid.
|
||||
"""
|
||||
self._settings = load_settings(
|
||||
AzureAISettings,
|
||||
@@ -130,12 +129,12 @@ class AzureAIAgentsProvider(Generic[OptionsCoT]):
|
||||
else:
|
||||
resolved_endpoint = self._settings.get("project_endpoint")
|
||||
if not resolved_endpoint:
|
||||
raise ServiceInitializationError(
|
||||
raise ValueError(
|
||||
"Azure AI project endpoint is required. Provide 'project_endpoint' parameter "
|
||||
"or set 'AZURE_AI_PROJECT_ENDPOINT' environment variable."
|
||||
)
|
||||
if not credential:
|
||||
raise ServiceInitializationError("Azure credential is required when agents_client is not provided.")
|
||||
raise ValueError("Azure credential is required when agents_client is not provided.")
|
||||
self._agents_client = AgentsClient(
|
||||
endpoint=resolved_endpoint,
|
||||
credential=credential, # type: ignore[arg-type]
|
||||
@@ -199,7 +198,7 @@ class AzureAIAgentsProvider(Generic[OptionsCoT]):
|
||||
Agent: A Agent instance configured with the created agent.
|
||||
|
||||
Raises:
|
||||
ServiceInitializationError: If model deployment name is not available.
|
||||
ValueError: If model deployment name is not available.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
@@ -212,7 +211,7 @@ class AzureAIAgentsProvider(Generic[OptionsCoT]):
|
||||
"""
|
||||
resolved_model = model or self._settings.get("model_deployment_name")
|
||||
if not resolved_model:
|
||||
raise ServiceInitializationError(
|
||||
raise ValueError(
|
||||
"Model deployment name is required. Provide 'model' parameter "
|
||||
"or set 'AZURE_AI_MODEL_DEPLOYMENT_NAME' environment variable."
|
||||
)
|
||||
@@ -290,7 +289,7 @@ class AzureAIAgentsProvider(Generic[OptionsCoT]):
|
||||
Agent: A Agent instance configured with the retrieved agent.
|
||||
|
||||
Raises:
|
||||
ServiceInitializationError: If required function tools are not provided.
|
||||
ValueError: If required function tools are not provided.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
@@ -340,7 +339,7 @@ class AzureAIAgentsProvider(Generic[OptionsCoT]):
|
||||
Agent: A Agent instance configured with the agent.
|
||||
|
||||
Raises:
|
||||
ServiceInitializationError: If required function tools are not provided.
|
||||
ValueError: If required function tools are not provided.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
@@ -449,7 +448,7 @@ class AzureAIAgentsProvider(Generic[OptionsCoT]):
|
||||
"""Validate that required function tools are provided.
|
||||
|
||||
Raises:
|
||||
ServiceInitializationError: If agent has function tools but user
|
||||
ValueError: If agent has function tools but user
|
||||
didn't provide implementations.
|
||||
"""
|
||||
if not agent_tools:
|
||||
@@ -483,7 +482,7 @@ class AzureAIAgentsProvider(Generic[OptionsCoT]):
|
||||
# Check for missing implementations
|
||||
missing = function_tool_names - provided_names
|
||||
if missing:
|
||||
raise ServiceInitializationError(
|
||||
raise ValueError(
|
||||
f"Agent has function tools that require implementations: {missing}. "
|
||||
"Provide these functions via the 'tools' parameter."
|
||||
)
|
||||
|
||||
@@ -36,7 +36,10 @@ from agent_framework import (
|
||||
from agent_framework._settings import load_settings
|
||||
from agent_framework._tools import ToolTypes
|
||||
from agent_framework.azure._entra_id_authentication import AzureCredentialTypes
|
||||
from agent_framework.exceptions import ServiceInitializationError, ServiceInvalidRequestError, ServiceResponseException
|
||||
from agent_framework.exceptions import (
|
||||
ChatClientException,
|
||||
ChatClientInvalidRequestException,
|
||||
)
|
||||
from agent_framework.observability import ChatTelemetryLayer
|
||||
from azure.ai.agents.aio import AgentsClient
|
||||
from azure.ai.agents.models import (
|
||||
@@ -498,20 +501,20 @@ class AzureAIAgentClient(
|
||||
if agents_client is None:
|
||||
resolved_endpoint = azure_ai_settings.get("project_endpoint")
|
||||
if not resolved_endpoint:
|
||||
raise ServiceInitializationError(
|
||||
raise ValueError(
|
||||
"Azure AI project endpoint is required. Set via 'project_endpoint' parameter "
|
||||
"or 'AZURE_AI_PROJECT_ENDPOINT' environment variable."
|
||||
)
|
||||
|
||||
if agent_id is None and not azure_ai_settings.get("model_deployment_name"):
|
||||
raise ServiceInitializationError(
|
||||
raise ValueError(
|
||||
"Azure AI model deployment name is required. Set via 'model_deployment_name' parameter "
|
||||
"or 'AZURE_AI_MODEL_DEPLOYMENT_NAME' environment variable."
|
||||
)
|
||||
|
||||
# Use provided credential
|
||||
if not credential:
|
||||
raise ServiceInitializationError("Azure credential is required when agents_client is not provided.")
|
||||
raise ValueError("Azure credential is required when agents_client is not provided.")
|
||||
agents_client = AgentsClient(
|
||||
endpoint=resolved_endpoint,
|
||||
credential=credential, # type: ignore[arg-type]
|
||||
@@ -606,7 +609,7 @@ class AzureAIAgentClient(
|
||||
# If no agent_id is provided, create a temporary agent
|
||||
if self.agent_id is None:
|
||||
if "model" not in run_options or not run_options["model"]:
|
||||
raise ServiceInitializationError(
|
||||
raise ValueError(
|
||||
"Model deployment name is required for agent creation, "
|
||||
"can also be passed to the get_response methods."
|
||||
)
|
||||
@@ -916,7 +919,7 @@ class AzureAIAgentClient(
|
||||
response_id=response_id,
|
||||
)
|
||||
case AgentStreamEvent.THREAD_RUN_FAILED:
|
||||
raise ServiceResponseException(event_data.last_error.message)
|
||||
raise ChatClientException(event_data.last_error.message)
|
||||
case _:
|
||||
yield ChatResponseUpdate(
|
||||
contents=[],
|
||||
@@ -1159,7 +1162,7 @@ class AzureAIAgentClient(
|
||||
# Runtime JSON schema dict - pass through as-is
|
||||
run_options["response_format"] = response_format
|
||||
else:
|
||||
raise ServiceInvalidRequestError(
|
||||
raise ChatClientInvalidRequestException(
|
||||
"response_format must be a Pydantic BaseModel class or a dict with runtime JSON schema."
|
||||
)
|
||||
|
||||
|
||||
@@ -24,7 +24,6 @@ from agent_framework import (
|
||||
from agent_framework._settings import load_settings
|
||||
from agent_framework._tools import ToolTypes
|
||||
from agent_framework.azure._entra_id_authentication import AzureCredentialTypes
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
from agent_framework.observability import ChatTelemetryLayer
|
||||
from agent_framework.openai import OpenAIResponsesOptions
|
||||
from agent_framework.openai._responses_client import RawOpenAIResponsesClient
|
||||
@@ -188,14 +187,14 @@ class RawAzureAIClient(RawOpenAIResponsesClient[AzureAIClientOptionsT], Generic[
|
||||
if project_client is None:
|
||||
resolved_endpoint = azure_ai_settings.get("project_endpoint")
|
||||
if not resolved_endpoint:
|
||||
raise ServiceInitializationError(
|
||||
raise ValueError(
|
||||
"Azure AI project endpoint is required. Set via 'project_endpoint' parameter "
|
||||
"or 'AZURE_AI_PROJECT_ENDPOINT' environment variable."
|
||||
)
|
||||
|
||||
# Use provided credential
|
||||
if not credential:
|
||||
raise ServiceInitializationError("Azure credential is required when project_client is not provided.")
|
||||
raise ValueError("Azure credential is required when project_client is not provided.")
|
||||
project_client = AIProjectClient(
|
||||
endpoint=resolved_endpoint,
|
||||
credential=credential, # type: ignore[arg-type]
|
||||
@@ -345,7 +344,7 @@ class RawAzureAIClient(RawOpenAIResponsesClient[AzureAIClientOptionsT], Generic[
|
||||
"""
|
||||
# Agent name must be explicitly provided by the user.
|
||||
if self.agent_name is None:
|
||||
raise ServiceInitializationError(
|
||||
raise ValueError(
|
||||
"Agent name is required. Provide 'agent_name' when initializing AzureAIClient "
|
||||
"or 'name' when initializing Agent."
|
||||
)
|
||||
@@ -363,7 +362,7 @@ class RawAzureAIClient(RawOpenAIResponsesClient[AzureAIClientOptionsT], Generic[
|
||||
return {"name": self.agent_name, "version": self.agent_version, "type": "agent_reference"}
|
||||
|
||||
if "model" not in run_options or not run_options["model"]:
|
||||
raise ServiceInitializationError(
|
||||
raise ValueError(
|
||||
"Model deployment name is required for agent creation, "
|
||||
"can also be passed to the get_response methods."
|
||||
)
|
||||
|
||||
@@ -19,7 +19,6 @@ from agent_framework._mcp import MCPTool
|
||||
from agent_framework._settings import load_settings
|
||||
from agent_framework._tools import ToolTypes
|
||||
from agent_framework.azure._entra_id_authentication import AzureCredentialTypes
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
from azure.ai.projects.aio import AIProjectClient
|
||||
from azure.ai.projects.models import (
|
||||
AgentReference,
|
||||
@@ -123,7 +122,7 @@ class AzureAIProjectAgentProvider(Generic[OptionsCoT]):
|
||||
env_file_encoding: Encoding of the environment file.
|
||||
|
||||
Raises:
|
||||
ServiceInitializationError: If required parameters are missing or invalid.
|
||||
ValueError: If required parameters are missing or invalid.
|
||||
"""
|
||||
self._settings = load_settings(
|
||||
AzureAISettings,
|
||||
@@ -140,13 +139,13 @@ class AzureAIProjectAgentProvider(Generic[OptionsCoT]):
|
||||
if project_client is None:
|
||||
resolved_endpoint = self._settings.get("project_endpoint")
|
||||
if not resolved_endpoint:
|
||||
raise ServiceInitializationError(
|
||||
raise ValueError(
|
||||
"Azure AI project endpoint is required. Set via 'project_endpoint' parameter "
|
||||
"or 'AZURE_AI_PROJECT_ENDPOINT' environment variable."
|
||||
)
|
||||
|
||||
if not credential:
|
||||
raise ServiceInitializationError("Azure credential is required when project_client is not provided.")
|
||||
raise ValueError("Azure credential is required when project_client is not provided.")
|
||||
|
||||
project_client = AIProjectClient(
|
||||
endpoint=resolved_endpoint,
|
||||
@@ -186,12 +185,12 @@ class AzureAIProjectAgentProvider(Generic[OptionsCoT]):
|
||||
Agent: A Agent instance configured with the created agent.
|
||||
|
||||
Raises:
|
||||
ServiceInitializationError: If required parameters are missing.
|
||||
ValueError: If required parameters are missing.
|
||||
"""
|
||||
# Resolve model from parameter or environment variable
|
||||
resolved_model = model or self._settings.get("model_deployment_name")
|
||||
if not resolved_model:
|
||||
raise ServiceInitializationError(
|
||||
raise ValueError(
|
||||
"Model deployment name is required. Provide 'model' parameter "
|
||||
"or set 'AZURE_AI_MODEL_DEPLOYMENT_NAME' environment variable."
|
||||
)
|
||||
|
||||
@@ -10,7 +10,7 @@ from typing import Any, cast
|
||||
from agent_framework import (
|
||||
FunctionTool,
|
||||
)
|
||||
from agent_framework.exceptions import ServiceInvalidRequestError
|
||||
from agent_framework.exceptions import IntegrationInvalidRequestException
|
||||
from azure.ai.agents.models import (
|
||||
CodeInterpreterToolDefinition,
|
||||
ToolDefinition,
|
||||
@@ -125,7 +125,7 @@ def to_azure_ai_agent_tools(
|
||||
List of Azure AI V1 SDK tool definitions.
|
||||
|
||||
Raises:
|
||||
ServiceInitializationError: If tool configuration is invalid.
|
||||
ValueError: If tool configuration is invalid.
|
||||
"""
|
||||
if not tools:
|
||||
return []
|
||||
@@ -458,7 +458,7 @@ def create_text_format_config(
|
||||
if format_type == "text":
|
||||
return ResponseTextFormatConfigurationText()
|
||||
|
||||
raise ServiceInvalidRequestError("response_format must be a Pydantic model or mapping.")
|
||||
raise IntegrationInvalidRequestException("response_format must be a Pydantic model or mapping.")
|
||||
|
||||
|
||||
def _convert_response_format(response_format: Mapping[str, Any]) -> dict[str, Any]:
|
||||
@@ -470,11 +470,11 @@ def _convert_response_format(response_format: Mapping[str, Any]) -> dict[str, An
|
||||
if format_type == "json_schema":
|
||||
schema_section = response_format.get("json_schema", response_format)
|
||||
if not isinstance(schema_section, Mapping):
|
||||
raise ServiceInvalidRequestError("json_schema response_format must be a mapping.")
|
||||
raise IntegrationInvalidRequestException("json_schema response_format must be a mapping.")
|
||||
schema_section_typed = cast("Mapping[str, Any]", schema_section)
|
||||
schema: Any = schema_section_typed.get("schema")
|
||||
if schema is None:
|
||||
raise ServiceInvalidRequestError("json_schema response_format requires a schema.")
|
||||
raise IntegrationInvalidRequestException("json_schema response_format requires a schema.")
|
||||
name: str = str(
|
||||
schema_section_typed.get("name")
|
||||
or schema_section_typed.get("title")
|
||||
@@ -495,4 +495,4 @@ def _convert_response_format(response_format: Mapping[str, Any]) -> dict[str, An
|
||||
if format_type in {"json_object", "text"}:
|
||||
return {"type": format_type}
|
||||
|
||||
raise ServiceInvalidRequestError("Unsupported response_format provided for Azure AI client.")
|
||||
raise IntegrationInvalidRequestException("Unsupported response_format provided for Azure AI client.")
|
||||
|
||||
@@ -9,7 +9,6 @@ from agent_framework import (
|
||||
Agent,
|
||||
tool,
|
||||
)
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
from azure.ai.agents.models import (
|
||||
Agent as AzureAgent,
|
||||
)
|
||||
@@ -37,7 +36,6 @@ skip_if_azure_ai_integration_tests_disabled = pytest.mark.skipif(
|
||||
else "Integration tests are disabled.",
|
||||
)
|
||||
|
||||
|
||||
# region Provider Initialization Tests
|
||||
|
||||
|
||||
@@ -90,7 +88,7 @@ def test_provider_init_missing_endpoint_raises(
|
||||
with patch("agent_framework_azure_ai._agent_provider.load_settings") as mock_load_settings:
|
||||
mock_load_settings.return_value = {"project_endpoint": None, "model_deployment_name": "test-model"}
|
||||
|
||||
with pytest.raises(ServiceInitializationError) as exc_info:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
AzureAIAgentsProvider(credential=mock_azure_credential)
|
||||
|
||||
assert "project endpoint is required" in str(exc_info.value).lower()
|
||||
@@ -98,7 +96,7 @@ def test_provider_init_missing_endpoint_raises(
|
||||
|
||||
def test_provider_init_missing_credential_raises(azure_ai_unit_test_env: dict[str, str]) -> None:
|
||||
"""Test AzureAIAgentsProvider raises error when credential is missing."""
|
||||
with pytest.raises(ServiceInitializationError) as exc_info:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
AzureAIAgentsProvider()
|
||||
|
||||
assert "credential is required" in str(exc_info.value).lower()
|
||||
@@ -106,7 +104,6 @@ def test_provider_init_missing_credential_raises(azure_ai_unit_test_env: dict[st
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region Context Manager Tests
|
||||
|
||||
|
||||
@@ -142,7 +139,6 @@ async def test_provider_context_manager_does_not_close_external_client(mock_agen
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region create_agent Tests
|
||||
|
||||
|
||||
@@ -272,7 +268,7 @@ async def test_create_agent_missing_model_raises(
|
||||
|
||||
provider = AzureAIAgentsProvider(agents_client=mock_agents_client)
|
||||
|
||||
with pytest.raises(ServiceInitializationError) as exc_info:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await provider.create_agent(name="TestAgent")
|
||||
|
||||
assert "model deployment name is required" in str(exc_info.value).lower()
|
||||
@@ -280,7 +276,6 @@ async def test_create_agent_missing_model_raises(
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region get_agent Tests
|
||||
|
||||
|
||||
@@ -332,7 +327,7 @@ async def test_get_agent_with_function_tools(
|
||||
|
||||
provider = AzureAIAgentsProvider(agents_client=mock_agents_client)
|
||||
|
||||
with pytest.raises(ServiceInitializationError) as exc_info:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await provider.get_agent("agent-with-tools")
|
||||
|
||||
assert "get_weather" in str(exc_info.value)
|
||||
@@ -374,7 +369,6 @@ async def test_get_agent_with_provided_function_tools(
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region as_agent Tests
|
||||
|
||||
|
||||
@@ -427,7 +421,7 @@ def test_as_agent_with_function_tools_validates(
|
||||
|
||||
provider = AzureAIAgentsProvider(agents_client=mock_agents_client)
|
||||
|
||||
with pytest.raises(ServiceInitializationError) as exc_info:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
provider.as_agent(mock_agent)
|
||||
|
||||
assert "my_function" in str(exc_info.value)
|
||||
@@ -489,7 +483,7 @@ def test_as_agent_with_dict_function_tools_validates(
|
||||
|
||||
provider = AzureAIAgentsProvider(agents_client=mock_agents_client)
|
||||
|
||||
with pytest.raises(ServiceInitializationError) as exc_info:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
provider.as_agent(mock_agent)
|
||||
|
||||
assert "dict_based_function" in str(exc_info.value)
|
||||
@@ -534,7 +528,6 @@ def test_as_agent_with_dict_function_tools_provided(
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region Tool Conversion Tests - to_azure_ai_agent_tools
|
||||
|
||||
|
||||
@@ -659,7 +652,6 @@ def test_to_azure_ai_agent_tools_unsupported_type() -> None:
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region Tool Conversion Tests - from_azure_ai_agent_tools
|
||||
|
||||
|
||||
@@ -784,7 +776,6 @@ def test_from_azure_ai_agent_tools_unknown_dict() -> None:
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region Integration Tests
|
||||
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ from agent_framework import (
|
||||
)
|
||||
from agent_framework._serialization import SerializationMixin
|
||||
from agent_framework._settings import load_settings
|
||||
from agent_framework.exceptions import ServiceInitializationError, ServiceInvalidRequestError
|
||||
from agent_framework.exceptions import ChatClientInvalidRequestException
|
||||
from azure.ai.agents.models import (
|
||||
AgentsNamedToolChoice,
|
||||
AgentsNamedToolChoiceType,
|
||||
@@ -165,7 +165,7 @@ def test_azure_ai_chat_client_init_missing_project_endpoint() -> None:
|
||||
with patch("agent_framework_azure_ai._chat_client.load_settings") as mock_load_settings:
|
||||
mock_load_settings.return_value = {"project_endpoint": None, "model_deployment_name": "test-model"}
|
||||
|
||||
with pytest.raises(ServiceInitializationError, match="project endpoint is required"):
|
||||
with pytest.raises(ValueError, match="project endpoint is required"):
|
||||
AzureAIAgentClient(
|
||||
agents_client=None,
|
||||
agent_id=None,
|
||||
@@ -181,7 +181,7 @@ def test_azure_ai_chat_client_init_missing_model_deployment_for_agent_creation()
|
||||
with patch("agent_framework_azure_ai._chat_client.load_settings") as mock_load_settings:
|
||||
mock_load_settings.return_value = {"project_endpoint": "https://test.com", "model_deployment_name": None}
|
||||
|
||||
with pytest.raises(ServiceInitializationError, match="model deployment name is required"):
|
||||
with pytest.raises(ValueError, match="model deployment name is required"):
|
||||
AzureAIAgentClient(
|
||||
agents_client=None,
|
||||
agent_id=None, # No existing agent
|
||||
@@ -193,9 +193,7 @@ def test_azure_ai_chat_client_init_missing_model_deployment_for_agent_creation()
|
||||
|
||||
def test_azure_ai_chat_client_init_missing_credential(azure_ai_unit_test_env: dict[str, str]) -> None:
|
||||
"""Test AzureAIAgentClient.__init__ when credential is missing and no agents_client provided."""
|
||||
with pytest.raises(
|
||||
ServiceInitializationError, match="Azure credential is required when agents_client is not provided"
|
||||
):
|
||||
with pytest.raises(ValueError, match="Azure credential is required when agents_client is not provided"):
|
||||
AzureAIAgentClient(
|
||||
agents_client=None,
|
||||
agent_id="existing-agent",
|
||||
@@ -325,7 +323,7 @@ async def test_azure_ai_chat_client_get_agent_id_or_create_missing_model(
|
||||
"""Test _get_agent_id_or_create when model_deployment_name is missing."""
|
||||
client = create_test_azure_ai_chat_client(mock_agents_client)
|
||||
|
||||
with pytest.raises(ServiceInitializationError, match="Model deployment name is required"):
|
||||
with pytest.raises(ValueError, match="Model deployment name is required"):
|
||||
await client._get_agent_id_or_create() # type: ignore
|
||||
|
||||
|
||||
@@ -2011,7 +2009,7 @@ async def test_azure_ai_chat_client_prepare_options_with_invalid_response_format
|
||||
# Invalid response_format (not BaseModel or Mapping)
|
||||
chat_options: ChatOptions = {"response_format": "invalid_format"} # type: ignore[typeddict-item]
|
||||
|
||||
with pytest.raises(ServiceInvalidRequestError, match="response_format must be a Pydantic BaseModel"):
|
||||
with pytest.raises(ChatClientInvalidRequestException, match="response_format must be a Pydantic BaseModel"):
|
||||
await client._prepare_options([], chat_options) # type: ignore
|
||||
|
||||
|
||||
|
||||
@@ -21,7 +21,6 @@ from agent_framework import (
|
||||
tool,
|
||||
)
|
||||
from agent_framework._settings import load_settings
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
from azure.ai.projects.aio import AIProjectClient
|
||||
from azure.ai.projects.models import (
|
||||
ApproximateLocation,
|
||||
@@ -213,15 +212,13 @@ def test_init_missing_project_endpoint() -> None:
|
||||
with patch("agent_framework_azure_ai._client.load_settings") as mock_load_settings:
|
||||
mock_load_settings.return_value = {"project_endpoint": None, "model_deployment_name": "test-model"}
|
||||
|
||||
with pytest.raises(ServiceInitializationError, match="Azure AI project endpoint is required"):
|
||||
with pytest.raises(ValueError, match="Azure AI project endpoint is required"):
|
||||
AzureAIClient(credential=MagicMock())
|
||||
|
||||
|
||||
def test_init_missing_credential(azure_ai_unit_test_env: dict[str, str]) -> None:
|
||||
"""Test AzureAIClient.__init__ when credential is missing and no project_client provided."""
|
||||
with pytest.raises(
|
||||
ServiceInitializationError, match="Azure credential is required when project_client is not provided"
|
||||
):
|
||||
with pytest.raises(ValueError, match="Azure credential is required when project_client is not provided"):
|
||||
AzureAIClient(
|
||||
project_endpoint=azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"],
|
||||
model_deployment_name=azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"],
|
||||
@@ -245,7 +242,7 @@ async def test_get_agent_reference_or_create_missing_agent_name(
|
||||
"""Test _get_agent_reference_or_create raises when agent_name is missing."""
|
||||
client = create_test_azure_ai_client(mock_project_client, agent_name=None)
|
||||
|
||||
with pytest.raises(ServiceInitializationError, match="Agent name is required"):
|
||||
with pytest.raises(ValueError, match="Agent name is required"):
|
||||
await client._get_agent_reference_or_create({}, None) # type: ignore
|
||||
|
||||
|
||||
@@ -283,7 +280,7 @@ async def test_get_agent_reference_missing_model(
|
||||
"""Test _get_agent_reference_or_create when model is missing for agent creation."""
|
||||
client = create_test_azure_ai_client(mock_project_client, agent_name="test-agent")
|
||||
|
||||
with pytest.raises(ServiceInitializationError, match="Model deployment name is required for agent creation"):
|
||||
with pytest.raises(ValueError, match="Model deployment name is required for agent creation"):
|
||||
await client._get_agent_reference_or_create({}, None) # type: ignore
|
||||
|
||||
|
||||
@@ -1287,7 +1284,6 @@ def test_from_azure_ai_tools_web_search() -> None:
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region Integration Tests
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
from agent_framework import Agent, FunctionTool
|
||||
from agent_framework._mcp import MCPTool
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
from azure.ai.projects.aio import AIProjectClient
|
||||
from azure.ai.projects.models import (
|
||||
AgentReference,
|
||||
@@ -110,15 +109,13 @@ def test_provider_init_missing_endpoint() -> None:
|
||||
with patch("agent_framework_azure_ai._project_provider.load_settings") as mock_load_settings:
|
||||
mock_load_settings.return_value = {"project_endpoint": None, "model_deployment_name": "test-model"}
|
||||
|
||||
with pytest.raises(ServiceInitializationError, match="Azure AI project endpoint is required"):
|
||||
with pytest.raises(ValueError, match="Azure AI project endpoint is required"):
|
||||
AzureAIProjectAgentProvider(credential=MagicMock())
|
||||
|
||||
|
||||
def test_provider_init_missing_credential(azure_ai_unit_test_env: dict[str, str]) -> None:
|
||||
"""Test AzureAIProjectAgentProvider initialization when credential is missing."""
|
||||
with pytest.raises(
|
||||
ServiceInitializationError, match="Azure credential is required when project_client is not provided"
|
||||
):
|
||||
with pytest.raises(ValueError, match="Azure credential is required when project_client is not provided"):
|
||||
AzureAIProjectAgentProvider(
|
||||
project_endpoint=azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"],
|
||||
)
|
||||
@@ -208,7 +205,7 @@ async def test_provider_create_agent_missing_model(mock_project_client: MagicMoc
|
||||
|
||||
provider = AzureAIProjectAgentProvider(project_client=mock_project_client)
|
||||
|
||||
with pytest.raises(ServiceInitializationError, match="Model deployment name is required"):
|
||||
with pytest.raises(ValueError, match="Model deployment name is required"):
|
||||
await provider.create_agent(name="test-agent")
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import pytest
|
||||
from agent_framework import (
|
||||
FunctionTool,
|
||||
)
|
||||
from agent_framework.exceptions import ServiceInvalidRequestError
|
||||
from agent_framework.exceptions import IntegrationInvalidRequestException
|
||||
from azure.ai.agents.models import CodeInterpreterToolDefinition
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -387,7 +387,7 @@ def test_create_text_format_config_text() -> None:
|
||||
|
||||
def test_create_text_format_config_invalid_raises() -> None:
|
||||
"""Test invalid response_format raises error."""
|
||||
with pytest.raises(ServiceInvalidRequestError):
|
||||
with pytest.raises(IntegrationInvalidRequestException):
|
||||
create_text_format_config({"type": "invalid"})
|
||||
|
||||
|
||||
@@ -400,7 +400,7 @@ def test_convert_response_format_with_format_key() -> None:
|
||||
|
||||
def test_convert_response_format_json_schema_missing_schema_raises() -> None:
|
||||
"""Test json_schema without schema raises error."""
|
||||
with pytest.raises(ServiceInvalidRequestError, match="requires a schema"):
|
||||
with pytest.raises(IntegrationInvalidRequestException, match="requires a schema"):
|
||||
_convert_response_format({"type": "json_schema", "json_schema": {}})
|
||||
|
||||
|
||||
|
||||
@@ -272,7 +272,7 @@ class AgentFunctionApp(DFAppBase):
|
||||
Note: We use str type annotations instead of dict to work around
|
||||
Azure Functions worker type validation issues with dict[str, Any].
|
||||
"""
|
||||
from agent_framework._workflows import State
|
||||
from agent_framework._workflows._state import State
|
||||
|
||||
data = json.loads(inputData)
|
||||
message_data = data["message"]
|
||||
|
||||
@@ -19,7 +19,7 @@ from agent_framework import (
|
||||
WorkflowEvent,
|
||||
WorkflowMessage,
|
||||
)
|
||||
from agent_framework._workflows import State
|
||||
from agent_framework._workflows._state import State
|
||||
|
||||
|
||||
class CapturingRunnerContext(RunnerContext):
|
||||
|
||||
@@ -23,7 +23,7 @@ import logging
|
||||
from dataclasses import is_dataclass
|
||||
from typing import Any
|
||||
|
||||
from agent_framework._workflows import decode_checkpoint_value, encode_checkpoint_value
|
||||
from agent_framework._workflows._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -182,7 +182,7 @@ class TestCapturingRunnerContext:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_checkpoint_raises_not_implemented(self, context: CapturingRunnerContext) -> None:
|
||||
"""Test that checkpointing methods raise NotImplementedError."""
|
||||
from agent_framework._workflows import State
|
||||
from agent_framework._workflows._state import State
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
await context.create_checkpoint("test_workflow", "abc123", State(), None, 1)
|
||||
|
||||
@@ -30,7 +30,7 @@ from agent_framework import (
|
||||
validate_tool_mode,
|
||||
)
|
||||
from agent_framework._settings import SecretString, load_settings
|
||||
from agent_framework.exceptions import ServiceInitializationError, ServiceInvalidResponseError
|
||||
from agent_framework.exceptions import ChatClientInvalidResponseException
|
||||
from agent_framework.observability import ChatTelemetryLayer
|
||||
from boto3.session import Session as Boto3Session
|
||||
from botocore.client import BaseClient
|
||||
@@ -362,13 +362,13 @@ class BedrockChatClient(
|
||||
) -> dict[str, Any]:
|
||||
model_id = options.get("model_id") or self.model_id
|
||||
if not model_id:
|
||||
raise ServiceInitializationError(
|
||||
raise ValueError(
|
||||
"Bedrock model_id is required. Set via chat options or BEDROCK_CHAT_MODEL_ID environment variable."
|
||||
)
|
||||
|
||||
system_prompts, conversation = self._prepare_bedrock_messages(messages)
|
||||
if not conversation:
|
||||
raise ServiceInitializationError("At least one non-system message is required for Bedrock requests.")
|
||||
raise ValueError("At least one non-system message is required for Bedrock requests.")
|
||||
# Prepend instructions from options if they exist
|
||||
if instructions := options.get("instructions"):
|
||||
system_prompts = [{"text": instructions}, *system_prompts]
|
||||
@@ -400,7 +400,7 @@ class BedrockChatClient(
|
||||
else:
|
||||
tool_config["toolChoice"] = {"any": {}}
|
||||
case _:
|
||||
raise ServiceInitializationError(f"Unsupported tool mode for Bedrock: {tool_mode.get('mode')}")
|
||||
raise ValueError(f"Unsupported tool mode for Bedrock: {tool_mode.get('mode')}")
|
||||
if tool_config:
|
||||
run_options["toolConfig"] = tool_config
|
||||
|
||||
@@ -629,7 +629,9 @@ class BedrockChatClient(
|
||||
if isinstance(tool_use, MutableMapping):
|
||||
tool_name = tool_use.get("name")
|
||||
if not tool_name:
|
||||
raise ServiceInvalidResponseError("Bedrock response missing required tool name in toolUse block.")
|
||||
raise ChatClientInvalidResponseException(
|
||||
"Bedrock response missing required tool name in toolUse block."
|
||||
)
|
||||
contents.append(
|
||||
Content.from_function_call(
|
||||
call_id=tool_use.get("toolUseId") or self._generate_tool_call_id(),
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import Any
|
||||
|
||||
import pytest
|
||||
from agent_framework import Content, Message
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
|
||||
from agent_framework_bedrock import BedrockChatClient
|
||||
|
||||
@@ -64,5 +63,5 @@ def test_build_request_requires_non_system_messages() -> None:
|
||||
|
||||
messages = [Message(role="system", contents=[Content.from_text(text="Only system text")])]
|
||||
|
||||
with pytest.raises(ServiceInitializationError):
|
||||
with pytest.raises(ValueError):
|
||||
client._prepare_options(messages, {})
|
||||
|
||||
@@ -26,7 +26,7 @@ from agent_framework import (
|
||||
normalize_messages,
|
||||
normalize_tools,
|
||||
)
|
||||
from agent_framework.exceptions import ServiceException
|
||||
from agent_framework.exceptions import AgentException
|
||||
from claude_agent_sdk import (
|
||||
AssistantMessage,
|
||||
ClaudeSDKClient,
|
||||
@@ -360,7 +360,7 @@ class ClaudeAgent(BaseAgent, Generic[OptionsT]):
|
||||
as an async context manager.
|
||||
|
||||
Raises:
|
||||
ServiceException: If the client fails to start.
|
||||
AgentException: If the client fails to start.
|
||||
"""
|
||||
await self._ensure_session()
|
||||
|
||||
@@ -407,7 +407,7 @@ class ClaudeAgent(BaseAgent, Generic[OptionsT]):
|
||||
self._current_session_id = session_id
|
||||
except Exception as ex:
|
||||
self._client = None
|
||||
raise ServiceException(f"Failed to start Claude SDK client: {ex}") from ex
|
||||
raise AgentException(f"Failed to start Claude SDK client: {ex}") from ex
|
||||
|
||||
def _prepare_client_options(self, resume_session_id: str | None = None) -> SDKOptions:
|
||||
"""Prepare SDK options for client initialization.
|
||||
@@ -639,7 +639,7 @@ class ClaudeAgent(BaseAgent, Generic[OptionsT]):
|
||||
await self._ensure_session(session.service_session_id)
|
||||
|
||||
if not self._client:
|
||||
raise ServiceException("Claude SDK client not initialized.")
|
||||
raise RuntimeError("Claude SDK client not initialized.")
|
||||
|
||||
prompt = self._format_prompt(normalize_messages(messages))
|
||||
|
||||
@@ -693,12 +693,12 @@ class ClaudeAgent(BaseAgent, Generic[OptionsT]):
|
||||
if isinstance(block, TextBlock):
|
||||
error_msg = f"{error_msg}: {block.text}"
|
||||
break
|
||||
raise ServiceException(error_msg)
|
||||
raise AgentException(error_msg)
|
||||
elif isinstance(message, ResultMessage):
|
||||
# Check for errors in result message
|
||||
if message.is_error:
|
||||
error_msg = message.result or "Unknown error from Claude API"
|
||||
raise ServiceException(f"Claude API error: {error_msg}")
|
||||
raise AgentException(f"Claude API error: {error_msg}")
|
||||
session_id = message.session_id
|
||||
|
||||
# Update session with session ID
|
||||
|
||||
@@ -379,8 +379,8 @@ class TestClaudeAgentRunStream:
|
||||
assert updates[1].text == "response"
|
||||
|
||||
async def test_run_stream_raises_on_assistant_message_error(self) -> None:
|
||||
"""Test run raises ServiceException when AssistantMessage has an error."""
|
||||
from agent_framework.exceptions import ServiceException
|
||||
"""Test run raises AgentException when AssistantMessage has an error."""
|
||||
from agent_framework.exceptions import AgentException
|
||||
from claude_agent_sdk import AssistantMessage, ResultMessage, TextBlock
|
||||
|
||||
messages = [
|
||||
@@ -402,15 +402,15 @@ class TestClaudeAgentRunStream:
|
||||
|
||||
with patch("agent_framework_claude._agent.ClaudeSDKClient", return_value=mock_client):
|
||||
agent = ClaudeAgent()
|
||||
with pytest.raises(ServiceException) as exc_info:
|
||||
with pytest.raises(AgentException) as exc_info:
|
||||
async for _ in agent.run("Hello", stream=True):
|
||||
pass
|
||||
assert "Invalid request to Claude API" in str(exc_info.value)
|
||||
assert "Error details from API" in str(exc_info.value)
|
||||
|
||||
async def test_run_stream_raises_on_result_message_error(self) -> None:
|
||||
"""Test run raises ServiceException when ResultMessage.is_error is True."""
|
||||
from agent_framework.exceptions import ServiceException
|
||||
"""Test run raises AgentException when ResultMessage.is_error is True."""
|
||||
from agent_framework.exceptions import AgentException
|
||||
from claude_agent_sdk import ResultMessage
|
||||
|
||||
messages = [
|
||||
@@ -428,7 +428,7 @@ class TestClaudeAgentRunStream:
|
||||
|
||||
with patch("agent_framework_claude._agent.ClaudeSDKClient", return_value=mock_client):
|
||||
agent = ClaudeAgent()
|
||||
with pytest.raises(ServiceException) as exc_info:
|
||||
with pytest.raises(AgentException) as exc_info:
|
||||
async for _ in agent.run("Hello", stream=True):
|
||||
pass
|
||||
assert "Model 'claude-sonnet-4.5' not found" in str(exc_info.value)
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from agent_framework.exceptions import ServiceException
|
||||
from agent_framework.exceptions import AgentException
|
||||
from msal import PublicClientApplication
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -39,13 +39,13 @@ def acquire_token(
|
||||
The access token string.
|
||||
|
||||
Raises:
|
||||
ServiceException: If authentication token cannot be acquired.
|
||||
AgentException: If authentication token cannot be acquired.
|
||||
"""
|
||||
if not client_id:
|
||||
raise ServiceException("Client ID is required for token acquisition.")
|
||||
raise ValueError("Client ID is required for token acquisition.")
|
||||
|
||||
if not tenant_id:
|
||||
raise ServiceException("Tenant ID is required for token acquisition.")
|
||||
raise ValueError("Tenant ID is required for token acquisition.")
|
||||
|
||||
authority = f"https://login.microsoftonline.com/{tenant_id}"
|
||||
target_scopes = scopes or DEFAULT_SCOPES
|
||||
@@ -87,9 +87,9 @@ def acquire_token(
|
||||
)
|
||||
except Exception as ex:
|
||||
logger.error("Interactive token acquisition failed with exception: %s", ex)
|
||||
raise ServiceException(f"Failed to acquire authentication token: {ex}") from ex
|
||||
raise AgentException(f"Failed to acquire authentication token: {ex}") from ex
|
||||
|
||||
if not token:
|
||||
raise ServiceException("Authentication token cannot be acquired.")
|
||||
raise AgentException("Authentication token cannot be acquired.")
|
||||
|
||||
return token
|
||||
|
||||
@@ -19,7 +19,7 @@ from agent_framework import (
|
||||
)
|
||||
from agent_framework._settings import load_settings
|
||||
from agent_framework._types import AgentRunInputs
|
||||
from agent_framework.exceptions import ServiceException, ServiceInitializationError
|
||||
from agent_framework.exceptions import AgentException
|
||||
from microsoft_agents.copilotstudio.client import AgentType, ConnectionSettings, CopilotClient, PowerPlatformCloud
|
||||
|
||||
from ._acquire_token import acquire_token
|
||||
@@ -113,7 +113,7 @@ class CopilotStudioAgent(BaseAgent):
|
||||
env_file_encoding: Encoding of the .env file, defaults to 'utf-8'.
|
||||
|
||||
Raises:
|
||||
ServiceInitializationError: If required configuration is missing or invalid.
|
||||
ValueError: If required configuration is missing or invalid.
|
||||
"""
|
||||
super().__init__(
|
||||
id=id,
|
||||
@@ -136,12 +136,12 @@ class CopilotStudioAgent(BaseAgent):
|
||||
|
||||
if not settings:
|
||||
if not copilot_studio_settings["environmentid"]:
|
||||
raise ServiceInitializationError(
|
||||
raise ValueError(
|
||||
"Copilot Studio environment ID is required. Set via 'environment_id' parameter "
|
||||
"or 'COPILOTSTUDIOAGENT__ENVIRONMENTID' environment variable."
|
||||
)
|
||||
if not copilot_studio_settings["schemaname"]:
|
||||
raise ServiceInitializationError(
|
||||
raise ValueError(
|
||||
"Copilot Studio agent identifier/schema name is required. Set via 'agent_identifier' parameter "
|
||||
"or 'COPILOTSTUDIOAGENT__SCHEMANAME' environment variable."
|
||||
)
|
||||
@@ -156,13 +156,13 @@ class CopilotStudioAgent(BaseAgent):
|
||||
|
||||
if not token:
|
||||
if not copilot_studio_settings["agentappid"]:
|
||||
raise ServiceInitializationError(
|
||||
raise ValueError(
|
||||
"Copilot Studio client ID is required. Set via 'client_id' parameter "
|
||||
"or 'COPILOTSTUDIOAGENT__AGENTAPPID' environment variable."
|
||||
)
|
||||
|
||||
if not copilot_studio_settings["tenantid"]:
|
||||
raise ServiceInitializationError(
|
||||
raise ValueError(
|
||||
"Copilot Studio tenant ID is required. Set via 'tenant_id' parameter "
|
||||
"or 'COPILOTSTUDIOAGENT__TENANTID' environment variable."
|
||||
)
|
||||
@@ -303,7 +303,7 @@ class CopilotStudioAgent(BaseAgent):
|
||||
The conversation ID for the new conversation.
|
||||
|
||||
Raises:
|
||||
ServiceException: If the conversation could not be started.
|
||||
AgentException: If the conversation could not be started.
|
||||
"""
|
||||
conversation_id: str | None = None
|
||||
|
||||
@@ -312,7 +312,7 @@ class CopilotStudioAgent(BaseAgent):
|
||||
conversation_id = activity.conversation.id
|
||||
|
||||
if not conversation_id:
|
||||
raise ServiceException("Failed to start a new conversation.")
|
||||
raise AgentException("Failed to start a new conversation.")
|
||||
|
||||
return conversation_id
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from agent_framework.exceptions import ServiceException
|
||||
from agent_framework.exceptions import AgentException
|
||||
|
||||
from agent_framework_copilotstudio._acquire_token import DEFAULT_SCOPES, acquire_token
|
||||
|
||||
@@ -12,23 +12,23 @@ class TestAcquireToken:
|
||||
"""Test class for token acquisition functionality."""
|
||||
|
||||
def test_acquire_token_missing_client_id(self) -> None:
|
||||
"""Test that acquire_token raises ServiceException when client_id is missing."""
|
||||
with pytest.raises(ServiceException, match="Client ID is required for token acquisition"):
|
||||
"""Test that acquire_token raises ValueError when client_id is missing."""
|
||||
with pytest.raises(ValueError, match="Client ID is required for token acquisition"):
|
||||
acquire_token(client_id="", tenant_id="test-tenant-id")
|
||||
|
||||
def test_acquire_token_missing_tenant_id(self) -> None:
|
||||
"""Test that acquire_token raises ServiceException when tenant_id is missing."""
|
||||
with pytest.raises(ServiceException, match="Tenant ID is required for token acquisition"):
|
||||
"""Test that acquire_token raises ValueError when tenant_id is missing."""
|
||||
with pytest.raises(ValueError, match="Tenant ID is required for token acquisition"):
|
||||
acquire_token(client_id="test-client-id", tenant_id="")
|
||||
|
||||
def test_acquire_token_none_client_id(self) -> None:
|
||||
"""Test that acquire_token raises ServiceException when client_id is None."""
|
||||
with pytest.raises(ServiceException, match="Client ID is required for token acquisition"):
|
||||
"""Test that acquire_token raises ValueError when client_id is None."""
|
||||
with pytest.raises(ValueError, match="Client ID is required for token acquisition"):
|
||||
acquire_token(client_id=None, tenant_id="test-tenant-id") # type: ignore
|
||||
|
||||
def test_acquire_token_none_tenant_id(self) -> None:
|
||||
"""Test that acquire_token raises ServiceException when tenant_id is None."""
|
||||
with pytest.raises(ServiceException, match="Tenant ID is required for token acquisition"):
|
||||
"""Test that acquire_token raises ValueError when tenant_id is None."""
|
||||
with pytest.raises(ValueError, match="Tenant ID is required for token acquisition"):
|
||||
acquire_token(client_id="test-client-id", tenant_id=None) # type: ignore
|
||||
|
||||
@patch("agent_framework_copilotstudio._acquire_token.PublicClientApplication")
|
||||
@@ -186,7 +186,7 @@ class TestAcquireToken:
|
||||
mock_error_response = {"error": "access_denied", "error_description": "User denied consent"}
|
||||
mock_pca.acquire_token_interactive.return_value = mock_error_response
|
||||
|
||||
with pytest.raises(ServiceException, match="Authentication token cannot be acquired"):
|
||||
with pytest.raises(AgentException, match="Authentication token cannot be acquired"):
|
||||
acquire_token(
|
||||
client_id="test-client-id",
|
||||
tenant_id="test-tenant-id",
|
||||
@@ -203,7 +203,7 @@ class TestAcquireToken:
|
||||
# Interactive acquisition throws exception
|
||||
mock_pca.acquire_token_interactive.side_effect = Exception("Authentication service unavailable")
|
||||
|
||||
with pytest.raises(ServiceException, match="Failed to acquire authentication token"):
|
||||
with pytest.raises(AgentException, match="Failed to acquire authentication token"):
|
||||
acquire_token(
|
||||
client_id="test-client-id",
|
||||
tenant_id="test-tenant-id",
|
||||
|
||||
@@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from agent_framework import AgentResponse, AgentResponseUpdate, AgentSession, Content, Message
|
||||
from agent_framework.exceptions import ServiceException, ServiceInitializationError
|
||||
from agent_framework.exceptions import AgentException
|
||||
from microsoft_agents.copilotstudio.client import CopilotClient
|
||||
|
||||
from agent_framework_copilotstudio import CopilotStudioAgent
|
||||
@@ -48,7 +48,7 @@ class TestCopilotStudioAgent:
|
||||
"agentappid": "test-client",
|
||||
}
|
||||
|
||||
with pytest.raises(ServiceInitializationError, match="environment ID is required"):
|
||||
with pytest.raises(ValueError, match="environment ID is required"):
|
||||
CopilotStudioAgent()
|
||||
|
||||
@patch("agent_framework_copilotstudio._acquire_token.acquire_token")
|
||||
@@ -62,7 +62,7 @@ class TestCopilotStudioAgent:
|
||||
"agentappid": "test-client",
|
||||
}
|
||||
|
||||
with pytest.raises(ServiceInitializationError, match="agent identifier"):
|
||||
with pytest.raises(ValueError, match="agent identifier"):
|
||||
CopilotStudioAgent()
|
||||
|
||||
@patch("agent_framework_copilotstudio._acquire_token.acquire_token")
|
||||
@@ -76,7 +76,7 @@ class TestCopilotStudioAgent:
|
||||
"agentappid": "test-client",
|
||||
}
|
||||
|
||||
with pytest.raises(ServiceInitializationError, match="tenant ID is required"):
|
||||
with pytest.raises(ValueError, match="tenant ID is required"):
|
||||
CopilotStudioAgent()
|
||||
|
||||
@patch("agent_framework_copilotstudio._acquire_token.acquire_token")
|
||||
@@ -90,7 +90,7 @@ class TestCopilotStudioAgent:
|
||||
"agentappid": None,
|
||||
}
|
||||
|
||||
with pytest.raises(ServiceInitializationError, match="client ID is required"):
|
||||
with pytest.raises(ValueError, match="client ID is required"):
|
||||
CopilotStudioAgent()
|
||||
|
||||
def test_init_with_client(self, mock_copilot_client: MagicMock) -> None:
|
||||
@@ -109,7 +109,7 @@ class TestCopilotStudioAgent:
|
||||
"agentappid": "test-client",
|
||||
}
|
||||
|
||||
with pytest.raises(ServiceInitializationError, match="environment ID is required"):
|
||||
with pytest.raises(ValueError, match="environment ID is required"):
|
||||
CopilotStudioAgent()
|
||||
|
||||
@patch("agent_framework_copilotstudio._acquire_token.acquire_token")
|
||||
@@ -123,7 +123,7 @@ class TestCopilotStudioAgent:
|
||||
"agentappid": "test-client",
|
||||
}
|
||||
|
||||
with pytest.raises(ServiceInitializationError, match="agent identifier"):
|
||||
with pytest.raises(ValueError, match="agent identifier"):
|
||||
CopilotStudioAgent()
|
||||
|
||||
async def test_run_with_string_message(self, mock_copilot_client: MagicMock, mock_activity: MagicMock) -> None:
|
||||
@@ -188,7 +188,7 @@ class TestCopilotStudioAgent:
|
||||
|
||||
mock_copilot_client.start_conversation.return_value = create_async_generator([])
|
||||
|
||||
with pytest.raises(ServiceException, match="Failed to start a new conversation"):
|
||||
with pytest.raises(AgentException, match="Failed to start a new conversation"):
|
||||
await agent.run("test message")
|
||||
|
||||
async def test_run_streaming_with_string_message(self, mock_copilot_client: MagicMock) -> None:
|
||||
@@ -315,6 +315,6 @@ class TestCopilotStudioAgent:
|
||||
|
||||
mock_copilot_client.start_conversation.return_value = create_async_generator([])
|
||||
|
||||
with pytest.raises(ServiceException, match="Failed to start a new conversation"):
|
||||
with pytest.raises(AgentException, match="Failed to start a new conversation"):
|
||||
async for _ in agent.run("test message", stream=True):
|
||||
pass
|
||||
|
||||
@@ -106,62 +106,78 @@ from ._types import (
|
||||
validate_tool_mode,
|
||||
validate_tools,
|
||||
)
|
||||
from ._workflows import (
|
||||
DEFAULT_MAX_ITERATIONS,
|
||||
from ._workflows._agent import WorkflowAgent
|
||||
from ._workflows._agent_executor import (
|
||||
AgentExecutor,
|
||||
AgentExecutorRequest,
|
||||
AgentExecutorResponse,
|
||||
Case,
|
||||
)
|
||||
from ._workflows._agent_utils import resolve_agent_id
|
||||
from ._workflows._checkpoint import (
|
||||
CheckpointStorage,
|
||||
FileCheckpointStorage,
|
||||
InMemoryCheckpointStorage,
|
||||
WorkflowCheckpoint,
|
||||
)
|
||||
from ._workflows._const import (
|
||||
DEFAULT_MAX_ITERATIONS,
|
||||
)
|
||||
from ._workflows._edge import (
|
||||
Case,
|
||||
Default,
|
||||
Edge,
|
||||
EdgeCondition,
|
||||
EdgeDuplicationError,
|
||||
Executor,
|
||||
FanInEdgeGroup,
|
||||
FanOutEdgeGroup,
|
||||
FileCheckpointStorage,
|
||||
FunctionExecutor,
|
||||
GraphConnectivityError,
|
||||
InMemoryCheckpointStorage,
|
||||
InProcRunnerContext,
|
||||
Runner,
|
||||
RunnerContext,
|
||||
SingleEdgeGroup,
|
||||
SubWorkflowRequestMessage,
|
||||
SubWorkflowResponseMessage,
|
||||
SwitchCaseEdgeGroup,
|
||||
SwitchCaseEdgeGroupCase,
|
||||
SwitchCaseEdgeGroupDefault,
|
||||
TypeCompatibilityError,
|
||||
ValidationTypeEnum,
|
||||
Workflow,
|
||||
WorkflowAgent,
|
||||
WorkflowBuilder,
|
||||
WorkflowCheckpoint,
|
||||
WorkflowCheckpointException,
|
||||
WorkflowContext,
|
||||
WorkflowConvergenceException,
|
||||
)
|
||||
from ._workflows._edge_runner import create_edge_runner
|
||||
from ._workflows._events import (
|
||||
WorkflowErrorDetails,
|
||||
WorkflowEvent,
|
||||
WorkflowEventSource,
|
||||
WorkflowEventType,
|
||||
WorkflowException,
|
||||
WorkflowExecutor,
|
||||
WorkflowMessage,
|
||||
WorkflowRunnerException,
|
||||
WorkflowRunResult,
|
||||
WorkflowRunState,
|
||||
WorkflowValidationError,
|
||||
WorkflowViz,
|
||||
create_edge_runner,
|
||||
executor,
|
||||
)
|
||||
from ._workflows._executor import (
|
||||
Executor,
|
||||
handler,
|
||||
resolve_agent_id,
|
||||
response_handler,
|
||||
)
|
||||
from ._workflows._function_executor import FunctionExecutor, executor
|
||||
from ._workflows._request_info_mixin import response_handler
|
||||
from ._workflows._runner import Runner
|
||||
from ._workflows._runner_context import (
|
||||
InProcRunnerContext,
|
||||
RunnerContext,
|
||||
WorkflowMessage,
|
||||
)
|
||||
from ._workflows._validation import (
|
||||
EdgeDuplicationError,
|
||||
GraphConnectivityError,
|
||||
TypeCompatibilityError,
|
||||
ValidationTypeEnum,
|
||||
WorkflowValidationError,
|
||||
validate_workflow_graph,
|
||||
)
|
||||
from .exceptions import MiddlewareException
|
||||
from ._workflows._viz import WorkflowViz
|
||||
from ._workflows._workflow import Workflow, WorkflowRunResult
|
||||
from ._workflows._workflow_builder import WorkflowBuilder
|
||||
from ._workflows._workflow_context import WorkflowContext
|
||||
from ._workflows._workflow_executor import (
|
||||
SubWorkflowRequestMessage,
|
||||
SubWorkflowResponseMessage,
|
||||
WorkflowExecutor,
|
||||
)
|
||||
from .exceptions import (
|
||||
MiddlewareException,
|
||||
WorkflowCheckpointException,
|
||||
WorkflowConvergenceException,
|
||||
WorkflowException,
|
||||
WorkflowRunnerException,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AGENT_FRAMEWORK_USER_AGENT",
|
||||
|
||||
@@ -51,7 +51,7 @@ from ._types import (
|
||||
map_chat_to_agent_update,
|
||||
normalize_messages,
|
||||
)
|
||||
from .exceptions import AgentExecutionException
|
||||
from .exceptions import AgentInvalidResponseException
|
||||
from .observability import AgentTelemetryLayer
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
@@ -843,7 +843,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
)
|
||||
|
||||
if not response:
|
||||
raise AgentExecutionException("Chat client did not return a response.")
|
||||
raise AgentInvalidResponseException("Chat client did not return a response.")
|
||||
|
||||
await self._finalize_response(
|
||||
response=response,
|
||||
|
||||
@@ -118,7 +118,7 @@ def _coerce_value(value: str, target_type: type) -> Any:
|
||||
def _check_override_type(value: Any, field_type: type, field_name: str) -> None:
|
||||
"""Validate that *value* is compatible with *field_type*.
|
||||
|
||||
Raises ``ServiceInitializationError`` when the override is clearly
|
||||
Raises ``ValueError`` when the override is clearly
|
||||
incompatible (e.g. a ``dict`` passed where ``str`` is expected).
|
||||
Callable values and ``None`` are always accepted.
|
||||
"""
|
||||
@@ -155,10 +155,8 @@ def _check_override_type(value: Any, field_type: type, field_name: str) -> None:
|
||||
if isinstance(value, int) and float in allowed:
|
||||
return
|
||||
|
||||
from .exceptions import ServiceInitializationError
|
||||
|
||||
allowed_names = ", ".join(t.__name__ for t in allowed)
|
||||
raise ServiceInitializationError(
|
||||
raise ValueError(
|
||||
f"Invalid type for setting '{field_name}': expected {allowed_names}, got {type(value).__name__}."
|
||||
)
|
||||
|
||||
@@ -207,7 +205,7 @@ def load_settings(
|
||||
FileNotFoundError: If *env_file_path* was provided but the file does not exist.
|
||||
SettingNotFoundError: If a required field could not be resolved from any
|
||||
source, or if a mutually exclusive constraint is violated.
|
||||
ServiceInitializationError: If an override value has an incompatible type.
|
||||
ValueError: If an override value has an incompatible type.
|
||||
"""
|
||||
encoding = env_file_encoding or "utf-8"
|
||||
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
# Get Started with Microsoft Agent Framework Workflows
|
||||
|
||||
Workflow capabilities ship with the core `agent-framework` package.
|
||||
|
||||
```bash
|
||||
pip install agent-framework --pre
|
||||
```
|
||||
|
||||
Optional visualization support is still available via the `viz` extra:
|
||||
|
||||
```bash
|
||||
pip install agent-framework[viz] --pre
|
||||
```
|
||||
|
||||
See the [project README](https://github.com/microsoft/agent-framework/tree/main/python/README.md) for more information.
|
||||
@@ -1,150 +1 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Workflow namespace for built-in Agent Framework orchestration primitives.
|
||||
|
||||
This module re-exports objects from workflow implementation modules under
|
||||
``agent_framework._workflows``.
|
||||
|
||||
Supported classes include:
|
||||
- Workflow
|
||||
- WorkflowBuilder
|
||||
- AgentExecutor
|
||||
- Runner
|
||||
- WorkflowExecutor
|
||||
"""
|
||||
|
||||
from ._agent import WorkflowAgent
|
||||
from ._agent_executor import (
|
||||
AgentExecutor,
|
||||
AgentExecutorRequest,
|
||||
AgentExecutorResponse,
|
||||
)
|
||||
from ._agent_utils import resolve_agent_id
|
||||
from ._checkpoint import (
|
||||
CheckpointStorage,
|
||||
FileCheckpointStorage,
|
||||
InMemoryCheckpointStorage,
|
||||
WorkflowCheckpoint,
|
||||
)
|
||||
from ._checkpoint_encoding import (
|
||||
decode_checkpoint_value,
|
||||
encode_checkpoint_value,
|
||||
)
|
||||
from ._const import (
|
||||
DEFAULT_MAX_ITERATIONS,
|
||||
)
|
||||
from ._edge import (
|
||||
Case,
|
||||
Default,
|
||||
Edge,
|
||||
EdgeCondition,
|
||||
FanInEdgeGroup,
|
||||
FanOutEdgeGroup,
|
||||
SingleEdgeGroup,
|
||||
SwitchCaseEdgeGroup,
|
||||
SwitchCaseEdgeGroupCase,
|
||||
SwitchCaseEdgeGroupDefault,
|
||||
)
|
||||
from ._edge_runner import create_edge_runner
|
||||
from ._events import (
|
||||
WorkflowErrorDetails,
|
||||
WorkflowEvent,
|
||||
WorkflowEventSource,
|
||||
WorkflowEventType,
|
||||
WorkflowRunState,
|
||||
)
|
||||
from ._exceptions import (
|
||||
WorkflowCheckpointException,
|
||||
WorkflowConvergenceException,
|
||||
WorkflowException,
|
||||
WorkflowRunnerException,
|
||||
)
|
||||
from ._executor import (
|
||||
Executor,
|
||||
handler,
|
||||
)
|
||||
from ._function_executor import FunctionExecutor, executor
|
||||
from ._request_info_mixin import response_handler
|
||||
from ._runner import Runner
|
||||
from ._runner_context import (
|
||||
InProcRunnerContext,
|
||||
RunnerContext,
|
||||
WorkflowMessage,
|
||||
)
|
||||
from ._state import State
|
||||
from ._validation import (
|
||||
EdgeDuplicationError,
|
||||
GraphConnectivityError,
|
||||
TypeCompatibilityError,
|
||||
ValidationTypeEnum,
|
||||
WorkflowValidationError,
|
||||
validate_workflow_graph,
|
||||
)
|
||||
from ._viz import WorkflowViz
|
||||
from ._workflow import Workflow, WorkflowRunResult
|
||||
from ._workflow_builder import WorkflowBuilder
|
||||
from ._workflow_context import WorkflowContext
|
||||
from ._workflow_executor import (
|
||||
SubWorkflowRequestMessage,
|
||||
SubWorkflowResponseMessage,
|
||||
WorkflowExecutor,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_MAX_ITERATIONS",
|
||||
"AgentExecutor",
|
||||
"AgentExecutorRequest",
|
||||
"AgentExecutorResponse",
|
||||
"Case",
|
||||
"CheckpointStorage",
|
||||
"Default",
|
||||
"Edge",
|
||||
"EdgeCondition",
|
||||
"EdgeDuplicationError",
|
||||
"Executor",
|
||||
"FanInEdgeGroup",
|
||||
"FanOutEdgeGroup",
|
||||
"FileCheckpointStorage",
|
||||
"FunctionExecutor",
|
||||
"GraphConnectivityError",
|
||||
"InMemoryCheckpointStorage",
|
||||
"InProcRunnerContext",
|
||||
"Runner",
|
||||
"RunnerContext",
|
||||
"SingleEdgeGroup",
|
||||
"State",
|
||||
"SubWorkflowRequestMessage",
|
||||
"SubWorkflowResponseMessage",
|
||||
"SwitchCaseEdgeGroup",
|
||||
"SwitchCaseEdgeGroupCase",
|
||||
"SwitchCaseEdgeGroupDefault",
|
||||
"TypeCompatibilityError",
|
||||
"ValidationTypeEnum",
|
||||
"Workflow",
|
||||
"WorkflowAgent",
|
||||
"WorkflowBuilder",
|
||||
"WorkflowCheckpoint",
|
||||
"WorkflowCheckpointException",
|
||||
"WorkflowContext",
|
||||
"WorkflowConvergenceException",
|
||||
"WorkflowErrorDetails",
|
||||
"WorkflowEvent",
|
||||
"WorkflowEventSource",
|
||||
"WorkflowEventType",
|
||||
"WorkflowException",
|
||||
"WorkflowExecutor",
|
||||
"WorkflowMessage",
|
||||
"WorkflowRunResult",
|
||||
"WorkflowRunState",
|
||||
"WorkflowRunnerException",
|
||||
"WorkflowValidationError",
|
||||
"WorkflowViz",
|
||||
"create_edge_runner",
|
||||
"decode_checkpoint_value",
|
||||
"encode_checkpoint_value",
|
||||
"executor",
|
||||
"handler",
|
||||
"resolve_agent_id",
|
||||
"response_handler",
|
||||
"validate_workflow_graph",
|
||||
]
|
||||
|
||||
@@ -29,7 +29,7 @@ from .._types import (
|
||||
UsageDetails,
|
||||
add_usage_details,
|
||||
)
|
||||
from ..exceptions import AgentExecutionException
|
||||
from ..exceptions import AgentInvalidRequestException, AgentInvalidResponseException
|
||||
from ._checkpoint import CheckpointStorage
|
||||
from ._events import (
|
||||
WorkflowEvent,
|
||||
@@ -456,7 +456,7 @@ class WorkflowAgent(BaseAgent):
|
||||
# We cannot support AgentResponseUpdate in non-streaming mode. This is because the message
|
||||
# sequence cannot be guaranteed when there are streaming updates in between non-streaming
|
||||
# responses.
|
||||
raise AgentExecutionException(
|
||||
raise AgentInvalidRequestException(
|
||||
"Output event with AgentResponseUpdate data cannot be emitted in non-streaming mode. "
|
||||
"Please ensure executors emit AgentResponse for non-streaming workflows."
|
||||
)
|
||||
@@ -669,24 +669,24 @@ class WorkflowAgent(BaseAgent):
|
||||
try:
|
||||
parsed_args = self.RequestInfoFunctionArgs.from_json(arguments_payload)
|
||||
except ValueError as exc:
|
||||
raise AgentExecutionException(
|
||||
raise AgentInvalidResponseException(
|
||||
"FunctionApprovalResponseContent arguments must decode to a mapping."
|
||||
) from exc
|
||||
elif isinstance(arguments_payload, dict):
|
||||
parsed_args = self.RequestInfoFunctionArgs.from_dict(arguments_payload)
|
||||
else:
|
||||
raise AgentExecutionException(
|
||||
raise AgentInvalidResponseException(
|
||||
"FunctionApprovalResponseContent arguments must be a mapping or JSON string."
|
||||
)
|
||||
|
||||
request_id = parsed_args.request_id or content.id # type: ignore[attr-defined]
|
||||
if not content.approved: # type: ignore[attr-defined]
|
||||
raise AgentExecutionException(f"Request '{request_id}' was not approved by the caller.")
|
||||
raise AgentInvalidResponseException(f"Request '{request_id}' was not approved by the caller.")
|
||||
|
||||
if request_id in self.pending_requests:
|
||||
function_responses[request_id] = parsed_args.data
|
||||
elif bool(self.pending_requests):
|
||||
raise AgentExecutionException(
|
||||
raise AgentInvalidRequestException(
|
||||
"Only responses for pending requests are allowed when there are outstanding approvals."
|
||||
)
|
||||
elif content.type == "function_result":
|
||||
@@ -695,12 +695,14 @@ class WorkflowAgent(BaseAgent):
|
||||
response_data = content.result if hasattr(content, "result") else str(content) # type: ignore[attr-defined]
|
||||
function_responses[request_id] = response_data
|
||||
elif bool(self.pending_requests):
|
||||
raise AgentExecutionException(
|
||||
raise AgentInvalidRequestException(
|
||||
"Only function responses for pending requests are allowed while requests are outstanding."
|
||||
)
|
||||
else:
|
||||
if bool(self.pending_requests):
|
||||
raise AgentExecutionException("Unexpected content type while awaiting request info responses.")
|
||||
raise AgentInvalidResponseException(
|
||||
"Unexpected content type while awaiting request info responses."
|
||||
)
|
||||
return function_responses
|
||||
|
||||
def _extract_contents(self, data: Any) -> list[Content]:
|
||||
|
||||
@@ -14,7 +14,7 @@ from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Protocol, TypeAlias
|
||||
|
||||
from ._exceptions import WorkflowCheckpointException
|
||||
from ..exceptions import WorkflowCheckpointException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -324,12 +324,12 @@ class FileCheckpointStorage:
|
||||
|
||||
encoded_checkpoint = await asyncio.to_thread(_read)
|
||||
|
||||
from ._checkpoint_encoding import CheckpointDecodingError, decode_checkpoint_value
|
||||
from ._checkpoint_encoding import decode_checkpoint_value
|
||||
|
||||
try:
|
||||
decoded_checkpoint_dict = decode_checkpoint_value(encoded_checkpoint)
|
||||
except CheckpointDecodingError as exc:
|
||||
raise WorkflowCheckpointException(f"Failed to decode checkpoint {checkpoint_id}: {exc}") from exc
|
||||
except WorkflowCheckpointException:
|
||||
raise
|
||||
checkpoint = WorkflowCheckpoint.from_dict(decoded_checkpoint_dict)
|
||||
logger.info(f"Loaded checkpoint {checkpoint_id} from {file_path}")
|
||||
return checkpoint
|
||||
|
||||
@@ -7,6 +7,8 @@ import logging
|
||||
import pickle # nosec # noqa: S403
|
||||
from typing import Any
|
||||
|
||||
from ..exceptions import WorkflowCheckpointException
|
||||
|
||||
"""Checkpoint encoding using JSON structure with pickle+base64 for arbitrary data.
|
||||
|
||||
This hybrid approach provides:
|
||||
@@ -29,10 +31,6 @@ _TYPE_MARKER = "__type__"
|
||||
_JSON_NATIVE_TYPES = (str, int, float, bool, type(None))
|
||||
|
||||
|
||||
class CheckpointDecodingError(Exception):
|
||||
"""Raised when checkpoint decoding fails due to type mismatch or corruption."""
|
||||
|
||||
|
||||
def encode_checkpoint_value(value: Any) -> Any:
|
||||
"""Encode a Python value for checkpoint storage.
|
||||
|
||||
@@ -68,7 +66,7 @@ def decode_checkpoint_value(value: Any) -> Any:
|
||||
The original Python value.
|
||||
|
||||
Raises:
|
||||
CheckpointDecodingError: If the unpickled object's type doesn't match
|
||||
WorkflowCheckpointException: If the unpickled object's type doesn't match
|
||||
the recorded type, indicating corruption, or if the base64/pickle
|
||||
data is malformed.
|
||||
"""
|
||||
@@ -133,11 +131,11 @@ def _verify_type(obj: Any, expected_type_key: str) -> None:
|
||||
expected_type_key: The recorded type key (module:qualname format).
|
||||
|
||||
Raises:
|
||||
CheckpointDecodingError: If the types don't match.
|
||||
WorkflowCheckpointException: If the types don't match.
|
||||
"""
|
||||
actual_type_key = _type_to_key(type(obj)) # type: ignore
|
||||
if actual_type_key != expected_type_key:
|
||||
raise CheckpointDecodingError(
|
||||
raise WorkflowCheckpointException(
|
||||
f"Type mismatch during checkpoint decoding: "
|
||||
f"expected '{expected_type_key}', got '{actual_type_key}'. "
|
||||
f"The checkpoint may be corrupted or tampered with."
|
||||
@@ -154,14 +152,14 @@ def _base64_to_unpickle(encoded: str) -> Any:
|
||||
"""Decode base64 string and unpickle.
|
||||
|
||||
Raises:
|
||||
CheckpointDecodingError: If the base64 data is corrupted or the pickle
|
||||
WorkflowCheckpointException: If the base64 data is corrupted or the pickle
|
||||
format is incompatible.
|
||||
"""
|
||||
try:
|
||||
pickled = base64.b64decode(encoded.encode("ascii"))
|
||||
return pickle.loads(pickled) # nosec # noqa: S301
|
||||
except Exception as exc:
|
||||
raise CheckpointDecodingError(f"Failed to decode pickled checkpoint data: {exc}") from exc
|
||||
raise WorkflowCheckpointException(f"Failed to decode pickled checkpoint data: {exc}") from exc
|
||||
|
||||
|
||||
def _type_to_key(t: type) -> str:
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from ..exceptions import AgentFrameworkException
|
||||
|
||||
|
||||
class WorkflowException(AgentFrameworkException):
|
||||
"""Base exception for workflow errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class WorkflowRunnerException(WorkflowException):
|
||||
"""Base exception for workflow runner errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class WorkflowConvergenceException(WorkflowRunnerException):
|
||||
"""Exception raised when a workflow runner fails to converge within the maximum iterations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class WorkflowCheckpointException(WorkflowRunnerException):
|
||||
"""Exception raised for errors related to workflow checkpoints."""
|
||||
|
||||
pass
|
||||
@@ -7,16 +7,16 @@ from collections import defaultdict
|
||||
from collections.abc import AsyncGenerator, Sequence
|
||||
from typing import Any
|
||||
|
||||
from ..exceptions import (
|
||||
WorkflowCheckpointException,
|
||||
WorkflowConvergenceException,
|
||||
WorkflowRunnerException,
|
||||
)
|
||||
from ._checkpoint import CheckpointID, CheckpointStorage, WorkflowCheckpoint
|
||||
from ._const import EXECUTOR_STATE_KEY
|
||||
from ._edge import EdgeGroup
|
||||
from ._edge_runner import EdgeRunner, create_edge_runner
|
||||
from ._events import WorkflowEvent
|
||||
from ._exceptions import (
|
||||
WorkflowCheckpointException,
|
||||
WorkflowConvergenceException,
|
||||
WorkflowRunnerException,
|
||||
)
|
||||
from ._executor import Executor
|
||||
from ._runner_context import (
|
||||
RunnerContext,
|
||||
|
||||
@@ -7,6 +7,7 @@ from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from ..exceptions import WorkflowException
|
||||
from ._edge import Edge, EdgeGroup, FanInEdgeGroup, InternalEdgeGroup
|
||||
from ._executor import Executor
|
||||
from ._typing_utils import is_type_compatible
|
||||
@@ -26,7 +27,7 @@ class ValidationTypeEnum(Enum):
|
||||
OUTPUT_VALIDATION = "OUTPUT_VALIDATION"
|
||||
|
||||
|
||||
class WorkflowValidationError(Exception):
|
||||
class WorkflowValidationError(WorkflowException):
|
||||
"""Base exception for workflow validation errors."""
|
||||
|
||||
def __init__(self, message: str, validation_type: ValidationTypeEnum):
|
||||
|
||||
@@ -9,7 +9,6 @@ from typing import Any, ClassVar, Generic
|
||||
from openai.lib.azure import AsyncAzureOpenAI
|
||||
|
||||
from .._settings import load_settings
|
||||
from ..exceptions import ServiceInitializationError
|
||||
from ..openai import OpenAIAssistantsClient
|
||||
from ..openai._assistants_client import OpenAIAssistantsOptions
|
||||
from ._entra_id_authentication import AzureCredentialTypes, AzureTokenProvider, resolve_credential_to_token_provider
|
||||
@@ -147,7 +146,7 @@ class AzureOpenAIAssistantsClient(
|
||||
_apply_azure_defaults(azure_openai_settings, default_api_version=self.DEFAULT_AZURE_API_VERSION)
|
||||
|
||||
if not azure_openai_settings["chat_deployment_name"]:
|
||||
raise ServiceInitializationError(
|
||||
raise ValueError(
|
||||
"Azure OpenAI deployment name is required. Set via 'deployment_name' parameter "
|
||||
"or 'AZURE_OPENAI_CHAT_DEPLOYMENT_NAME' environment variable."
|
||||
)
|
||||
@@ -160,7 +159,7 @@ class AzureOpenAIAssistantsClient(
|
||||
)
|
||||
|
||||
if not async_client and not azure_openai_settings["api_key"] and not ad_token_provider:
|
||||
raise ServiceInitializationError("Please provide either api_key, credential, or a client.")
|
||||
raise ValueError("Please provide either api_key, credential, or a client.")
|
||||
|
||||
# Create Azure client if not provided
|
||||
if not async_client:
|
||||
|
||||
@@ -22,7 +22,6 @@ from agent_framework import (
|
||||
FunctionInvocationConfiguration,
|
||||
FunctionInvocationLayer,
|
||||
)
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
from agent_framework.observability import ChatTelemetryLayer
|
||||
from agent_framework.openai import OpenAIChatOptions
|
||||
from agent_framework.openai._chat_client import RawOpenAIChatClient
|
||||
@@ -262,7 +261,7 @@ class AzureOpenAIChatClient( # type: ignore[misc]
|
||||
_apply_azure_defaults(azure_openai_settings)
|
||||
|
||||
if not azure_openai_settings["chat_deployment_name"]:
|
||||
raise ServiceInitializationError(
|
||||
raise ValueError(
|
||||
"Azure OpenAI deployment name is required. Set via 'deployment_name' parameter "
|
||||
"or 'AZURE_OPENAI_CHAT_DEPLOYMENT_NAME' environment variable."
|
||||
)
|
||||
|
||||
@@ -9,7 +9,7 @@ from typing import Union
|
||||
from azure.core.credentials import TokenCredential
|
||||
from azure.core.credentials_async import AsyncTokenCredential
|
||||
|
||||
from ..exceptions import ServiceInvalidAuthError
|
||||
from ..exceptions import ChatClientInvalidAuthException
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -54,7 +54,7 @@ def resolve_credential_to_token_provider(
|
||||
return credential
|
||||
|
||||
if not token_endpoint:
|
||||
raise ServiceInvalidAuthError(
|
||||
raise ChatClientInvalidAuthException(
|
||||
"A token endpoint must be provided either in settings, as an environment variable, or as an argument."
|
||||
)
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ from .._middleware import ChatMiddlewareLayer
|
||||
from .._settings import load_settings
|
||||
from .._telemetry import AGENT_FRAMEWORK_USER_AGENT
|
||||
from .._tools import FunctionInvocationConfiguration, FunctionInvocationLayer
|
||||
from ..exceptions import ServiceInitializationError
|
||||
from ..observability import ChatTelemetryLayer
|
||||
from ..openai._responses_client import RawOpenAIResponsesClient
|
||||
from ._entra_id_authentication import AzureCredentialTypes, AzureTokenProvider
|
||||
@@ -217,7 +216,7 @@ class AzureOpenAIResponsesClient( # type: ignore[misc]
|
||||
azure_openai_settings["base_url"] = urljoin(str(azure_openai_settings["endpoint"]), "/openai/v1/")
|
||||
|
||||
if not azure_openai_settings["responses_deployment_name"]:
|
||||
raise ServiceInitializationError(
|
||||
raise ValueError(
|
||||
"Azure OpenAI deployment name is required. Set via 'deployment_name' parameter "
|
||||
"or 'AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME' environment variable."
|
||||
)
|
||||
@@ -255,20 +254,16 @@ class AzureOpenAIResponsesClient( # type: ignore[misc]
|
||||
An AsyncAzureOpenAI client obtained from the project client.
|
||||
|
||||
Raises:
|
||||
ServiceInitializationError: If required parameters are missing or
|
||||
ValueError: If required parameters are missing or
|
||||
the azure-ai-projects package is not installed.
|
||||
"""
|
||||
if project_client is not None:
|
||||
return project_client.get_openai_client()
|
||||
|
||||
if not project_endpoint:
|
||||
raise ServiceInitializationError(
|
||||
"Azure AI project endpoint is required when project_client is not provided."
|
||||
)
|
||||
raise ValueError("Azure AI project endpoint is required when project_client is not provided.")
|
||||
if not credential:
|
||||
raise ServiceInitializationError(
|
||||
"Azure credential is required when using project_endpoint without a project_client."
|
||||
)
|
||||
raise ValueError("Azure credential is required when using project_endpoint without a project_client.")
|
||||
project_client = AIProjectClient(
|
||||
endpoint=project_endpoint,
|
||||
credential=credential, # type: ignore[arg-type]
|
||||
|
||||
@@ -13,7 +13,6 @@ from openai.lib.azure import AsyncAzureOpenAI
|
||||
|
||||
from .._settings import SecretString
|
||||
from .._telemetry import APP_INFO, prepend_agent_framework_to_user_agent
|
||||
from ..exceptions import ServiceInitializationError
|
||||
from ..openai._shared import OpenAIBase
|
||||
from ._entra_id_authentication import AzureCredentialTypes, AzureTokenProvider, resolve_credential_to_token_provider
|
||||
|
||||
@@ -175,10 +174,10 @@ class AzureOpenAIConfigMixin(OpenAIBase):
|
||||
ad_token_provider = resolve_credential_to_token_provider(credential, token_endpoint)
|
||||
|
||||
if not api_key and not ad_token_provider:
|
||||
raise ServiceInitializationError("Please provide either api_key, credential, or a client.")
|
||||
raise ValueError("Please provide either api_key, credential, or a client.")
|
||||
|
||||
if not endpoint and not base_url:
|
||||
raise ServiceInitializationError("Please provide an endpoint or a base_url")
|
||||
raise ValueError("Please provide an endpoint or a base_url")
|
||||
|
||||
args: dict[str, Any] = {
|
||||
"default_headers": merged_headers,
|
||||
|
||||
@@ -21,7 +21,6 @@ _IMPORTS = [
|
||||
"AgentFactory",
|
||||
"AgentExternalInputRequest",
|
||||
"AgentExternalInputResponse",
|
||||
"AgentInvocationError",
|
||||
"DeclarativeLoaderError",
|
||||
"DeclarativeWorkflowError",
|
||||
"ExternalInputRequest",
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Exception hierarchy used across Agent Framework core and connectors."""
|
||||
"""Exception hierarchy used across Agent Framework core and connectors.
|
||||
|
||||
See python/CODING_STANDARD.md § Exception Hierarchy for design rationale
|
||||
and guidance on choosing the correct exception class.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
@@ -9,7 +13,7 @@ logger = logging.getLogger("agent_framework")
|
||||
|
||||
|
||||
class AgentFrameworkException(Exception):
|
||||
"""Base exceptions for the Agent Framework.
|
||||
"""Base exception for the Agent Framework.
|
||||
|
||||
Automatically logs the message as debug.
|
||||
"""
|
||||
@@ -33,115 +37,118 @@ class AgentFrameworkException(Exception):
|
||||
super().__init__(message, *args) # type: ignore
|
||||
|
||||
|
||||
# region Agent Exceptions
|
||||
|
||||
|
||||
class AgentException(AgentFrameworkException):
|
||||
"""Base class for all agent exceptions."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AgentExecutionException(AgentException):
|
||||
"""An error occurred while executing the agent."""
|
||||
class AgentInvalidAuthException(AgentException):
|
||||
"""An authentication error occurred in an agent."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AgentInitializationError(AgentException):
|
||||
"""An error occurred while initializing the agent."""
|
||||
class AgentInvalidRequestException(AgentException):
|
||||
"""An invalid request was made to an agent."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AgentSessionException(AgentException):
|
||||
"""An error occurred while managing the agent session."""
|
||||
class AgentInvalidResponseException(AgentException):
|
||||
"""An invalid or unexpected response was received from an agent."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AgentContentFilterException(AgentException):
|
||||
"""A content filter was triggered by an agent."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Chat Client Exceptions
|
||||
|
||||
|
||||
class ChatClientException(AgentFrameworkException):
|
||||
"""An error occurred while dealing with a chat client."""
|
||||
"""Base class for all chat client exceptions."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ChatClientInitializationError(ChatClientException):
|
||||
"""An error occurred while initializing the chat client."""
|
||||
class ChatClientInvalidAuthException(ChatClientException):
|
||||
"""An authentication error occurred in a chat client."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# region Service Exceptions
|
||||
|
||||
|
||||
class ServiceException(AgentFrameworkException):
|
||||
"""Base class for all service exceptions."""
|
||||
class ChatClientInvalidRequestException(ChatClientException):
|
||||
"""An invalid request was made to a chat client."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ServiceInitializationError(ServiceException):
|
||||
"""An error occurred while initializing the service."""
|
||||
class ChatClientInvalidResponseException(ChatClientException):
|
||||
"""An invalid or unexpected response was received from a chat client."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ServiceResponseException(ServiceException):
|
||||
"""Base class for all service response exceptions."""
|
||||
class ChatClientContentFilterException(ChatClientException):
|
||||
"""A content filter was triggered by a chat client."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ServiceContentFilterException(ServiceResponseException):
|
||||
"""An error was raised by the content filter of the service."""
|
||||
# endregion
|
||||
|
||||
# region Integration Exceptions
|
||||
|
||||
|
||||
class IntegrationException(AgentFrameworkException):
|
||||
"""Base class for all external service/dependency integration exceptions."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ServiceInvalidAuthError(ServiceException):
|
||||
"""An error occurred while authenticating the service."""
|
||||
class IntegrationInitializationError(IntegrationException):
|
||||
"""A wrapped dependency/service lifecycle failure occurred during setup."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ServiceInvalidExecutionSettingsError(ServiceResponseException):
|
||||
"""An error occurred while validating the execution settings of the service."""
|
||||
class IntegrationInvalidAuthException(IntegrationException):
|
||||
"""An authentication error occurred in an external integration."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ServiceInvalidRequestError(ServiceResponseException):
|
||||
"""An error occurred while validating the request to the service."""
|
||||
class IntegrationInvalidRequestException(IntegrationException):
|
||||
"""An invalid request was made to an external integration."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ServiceInvalidResponseError(ServiceResponseException):
|
||||
"""An error occurred while validating the response from the service."""
|
||||
class IntegrationInvalidResponseException(IntegrationException):
|
||||
"""An invalid or unexpected response was received from an external integration."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ToolException(AgentFrameworkException):
|
||||
"""An error occurred while executing a tool."""
|
||||
class IntegrationContentFilterException(IntegrationException):
|
||||
"""A content filter was triggered by an external integration."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ToolExecutionException(ToolException):
|
||||
"""An error occurred while executing a tool."""
|
||||
# endregion
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AdditionItemMismatch(AgentFrameworkException):
|
||||
"""An error occurred while adding two types."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class MiddlewareException(AgentFrameworkException):
|
||||
"""An error occurred during middleware execution."""
|
||||
|
||||
pass
|
||||
# region Content Exceptions
|
||||
|
||||
|
||||
class ContentError(AgentFrameworkException):
|
||||
@@ -150,7 +157,78 @@ class ContentError(AgentFrameworkException):
|
||||
pass
|
||||
|
||||
|
||||
class AdditionItemMismatch(ContentError):
|
||||
"""A type mismatch occurred while merging content items."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Tool Exceptions
|
||||
|
||||
|
||||
class ToolException(AgentFrameworkException):
|
||||
"""Base class for all tool-related exceptions."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ToolExecutionException(ToolException):
|
||||
"""A tool or prompt call failed at runtime."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Middleware Exceptions
|
||||
|
||||
|
||||
class MiddlewareException(AgentFrameworkException):
|
||||
"""An error occurred during middleware execution."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Settings Exceptions
|
||||
|
||||
|
||||
class SettingNotFoundError(AgentFrameworkException):
|
||||
"""A required setting could not be resolved from any source."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Workflow Exceptions
|
||||
|
||||
|
||||
class WorkflowException(AgentFrameworkException):
|
||||
"""Base exception for workflow errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class WorkflowRunnerException(WorkflowException):
|
||||
"""Base exception for workflow runner errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class WorkflowConvergenceException(WorkflowRunnerException):
|
||||
"""Exception raised when a workflow runner fails to converge within the maximum iterations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class WorkflowCheckpointException(WorkflowRunnerException):
|
||||
"""Exception raised for errors related to workflow checkpoints."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -16,7 +16,6 @@ from .._agents import Agent
|
||||
from .._middleware import MiddlewareTypes
|
||||
from .._sessions import BaseContextProvider
|
||||
from .._tools import FunctionTool, ToolTypes, normalize_tools
|
||||
from ..exceptions import ServiceInitializationError
|
||||
from ._assistants_client import OpenAIAssistantsClient
|
||||
from ._shared import OpenAISettings, from_assistant_tools, to_assistant_tools
|
||||
|
||||
@@ -120,7 +119,7 @@ class OpenAIAssistantProvider(Generic[OptionsCoT]):
|
||||
env_file_encoding: Encoding of the .env file.
|
||||
|
||||
Raises:
|
||||
ServiceInitializationError: If no client is provided and API key is missing.
|
||||
ValueError: If no client is provided and API key is missing.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
@@ -151,7 +150,7 @@ class OpenAIAssistantProvider(Generic[OptionsCoT]):
|
||||
)
|
||||
|
||||
if not settings["api_key"]:
|
||||
raise ServiceInitializationError(
|
||||
raise ValueError(
|
||||
"OpenAI API key is required. Set via 'api_key' parameter or 'OPENAI_API_KEY' environment variable."
|
||||
)
|
||||
|
||||
@@ -227,7 +226,7 @@ class OpenAIAssistantProvider(Generic[OptionsCoT]):
|
||||
A Agent instance wrapping the created assistant.
|
||||
|
||||
Raises:
|
||||
ServiceInitializationError: If assistant creation fails.
|
||||
ValueError: If assistant creation fails.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
@@ -286,7 +285,7 @@ class OpenAIAssistantProvider(Generic[OptionsCoT]):
|
||||
|
||||
# Create the assistant
|
||||
if not self._client:
|
||||
raise ServiceInitializationError("OpenAI client is not initialized.")
|
||||
raise RuntimeError("OpenAI client is not initialized.")
|
||||
|
||||
assistant = await self._client.beta.assistants.create(**create_params)
|
||||
|
||||
@@ -333,7 +332,7 @@ class OpenAIAssistantProvider(Generic[OptionsCoT]):
|
||||
A Agent instance wrapping the retrieved assistant.
|
||||
|
||||
Raises:
|
||||
ServiceInitializationError: If the assistant cannot be retrieved.
|
||||
RuntimeError: If the assistant cannot be retrieved.
|
||||
ValueError: If required function tools are missing.
|
||||
|
||||
Examples:
|
||||
@@ -352,7 +351,7 @@ class OpenAIAssistantProvider(Generic[OptionsCoT]):
|
||||
"""
|
||||
# Fetch the assistant
|
||||
if not self._client:
|
||||
raise ServiceInitializationError("OpenAI client is not initialized.")
|
||||
raise RuntimeError("OpenAI client is not initialized.")
|
||||
|
||||
assistant = await self._client.beta.assistants.retrieve(assistant_id)
|
||||
|
||||
|
||||
@@ -47,7 +47,6 @@ from .._types import (
|
||||
ResponseStream,
|
||||
UsageDetails,
|
||||
)
|
||||
from ..exceptions import ServiceInitializationError
|
||||
from ..observability import ChatTelemetryLayer
|
||||
from ._shared import OpenAIConfigMixin, OpenAISettings
|
||||
|
||||
@@ -350,11 +349,11 @@ class OpenAIAssistantsClient( # type: ignore[misc]
|
||||
)
|
||||
|
||||
if not async_client and not openai_settings["api_key"]:
|
||||
raise ServiceInitializationError(
|
||||
raise ValueError(
|
||||
"OpenAI API key is required. Set via 'api_key' parameter or 'OPENAI_API_KEY' environment variable."
|
||||
)
|
||||
if not openai_settings["chat_model_id"]:
|
||||
raise ServiceInitializationError(
|
||||
raise ValueError(
|
||||
"OpenAI model ID is required. "
|
||||
"Set via 'model_id' parameter or 'OPENAI_CHAT_MODEL_ID' environment variable."
|
||||
)
|
||||
@@ -452,7 +451,7 @@ class OpenAIAssistantsClient( # type: ignore[misc]
|
||||
# If no assistant is provided, create a temporary assistant
|
||||
if self.assistant_id is None:
|
||||
if not self.model_id:
|
||||
raise ServiceInitializationError("Parameter 'model_id' is required for assistant creation.")
|
||||
raise ValueError("Parameter 'model_id' is required for assistant creation.")
|
||||
|
||||
client = await self._ensure_client()
|
||||
created_assistant = await client.beta.assistants.create(
|
||||
|
||||
@@ -41,9 +41,8 @@ from .._types import (
|
||||
UsageDetails,
|
||||
)
|
||||
from ..exceptions import (
|
||||
ServiceInitializationError,
|
||||
ServiceInvalidRequestError,
|
||||
ServiceResponseException,
|
||||
ChatClientException,
|
||||
ChatClientInvalidRequestException,
|
||||
)
|
||||
from ..observability import ChatTelemetryLayer
|
||||
from ._exceptions import OpenAIContentFilterException
|
||||
@@ -234,12 +233,12 @@ class RawOpenAIChatClient( # type: ignore[misc]
|
||||
f"{type(self)} service encountered a content error: {ex}",
|
||||
inner_exception=ex,
|
||||
) from ex
|
||||
raise ServiceResponseException(
|
||||
raise ChatClientException(
|
||||
f"{type(self)} service failed to complete the prompt: {ex}",
|
||||
inner_exception=ex,
|
||||
) from ex
|
||||
except Exception as ex:
|
||||
raise ServiceResponseException(
|
||||
raise ChatClientException(
|
||||
f"{type(self)} service failed to complete the prompt: {ex}",
|
||||
inner_exception=ex,
|
||||
) from ex
|
||||
@@ -259,12 +258,12 @@ class RawOpenAIChatClient( # type: ignore[misc]
|
||||
f"{type(self)} service encountered a content error: {ex}",
|
||||
inner_exception=ex,
|
||||
) from ex
|
||||
raise ServiceResponseException(
|
||||
raise ChatClientException(
|
||||
f"{type(self)} service failed to complete the prompt: {ex}",
|
||||
inner_exception=ex,
|
||||
) from ex
|
||||
except Exception as ex:
|
||||
raise ServiceResponseException(
|
||||
raise ChatClientException(
|
||||
f"{type(self)} service failed to complete the prompt: {ex}",
|
||||
inner_exception=ex,
|
||||
) from ex
|
||||
@@ -320,7 +319,7 @@ class RawOpenAIChatClient( # type: ignore[misc]
|
||||
if messages and "messages" not in run_options:
|
||||
run_options["messages"] = self._prepare_messages_for_openai(messages)
|
||||
if "messages" not in run_options:
|
||||
raise ServiceInvalidRequestError("Messages are required for chat completions")
|
||||
raise ChatClientInvalidRequestException("Messages are required for chat completions")
|
||||
|
||||
# Translation between options keys and Chat Completion API
|
||||
for old_key, new_key in OPTION_TRANSLATIONS.items():
|
||||
@@ -732,11 +731,11 @@ class OpenAIChatClient( # type: ignore[misc]
|
||||
)
|
||||
|
||||
if not async_client and not openai_settings["api_key"]:
|
||||
raise ServiceInitializationError(
|
||||
raise ValueError(
|
||||
"OpenAI API key is required. Set via 'api_key' parameter or 'OPENAI_API_KEY' environment variable."
|
||||
)
|
||||
if not openai_settings["chat_model_id"]:
|
||||
raise ServiceInitializationError(
|
||||
raise ValueError(
|
||||
"OpenAI model ID is required. "
|
||||
"Set via 'model_id' parameter or 'OPENAI_CHAT_MODEL_ID' environment variable."
|
||||
)
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import Any
|
||||
|
||||
from openai import BadRequestError
|
||||
|
||||
from ..exceptions import ServiceContentFilterException
|
||||
from ..exceptions import ChatClientContentFilterException
|
||||
|
||||
|
||||
class ContentFilterResultSeverity(Enum):
|
||||
@@ -54,7 +54,7 @@ class ContentFilterCodes(Enum):
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenAIContentFilterException(ServiceContentFilterException):
|
||||
class OpenAIContentFilterException(ChatClientContentFilterException):
|
||||
"""AI exception for an error from Azure OpenAI's content filter."""
|
||||
|
||||
# The parameter that caused the error.
|
||||
|
||||
@@ -61,9 +61,8 @@ from .._types import (
|
||||
validate_tool_mode,
|
||||
)
|
||||
from ..exceptions import (
|
||||
ServiceInitializationError,
|
||||
ServiceInvalidRequestError,
|
||||
ServiceResponseException,
|
||||
ChatClientException,
|
||||
ChatClientInvalidRequestException,
|
||||
)
|
||||
from ..observability import ChatTelemetryLayer
|
||||
from ._exceptions import OpenAIContentFilterException
|
||||
@@ -264,7 +263,7 @@ class RawOpenAIResponsesClient( # type: ignore[misc]
|
||||
f"{type(self)} service encountered a content error: {ex}",
|
||||
inner_exception=ex,
|
||||
) from ex
|
||||
raise ServiceResponseException(
|
||||
raise ChatClientException(
|
||||
f"{type(self)} service failed to complete the prompt: {ex}",
|
||||
inner_exception=ex,
|
||||
) from ex
|
||||
@@ -352,7 +351,7 @@ class RawOpenAIResponsesClient( # type: ignore[misc]
|
||||
) -> tuple[type[BaseModel] | None, dict[str, Any] | None]:
|
||||
"""Normalize response_format into Responses text configuration and parse target."""
|
||||
if text_config is not None and not isinstance(text_config, MutableMapping):
|
||||
raise ServiceInvalidRequestError("text must be a mapping when provided.")
|
||||
raise ChatClientInvalidRequestException("text must be a mapping when provided.")
|
||||
text_config = cast(dict[str, Any], text_config) if isinstance(text_config, MutableMapping) else None
|
||||
|
||||
if response_format is None:
|
||||
@@ -360,7 +359,7 @@ class RawOpenAIResponsesClient( # type: ignore[misc]
|
||||
|
||||
if isinstance(response_format, type) and issubclass(response_format, BaseModel):
|
||||
if text_config and "format" in text_config:
|
||||
raise ServiceInvalidRequestError("response_format cannot be combined with explicit text.format.")
|
||||
raise ChatClientInvalidRequestException("response_format cannot be combined with explicit text.format.")
|
||||
return response_format, text_config
|
||||
|
||||
if isinstance(response_format, Mapping):
|
||||
@@ -368,11 +367,11 @@ class RawOpenAIResponsesClient( # type: ignore[misc]
|
||||
if text_config is None:
|
||||
text_config = {}
|
||||
elif "format" in text_config and text_config["format"] != format_config:
|
||||
raise ServiceInvalidRequestError("Conflicting response_format definitions detected.")
|
||||
raise ChatClientInvalidRequestException("Conflicting response_format definitions detected.")
|
||||
text_config["format"] = format_config
|
||||
return None, text_config
|
||||
|
||||
raise ServiceInvalidRequestError("response_format must be a Pydantic model or mapping.")
|
||||
raise ChatClientInvalidRequestException("response_format must be a Pydantic model or mapping.")
|
||||
|
||||
def _convert_response_format(self, response_format: Mapping[str, Any]) -> dict[str, Any]:
|
||||
"""Convert Chat style response_format into Responses text format config."""
|
||||
@@ -383,11 +382,11 @@ class RawOpenAIResponsesClient( # type: ignore[misc]
|
||||
if format_type == "json_schema":
|
||||
schema_section = response_format.get("json_schema", response_format)
|
||||
if not isinstance(schema_section, Mapping):
|
||||
raise ServiceInvalidRequestError("json_schema response_format must be a mapping.")
|
||||
raise ChatClientInvalidRequestException("json_schema response_format must be a mapping.")
|
||||
schema_section_typed = cast("Mapping[str, Any]", schema_section)
|
||||
schema: Any = schema_section_typed.get("schema")
|
||||
if schema is None:
|
||||
raise ServiceInvalidRequestError("json_schema response_format requires a schema.")
|
||||
raise ChatClientInvalidRequestException("json_schema response_format requires a schema.")
|
||||
name: str = str(
|
||||
schema_section_typed.get("name")
|
||||
or schema_section_typed.get("title")
|
||||
@@ -408,7 +407,7 @@ class RawOpenAIResponsesClient( # type: ignore[misc]
|
||||
if format_type in {"json_object", "text"}:
|
||||
return {"type": format_type}
|
||||
|
||||
raise ServiceInvalidRequestError("Unsupported response_format provided for Responses client.")
|
||||
raise ChatClientInvalidRequestException("Unsupported response_format provided for Responses client.")
|
||||
|
||||
def _get_conversation_id(
|
||||
self, response: OpenAIResponse | ParsedResponse[BaseModel], store: bool | None
|
||||
@@ -787,10 +786,8 @@ class RawOpenAIResponsesClient( # type: ignore[misc]
|
||||
# Continuation turn: instructions already exist in conversation context, skip prepending
|
||||
request_input = self._prepare_messages_for_openai(messages)
|
||||
if not request_input:
|
||||
raise ServiceInvalidRequestError("Messages are required for chat completions")
|
||||
|
||||
raise ChatClientInvalidRequestException("Messages are required for chat completions")
|
||||
conversation_id = self._get_current_conversation_id(options, **kwargs)
|
||||
|
||||
run_options["input"] = request_input
|
||||
|
||||
# model id
|
||||
@@ -1876,11 +1873,11 @@ class OpenAIResponsesClient( # type: ignore[misc]
|
||||
)
|
||||
|
||||
if not async_client and not openai_settings["api_key"]:
|
||||
raise ServiceInitializationError(
|
||||
raise ValueError(
|
||||
"OpenAI API key is required. Set via 'api_key' parameter or 'OPENAI_API_KEY' environment variable."
|
||||
)
|
||||
if not openai_settings["responses_model_id"]:
|
||||
raise ServiceInitializationError(
|
||||
raise ValueError(
|
||||
"OpenAI model ID is required. "
|
||||
"Set via 'model_id' parameter or 'OPENAI_RESPONSES_MODEL_ID' environment variable."
|
||||
)
|
||||
|
||||
@@ -26,7 +26,6 @@ from .._serialization import SerializationMixin
|
||||
from .._settings import SecretString
|
||||
from .._telemetry import APP_INFO, USER_AGENT_KEY, prepend_agent_framework_to_user_agent
|
||||
from .._tools import FunctionTool
|
||||
from ..exceptions import ServiceInitializationError
|
||||
|
||||
logger: logging.Logger = logging.getLogger("agent_framework.openai")
|
||||
|
||||
@@ -56,20 +55,20 @@ def _check_openai_version_for_callable_api_key() -> None:
|
||||
"""Check if OpenAI version supports callable API keys.
|
||||
|
||||
Callable API keys require OpenAI >= 1.106.0.
|
||||
If the version is too old, raise a ServiceInitializationError with helpful message.
|
||||
If the version is too old, raise a ValueError with helpful message.
|
||||
"""
|
||||
try:
|
||||
current_version = parse(openai.__version__)
|
||||
min_required_version = parse("1.106.0")
|
||||
|
||||
if current_version < min_required_version:
|
||||
raise ServiceInitializationError(
|
||||
raise ValueError(
|
||||
f"Callable API keys require OpenAI SDK >= 1.106.0, but you have {openai.__version__}. "
|
||||
f"Please upgrade with 'pip install openai>=1.106.0' or provide a string API key instead. "
|
||||
f"Note: If you're using mem0ai, you may need to upgrade to mem0ai>=1.0.0 "
|
||||
f"to allow newer OpenAI versions."
|
||||
)
|
||||
except ServiceInitializationError:
|
||||
except ValueError:
|
||||
raise # Re-raise our own exception
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not check OpenAI version for callable API key support: {e}")
|
||||
@@ -172,7 +171,7 @@ class OpenAIBase(SerializationMixin):
|
||||
"""Ensure OpenAI client is initialized."""
|
||||
await self._initialize_client()
|
||||
if self.client is None:
|
||||
raise ServiceInitializationError("OpenAI client is not initialized")
|
||||
raise RuntimeError("OpenAI client is not initialized")
|
||||
|
||||
return self.client
|
||||
|
||||
@@ -247,7 +246,7 @@ class OpenAIConfigMixin(OpenAIBase):
|
||||
|
||||
if not client:
|
||||
if not api_key:
|
||||
raise ServiceInitializationError("Please provide an api_key")
|
||||
raise ValueError("Please provide an api_key")
|
||||
args: dict[str, Any] = {"api_key": api_key_value, "default_headers": merged_headers}
|
||||
if org_id:
|
||||
args["organization"] = org_id
|
||||
|
||||
@@ -21,7 +21,6 @@ from agent_framework import (
|
||||
)
|
||||
from agent_framework._settings import SecretString
|
||||
from agent_framework.azure import AzureOpenAIAssistantsClient
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
|
||||
skip_if_azure_integration_tests_disabled = pytest.mark.skipif(
|
||||
os.getenv("RUN_INTEGRATION_TESTS", "false").lower() != "true"
|
||||
@@ -120,7 +119,7 @@ def test_azure_assistants_client_init_auto_create_client(
|
||||
|
||||
def test_azure_assistants_client_init_validation_fail() -> None:
|
||||
"""Test AzureOpenAIAssistantsClient initialization with validation failure."""
|
||||
with pytest.raises(ServiceInitializationError):
|
||||
with pytest.raises(ValueError):
|
||||
# Force failure by providing invalid deployment name type - this should cause validation to fail
|
||||
AzureOpenAIAssistantsClient(deployment_name=123, api_key="valid-key") # type: ignore
|
||||
|
||||
@@ -128,7 +127,7 @@ def test_azure_assistants_client_init_validation_fail() -> None:
|
||||
@pytest.mark.parametrize("exclude_list", [["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"]], indirect=True)
|
||||
def test_azure_assistants_client_init_missing_deployment_name(azure_openai_unit_test_env: dict[str, str]) -> None:
|
||||
"""Test AzureOpenAIAssistantsClient initialization with missing deployment name."""
|
||||
with pytest.raises(ServiceInitializationError):
|
||||
with pytest.raises(ValueError):
|
||||
AzureOpenAIAssistantsClient(api_key=azure_openai_unit_test_env.get("AZURE_OPENAI_API_KEY", "test-key"))
|
||||
|
||||
|
||||
@@ -607,7 +606,7 @@ def test_azure_assistants_client_no_authentication_error() -> None:
|
||||
}
|
||||
|
||||
# Test missing authentication raises error
|
||||
with pytest.raises(ServiceInitializationError, match="api_key, credential, or a client"):
|
||||
with pytest.raises(ValueError, match="api_key, credential, or a client"):
|
||||
AzureOpenAIAssistantsClient(
|
||||
deployment_name="test-deployment",
|
||||
endpoint="https://test-endpoint.openai.azure.com",
|
||||
|
||||
@@ -28,7 +28,7 @@ from agent_framework import (
|
||||
)
|
||||
from agent_framework._telemetry import USER_AGENT_KEY
|
||||
from agent_framework.azure import AzureOpenAIChatClient
|
||||
from agent_framework.exceptions import ServiceInitializationError, ServiceResponseException
|
||||
from agent_framework.exceptions import ChatClientException
|
||||
from agent_framework.openai import (
|
||||
ContentFilterResultSeverity,
|
||||
OpenAIContentFilterException,
|
||||
@@ -93,13 +93,13 @@ def test_init_endpoint(azure_openai_unit_test_env: dict[str, str]) -> None:
|
||||
|
||||
@pytest.mark.parametrize("exclude_list", [["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"]], indirect=True)
|
||||
def test_init_with_empty_deployment_name(azure_openai_unit_test_env: dict[str, str]) -> None:
|
||||
with pytest.raises(ServiceInitializationError):
|
||||
with pytest.raises(ValueError):
|
||||
AzureOpenAIChatClient()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("exclude_list", [["AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_BASE_URL"]], indirect=True)
|
||||
def test_init_with_empty_endpoint_and_base_url(azure_openai_unit_test_env: dict[str, str]) -> None:
|
||||
with pytest.raises(ServiceInitializationError):
|
||||
with pytest.raises(ValueError):
|
||||
AzureOpenAIChatClient()
|
||||
|
||||
|
||||
@@ -554,7 +554,7 @@ async def test_bad_request_non_content_filter(
|
||||
|
||||
azure_chat_client = AzureOpenAIChatClient()
|
||||
|
||||
with pytest.raises(ServiceResponseException, match="service failed to complete the prompt"):
|
||||
with pytest.raises(ChatClientException, match="service failed to complete the prompt"):
|
||||
await azure_chat_client.get_response(
|
||||
messages=chat_history,
|
||||
)
|
||||
|
||||
@@ -21,7 +21,6 @@ from agent_framework import (
|
||||
tool,
|
||||
)
|
||||
from agent_framework.azure import AzureOpenAIResponsesClient
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
|
||||
skip_if_azure_integration_tests_disabled = pytest.mark.skipif(
|
||||
os.getenv("RUN_INTEGRATION_TESTS", "false").lower() != "true"
|
||||
@@ -81,7 +80,7 @@ def test_init(azure_openai_unit_test_env: dict[str, str]) -> None:
|
||||
|
||||
def test_init_validation_fail() -> None:
|
||||
# Test successful initialization
|
||||
with pytest.raises(ServiceInitializationError):
|
||||
with pytest.raises(ValueError):
|
||||
AzureOpenAIResponsesClient(api_key="34523", deployment_name={"test": "dict"}) # type: ignore
|
||||
|
||||
|
||||
@@ -113,7 +112,7 @@ def test_init_with_default_header(azure_openai_unit_test_env: dict[str, str]) ->
|
||||
|
||||
@pytest.mark.parametrize("exclude_list", [["AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME"]], indirect=True)
|
||||
def test_init_with_empty_model_id(azure_openai_unit_test_env: dict[str, str]) -> None:
|
||||
with pytest.raises(ServiceInitializationError):
|
||||
with pytest.raises(ValueError):
|
||||
AzureOpenAIResponsesClient()
|
||||
|
||||
|
||||
@@ -212,7 +211,7 @@ def test_create_client_from_project_with_endpoint() -> None:
|
||||
|
||||
def test_create_client_from_project_missing_endpoint() -> None:
|
||||
"""Test _create_client_from_project raises error when endpoint is missing."""
|
||||
with pytest.raises(ServiceInitializationError, match="project endpoint is required"):
|
||||
with pytest.raises(ValueError, match="project endpoint is required"):
|
||||
AzureOpenAIResponsesClient._create_client_from_project(
|
||||
project_client=None,
|
||||
project_endpoint=None,
|
||||
@@ -222,7 +221,7 @@ def test_create_client_from_project_missing_endpoint() -> None:
|
||||
|
||||
def test_create_client_from_project_missing_credential() -> None:
|
||||
"""Test _create_client_from_project raises error when credential is missing."""
|
||||
with pytest.raises(ServiceInitializationError, match="credential is required"):
|
||||
with pytest.raises(ValueError, match="credential is required"):
|
||||
AzureOpenAIResponsesClient._create_client_from_project(
|
||||
project_client=None,
|
||||
project_endpoint="https://test-project.services.ai.azure.com",
|
||||
|
||||
@@ -9,7 +9,7 @@ from azure.core.credentials_async import AsyncTokenCredential
|
||||
from agent_framework.azure._entra_id_authentication import (
|
||||
resolve_credential_to_token_provider,
|
||||
)
|
||||
from agent_framework.exceptions import ServiceInvalidAuthError
|
||||
from agent_framework.exceptions import ChatClientInvalidAuthException
|
||||
|
||||
TOKEN_ENDPOINT = "https://cognitiveservices.azure.com/.default"
|
||||
|
||||
@@ -51,11 +51,11 @@ def test_resolve_callable_provider_passthrough() -> None:
|
||||
|
||||
|
||||
def test_resolve_missing_endpoint_raises() -> None:
|
||||
"""Test that missing token endpoint raises ServiceInvalidAuthError."""
|
||||
"""Test that missing token endpoint raises ChatClientInvalidAuthException."""
|
||||
mock_credential = MagicMock(spec=TokenCredential)
|
||||
|
||||
with pytest.raises(ServiceInvalidAuthError, match="A token endpoint must be provided"):
|
||||
with pytest.raises(ChatClientInvalidAuthException, match="A token endpoint must be provided"):
|
||||
resolve_credential_to_token_provider(mock_credential, "")
|
||||
|
||||
with pytest.raises(ServiceInvalidAuthError, match="A token endpoint must be provided"):
|
||||
with pytest.raises(ChatClientInvalidAuthException, match="A token endpoint must be provided"):
|
||||
resolve_credential_to_token_provider(mock_credential, None) # type: ignore[arg-type]
|
||||
|
||||
@@ -245,9 +245,8 @@ class TestOverrideTypeValidation:
|
||||
"""Test override type validation."""
|
||||
|
||||
def test_invalid_type_raises(self) -> None:
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
|
||||
with pytest.raises(ServiceInitializationError, match="Invalid type for setting 'api_key'"):
|
||||
with pytest.raises(ValueError, match="Invalid type for setting 'api_key'"):
|
||||
load_settings(SimpleSettings, env_prefix="TEST_", api_key={"bad": "type"})
|
||||
|
||||
def test_valid_types_accepted(self) -> None:
|
||||
|
||||
@@ -9,7 +9,6 @@ from openai.types.beta.assistant import Assistant
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from agent_framework import Agent, normalize_tools, tool
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
from agent_framework.openai import OpenAIAssistantProvider, OpenAIAssistantsClient
|
||||
from agent_framework.openai._shared import from_assistant_tools, to_assistant_tools
|
||||
|
||||
@@ -99,7 +98,6 @@ class WeatherResponse(BaseModel):
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region Initialization Tests
|
||||
|
||||
|
||||
@@ -141,7 +139,7 @@ class TestOpenAIAssistantProviderInit:
|
||||
"responses_model_id": None,
|
||||
}
|
||||
|
||||
with pytest.raises(ServiceInitializationError) as exc_info:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
OpenAIAssistantProvider()
|
||||
|
||||
assert "API key is required" in str(exc_info.value)
|
||||
@@ -191,7 +189,6 @@ class TestOpenAIAssistantProviderContextManager:
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region create_agent Tests
|
||||
|
||||
|
||||
@@ -366,7 +363,6 @@ class TestOpenAIAssistantProviderCreateAgent:
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region get_agent Tests
|
||||
|
||||
|
||||
@@ -454,7 +450,6 @@ class TestOpenAIAssistantProviderGetAgent:
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region as_agent Tests
|
||||
|
||||
|
||||
@@ -540,7 +535,6 @@ class TestOpenAIAssistantProviderAsAgent:
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region Tool Conversion Tests
|
||||
|
||||
|
||||
@@ -643,7 +637,6 @@ class TestToolConversion:
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region Tool Validation Tests
|
||||
|
||||
|
||||
@@ -702,7 +695,6 @@ class TestToolValidation:
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region Tool Merging Tests
|
||||
|
||||
|
||||
@@ -760,10 +752,8 @@ class TestToolMerging:
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region Integration Tests
|
||||
|
||||
|
||||
skip_if_openai_integration_tests_disabled = pytest.mark.skipif(
|
||||
os.getenv("RUN_INTEGRATION_TESTS", "false").lower() != "true"
|
||||
or os.getenv("OPENAI_API_KEY", "") in ("", "test-dummy-key"),
|
||||
|
||||
@@ -22,7 +22,6 @@ from agent_framework import (
|
||||
SupportsChatGetResponse,
|
||||
tool,
|
||||
)
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
from agent_framework.openai import OpenAIAssistantsClient
|
||||
|
||||
skip_if_openai_integration_tests_disabled = pytest.mark.skipif(
|
||||
@@ -145,7 +144,7 @@ def test_init_auto_create_client(
|
||||
|
||||
def test_init_validation_fail() -> None:
|
||||
"""Test OpenAIAssistantsClient initialization with validation failure."""
|
||||
with pytest.raises(ServiceInitializationError):
|
||||
with pytest.raises(ValueError):
|
||||
# Force failure by providing invalid model ID type
|
||||
OpenAIAssistantsClient(model_id=123, api_key="valid-key") # type: ignore
|
||||
|
||||
@@ -153,14 +152,14 @@ def test_init_validation_fail() -> None:
|
||||
@pytest.mark.parametrize("exclude_list", [["OPENAI_CHAT_MODEL_ID"]], indirect=True)
|
||||
def test_init_missing_model_id(openai_unit_test_env: dict[str, str]) -> None:
|
||||
"""Test OpenAIAssistantsClient initialization with missing model ID."""
|
||||
with pytest.raises(ServiceInitializationError):
|
||||
with pytest.raises(ValueError):
|
||||
OpenAIAssistantsClient(api_key=openai_unit_test_env.get("OPENAI_API_KEY", "test-key"))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("exclude_list", [["OPENAI_API_KEY"]], indirect=True)
|
||||
def test_init_missing_api_key(openai_unit_test_env: dict[str, str]) -> None:
|
||||
"""Test OpenAIAssistantsClient initialization with missing API key."""
|
||||
with pytest.raises(ServiceInitializationError):
|
||||
with pytest.raises(ValueError):
|
||||
OpenAIAssistantsClient(model_id="gpt-4")
|
||||
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ from agent_framework import (
|
||||
SupportsChatGetResponse,
|
||||
tool,
|
||||
)
|
||||
from agent_framework.exceptions import ServiceInitializationError, ServiceResponseException
|
||||
from agent_framework.exceptions import ChatClientException
|
||||
from agent_framework.openai import OpenAIChatClient
|
||||
from agent_framework.openai._exceptions import OpenAIContentFilterException
|
||||
|
||||
@@ -42,7 +42,7 @@ def test_init(openai_unit_test_env: dict[str, str]) -> None:
|
||||
|
||||
def test_init_validation_fail() -> None:
|
||||
# Test successful initialization
|
||||
with pytest.raises(ServiceInitializationError):
|
||||
with pytest.raises(ValueError):
|
||||
OpenAIChatClient(api_key="34523", model_id={"test": "dict"}) # type: ignore
|
||||
|
||||
|
||||
@@ -96,7 +96,7 @@ def test_init_base_url_from_settings_env() -> None:
|
||||
|
||||
@pytest.mark.parametrize("exclude_list", [["OPENAI_CHAT_MODEL_ID"]], indirect=True)
|
||||
def test_init_with_empty_model_id(openai_unit_test_env: dict[str, str]) -> None:
|
||||
with pytest.raises(ServiceInitializationError):
|
||||
with pytest.raises(ValueError):
|
||||
OpenAIChatClient()
|
||||
|
||||
|
||||
@@ -104,7 +104,7 @@ def test_init_with_empty_model_id(openai_unit_test_env: dict[str, str]) -> None:
|
||||
def test_init_with_empty_api_key(openai_unit_test_env: dict[str, str]) -> None:
|
||||
model_id = "test_model_id"
|
||||
|
||||
with pytest.raises(ServiceInitializationError):
|
||||
with pytest.raises(ValueError):
|
||||
OpenAIChatClient(
|
||||
model_id=model_id,
|
||||
)
|
||||
@@ -235,7 +235,7 @@ async def test_exception_message_includes_original_error_details() -> None:
|
||||
|
||||
with (
|
||||
patch.object(client.client.chat.completions, "create", side_effect=mock_error),
|
||||
pytest.raises(ServiceResponseException) as exc_info,
|
||||
pytest.raises(ChatClientException) as exc_info,
|
||||
):
|
||||
await client._inner_get_response(messages=messages, options={}) # type: ignore
|
||||
|
||||
@@ -779,11 +779,11 @@ def test_prepare_options_without_model_id(openai_unit_test_env: dict[str, str])
|
||||
|
||||
def test_prepare_options_without_messages(openai_unit_test_env: dict[str, str]) -> None:
|
||||
"""Test that prepare_options raises error when messages are missing."""
|
||||
from agent_framework.exceptions import ServiceInvalidRequestError
|
||||
from agent_framework.exceptions import ChatClientInvalidRequestException
|
||||
|
||||
client = OpenAIChatClient()
|
||||
|
||||
with pytest.raises(ServiceInvalidRequestError, match="Messages are required"):
|
||||
with pytest.raises(ChatClientInvalidRequestException, match="Messages are required"):
|
||||
client._prepare_options([], {})
|
||||
|
||||
|
||||
@@ -932,7 +932,7 @@ async def test_streaming_exception_handling(openai_unit_test_env: dict[str, str]
|
||||
|
||||
with (
|
||||
patch.object(client.client.chat.completions, "create", side_effect=mock_error),
|
||||
pytest.raises(ServiceResponseException),
|
||||
pytest.raises(ChatClientException),
|
||||
):
|
||||
async for _ in client._inner_get_response(messages=messages, stream=True, options={}): # type: ignore
|
||||
pass
|
||||
|
||||
@@ -15,9 +15,7 @@ from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agent_framework import ChatResponseUpdate, Message
|
||||
from agent_framework.exceptions import (
|
||||
ServiceResponseException,
|
||||
)
|
||||
from agent_framework.exceptions import ChatClientException
|
||||
from agent_framework.openai import OpenAIChatClient
|
||||
|
||||
|
||||
@@ -182,7 +180,7 @@ async def test_cmc_general_exception(
|
||||
chat_history.append(Message(role="user", text="hello world"))
|
||||
|
||||
openai_chat_completion = OpenAIChatClient()
|
||||
with pytest.raises(ServiceResponseException):
|
||||
with pytest.raises(ChatClientException):
|
||||
await openai_chat_completion.get_response(
|
||||
messages=chat_history,
|
||||
)
|
||||
|
||||
@@ -35,11 +35,7 @@ from agent_framework import (
|
||||
SupportsChatGetResponse,
|
||||
tool,
|
||||
)
|
||||
from agent_framework.exceptions import (
|
||||
ServiceInitializationError,
|
||||
ServiceInvalidRequestError,
|
||||
ServiceResponseException,
|
||||
)
|
||||
from agent_framework.exceptions import ChatClientException, ChatClientInvalidRequestException
|
||||
from agent_framework.openai import OpenAIResponsesClient
|
||||
from agent_framework.openai._exceptions import OpenAIContentFilterException
|
||||
|
||||
@@ -106,7 +102,7 @@ def test_init(openai_unit_test_env: dict[str, str]) -> None:
|
||||
|
||||
def test_init_validation_fail() -> None:
|
||||
# Test successful initialization
|
||||
with pytest.raises(ServiceInitializationError):
|
||||
with pytest.raises(ValueError):
|
||||
OpenAIResponsesClient(api_key="34523", model_id={"test": "dict"}) # type: ignore
|
||||
|
||||
|
||||
@@ -138,7 +134,7 @@ def test_init_with_default_header(openai_unit_test_env: dict[str, str]) -> None:
|
||||
|
||||
@pytest.mark.parametrize("exclude_list", [["OPENAI_RESPONSES_MODEL_ID"]], indirect=True)
|
||||
def test_init_with_empty_model_id(openai_unit_test_env: dict[str, str]) -> None:
|
||||
with pytest.raises(ServiceInitializationError):
|
||||
with pytest.raises(ValueError):
|
||||
OpenAIResponsesClient()
|
||||
|
||||
|
||||
@@ -146,7 +142,7 @@ def test_init_with_empty_model_id(openai_unit_test_env: dict[str, str]) -> None:
|
||||
def test_init_with_empty_api_key(openai_unit_test_env: dict[str, str]) -> None:
|
||||
model_id = "test_model_id"
|
||||
|
||||
with pytest.raises(ServiceInitializationError):
|
||||
with pytest.raises(ValueError):
|
||||
OpenAIResponsesClient(
|
||||
model_id=model_id,
|
||||
)
|
||||
@@ -192,8 +188,8 @@ async def test_get_response_with_invalid_input() -> None:
|
||||
|
||||
client = OpenAIResponsesClient(model_id="invalid-model", api_key="test-key")
|
||||
|
||||
# Test with empty messages which should trigger ServiceInvalidRequestError
|
||||
with pytest.raises(ServiceInvalidRequestError, match="Messages are required"):
|
||||
# Test with empty messages which should trigger ChatClientInvalidRequestException
|
||||
with pytest.raises(ChatClientInvalidRequestException, match="Messages are required"):
|
||||
await client.get_response(messages=[])
|
||||
|
||||
|
||||
@@ -201,7 +197,7 @@ async def test_get_response_with_all_parameters() -> None:
|
||||
"""Test get_response with all possible parameters to cover parameter handling logic."""
|
||||
client = OpenAIResponsesClient(model_id="test-model", api_key="test-key")
|
||||
# Test with comprehensive parameter set - should fail due to invalid API key
|
||||
with pytest.raises(ServiceResponseException):
|
||||
with pytest.raises(ChatClientException):
|
||||
await client.get_response(
|
||||
messages=[Message(role="user", text="Test message")],
|
||||
options={
|
||||
@@ -244,7 +240,7 @@ async def test_web_search_tool_with_location() -> None:
|
||||
)
|
||||
|
||||
# Should raise an authentication error due to invalid API key
|
||||
with pytest.raises(ServiceResponseException):
|
||||
with pytest.raises(ChatClientException):
|
||||
await client.get_response(
|
||||
messages=[Message(role="user", text="What's the weather?")],
|
||||
options={"tools": [web_search_tool], "tool_choice": "auto"},
|
||||
@@ -258,7 +254,7 @@ async def test_code_interpreter_tool_variations() -> None:
|
||||
# Test code interpreter using static method
|
||||
code_tool = OpenAIResponsesClient.get_code_interpreter_tool()
|
||||
|
||||
with pytest.raises(ServiceResponseException):
|
||||
with pytest.raises(ChatClientException):
|
||||
await client.get_response(
|
||||
messages=[Message("user", ["Run some code"])],
|
||||
options={"tools": [code_tool]},
|
||||
@@ -267,7 +263,7 @@ async def test_code_interpreter_tool_variations() -> None:
|
||||
# Test code interpreter with files using static method
|
||||
code_tool_with_files = OpenAIResponsesClient.get_code_interpreter_tool(file_ids=["file1", "file2"])
|
||||
|
||||
with pytest.raises(ServiceResponseException):
|
||||
with pytest.raises(ChatClientException):
|
||||
await client.get_response(
|
||||
messages=[Message(role="user", text="Process these files")],
|
||||
options={"tools": [code_tool_with_files]},
|
||||
@@ -303,7 +299,7 @@ async def test_hosted_file_search_tool_validation() -> None:
|
||||
file_search_tool = OpenAIResponsesClient.get_file_search_tool(vector_store_ids=["vs_123"])
|
||||
|
||||
# Test using file search tool - may raise various exceptions depending on API response
|
||||
with pytest.raises((ValueError, ServiceInvalidRequestError, ServiceResponseException)):
|
||||
with pytest.raises((ValueError, ChatClientInvalidRequestException, ChatClientException)):
|
||||
await client.get_response(
|
||||
messages=[Message("user", ["Test"])],
|
||||
options={"tools": [file_search_tool]},
|
||||
@@ -331,7 +327,7 @@ async def test_chat_message_parsing_with_function_calls() -> None:
|
||||
]
|
||||
|
||||
# This should exercise the message parsing logic - will fail due to invalid API key
|
||||
with pytest.raises(ServiceResponseException):
|
||||
with pytest.raises(ChatClientException):
|
||||
await client.get_response(messages=messages)
|
||||
|
||||
|
||||
@@ -401,7 +397,7 @@ async def test_bad_request_error_non_content_filter() -> None:
|
||||
mock_error.code = "invalid_request"
|
||||
|
||||
with patch.object(client.client.responses, "parse", side_effect=mock_error):
|
||||
with pytest.raises(ServiceResponseException) as exc_info:
|
||||
with pytest.raises(ChatClientException) as exc_info:
|
||||
await client.get_response(
|
||||
messages=[Message(role="user", text="Test message")],
|
||||
options={"response_format": OutputStruct},
|
||||
@@ -996,7 +992,7 @@ def test_response_format_with_conflicting_definitions() -> None:
|
||||
response_format = {"type": "json_schema", "format": {"type": "json_schema", "name": "Test", "schema": {}}}
|
||||
text_config = {"format": {"type": "json_object"}}
|
||||
|
||||
with pytest.raises(ServiceInvalidRequestError, match="Conflicting response_format definitions"):
|
||||
with pytest.raises(ChatClientInvalidRequestException, match="Conflicting response_format definitions"):
|
||||
client._prepare_response_and_text_format(response_format=response_format, text_config=text_config)
|
||||
|
||||
|
||||
@@ -1092,7 +1088,7 @@ def test_response_format_json_schema_missing_schema() -> None:
|
||||
|
||||
response_format = {"type": "json_schema", "json_schema": {"name": "NoSchema"}}
|
||||
|
||||
with pytest.raises(ServiceInvalidRequestError, match="json_schema response_format requires a schema"):
|
||||
with pytest.raises(ChatClientInvalidRequestException, match="json_schema response_format requires a schema"):
|
||||
client._prepare_response_and_text_format(response_format=response_format, text_config=None)
|
||||
|
||||
|
||||
@@ -1102,7 +1098,7 @@ def test_response_format_unsupported_type() -> None:
|
||||
|
||||
response_format = {"type": "unsupported_format"}
|
||||
|
||||
with pytest.raises(ServiceInvalidRequestError, match="Unsupported response_format"):
|
||||
with pytest.raises(ChatClientInvalidRequestException, match="Unsupported response_format"):
|
||||
client._prepare_response_and_text_format(response_format=response_format, text_config=None)
|
||||
|
||||
|
||||
@@ -1112,7 +1108,7 @@ def test_response_format_invalid_type() -> None:
|
||||
|
||||
response_format = "invalid" # Not a Pydantic model or mapping
|
||||
|
||||
with pytest.raises(ServiceInvalidRequestError, match="response_format must be a Pydantic model or mapping"):
|
||||
with pytest.raises(ChatClientInvalidRequestException, match="response_format must be a Pydantic model or mapping"):
|
||||
client._prepare_response_and_text_format(response_format=response_format, text_config=None) # type: ignore
|
||||
|
||||
|
||||
@@ -1689,7 +1685,7 @@ def test_streaming_annotation_added_with_unknown_type() -> None:
|
||||
|
||||
|
||||
async def test_service_response_exception_includes_original_error_details() -> None:
|
||||
"""Test that ServiceResponseException messages include original error details in the new format."""
|
||||
"""Test that ChatClientException messages include original error details in the new format."""
|
||||
client = OpenAIResponsesClient(model_id="test-model", api_key="test-key")
|
||||
messages = [Message(role="user", text="test message")]
|
||||
|
||||
@@ -1704,7 +1700,7 @@ async def test_service_response_exception_includes_original_error_details() -> N
|
||||
|
||||
with (
|
||||
patch.object(client.client.responses, "parse", side_effect=mock_error),
|
||||
pytest.raises(ServiceResponseException) as exc_info,
|
||||
pytest.raises(ChatClientException) as exc_info,
|
||||
):
|
||||
await client.get_response(messages=messages, options={"response_format": OutputStruct})
|
||||
|
||||
@@ -1719,7 +1715,7 @@ async def test_get_response_streaming_with_response_format() -> None:
|
||||
messages = [Message(role="user", text="Test streaming with format")]
|
||||
|
||||
# It will fail due to invalid API key, but exercises the code path
|
||||
with pytest.raises(ServiceResponseException):
|
||||
with pytest.raises(ChatClientException):
|
||||
|
||||
async def run_streaming():
|
||||
async for _ in client.get_response(
|
||||
|
||||
@@ -6,9 +6,9 @@ from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
|
||||
from agent_framework import WorkflowCheckpointException
|
||||
from agent_framework._workflows._checkpoint_encoding import (
|
||||
_TYPE_MARKER, # type: ignore
|
||||
CheckpointDecodingError,
|
||||
decode_checkpoint_value,
|
||||
encode_checkpoint_value,
|
||||
)
|
||||
@@ -178,13 +178,13 @@ def test_decode_plain_list() -> None:
|
||||
|
||||
|
||||
def test_decode_raises_on_type_mismatch() -> None:
|
||||
"""Test that decoding raises CheckpointDecodingError when type doesn't match."""
|
||||
"""Test that decoding raises WorkflowCheckpointException when type doesn't match."""
|
||||
# Encode a SampleRequest but tamper with the type marker
|
||||
encoded = encode_checkpoint_value(SampleRequest(request_id="r1", prompt="p1"))
|
||||
assert isinstance(encoded, dict)
|
||||
encoded[_TYPE_MARKER] = "nonexistent.module:FakeClass"
|
||||
|
||||
with pytest.raises(CheckpointDecodingError, match="Type mismatch"):
|
||||
with pytest.raises(WorkflowCheckpointException, match="Type mismatch"):
|
||||
decode_checkpoint_value(encoded)
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from ._loader import AgentFactory, DeclarativeLoaderError, ProviderLookupError,
|
||||
from ._workflows import (
|
||||
AgentExternalInputRequest,
|
||||
AgentExternalInputResponse,
|
||||
AgentInvocationError,
|
||||
DeclarativeWorkflowError,
|
||||
ExternalInputRequest,
|
||||
ExternalInputResponse,
|
||||
@@ -23,7 +22,6 @@ __all__ = [
|
||||
"AgentExternalInputRequest",
|
||||
"AgentExternalInputResponse",
|
||||
"AgentFactory",
|
||||
"AgentInvocationError",
|
||||
"DeclarativeLoaderError",
|
||||
"DeclarativeWorkflowError",
|
||||
"ExternalInputRequest",
|
||||
|
||||
@@ -16,7 +16,7 @@ from agent_framework import (
|
||||
FunctionTool as AFFunctionTool,
|
||||
)
|
||||
from agent_framework._tools import _create_model_from_json_schema # type: ignore
|
||||
from agent_framework.exceptions import AgentFrameworkException
|
||||
from agent_framework.exceptions import AgentException
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from ._models import (
|
||||
@@ -104,7 +104,7 @@ PROVIDER_TYPE_OBJECT_MAPPING: dict[str, ProviderTypeMapping] = {
|
||||
}
|
||||
|
||||
|
||||
class DeclarativeLoaderError(AgentFrameworkException):
|
||||
class DeclarativeLoaderError(AgentException):
|
||||
"""Exception raised for errors in the declarative loader."""
|
||||
|
||||
pass
|
||||
|
||||
@@ -31,7 +31,6 @@ from ._executors_agents import (
|
||||
TOOL_REGISTRY_KEY,
|
||||
AgentExternalInputRequest,
|
||||
AgentExternalInputResponse,
|
||||
AgentInvocationError,
|
||||
AgentResult,
|
||||
ExternalLoopState,
|
||||
InvokeAzureAgentExecutor,
|
||||
@@ -92,7 +91,6 @@ __all__ = [
|
||||
"ActionTrigger",
|
||||
"AgentExternalInputRequest",
|
||||
"AgentExternalInputResponse",
|
||||
"AgentInvocationError",
|
||||
"AgentResult",
|
||||
"AppendValueExecutor",
|
||||
"BreakLoopExecutor",
|
||||
|
||||
+3
-1
@@ -13,6 +13,8 @@ import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
|
||||
from agent_framework.exceptions import WorkflowException
|
||||
|
||||
from ._handlers import (
|
||||
ActionContext,
|
||||
WorkflowEvent,
|
||||
@@ -22,7 +24,7 @@ from ._handlers import (
|
||||
logger = logging.getLogger("agent_framework.declarative")
|
||||
|
||||
|
||||
class WorkflowActionError(Exception):
|
||||
class WorkflowActionError(WorkflowException):
|
||||
"""Exception raised by ThrowException action."""
|
||||
|
||||
def __init__(self, message: str, code: str | None = None):
|
||||
|
||||
+1
-1
@@ -32,7 +32,7 @@ from dataclasses import dataclass
|
||||
from decimal import Decimal as _Decimal
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from agent_framework._workflows import (
|
||||
from agent_framework import (
|
||||
Executor,
|
||||
WorkflowContext,
|
||||
)
|
||||
|
||||
+1
-1
@@ -15,7 +15,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from agent_framework._workflows import (
|
||||
from agent_framework import (
|
||||
Workflow,
|
||||
WorkflowBuilder,
|
||||
)
|
||||
|
||||
+9
-19
@@ -26,6 +26,7 @@ from agent_framework import (
|
||||
handler,
|
||||
response_handler,
|
||||
)
|
||||
from agent_framework.exceptions import AgentInvalidRequestException, AgentInvalidResponseException
|
||||
|
||||
from ._declarative_base import (
|
||||
ActionComplete,
|
||||
@@ -243,19 +244,6 @@ TOOL_REGISTRY_KEY = "_tool_registry"
|
||||
EXTERNAL_LOOP_STATE_KEY = "_external_loop_state"
|
||||
|
||||
|
||||
class AgentInvocationError(Exception):
|
||||
"""Raised when an agent invocation fails.
|
||||
|
||||
Attributes:
|
||||
agent_name: Name of the agent that failed
|
||||
message: Error description
|
||||
"""
|
||||
|
||||
def __init__(self, agent_name: str, message: str) -> None:
|
||||
self.agent_name = agent_name
|
||||
super().__init__(f"Agent '{agent_name}' invocation failed: {message}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentResult:
|
||||
"""Result from an agent invocation."""
|
||||
@@ -807,7 +795,7 @@ class InvokeAzureAgentExecutor(DeclarativeActionExecutor):
|
||||
state.set("Agent.error", error_msg)
|
||||
if result_property:
|
||||
state.set(result_property, {"error": error_msg})
|
||||
raise AgentInvocationError(agent_name, "not found in registry")
|
||||
raise AgentInvalidRequestException(f"Agent '{agent_name}' invocation failed: not found in registry")
|
||||
|
||||
iteration = 0
|
||||
|
||||
@@ -824,14 +812,14 @@ class InvokeAzureAgentExecutor(DeclarativeActionExecutor):
|
||||
auto_send=auto_send,
|
||||
messages_path=messages_path,
|
||||
)
|
||||
except AgentInvocationError:
|
||||
except (AgentInvalidRequestException, AgentInvalidResponseException):
|
||||
raise # Re-raise our own errors
|
||||
except Exception as e:
|
||||
logger.error(f"InvokeAzureAgent: error invoking agent '{agent_name}': {e}")
|
||||
state.set("Agent.error", str(e))
|
||||
if result_property:
|
||||
state.set(result_property, {"error": str(e)})
|
||||
raise AgentInvocationError(agent_name, str(e)) from e
|
||||
raise AgentInvalidResponseException(f"Agent '{agent_name}' invocation failed: {e}") from e
|
||||
|
||||
# Check external loop condition
|
||||
if external_loop_when:
|
||||
@@ -948,7 +936,9 @@ class InvokeAzureAgentExecutor(DeclarativeActionExecutor):
|
||||
|
||||
if agent is None:
|
||||
logger.error(f"InvokeAzureAgent: agent '{agent_name}' not found during loop resumption")
|
||||
raise AgentInvocationError(agent_name, "not found during loop resumption")
|
||||
raise AgentInvalidRequestException(
|
||||
f"Agent '{agent_name}' invocation failed: not found during loop resumption"
|
||||
)
|
||||
|
||||
try:
|
||||
accumulated_response, all_messages, tool_calls = await self._invoke_agent_and_store_results(
|
||||
@@ -963,12 +953,12 @@ class InvokeAzureAgentExecutor(DeclarativeActionExecutor):
|
||||
auto_send=loop_state.auto_send,
|
||||
messages_path=loop_state.messages_path,
|
||||
)
|
||||
except AgentInvocationError:
|
||||
except (AgentInvalidRequestException, AgentInvalidResponseException):
|
||||
raise # Re-raise our own errors
|
||||
except Exception as e:
|
||||
logger.error(f"InvokeAzureAgent: error invoking agent '{agent_name}' during loop: {e}")
|
||||
state.set("Agent.error", str(e))
|
||||
raise AgentInvocationError(agent_name, str(e)) from e
|
||||
raise AgentInvalidResponseException(f"Agent '{agent_name}' invocation failed: {e}") from e
|
||||
|
||||
# Re-evaluate the condition AFTER the agent responds
|
||||
# This is critical: the agent's response may have set NeedsTicket=true or IsResolved=true
|
||||
|
||||
+1
-1
@@ -8,7 +8,7 @@ Each action becomes a node in the workflow graph.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from agent_framework._workflows import (
|
||||
from agent_framework import (
|
||||
WorkflowContext,
|
||||
handler,
|
||||
)
|
||||
|
||||
+1
-1
@@ -16,7 +16,7 @@ The key insight is that control flow becomes GRAPH STRUCTURE, not executor logic
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
from agent_framework._workflows import (
|
||||
from agent_framework import (
|
||||
WorkflowContext,
|
||||
handler,
|
||||
)
|
||||
|
||||
+1
-1
@@ -11,7 +11,7 @@ import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from agent_framework._workflows import (
|
||||
from agent_framework import (
|
||||
WorkflowContext,
|
||||
handler,
|
||||
response_handler,
|
||||
|
||||
@@ -24,6 +24,7 @@ from agent_framework import (
|
||||
SupportsAgentRun,
|
||||
Workflow,
|
||||
)
|
||||
from agent_framework.exceptions import WorkflowException
|
||||
|
||||
from .._loader import AgentFactory
|
||||
from ._declarative_builder import DeclarativeWorkflowBuilder
|
||||
@@ -31,7 +32,7 @@ from ._declarative_builder import DeclarativeWorkflowBuilder
|
||||
logger = logging.getLogger("agent_framework.declarative")
|
||||
|
||||
|
||||
class DeclarativeWorkflowError(Exception):
|
||||
class DeclarativeWorkflowError(WorkflowException):
|
||||
"""Exception raised for errors in declarative workflow processing."""
|
||||
|
||||
pass
|
||||
|
||||
@@ -1851,9 +1851,10 @@ class TestAgentExternalLoopCoverage:
|
||||
assert request.agent_name == "TestAgent"
|
||||
|
||||
async def test_agent_executor_agent_error_handling(self, mock_context, mock_state):
|
||||
"""Test agent executor raises AgentInvocationError on failure."""
|
||||
"""Test agent executor raises AgentInvalidResponseException on failure."""
|
||||
from agent_framework.exceptions import AgentInvalidResponseException
|
||||
|
||||
from agent_framework_declarative._workflows._executors_agents import (
|
||||
AgentInvocationError,
|
||||
InvokeAzureAgentExecutor,
|
||||
)
|
||||
|
||||
@@ -1871,7 +1872,7 @@ class TestAgentExternalLoopCoverage:
|
||||
}
|
||||
executor = InvokeAzureAgentExecutor(action_def, agents={"TestAgent": mock_agent})
|
||||
|
||||
with pytest.raises(AgentInvocationError) as exc_info:
|
||||
with pytest.raises(AgentInvalidResponseException) as exc_info:
|
||||
await executor.handle_action(ActionTrigger(), mock_context)
|
||||
|
||||
assert "TestAgent" in str(exc_info.value)
|
||||
@@ -2375,11 +2376,12 @@ class TestAgentExecutorExternalLoop:
|
||||
|
||||
async def test_handle_external_input_response_agent_not_found(self, mock_context, mock_state):
|
||||
"""Test handling external input raises error when agent not found during resumption."""
|
||||
from agent_framework.exceptions import AgentInvalidRequestException
|
||||
|
||||
from agent_framework_declarative._workflows._executors_agents import (
|
||||
EXTERNAL_LOOP_STATE_KEY,
|
||||
AgentExternalInputRequest,
|
||||
AgentExternalInputResponse,
|
||||
AgentInvocationError,
|
||||
ExternalLoopState,
|
||||
InvokeAzureAgentExecutor,
|
||||
)
|
||||
@@ -2411,7 +2413,7 @@ class TestAgentExecutorExternalLoop:
|
||||
)
|
||||
response = AgentExternalInputResponse(user_input="continue")
|
||||
|
||||
with pytest.raises(AgentInvocationError) as exc_info:
|
||||
with pytest.raises(AgentInvalidRequestException) as exc_info:
|
||||
await executor.handle_external_input_response(original_request, response, mock_context)
|
||||
|
||||
assert "NonExistentAgent" in str(exc_info.value)
|
||||
|
||||
@@ -415,8 +415,9 @@ class TestAgentExecutors:
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_agent_not_found(self, mock_context, mock_state):
|
||||
"""Test InvokeAzureAgentExecutor raises error when agent not found."""
|
||||
from agent_framework.exceptions import AgentInvalidRequestException
|
||||
|
||||
from agent_framework_declarative._workflows import (
|
||||
AgentInvocationError,
|
||||
InvokeAzureAgentExecutor,
|
||||
)
|
||||
|
||||
@@ -430,8 +431,8 @@ class TestAgentExecutors:
|
||||
}
|
||||
executor = InvokeAzureAgentExecutor(action_def)
|
||||
|
||||
# Execute - should raise AgentInvocationError
|
||||
with pytest.raises(AgentInvocationError) as exc_info:
|
||||
# Execute - should raise AgentInvalidRequestException
|
||||
with pytest.raises(AgentInvalidRequestException) as exc_info:
|
||||
await executor.handle_action(ActionTrigger(), mock_context)
|
||||
|
||||
assert "non_existent_agent" in str(exc_info.value)
|
||||
|
||||
+2
-3
@@ -14,7 +14,6 @@ from agent_framework import (
|
||||
FunctionInvocationLayer,
|
||||
)
|
||||
from agent_framework._settings import load_settings
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
from agent_framework.observability import ChatTelemetryLayer
|
||||
from agent_framework.openai._chat_client import RawOpenAIChatClient
|
||||
from foundry_local import FoundryLocalManager
|
||||
@@ -236,7 +235,7 @@ class FoundryLocalClient(
|
||||
response = await client.get_response("Hello", options={"my_custom_option": "value"})
|
||||
|
||||
Raises:
|
||||
ServiceInitializationError: If the specified model ID or alias is not found.
|
||||
ValueError: If the specified model ID or alias is not found.
|
||||
Sometimes a model might be available but if you have specified a device
|
||||
type that is not supported by the model, it will not be found.
|
||||
|
||||
@@ -263,7 +262,7 @@ class FoundryLocalClient(
|
||||
"not found in Foundry Local."
|
||||
)
|
||||
)
|
||||
raise ServiceInitializationError(message)
|
||||
raise ValueError(message)
|
||||
if prepare_model:
|
||||
manager.download_model(alias_or_model_id=model_info.id, device=device)
|
||||
manager.load_model(alias_or_model_id=model_info.id, device=device)
|
||||
|
||||
@@ -5,7 +5,7 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from agent_framework import SupportsChatGetResponse
|
||||
from agent_framework._settings import load_settings
|
||||
from agent_framework.exceptions import ServiceInitializationError, SettingNotFoundError
|
||||
from agent_framework.exceptions import SettingNotFoundError
|
||||
|
||||
from agent_framework_foundry_local import FoundryLocalClient
|
||||
from agent_framework_foundry_local._foundry_local_client import FoundryLocalSettings
|
||||
@@ -103,7 +103,7 @@ def test_foundry_local_client_init_model_not_found(mock_foundry_local_manager: M
|
||||
"agent_framework_foundry_local._foundry_local_client.FoundryLocalManager",
|
||||
return_value=mock_foundry_local_manager,
|
||||
),
|
||||
pytest.raises(ServiceInitializationError, match="not found in Foundry Local"),
|
||||
pytest.raises(ValueError, match="not found in Foundry Local"),
|
||||
):
|
||||
FoundryLocalClient(model_id="unknown-model")
|
||||
|
||||
@@ -171,7 +171,7 @@ def test_foundry_local_client_init_model_not_found_with_device(mock_foundry_loca
|
||||
"agent_framework_foundry_local._foundry_local_client.FoundryLocalManager",
|
||||
return_value=mock_foundry_local_manager,
|
||||
),
|
||||
pytest.raises(ServiceInitializationError, match="unknown-model:GPU.*not found"),
|
||||
pytest.raises(ValueError, match="unknown-model:GPU.*not found"),
|
||||
):
|
||||
FoundryLocalClient(model_id="unknown-model", device=DeviceType.GPU)
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from agent_framework import (
|
||||
from agent_framework._settings import load_settings
|
||||
from agent_framework._tools import FunctionTool, ToolTypes
|
||||
from agent_framework._types import AgentRunInputs, normalize_tools
|
||||
from agent_framework.exceptions import ServiceException
|
||||
from agent_framework.exceptions import AgentException
|
||||
from copilot import CopilotClient, CopilotSession
|
||||
from copilot.generated.session_events import SessionEvent, SessionEventType
|
||||
from copilot.types import (
|
||||
@@ -199,7 +199,7 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
env_file_encoding: Encoding of the .env file, defaults to 'utf-8'.
|
||||
|
||||
Raises:
|
||||
ServiceInitializationError: If required configuration is missing or invalid.
|
||||
ValueError: If required configuration is missing or invalid.
|
||||
"""
|
||||
super().__init__(
|
||||
id=id,
|
||||
@@ -259,7 +259,7 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
agent as an async context manager.
|
||||
|
||||
Raises:
|
||||
ServiceException: If the client fails to start.
|
||||
AgentException: If the client fails to start.
|
||||
"""
|
||||
if self._started:
|
||||
return
|
||||
@@ -277,7 +277,7 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
await self._client.start()
|
||||
self._started = True
|
||||
except Exception as ex:
|
||||
raise ServiceException(f"Failed to start GitHub Copilot client: {ex}") from ex
|
||||
raise AgentException(f"Failed to start GitHub Copilot client: {ex}") from ex
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the Copilot client and clean up resources.
|
||||
@@ -343,7 +343,7 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
When stream=True: A ResponseStream of AgentResponseUpdate items.
|
||||
|
||||
Raises:
|
||||
ServiceException: If the request fails.
|
||||
AgentException: If the request fails.
|
||||
"""
|
||||
if stream:
|
||||
|
||||
@@ -381,7 +381,7 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
try:
|
||||
response_event = await copilot_session.send_and_wait({"prompt": prompt}, timeout=timeout)
|
||||
except Exception as ex:
|
||||
raise ServiceException(f"GitHub Copilot request failed: {ex}") from ex
|
||||
raise AgentException(f"GitHub Copilot request failed: {ex}") from ex
|
||||
|
||||
response_messages: list[Message] = []
|
||||
response_id: str | None = None
|
||||
@@ -426,7 +426,7 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
AgentResponseUpdate items.
|
||||
|
||||
Raises:
|
||||
ServiceException: If the request fails.
|
||||
AgentException: If the request fails.
|
||||
"""
|
||||
if not self._started:
|
||||
await self.start()
|
||||
@@ -457,7 +457,7 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
queue.put_nowait(None)
|
||||
elif event.type == SessionEventType.SESSION_ERROR:
|
||||
error_msg = event.data.message or "Unknown error"
|
||||
queue.put_nowait(ServiceException(f"GitHub Copilot session error: {error_msg}"))
|
||||
queue.put_nowait(AgentException(f"GitHub Copilot session error: {error_msg}"))
|
||||
|
||||
unsubscribe = copilot_session.on(event_handler)
|
||||
|
||||
@@ -565,10 +565,10 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
A CopilotSession instance.
|
||||
|
||||
Raises:
|
||||
ServiceException: If the session cannot be created.
|
||||
AgentException: If the session cannot be created.
|
||||
"""
|
||||
if not self._client:
|
||||
raise ServiceException("GitHub Copilot client not initialized. Call start() first.")
|
||||
raise RuntimeError("GitHub Copilot client not initialized. Call start() first.")
|
||||
|
||||
try:
|
||||
if agent_session.service_session_id:
|
||||
@@ -578,7 +578,7 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
agent_session.service_session_id = session.session_id
|
||||
return session
|
||||
except Exception as ex:
|
||||
raise ServiceException(f"Failed to create GitHub Copilot session: {ex}") from ex
|
||||
raise AgentException(f"Failed to create GitHub Copilot session: {ex}") from ex
|
||||
|
||||
async def _create_session(
|
||||
self,
|
||||
@@ -592,7 +592,7 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
runtime_options: Runtime options that take precedence over default_options.
|
||||
"""
|
||||
if not self._client:
|
||||
raise ServiceException("GitHub Copilot client not initialized. Call start() first.")
|
||||
raise RuntimeError("GitHub Copilot client not initialized. Call start() first.")
|
||||
|
||||
opts = runtime_options or {}
|
||||
config: SessionConfig = {"streaming": streaming}
|
||||
@@ -621,7 +621,7 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
async def _resume_session(self, session_id: str, streaming: bool) -> CopilotSession:
|
||||
"""Resume an existing Copilot session by ID."""
|
||||
if not self._client:
|
||||
raise ServiceException("GitHub Copilot client not initialized. Call start() first.")
|
||||
raise RuntimeError("GitHub Copilot client not initialized. Call start() first.")
|
||||
|
||||
config: ResumeSessionConfig = {"streaming": streaming}
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ from agent_framework import (
|
||||
Content,
|
||||
Message,
|
||||
)
|
||||
from agent_framework.exceptions import ServiceException
|
||||
from agent_framework.exceptions import AgentException
|
||||
from copilot.generated.session_events import Data, SessionEvent, SessionEventType
|
||||
|
||||
from agent_framework_github_copilot import GitHubCopilotAgent, GitHubCopilotOptions
|
||||
@@ -430,7 +430,7 @@ class TestGitHubCopilotAgentRunStreaming:
|
||||
|
||||
agent = GitHubCopilotAgent(client=mock_client)
|
||||
|
||||
with pytest.raises(ServiceException, match="session error"):
|
||||
with pytest.raises(AgentException, match="session error"):
|
||||
async for _ in agent.run("Hello", stream=True):
|
||||
pass
|
||||
|
||||
@@ -835,12 +835,12 @@ class TestGitHubCopilotAgentErrorHandling:
|
||||
"""Test cases for error handling."""
|
||||
|
||||
async def test_start_raises_on_client_error(self, mock_client: MagicMock) -> None:
|
||||
"""Test that start raises ServiceException when client fails to start."""
|
||||
"""Test that start raises AgentException when client fails to start."""
|
||||
mock_client.start.side_effect = Exception("Connection failed")
|
||||
|
||||
agent = GitHubCopilotAgent(client=mock_client)
|
||||
|
||||
with pytest.raises(ServiceException, match="Failed to start GitHub Copilot client"):
|
||||
with pytest.raises(AgentException, match="Failed to start GitHub Copilot client"):
|
||||
await agent.start()
|
||||
|
||||
async def test_run_raises_on_send_error(
|
||||
@@ -848,33 +848,33 @@ class TestGitHubCopilotAgentErrorHandling:
|
||||
mock_client: MagicMock,
|
||||
mock_session: MagicMock,
|
||||
) -> None:
|
||||
"""Test that run raises ServiceException when send_and_wait fails."""
|
||||
"""Test that run raises AgentException when send_and_wait fails."""
|
||||
mock_session.send_and_wait.side_effect = Exception("Request timeout")
|
||||
|
||||
agent = GitHubCopilotAgent(client=mock_client)
|
||||
|
||||
with pytest.raises(ServiceException, match="GitHub Copilot request failed"):
|
||||
with pytest.raises(AgentException, match="GitHub Copilot request failed"):
|
||||
await agent.run("Hello")
|
||||
|
||||
async def test_get_or_create_session_raises_on_create_error(
|
||||
self,
|
||||
mock_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test that _get_or_create_session raises ServiceException when create_session fails."""
|
||||
"""Test that _get_or_create_session raises AgentException when create_session fails."""
|
||||
mock_client.create_session.side_effect = Exception("Session creation failed")
|
||||
|
||||
agent = GitHubCopilotAgent(client=mock_client)
|
||||
await agent.start()
|
||||
|
||||
with pytest.raises(ServiceException, match="Failed to create GitHub Copilot session"):
|
||||
with pytest.raises(AgentException, match="Failed to create GitHub Copilot session"):
|
||||
await agent._get_or_create_session(AgentSession()) # type: ignore
|
||||
|
||||
async def test_get_or_create_session_raises_when_client_not_initialized(self) -> None:
|
||||
"""Test that _get_or_create_session raises ServiceException when client is not initialized."""
|
||||
"""Test that _get_or_create_session raises RuntimeError when client is not initialized."""
|
||||
agent = GitHubCopilotAgent()
|
||||
# Don't call start() - client remains None
|
||||
|
||||
with pytest.raises(ServiceException, match="GitHub Copilot client not initialized"):
|
||||
with pytest.raises(RuntimeError, match="GitHub Copilot client not initialized"):
|
||||
await agent._get_or_create_session(AgentSession()) # type: ignore
|
||||
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ from typing import TYPE_CHECKING, Any, ClassVar
|
||||
|
||||
from agent_framework import Message
|
||||
from agent_framework._sessions import AgentSession, BaseContextProvider, SessionContext
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
from mem0 import AsyncMemory, AsyncMemoryClient
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
@@ -172,9 +171,7 @@ class Mem0ContextProvider(BaseContextProvider):
|
||||
def _validate_filters(self) -> None:
|
||||
"""Validates that at least one filter is provided."""
|
||||
if not self.agent_id and not self.user_id and not self.application_id:
|
||||
raise ServiceInitializationError(
|
||||
"At least one of the filters: agent_id, user_id, or application_id is required."
|
||||
)
|
||||
raise ValueError("At least one of the filters: agent_id, user_id, or application_id is required.")
|
||||
|
||||
def _build_filters(self) -> dict[str, Any]:
|
||||
"""Build search filters from initialization parameters."""
|
||||
|
||||
@@ -8,7 +8,6 @@ from unittest.mock import AsyncMock, patch
|
||||
import pytest
|
||||
from agent_framework import AgentResponse, Message
|
||||
from agent_framework._sessions import AgentSession, SessionContext
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
|
||||
from agent_framework_mem0._context_provider import Mem0ContextProvider
|
||||
|
||||
@@ -136,15 +135,13 @@ class TestBeforeRun:
|
||||
assert "mem0" not in ctx.context_messages
|
||||
|
||||
async def test_validates_filters_before_search(self, mock_mem0_client: AsyncMock) -> None:
|
||||
"""Raises ServiceInitializationError when no filters."""
|
||||
"""Raises ValueError when no filters."""
|
||||
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client)
|
||||
session = AgentSession(session_id="test-session")
|
||||
ctx = SessionContext(input_messages=[Message(role="user", text="test")], session_id="s1")
|
||||
|
||||
with pytest.raises(ServiceInitializationError, match="At least one of the filters"):
|
||||
await provider.before_run(
|
||||
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
|
||||
) # type: ignore[arg-type]
|
||||
with pytest.raises(ValueError, match="At least one of the filters"):
|
||||
await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
|
||||
|
||||
async def test_v1_1_response_format(self, mock_mem0_client: AsyncMock) -> None:
|
||||
"""Search response in v1.1 dict format with 'results' key."""
|
||||
@@ -312,16 +309,14 @@ class TestAfterRun:
|
||||
assert "run_id" not in mock_mem0_client.add.call_args.kwargs
|
||||
|
||||
async def test_validates_filters(self, mock_mem0_client: AsyncMock) -> None:
|
||||
"""Raises ServiceInitializationError when no filters."""
|
||||
"""Raises ValueError when no filters."""
|
||||
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client)
|
||||
session = AgentSession(session_id="test-session")
|
||||
ctx = SessionContext(input_messages=[Message(role="user", text="hi")], session_id="s1")
|
||||
ctx._response = AgentResponse(messages=[Message(role="assistant", text="hey")])
|
||||
|
||||
with pytest.raises(ServiceInitializationError, match="At least one of the filters"):
|
||||
await provider.after_run(
|
||||
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
|
||||
) # type: ignore[arg-type]
|
||||
with pytest.raises(ValueError, match="At least one of the filters"):
|
||||
await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
|
||||
|
||||
async def test_stores_with_application_id_metadata(self, mock_mem0_client: AsyncMock) -> None:
|
||||
"""application_id is passed in metadata."""
|
||||
@@ -347,7 +342,7 @@ class TestValidateFilters:
|
||||
|
||||
def test_raises_when_no_filters(self, mock_mem0_client: AsyncMock) -> None:
|
||||
provider = Mem0ContextProvider(source_id="mem0", mem0_client=mock_mem0_client)
|
||||
with pytest.raises(ServiceInitializationError, match="At least one of the filters"):
|
||||
with pytest.raises(ValueError, match="At least one of the filters"):
|
||||
provider._validate_filters()
|
||||
|
||||
def test_passes_with_user_id(self, mock_mem0_client: AsyncMock) -> None:
|
||||
|
||||
@@ -32,8 +32,8 @@ from agent_framework import (
|
||||
)
|
||||
from agent_framework._settings import load_settings
|
||||
from agent_framework.exceptions import (
|
||||
ServiceInvalidRequestError,
|
||||
ServiceResponseException,
|
||||
ChatClientException,
|
||||
ChatClientInvalidRequestException,
|
||||
)
|
||||
from agent_framework.observability import ChatTelemetryLayer
|
||||
from ollama import AsyncClient
|
||||
@@ -363,7 +363,7 @@ class OllamaChatClient(
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as ex:
|
||||
raise ServiceResponseException(f"Ollama streaming chat request failed : {ex}", ex) from ex
|
||||
raise ChatClientException(f"Ollama streaming chat request failed : {ex}", ex) from ex
|
||||
|
||||
async for part in response_object:
|
||||
yield self._parse_streaming_response_from_ollama(part)
|
||||
@@ -381,7 +381,7 @@ class OllamaChatClient(
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as ex:
|
||||
raise ServiceResponseException(f"Ollama chat request failed : {ex}", ex) from ex
|
||||
raise ChatClientException(f"Ollama chat request failed : {ex}", ex) from ex
|
||||
|
||||
return self._parse_response_from_ollama(response)
|
||||
|
||||
@@ -423,7 +423,7 @@ class OllamaChatClient(
|
||||
if messages and "messages" not in run_options:
|
||||
run_options["messages"] = self._prepare_messages_for_ollama(messages)
|
||||
if "messages" not in run_options:
|
||||
raise ServiceInvalidRequestError("Messages are required for chat completions")
|
||||
raise ChatClientInvalidRequestException("Messages are required for chat completions")
|
||||
|
||||
# model id
|
||||
if not run_options.get("model"):
|
||||
@@ -457,7 +457,7 @@ class OllamaChatClient(
|
||||
|
||||
def _format_user_message(self, message: Message) -> list[OllamaMessage]:
|
||||
if not any(c.type in {"text", "data"} for c in message.contents) and not message.text:
|
||||
raise ServiceInvalidRequestError(
|
||||
raise ChatClientInvalidRequestException(
|
||||
"Ollama connector currently only supports user messages with TextContent or DataContent."
|
||||
)
|
||||
|
||||
@@ -468,7 +468,9 @@ class OllamaChatClient(
|
||||
data_contents = [c for c in message.contents if c.type == "data"]
|
||||
if data_contents:
|
||||
if not any(c.has_top_level_media_type("image") for c in data_contents):
|
||||
raise ServiceInvalidRequestError("Only image data content is supported for user messages in Ollama.")
|
||||
raise ChatClientInvalidRequestException(
|
||||
"Only image data content is supported for user messages in Ollama."
|
||||
)
|
||||
# Ollama expects base64 strings without prefix
|
||||
user_message["images"] = [c.uri.split(",")[1] for c in data_contents if c.uri]
|
||||
return [user_message]
|
||||
|
||||
@@ -14,11 +14,7 @@ from agent_framework import (
|
||||
chat_middleware,
|
||||
tool,
|
||||
)
|
||||
from agent_framework.exceptions import (
|
||||
ServiceInvalidRequestError,
|
||||
ServiceResponseException,
|
||||
SettingNotFoundError,
|
||||
)
|
||||
from agent_framework.exceptions import ChatClientException, ChatClientInvalidRequestException, SettingNotFoundError
|
||||
from ollama import AsyncClient
|
||||
from ollama._types import ChatResponse as OllamaChatResponse
|
||||
from ollama._types import Message as OllamaMessage
|
||||
@@ -234,7 +230,7 @@ async def test_empty_messages() -> None:
|
||||
host="http://localhost:12345",
|
||||
model_id="test-model",
|
||||
)
|
||||
with pytest.raises(ServiceInvalidRequestError):
|
||||
with pytest.raises(ChatClientInvalidRequestException):
|
||||
await ollama_chat_client.get_response(messages=[])
|
||||
|
||||
|
||||
@@ -284,7 +280,7 @@ async def test_cmc_chat_failure(
|
||||
|
||||
ollama_client = OllamaChatClient()
|
||||
|
||||
with pytest.raises(ServiceResponseException) as exc_info:
|
||||
with pytest.raises(ChatClientException) as exc_info:
|
||||
await ollama_client.get_response(messages=chat_history)
|
||||
|
||||
assert "Ollama chat request failed" in str(exc_info.value)
|
||||
@@ -339,7 +335,7 @@ async def test_cmc_streaming_chat_failure(
|
||||
|
||||
ollama_client = OllamaChatClient()
|
||||
|
||||
with pytest.raises(ServiceResponseException) as exc_info:
|
||||
with pytest.raises(ChatClientException) as exc_info:
|
||||
async for _ in ollama_client.get_response(messages=chat_history, stream=True):
|
||||
pass
|
||||
|
||||
@@ -436,7 +432,7 @@ async def test_cmc_with_invalid_data_content_media_type(
|
||||
chat_history: list[Message],
|
||||
mock_streaming_chat_completion_response: AsyncStream[OllamaChatResponse],
|
||||
) -> None:
|
||||
with pytest.raises(ServiceInvalidRequestError):
|
||||
with pytest.raises(ChatClientInvalidRequestException):
|
||||
mock_chat.return_value = mock_streaming_chat_completion_response
|
||||
# Remote Uris are not supported by Ollama client
|
||||
chat_history.append(
|
||||
@@ -459,7 +455,7 @@ async def test_cmc_with_invalid_content_type(
|
||||
chat_history: list[Message],
|
||||
mock_chat_completion_response: AsyncStream[OllamaChatResponse],
|
||||
) -> None:
|
||||
with pytest.raises(ServiceInvalidRequestError):
|
||||
with pytest.raises(ChatClientInvalidRequestException):
|
||||
mock_chat.return_value = mock_chat_completion_response
|
||||
# Remote Uris are not supported by Ollama client
|
||||
chat_history.append(
|
||||
|
||||
@@ -20,10 +20,10 @@ Integration with Microsoft Purview for data governance and policy enforcement.
|
||||
|
||||
### Exceptions
|
||||
|
||||
- **`PurviewAuthenticationError`** - Authentication failures
|
||||
- **`PurviewRateLimitError`** - Rate limit exceeded
|
||||
- **`PurviewRequestError`** / **`PurviewServiceError`** - Request/service errors
|
||||
- **`PurviewPaymentRequiredError`** - Payment required
|
||||
- **`PurviewAuthenticationError`** - Authentication failures (inherits from `IntegrationInvalidAuthException`)
|
||||
- **`PurviewRateLimitError`** - Rate limit exceeded (inherits from `IntegrationException` via `PurviewServiceError`)
|
||||
- **`PurviewRequestError`** / **`PurviewServiceError`** - Request/service errors (inherit from `IntegrationException`)
|
||||
- **`PurviewPaymentRequiredError`** - Payment required (inherits from `IntegrationException` via `PurviewServiceError`)
|
||||
|
||||
## Usage
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
"""Purview specific exceptions (minimal error shaping)."""
|
||||
"""Purview specific exceptions mapped to the Integration exception hierarchy."""
|
||||
|
||||
from agent_framework.exceptions import ServiceResponseException
|
||||
from agent_framework.exceptions import IntegrationException, IntegrationInvalidAuthException
|
||||
|
||||
__all__ = [
|
||||
"PurviewAuthenticationError",
|
||||
@@ -12,11 +12,11 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
class PurviewServiceError(ServiceResponseException):
|
||||
class PurviewServiceError(IntegrationException):
|
||||
"""Base exception for Purview errors."""
|
||||
|
||||
|
||||
class PurviewAuthenticationError(PurviewServiceError):
|
||||
class PurviewAuthenticationError(IntegrationInvalidAuthException):
|
||||
"""Authentication / authorization failure (401/403)."""
|
||||
|
||||
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
"""Tests for Purview exceptions."""
|
||||
|
||||
from agent_framework.exceptions import IntegrationException, IntegrationInvalidAuthException
|
||||
|
||||
from agent_framework_purview import (
|
||||
PurviewAuthenticationError,
|
||||
PurviewPaymentRequiredError,
|
||||
@@ -18,28 +20,28 @@ class TestPurviewExceptions:
|
||||
"""Test PurviewServiceError base exception."""
|
||||
error = PurviewServiceError("Service error occurred")
|
||||
assert str(error) == "Service error occurred"
|
||||
assert isinstance(error, Exception)
|
||||
assert isinstance(error, IntegrationException)
|
||||
|
||||
def test_purview_authentication_error(self) -> None:
|
||||
"""Test PurviewAuthenticationError exception."""
|
||||
error = PurviewAuthenticationError("Authentication failed")
|
||||
assert str(error) == "Authentication failed"
|
||||
assert isinstance(error, PurviewServiceError)
|
||||
assert isinstance(error, IntegrationInvalidAuthException)
|
||||
|
||||
def test_purview_payment_required_error(self) -> None:
|
||||
"""Test PurviewPaymentRequiredError exception."""
|
||||
error = PurviewPaymentRequiredError("Payment required")
|
||||
assert str(error) == "Payment required"
|
||||
assert isinstance(error, PurviewServiceError)
|
||||
assert isinstance(error, IntegrationException)
|
||||
|
||||
def test_purview_rate_limit_error(self) -> None:
|
||||
"""Test PurviewRateLimitError exception."""
|
||||
error = PurviewRateLimitError("Rate limit exceeded")
|
||||
assert str(error) == "Rate limit exceeded"
|
||||
assert isinstance(error, PurviewServiceError)
|
||||
assert isinstance(error, IntegrationException)
|
||||
|
||||
def test_purview_request_error(self) -> None:
|
||||
"""Test PurviewRequestError exception."""
|
||||
error = PurviewRequestError("Request failed")
|
||||
assert str(error) == "Request failed"
|
||||
assert isinstance(error, PurviewServiceError)
|
||||
assert isinstance(error, IntegrationException)
|
||||
|
||||
@@ -19,8 +19,7 @@ from agent_framework import Message
|
||||
from agent_framework._sessions import AgentSession, BaseContextProvider, SessionContext
|
||||
from agent_framework.exceptions import (
|
||||
AgentException,
|
||||
ServiceInitializationError,
|
||||
ServiceInvalidRequestError,
|
||||
IntegrationInvalidRequestException,
|
||||
)
|
||||
from redisvl.index import AsyncSearchIndex
|
||||
from redisvl.query import HybridQuery, TextQuery
|
||||
@@ -285,7 +284,7 @@ class RedisContextProvider(BaseContextProvider):
|
||||
existing_sig = _schema_signature(existing_schema)
|
||||
current_sig = _schema_signature(current_schema)
|
||||
if existing_sig != current_sig:
|
||||
raise ServiceInitializationError(
|
||||
raise ValueError(
|
||||
"Existing Redis index schema is incompatible with the current configuration.\n"
|
||||
f"Existing (significant): {json.dumps(existing_sig, indent=2, sort_keys=True)}\n"
|
||||
f"Current (significant): {json.dumps(current_sig, indent=2, sort_keys=True)}\n"
|
||||
@@ -313,7 +312,7 @@ class RedisContextProvider(BaseContextProvider):
|
||||
d.setdefault("thread_id", session_id)
|
||||
d.setdefault("conversation_id", session_id)
|
||||
if "content" not in d:
|
||||
raise ServiceInvalidRequestError("add() requires a 'content' field in data")
|
||||
raise IntegrationInvalidRequestException("add() requires a 'content' field in data")
|
||||
if self.vector_field_name:
|
||||
d.setdefault(self.vector_field_name, None)
|
||||
prepared.append(d)
|
||||
@@ -345,7 +344,7 @@ class RedisContextProvider(BaseContextProvider):
|
||||
|
||||
q = (text or "").strip()
|
||||
if not q:
|
||||
raise ServiceInvalidRequestError("text_search() requires non-empty text")
|
||||
raise IntegrationInvalidRequestException("text_search() requires non-empty text")
|
||||
num_results = max(int(num_results or 10), 1)
|
||||
|
||||
combined_filter = self._build_filter_from_dict({
|
||||
@@ -394,14 +393,12 @@ class RedisContextProvider(BaseContextProvider):
|
||||
text_results = await self.redis_index.query(query)
|
||||
return cast(list[dict[str, Any]], text_results)
|
||||
except Exception as exc: # pragma: no cover
|
||||
raise ServiceInvalidRequestError(f"Redis text search failed: {exc}") from exc
|
||||
raise IntegrationInvalidRequestException(f"Redis text search failed: {exc}") from exc
|
||||
|
||||
def _validate_filters(self) -> None:
|
||||
"""Validates that at least one filter is provided."""
|
||||
if not self.agent_id and not self.user_id and not self.application_id:
|
||||
raise ServiceInitializationError(
|
||||
"At least one of the filters: agent_id, user_id, or application_id is required."
|
||||
)
|
||||
raise ValueError("At least one of the filters: agent_id, user_id, or application_id is required.")
|
||||
|
||||
async def search_all(self, page_size: int = 200) -> list[dict[str, Any]]:
|
||||
"""Returns all documents in the index."""
|
||||
|
||||
@@ -10,7 +10,6 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
from agent_framework import AgentResponse, Message
|
||||
from agent_framework._sessions import AgentSession, SessionContext
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
|
||||
from agent_framework_redis._context_provider import RedisContextProvider
|
||||
from agent_framework_redis._history_provider import RedisHistoryProvider
|
||||
@@ -108,7 +107,7 @@ class TestRedisContextProviderInit:
|
||||
class TestRedisContextProviderValidateFilters:
|
||||
def test_no_filters_raises(self, patch_index_from_dict: MagicMock): # noqa: ARG002
|
||||
provider = RedisContextProvider(source_id="ctx")
|
||||
with pytest.raises(ServiceInitializationError, match="(?i)at least one"):
|
||||
with pytest.raises(ValueError, match="(?i)at least one"):
|
||||
provider._validate_filters()
|
||||
|
||||
def test_any_single_filter_ok(self, patch_index_from_dict: MagicMock): # noqa: ARG002
|
||||
|
||||
Reference in New Issue
Block a user