mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Merge branch 'main' into local-branch-fix-workflow-as-agent-pending-request-handling
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import uuid
|
||||
from asyncio import CancelledError
|
||||
from collections.abc import Mapping
|
||||
from functools import partial
|
||||
@@ -181,9 +182,18 @@ class A2AExecutor(AgentExecutor):
|
||||
"""Run the agent in streaming mode and publish updates to the task updater."""
|
||||
response_stream = self._agent.run(query, session=session, stream=True, **self._run_kwargs)
|
||||
streamed_artifact_ids: set[str] = set()
|
||||
# Generate a stable artifact ID for the entire stream so all chunks share the same ID.
|
||||
# This ensures clients can coalesce streaming tokens into a single artifact/message
|
||||
# per the A2A spec (TaskArtifactUpdateEvent with append=True on same artifactId).
|
||||
default_artifact_id = str(uuid.uuid4())
|
||||
await (
|
||||
response_stream.with_transform_hook(
|
||||
partial(self.handle_events, updater=updater, streamed_artifact_ids=streamed_artifact_ids)
|
||||
partial(
|
||||
self.handle_events,
|
||||
updater=updater,
|
||||
streamed_artifact_ids=streamed_artifact_ids,
|
||||
default_artifact_id=default_artifact_id,
|
||||
)
|
||||
)
|
||||
).get_final_response()
|
||||
|
||||
@@ -199,7 +209,11 @@ class A2AExecutor(AgentExecutor):
|
||||
await self.handle_events(message, updater)
|
||||
|
||||
async def handle_events(
|
||||
self, item: Message | AgentResponseUpdate, updater: TaskUpdater, streamed_artifact_ids: set[str] | None = None
|
||||
self,
|
||||
item: Message | AgentResponseUpdate,
|
||||
updater: TaskUpdater,
|
||||
streamed_artifact_ids: set[str] | None = None,
|
||||
default_artifact_id: str | None = None,
|
||||
) -> None:
|
||||
"""Convert agent response items (Messages or Updates) to A2A protocol events.
|
||||
|
||||
@@ -213,7 +227,10 @@ class A2AExecutor(AgentExecutor):
|
||||
item: The agent response item (Message or AgentResponseUpdate) to process.
|
||||
updater: The task updater to publish events to.
|
||||
streamed_artifact_ids: A set of artifact IDs that have already been streamed.
|
||||
Used to prevent duplicate updates for the same artifact.
|
||||
Used to track which artifacts need append=True on subsequent chunks.
|
||||
default_artifact_id: A stable artifact ID to use when the item does not provide one.
|
||||
This ensures all streaming chunks for a single response share the same artifact ID,
|
||||
allowing clients to coalesce them into a single message.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
@@ -224,6 +241,7 @@ class A2AExecutor(AgentExecutor):
|
||||
item: Message | AgentResponseUpdate,
|
||||
updater: TaskUpdater,
|
||||
streamed_artifact_ids: set[str] | None = None,
|
||||
default_artifact_id: str | None = None,
|
||||
) -> None:
|
||||
# Custom logic to transform item contents
|
||||
if item.role == "assistant" and item.contents:
|
||||
@@ -260,19 +278,20 @@ class A2AExecutor(AgentExecutor):
|
||||
|
||||
if parts:
|
||||
if isinstance(item, AgentResponseUpdate):
|
||||
# Resolve artifact ID: use item's message_id if available, otherwise fall back
|
||||
# to the stable default_artifact_id so all streaming chunks share the same ID.
|
||||
artifact_id = item.message_id or default_artifact_id
|
||||
# For streaming updates, we send TaskArtifactUpdateEvent via add_artifact
|
||||
await updater.add_artifact(
|
||||
parts=parts,
|
||||
artifact_id=item.message_id,
|
||||
artifact_id=artifact_id,
|
||||
metadata=metadata,
|
||||
append=(
|
||||
True
|
||||
if streamed_artifact_ids is not None and item.message_id in (streamed_artifact_ids or set())
|
||||
else None
|
||||
True if streamed_artifact_ids is not None and artifact_id in streamed_artifact_ids else None
|
||||
),
|
||||
)
|
||||
if item.message_id and streamed_artifact_ids is not None:
|
||||
streamed_artifact_ids.add(item.message_id)
|
||||
if artifact_id and streamed_artifact_ids is not None:
|
||||
streamed_artifact_ids.add(artifact_id)
|
||||
else:
|
||||
# For final messages, we send TaskStatusUpdateEvent with 'working' state
|
||||
await updater.update_status(
|
||||
|
||||
@@ -492,6 +492,7 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
contents=contents,
|
||||
role="assistant" if msg.role == A2ARole.ROLE_AGENT else "user",
|
||||
response_id=msg.message_id or str(uuid.uuid4()),
|
||||
message_id=msg.message_id,
|
||||
additional_properties={"a2a_metadata": metadata} if metadata else None,
|
||||
raw_representation=msg,
|
||||
)
|
||||
@@ -732,6 +733,7 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
contents=contents,
|
||||
role="assistant" if message.role == A2ARole.ROLE_AGENT else "user",
|
||||
response_id=update_event.task_id,
|
||||
message_id=message.message_id,
|
||||
additional_properties={"a2a_metadata": merged_metadata} if merged_metadata else None,
|
||||
raw_representation=update_event,
|
||||
)
|
||||
|
||||
@@ -420,6 +420,7 @@ async def test_run_streaming_with_message_response(a2a_agent: A2AAgent, mock_a2a
|
||||
assert content.text == "Streaming response from agent!"
|
||||
|
||||
assert updates[0].response_id == "msg-stream-123"
|
||||
assert updates[0].message_id == "msg-stream-123"
|
||||
assert mock_a2a_client.call_count == 1
|
||||
|
||||
|
||||
@@ -1422,7 +1423,7 @@ async def test_streaming_status_update_event_yields_content(
|
||||
status=TaskStatus(
|
||||
state=TaskState.TASK_STATE_COMPLETED,
|
||||
message=A2AMessage(
|
||||
message_id=str(uuid4()),
|
||||
message_id="msg-status-done",
|
||||
role=A2ARole.ROLE_AGENT,
|
||||
parts=[Part(text="Done")],
|
||||
),
|
||||
@@ -1437,6 +1438,7 @@ async def test_streaming_status_update_event_yields_content(
|
||||
assert len(updates) == 1
|
||||
assert updates[0].text == "Done"
|
||||
assert updates[0].role == "assistant"
|
||||
assert updates[0].message_id == "msg-status-done"
|
||||
assert updates[0].raw_representation == update_event
|
||||
|
||||
|
||||
@@ -1449,7 +1451,7 @@ async def test_streaming_input_required_emits_content(a2a_agent: A2AAgent, mock_
|
||||
status=TaskStatus(
|
||||
state=TaskState.TASK_STATE_INPUT_REQUIRED,
|
||||
message=A2AMessage(
|
||||
message_id=str(uuid4()),
|
||||
message_id="msg-input-req",
|
||||
role=A2ARole.ROLE_AGENT,
|
||||
parts=[Part(text="What is your name?")],
|
||||
),
|
||||
@@ -1463,6 +1465,7 @@ async def test_streaming_input_required_emits_content(a2a_agent: A2AAgent, mock_
|
||||
|
||||
assert len(updates) == 1
|
||||
assert updates[0].text == "What is your name?"
|
||||
assert updates[0].message_id == "msg-input-req"
|
||||
|
||||
|
||||
@mark.asyncio
|
||||
|
||||
@@ -803,6 +803,15 @@ class RawAnthropicClient(
|
||||
}
|
||||
a_content.append(mcp_result)
|
||||
case "text_reasoning":
|
||||
if content.text is None:
|
||||
if (
|
||||
content.protected_data
|
||||
and a_content
|
||||
and a_content[-1].get("type") == "thinking"
|
||||
and "signature" not in a_content[-1]
|
||||
):
|
||||
a_content[-1]["signature"] = content.protected_data
|
||||
continue
|
||||
thinking_block: dict[str, Any] = {"type": "thinking", "thinking": content.text}
|
||||
if content.protected_data:
|
||||
thinking_block["signature"] = content.protected_data
|
||||
|
||||
@@ -485,6 +485,48 @@ def test_prepare_message_for_anthropic_text_reasoning_with_signature(
|
||||
assert result["content"][0]["signature"] == "sig_abc123"
|
||||
|
||||
|
||||
def test_prepare_message_for_anthropic_attaches_signature_only_reasoning(
|
||||
mock_anthropic_client: MagicMock,
|
||||
) -> None:
|
||||
client = create_test_anthropic_client(mock_anthropic_client)
|
||||
message = Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_text_reasoning(text="Let me think about this..."),
|
||||
Content.from_text_reasoning(text=None, protected_data="sig_abc123"),
|
||||
],
|
||||
)
|
||||
|
||||
result = client._prepare_message_for_anthropic(message)
|
||||
|
||||
assert result["content"] == [
|
||||
{"type": "thinking", "thinking": "Let me think about this...", "signature": "sig_abc123"}
|
||||
]
|
||||
|
||||
|
||||
def test_prepare_message_for_anthropic_skips_orphan_signature_only_reasoning(
|
||||
mock_anthropic_client: MagicMock,
|
||||
) -> None:
|
||||
client = create_test_anthropic_client(mock_anthropic_client)
|
||||
message = Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_text_reasoning(text=None, protected_data="sig_abc123"),
|
||||
Content.from_function_call(
|
||||
call_id="call_123",
|
||||
name="get_weather",
|
||||
arguments={"location": "San Francisco"},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
result = client._prepare_message_for_anthropic(message)
|
||||
|
||||
assert len(result["content"]) == 1
|
||||
assert result["content"][0]["type"] == "tool_use"
|
||||
assert result["content"][0]["id"] == "call_123"
|
||||
|
||||
|
||||
def test_prepare_message_for_anthropic_mcp_server_tool_call(
|
||||
mock_anthropic_client: MagicMock,
|
||||
) -> None:
|
||||
|
||||
@@ -365,6 +365,14 @@ def _start_function_app(sample_path: Path, port: int) -> subprocess.Popen[Any]:
|
||||
# use the task hub name to separate orchestration state.
|
||||
env["TASKHUB_NAME"] = f"test{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# The Azure Functions Python worker's dependency isolation mechanism crashes
|
||||
# on Python 3.13 with a SIGSEGV in the protobuf C extension (google._upb).
|
||||
# Disabling isolation lets the worker load dependencies from the app's own
|
||||
# environment, which avoids the crash.
|
||||
# See: https://github.com/Azure/azure-functions-python-worker/issues/1797
|
||||
if sys.version_info >= (3, 13):
|
||||
env.setdefault("PYTHON_ISOLATE_WORKER_DEPENDENCIES", "0")
|
||||
|
||||
# On Windows, use CREATE_NEW_PROCESS_GROUP to allow proper termination
|
||||
# shell=True only on Windows to handle PATH resolution
|
||||
if sys.platform == "win32":
|
||||
@@ -375,8 +383,15 @@ def _start_function_app(sample_path: Path, port: int) -> subprocess.Popen[Any]:
|
||||
shell=True,
|
||||
env=env,
|
||||
)
|
||||
# On Unix, don't use shell=True to avoid shell wrapper issues
|
||||
return subprocess.Popen(["func", "start", "--port", str(port)], cwd=str(sample_path), env=env)
|
||||
# On Unix, use start_new_session=True to isolate the process group from the
|
||||
# pytest-xdist worker. Without this, signals (e.g. from test-timeout) can
|
||||
# propagate to the func host and vice-versa, potentially killing the worker.
|
||||
return subprocess.Popen(
|
||||
["func", "start", "--port", str(port)],
|
||||
cwd=str(sample_path),
|
||||
env=env,
|
||||
start_new_session=True,
|
||||
)
|
||||
|
||||
|
||||
def _wait_for_function_app_ready(func_process: subprocess.Popen[Any], port: int, max_wait: int = 60) -> None:
|
||||
@@ -533,18 +548,33 @@ def function_app_for_test(request: pytest.FixtureRequest) -> Iterator[dict[str,
|
||||
_load_and_validate_env(sample_path)
|
||||
|
||||
max_attempts = 3
|
||||
# The overall budget MUST be shorter than the pytest-timeout value
|
||||
# (--timeout=120 by default) so that the fixture finishes cleanly instead
|
||||
# of being killed by os._exit() which crashes the xdist worker.
|
||||
overall_budget = 100 # seconds – leaves headroom below the 120 s test timeout
|
||||
last_error: Exception | None = None
|
||||
func_process: subprocess.Popen[Any] | None = None
|
||||
base_url = ""
|
||||
port = 0
|
||||
overall_start = time.monotonic()
|
||||
attempts_made = 0
|
||||
|
||||
for _ in range(max_attempts):
|
||||
remaining = overall_budget - (time.monotonic() - overall_start)
|
||||
if remaining < 10:
|
||||
# Not enough time for another attempt; bail out.
|
||||
break
|
||||
|
||||
attempts_made += 1
|
||||
port = _find_available_port()
|
||||
base_url = _build_base_url(port)
|
||||
func_process = _start_function_app(sample_path, port)
|
||||
|
||||
try:
|
||||
_wait_for_function_app_ready(func_process, port)
|
||||
# Cap each attempt's wait to the remaining budget minus a small
|
||||
# buffer for cleanup.
|
||||
per_attempt_wait = min(60, int(remaining) - 5)
|
||||
_wait_for_function_app_ready(func_process, port, max_wait=max(per_attempt_wait, 10))
|
||||
last_error = None
|
||||
break
|
||||
except FunctionAppStartupError as exc:
|
||||
@@ -553,7 +583,8 @@ def function_app_for_test(request: pytest.FixtureRequest) -> Iterator[dict[str,
|
||||
func_process = None
|
||||
|
||||
if func_process is None:
|
||||
error_message = f"Function app failed to start after {max_attempts} attempt(s)."
|
||||
elapsed = int(time.monotonic() - overall_start)
|
||||
error_message = f"Function app failed to start after {attempts_made} attempt(s) ({elapsed}s elapsed)."
|
||||
if last_error is not None:
|
||||
error_message += f" Last error: {last_error}"
|
||||
pytest.fail(error_message)
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
@@ -36,6 +37,7 @@ from agent_framework.observability import ChatTelemetryLayer
|
||||
from boto3.session import Session as Boto3Session
|
||||
from botocore.client import BaseClient
|
||||
from botocore.config import Config as BotoConfig
|
||||
from botocore.exceptions import ClientError
|
||||
from pydantic import BaseModel
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
@@ -115,13 +117,20 @@ class BedrockChatOptions(ChatOptions[ResponseModelT], Generic[ResponseModelT], t
|
||||
translates to ``toolConfig.tools``.
|
||||
tool_choice: How the model should use tools,
|
||||
translates to ``toolConfig.toolChoice``.
|
||||
response_format: Structured output format. Accepts a Pydantic BaseModel
|
||||
subclass or an OpenAI-style dict schema
|
||||
(``{"json_schema": {"name": ..., "schema": ...}}``).
|
||||
When provided, the Converse API request includes
|
||||
``outputConfig.textFormat`` with the schema serialized as a JSON
|
||||
string. ``ChatResponse.value`` will be populated with the parsed
|
||||
model instance. Only supported on models that support
|
||||
``outputConfig.textFormat``. Unsupported models raise a ValueError.
|
||||
|
||||
# Options not supported in Bedrock Converse API:
|
||||
seed: Not supported.
|
||||
frequency_penalty: Not supported.
|
||||
presence_penalty: Not supported.
|
||||
allow_multiple_tool_calls: Not supported (models handle parallel calls automatically).
|
||||
response_format: Not directly supported (use model-specific prompting).
|
||||
user: Not supported.
|
||||
store: Not supported.
|
||||
logit_bias: Not supported.
|
||||
@@ -161,9 +170,6 @@ class BedrockChatOptions(ChatOptions[ResponseModelT], Generic[ResponseModelT], t
|
||||
allow_multiple_tool_calls: None # type: ignore[misc]
|
||||
"""Not supported. Bedrock models handle parallel tool calls automatically."""
|
||||
|
||||
response_format: None # type: ignore[misc]
|
||||
"""Not directly supported. Use model-specific prompting for JSON output."""
|
||||
|
||||
user: None # type: ignore[misc]
|
||||
"""Not supported in Bedrock Converse API."""
|
||||
|
||||
@@ -324,10 +330,28 @@ class BedrockChatClient(
|
||||
return Boto3Session(**session_kwargs)
|
||||
|
||||
def _invoke_converse(self, request: Mapping[str, Any]) -> dict[str, Any]:
|
||||
response = self._bedrock_client.converse(**request)
|
||||
if not isinstance(response, Mapping):
|
||||
raise ChatClientInvalidResponseException("Bedrock converse response must be a mapping.")
|
||||
return response
|
||||
try:
|
||||
response = self._bedrock_client.converse(**request)
|
||||
if not isinstance(response, Mapping):
|
||||
raise ChatClientInvalidResponseException("Bedrock converse response must be a mapping.")
|
||||
return response
|
||||
except ClientError as e:
|
||||
error_details = e.response.get("Error", {})
|
||||
error_code = error_details.get("Code", "")
|
||||
error_message = error_details.get("Message", "")
|
||||
# "outputConfig" in error_message catches cases where Bedrock explicitly
|
||||
# rejects the outputConfig field (unsupported model). Other ValidationExceptions
|
||||
# (e.g. malformed schema shape, invalid property values) will not mention
|
||||
# "outputConfig" and will bubble up as raw ClientError without being misdiagnosed.
|
||||
if error_code == "ValidationException" and (
|
||||
"outputconfig" in error_message.lower() or "outputconfig" in str(e).lower()
|
||||
):
|
||||
raise ValueError(
|
||||
f"Model '{self.model}' does not support structured output via outputConfig.textFormat. "
|
||||
"Check the model's Bedrock Converse outputConfig/textFormat support. "
|
||||
f"AWS error Code: {error_code}. AWS error Message: {error_message}"
|
||||
) from e
|
||||
raise
|
||||
|
||||
@override
|
||||
def _inner_get_response(
|
||||
@@ -344,7 +368,7 @@ class BedrockChatClient(
|
||||
# Streaming mode - simulate streaming by yielding a single update
|
||||
async def _stream() -> AsyncIterable[ChatResponseUpdate]:
|
||||
response = await asyncio.to_thread(self._invoke_converse, request)
|
||||
parsed_response = self._process_converse_response(response)
|
||||
parsed_response = self._process_converse_response(response, options)
|
||||
contents = list(parsed_response.messages[0].contents if parsed_response.messages else [])
|
||||
if parsed_response.usage_details:
|
||||
contents.append(Content.from_usage(usage_details=parsed_response.usage_details)) # type: ignore[arg-type]
|
||||
@@ -360,12 +384,12 @@ class BedrockChatClient(
|
||||
raw_representation=parsed_response.raw_representation,
|
||||
)
|
||||
|
||||
return self._build_response_stream(_stream())
|
||||
return self._build_response_stream(_stream(), response_format=options.get("response_format"))
|
||||
|
||||
# Non-streaming mode
|
||||
async def _get_response() -> ChatResponse:
|
||||
raw_response = await asyncio.to_thread(self._invoke_converse, request)
|
||||
return self._process_converse_response(raw_response)
|
||||
return self._process_converse_response(raw_response, options)
|
||||
|
||||
return _get_response()
|
||||
|
||||
@@ -430,6 +454,9 @@ class BedrockChatClient(
|
||||
if tool_config:
|
||||
run_options["toolConfig"] = tool_config
|
||||
|
||||
if output_config := self._prepare_output_config(options.get("response_format")):
|
||||
run_options["outputConfig"] = output_config
|
||||
|
||||
return run_options
|
||||
|
||||
def _prepare_bedrock_messages(
|
||||
@@ -628,7 +655,9 @@ class BedrockChatClient(
|
||||
def _generate_tool_call_id() -> str:
|
||||
return f"tool-call-{uuid4().hex}"
|
||||
|
||||
def _process_converse_response(self, response: dict[str, Any]) -> ChatResponse:
|
||||
def _process_converse_response(
|
||||
self, response: dict[str, Any], options: Mapping[str, Any] | None = None
|
||||
) -> ChatResponse:
|
||||
"""Convert Bedrock Converse API response to ChatResponse."""
|
||||
output = response.get("output") or {}
|
||||
message = output.get("message") or {}
|
||||
@@ -646,6 +675,7 @@ class BedrockChatClient(
|
||||
usage_details=usage_details,
|
||||
model=model,
|
||||
finish_reason=finish_reason,
|
||||
response_format=options.get("response_format") if options else None,
|
||||
raw_representation=response,
|
||||
)
|
||||
|
||||
@@ -728,6 +758,108 @@ class BedrockChatClient(
|
||||
return None
|
||||
return FINISH_REASON_MAP.get(reason.lower())
|
||||
|
||||
def _prepare_output_config(self, response_format: Any | None) -> dict[str, Any] | None:
|
||||
"""Convert response_format into the AWS Bedrock outputConfig wire format.
|
||||
|
||||
Args:
|
||||
response_format: A Pydantic model class or a dict schema, or None.
|
||||
|
||||
Returns:
|
||||
A dict for the Converse API ``outputConfig`` parameter, or None if
|
||||
response_format is not set.
|
||||
"""
|
||||
if response_format is None:
|
||||
return None
|
||||
|
||||
if isinstance(response_format, Mapping):
|
||||
if "json_schema" in response_format:
|
||||
# Shape A — OpenAI-style wrapper
|
||||
json_schema_config = response_format["json_schema"]
|
||||
schema_src = json_schema_config.get("schema", {})
|
||||
name = json_schema_config.get("name", "output_schema")
|
||||
elif "schema" in response_format:
|
||||
# Shape B — inner shape directly {"name": ..., "schema": ...}
|
||||
schema_src = response_format["schema"]
|
||||
name = response_format.get("name", "output_schema")
|
||||
else:
|
||||
# Shape C — assume entire dict is the raw schema
|
||||
logger.warning(
|
||||
"response_format dict has no 'json_schema' or 'schema' key; "
|
||||
"treating entire dict as raw JSON schema."
|
||||
)
|
||||
schema_src = dict(response_format)
|
||||
name = "output_schema"
|
||||
|
||||
if isinstance(schema_src, str):
|
||||
schema_src = json.loads(schema_src)
|
||||
schema = copy.deepcopy(schema_src)
|
||||
else:
|
||||
if not isinstance(response_format, type) or not issubclass(response_format, BaseModel):
|
||||
raise TypeError(
|
||||
"response_format must be None, a dict JSON schema, "
|
||||
"or a Pydantic BaseModel subclass."
|
||||
)
|
||||
# response_format is a Pydantic model class
|
||||
schema = response_format.model_json_schema()
|
||||
name = response_format.__name__
|
||||
|
||||
self._set_additional_properties_false(schema)
|
||||
|
||||
json_schema: dict[str, Any] = {
|
||||
"name": name,
|
||||
"schema": json.dumps(schema),
|
||||
}
|
||||
|
||||
description = getattr(response_format, "__doc__", None) if not isinstance(response_format, Mapping) else None
|
||||
if description and isinstance(description, str) and description.strip():
|
||||
json_schema["description"] = description.strip()
|
||||
|
||||
return {
|
||||
"textFormat": {
|
||||
"type": "json_schema",
|
||||
"structure": {
|
||||
"jsonSchema": json_schema
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
def _set_additional_properties_false(self, schema: dict[str, Any]) -> None:
|
||||
"""Recursively set additionalProperties: false on all object types in a JSON schema.
|
||||
|
||||
AWS requires strict schema enforcement. This mirrors the approach used by
|
||||
AnthropicChatClient._prepare_response_format().
|
||||
|
||||
Args:
|
||||
schema: The JSON schema dict to modify in-place.
|
||||
"""
|
||||
visited: set[int] = set()
|
||||
|
||||
def walk(node: Any) -> None:
|
||||
if isinstance(node, dict):
|
||||
node_id = id(node)
|
||||
if node_id in visited:
|
||||
return
|
||||
visited.add(node_id)
|
||||
if node.get("type") == "object" or (
|
||||
"properties" in node and "type" not in node
|
||||
):
|
||||
existing = node.get("additionalProperties")
|
||||
if existing is None or existing is True:
|
||||
node["additionalProperties"] = False
|
||||
for value in node.values():
|
||||
if isinstance(value, (dict, list)):
|
||||
walk(value)
|
||||
elif isinstance(node, list):
|
||||
node_id = id(node)
|
||||
if node_id in visited:
|
||||
return
|
||||
visited.add(node_id)
|
||||
for item in node:
|
||||
if isinstance(item, (dict, list)):
|
||||
walk(item)
|
||||
|
||||
walk(schema)
|
||||
|
||||
def service_url(self) -> str:
|
||||
"""Returns the service URL for the Bedrock runtime in the configured AWS region.
|
||||
|
||||
|
||||
@@ -0,0 +1,382 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from agent_framework import Content, Message
|
||||
from botocore.exceptions import ClientError
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agent_framework_bedrock import BedrockChatClient
|
||||
|
||||
# region Test models
|
||||
|
||||
|
||||
class WeatherReport(BaseModel):
|
||||
city: str
|
||||
temperature: float
|
||||
summary: str
|
||||
|
||||
|
||||
class NestedAddress(BaseModel):
|
||||
street: str
|
||||
city: str
|
||||
zip_code: str
|
||||
|
||||
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
address: NestedAddress
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region Helpers
|
||||
|
||||
|
||||
class _StubBedrockRuntime:
|
||||
"""Stub that records calls and returns a canned response."""
|
||||
|
||||
def __init__(self, response_text: str = "Bedrock says hi") -> None:
|
||||
self.calls: list[dict[str, Any]] = []
|
||||
self._response_text = response_text
|
||||
|
||||
def converse(self, **kwargs: Any) -> dict[str, Any]:
|
||||
self.calls.append(kwargs)
|
||||
return {
|
||||
"modelId": kwargs["modelId"],
|
||||
"responseId": "resp-structured",
|
||||
"usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30},
|
||||
"output": {
|
||||
"completionReason": "end_turn",
|
||||
"message": {
|
||||
"id": "msg-structured",
|
||||
"role": "assistant",
|
||||
"content": [{"text": self._response_text}],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _make_client(response_text: str = "Bedrock says hi") -> tuple[BedrockChatClient, _StubBedrockRuntime]:
|
||||
stub = _StubBedrockRuntime(response_text)
|
||||
client = BedrockChatClient(
|
||||
model="us.anthropic.claude-haiku-4-5-v1:0",
|
||||
region="us-east-1",
|
||||
client=stub,
|
||||
)
|
||||
return client, stub
|
||||
|
||||
|
||||
def _user_messages() -> list[Message]:
|
||||
return [Message(role="user", contents=[Content.from_text(text="Give me a weather report")])]
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region Tests
|
||||
|
||||
|
||||
def test_prepare_output_config_correct_wire_shape() -> None:
|
||||
"""_prepare_output_config(WeatherReport) must produce the correct
|
||||
textFormat → structure → jsonSchema shape with type: 'json_schema'."""
|
||||
client, _ = _make_client()
|
||||
|
||||
output_config = client._prepare_output_config(WeatherReport)
|
||||
|
||||
assert output_config is not None
|
||||
text_format = output_config["textFormat"]
|
||||
assert text_format["type"] == "json_schema"
|
||||
assert "structure" in text_format
|
||||
json_schema = text_format["structure"]["jsonSchema"]
|
||||
assert json_schema["name"] == "WeatherReport"
|
||||
assert "schema" in json_schema
|
||||
|
||||
|
||||
def test_prepare_output_config_schema_is_json_string() -> None:
|
||||
"""The schema value inside jsonSchema must be a JSON string, not a dict."""
|
||||
client, _ = _make_client()
|
||||
|
||||
output_config = client._prepare_output_config(WeatherReport)
|
||||
|
||||
assert output_config is not None
|
||||
schema_value = output_config["textFormat"]["structure"]["jsonSchema"]["schema"]
|
||||
assert isinstance(schema_value, str), f"Expected str, got {type(schema_value)}"
|
||||
# Verify it's valid JSON
|
||||
parsed = json.loads(schema_value)
|
||||
assert isinstance(parsed, dict)
|
||||
assert parsed["type"] == "object"
|
||||
|
||||
|
||||
def test_additional_properties_false_set_recursively() -> None:
|
||||
"""additionalProperties: false must be set on all nested object types."""
|
||||
client, _ = _make_client()
|
||||
|
||||
output_config = client._prepare_output_config(Person)
|
||||
|
||||
assert output_config is not None
|
||||
schema_str = output_config["textFormat"]["structure"]["jsonSchema"]["schema"]
|
||||
schema = json.loads(schema_str)
|
||||
|
||||
# Top-level object
|
||||
assert schema.get("additionalProperties") is False
|
||||
|
||||
# Check $defs for NestedAddress
|
||||
defs = schema.get("$defs", {})
|
||||
assert "NestedAddress" in defs, "Expected NestedAddress to be present in $defs"
|
||||
assert defs["NestedAddress"].get("additionalProperties") is False, (
|
||||
"Expected additionalProperties=False on nested NestedAddress schema"
|
||||
)
|
||||
|
||||
|
||||
def test_no_output_config_when_response_format_none() -> None:
|
||||
"""When response_format is None, no outputConfig key should appear in the request."""
|
||||
client, stub = _make_client()
|
||||
messages = _user_messages()
|
||||
|
||||
request = client._prepare_options(messages, {"max_tokens": 100})
|
||||
|
||||
assert "outputConfig" not in request, (
|
||||
f"outputConfig should not be present when response_format is None, got: {request.get('outputConfig')}"
|
||||
)
|
||||
|
||||
|
||||
async def test_chat_response_value_populated() -> None:
|
||||
"""After a mocked response with response_format, .value should be a populated Pydantic model."""
|
||||
json_response = json.dumps({"city": "Seattle", "temperature": 72.5, "summary": "Sunny and warm"})
|
||||
client, stub = _make_client(response_text=json_response)
|
||||
messages = _user_messages()
|
||||
|
||||
response = await client.get_response(
|
||||
messages=messages,
|
||||
options={"max_tokens": 100, "response_format": WeatherReport},
|
||||
)
|
||||
|
||||
assert response.text == json_response
|
||||
assert response.value is not None
|
||||
assert isinstance(response.value, WeatherReport)
|
||||
assert response.value.city == "Seattle"
|
||||
assert response.value.temperature == 72.5
|
||||
assert response.value.summary == "Sunny and warm"
|
||||
|
||||
# Verify outputConfig was sent to the API
|
||||
assert len(stub.calls) == 1
|
||||
api_request = stub.calls[0]
|
||||
assert "outputConfig" in api_request
|
||||
assert api_request["outputConfig"]["textFormat"]["type"] == "json_schema"
|
||||
|
||||
|
||||
def test_dict_schema_response_format() -> None:
|
||||
"""_prepare_output_config should work when response_format is a dict, not just a Pydantic class."""
|
||||
client, _ = _make_client()
|
||||
|
||||
dict_schema = {
|
||||
"json_schema": {
|
||||
"name": "weather_output",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string"},
|
||||
"temp": {"type": "number"},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
output_config = client._prepare_output_config(dict_schema)
|
||||
|
||||
assert output_config is not None
|
||||
json_schema = output_config["textFormat"]["structure"]["jsonSchema"]
|
||||
assert json_schema["name"] == "weather_output"
|
||||
schema_parsed = json.loads(json_schema["schema"])
|
||||
assert schema_parsed["type"] == "object"
|
||||
assert "city" in schema_parsed["properties"]
|
||||
|
||||
|
||||
def test_prepare_output_config_none_returns_none() -> None:
|
||||
"""_prepare_output_config(None) must return None."""
|
||||
client, _ = _make_client()
|
||||
|
||||
result = client._prepare_output_config(None)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
async def test_chat_response_value_populated_streaming() -> None:
|
||||
"""In streaming mode, .value should also be populated on the final response."""
|
||||
json_response = json.dumps({"city": "Portland", "temperature": 68.0, "summary": "Cloudy"})
|
||||
client, stub = _make_client(response_text=json_response)
|
||||
messages = _user_messages()
|
||||
|
||||
stream = client.get_response(
|
||||
messages=messages,
|
||||
stream=True,
|
||||
options={"max_tokens": 100, "response_format": WeatherReport},
|
||||
)
|
||||
|
||||
# Consume stream and get final response
|
||||
async for _ in stream:
|
||||
pass
|
||||
response = await stream.get_final_response()
|
||||
|
||||
assert response.value is not None
|
||||
assert isinstance(response.value, WeatherReport)
|
||||
assert response.value.city == "Portland"
|
||||
|
||||
# Verify outputConfig was sent
|
||||
assert len(stub.calls) == 1
|
||||
assert "outputConfig" in stub.calls[0]
|
||||
|
||||
|
||||
async def test_unsupported_model_validation_exception() -> None:
|
||||
"""When a model doesn't support outputConfig, a clear error should be raised."""
|
||||
class _FailingStubBedrockRuntime:
|
||||
def converse(self, **kwargs: Any) -> dict[str, Any]:
|
||||
# Simulate botocore ClientError for ValidationException
|
||||
error_response = {"Error": {"Code": "ValidationException", "Message": "Invalid field outputConfig"}}
|
||||
raise ClientError(error_response, "Converse")
|
||||
|
||||
client = BedrockChatClient(
|
||||
model="us.anthropic.claude-v2",
|
||||
region="us-east-1",
|
||||
client=_FailingStubBedrockRuntime(),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc:
|
||||
await client.get_response(
|
||||
messages=_user_messages(),
|
||||
options={"response_format": WeatherReport},
|
||||
)
|
||||
|
||||
assert "does not support structured output via outputConfig.textFormat" in str(exc.value)
|
||||
assert "Check the model's Bedrock Converse outputConfig/textFormat support." in str(exc.value)
|
||||
|
||||
|
||||
def test_invalid_response_format_type_raises() -> None:
|
||||
"""Non-dict, non-BaseModel response_format should raise TypeError."""
|
||||
client, _ = _make_client()
|
||||
with pytest.raises(TypeError, match="Pydantic BaseModel subclass"):
|
||||
client._prepare_output_config("not_a_valid_format")
|
||||
|
||||
|
||||
def test_mapping_response_format_accepted() -> None:
|
||||
"""A non-dict Mapping response_format must be accepted and produce
|
||||
correct outputConfig, not raise TypeError."""
|
||||
from collections.abc import MutableMapping
|
||||
|
||||
class _WrappedMapping(MutableMapping):
|
||||
def __init__(self, data):
|
||||
self._data = dict(data)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._data[key]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self._data[key] = value
|
||||
|
||||
def __delitem__(self, key):
|
||||
del self._data[key]
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._data)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._data)
|
||||
|
||||
client, _ = _make_client()
|
||||
mapping_format = _WrappedMapping({
|
||||
"json_schema": {
|
||||
"name": "test_output",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {"result": {"type": "string"}},
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
output_config = client._prepare_output_config(mapping_format)
|
||||
|
||||
assert output_config is not None
|
||||
json_schema = output_config["textFormat"]["structure"]["jsonSchema"]
|
||||
assert json_schema["name"] == "test_output"
|
||||
schema = json.loads(json_schema["schema"])
|
||||
assert schema.get("additionalProperties") is False
|
||||
|
||||
|
||||
def test_shape_b_dict_schema_wire_format() -> None:
|
||||
"""Dict response_format in Shape B (inner shape directly) should
|
||||
produce correct outputConfig."""
|
||||
client, _ = _make_client()
|
||||
|
||||
response_format = {
|
||||
"name": "weather_output",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string"},
|
||||
"temperature": {"type": "number"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
output_config = client._prepare_output_config(response_format)
|
||||
|
||||
assert output_config is not None
|
||||
text_format = output_config["textFormat"]
|
||||
assert text_format["type"] == "json_schema"
|
||||
json_schema = text_format["structure"]["jsonSchema"]
|
||||
assert json_schema["name"] == "weather_output"
|
||||
schema = json.loads(json_schema["schema"])
|
||||
assert schema.get("additionalProperties") is False
|
||||
|
||||
|
||||
def test_dict_schema_not_mutated() -> None:
|
||||
"""Caller's dict schema must not be mutated by _prepare_output_config."""
|
||||
client, _ = _make_client()
|
||||
original_schema = {
|
||||
"json_schema": {
|
||||
"name": "test",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {"a": {"type": "string"}},
|
||||
},
|
||||
}
|
||||
}
|
||||
snapshot = copy.deepcopy(original_schema)
|
||||
client._prepare_output_config(original_schema)
|
||||
assert original_schema == snapshot, "Original dict schema was mutated"
|
||||
|
||||
|
||||
async def test_non_outputconfig_validation_exception_propagates() -> None:
|
||||
"""ValidationException unrelated to outputConfig must propagate
|
||||
as raw ClientError, not be caught and reclassified."""
|
||||
client, _ = _make_client()
|
||||
error_response = {
|
||||
"Error": {
|
||||
"Code": "ValidationException",
|
||||
"Message": "Invalid message format",
|
||||
}
|
||||
}
|
||||
with (
|
||||
patch.object(
|
||||
client,
|
||||
"_bedrock_client",
|
||||
**{"converse.side_effect": ClientError(error_response, "Converse")},
|
||||
),
|
||||
pytest.raises(ClientError),
|
||||
):
|
||||
await client.get_response(
|
||||
messages=_user_messages(),
|
||||
options={"max_tokens": 100},
|
||||
)
|
||||
|
||||
|
||||
# endregion
|
||||
@@ -71,6 +71,7 @@ from ._evaluation import (
|
||||
Evaluator,
|
||||
ExpectedToolCall,
|
||||
LocalEvaluator,
|
||||
RubricScore,
|
||||
evaluate_agent,
|
||||
evaluate_workflow,
|
||||
evaluator,
|
||||
@@ -460,6 +461,7 @@ __all__ = [
|
||||
"ResponseStream",
|
||||
"Role",
|
||||
"RoleLiteral",
|
||||
"RubricScore",
|
||||
"RunContext",
|
||||
"Runner",
|
||||
"RunnerContext",
|
||||
|
||||
@@ -311,12 +311,15 @@ class EvalScoreResult:
|
||||
score: Numeric score from the evaluator.
|
||||
passed: Whether the item passed this evaluator's threshold.
|
||||
sample: Optional raw evaluator output (rationale, metadata).
|
||||
dimensions: Per-dimension scores when this evaluator is a rubric
|
||||
evaluator. ``None`` for non-rubric (e.g. built-in) evaluators.
|
||||
"""
|
||||
|
||||
name: str
|
||||
score: float
|
||||
passed: bool | None = None
|
||||
sample: dict[str, Any] | None = None
|
||||
dimensions: list[RubricScore] | None = None
|
||||
|
||||
|
||||
@experimental(feature_id=ExperimentalFeature.EVALS)
|
||||
@@ -496,6 +499,179 @@ class EvalResults:
|
||||
detail += f" Errored items: {', '.join(summaries)}."
|
||||
raise EvalNotPassedError(detail)
|
||||
|
||||
def assert_score_at_least(
|
||||
self,
|
||||
min_score: float,
|
||||
*,
|
||||
evaluator: str | None = None,
|
||||
msg: str | None = None,
|
||||
) -> None:
|
||||
"""Assert every item's score (optionally filtered by evaluator) is ``>= min_score``.
|
||||
|
||||
Designed for CI gates on generated rubric evaluators (e.g.
|
||||
``results.assert_score_at_least(0.80)``). Includes any
|
||||
sub-results from workflow evaluations.
|
||||
|
||||
Args:
|
||||
min_score: Minimum acceptable score (inclusive).
|
||||
evaluator: When set, only check scores from the evaluator
|
||||
whose ``EvalScoreResult.name`` matches.
|
||||
msg: Optional custom failure message.
|
||||
|
||||
Raises:
|
||||
EvalNotPassedError: When any matching score is below the threshold.
|
||||
"""
|
||||
offenders: list[str] = []
|
||||
|
||||
def _check(results: EvalResults) -> None:
|
||||
for item in results.items:
|
||||
for score in item.scores:
|
||||
if evaluator is not None and score.name != evaluator:
|
||||
continue
|
||||
if score.score < min_score:
|
||||
offenders.append(f"{item.item_id}/{score.name}={score.score:.3f}")
|
||||
for sub in results.sub_results.values():
|
||||
_check(sub)
|
||||
|
||||
_check(self)
|
||||
if offenders:
|
||||
detail = msg or (
|
||||
f"{len(offenders)} score(s) below threshold {min_score}"
|
||||
f"{' for ' + evaluator if evaluator else ''}: {', '.join(offenders[:5])}"
|
||||
+ (f" (+{len(offenders) - 5} more)" if len(offenders) > 5 else "")
|
||||
)
|
||||
raise EvalNotPassedError(detail)
|
||||
|
||||
def assert_dimension_score_at_least(
|
||||
self,
|
||||
dimension_id: str,
|
||||
min_score: float,
|
||||
*,
|
||||
evaluator: str | None = None,
|
||||
require_applicable: bool = False,
|
||||
msg: str | None = None,
|
||||
) -> None:
|
||||
"""Assert every item's score for a rubric *dimension* is ``>= min_score``.
|
||||
|
||||
Walks ``EvalScoreResult.dimensions`` looking for the named
|
||||
dimension across all items (and sub-results). Non-applicable
|
||||
dimensions are skipped by default; pass
|
||||
``require_applicable=True`` to fail when no applicable score is
|
||||
produced.
|
||||
|
||||
Args:
|
||||
dimension_id: Dimension id (matches the rubric definition).
|
||||
min_score: Minimum acceptable dimension score (inclusive).
|
||||
evaluator: When set, only consider scores from the evaluator
|
||||
whose ``EvalScoreResult.name`` matches.
|
||||
require_applicable: When ``True``, missing or non-applicable
|
||||
dimension scores raise. Defaults to ``False`` (skip).
|
||||
msg: Optional custom failure message.
|
||||
|
||||
Raises:
|
||||
EvalNotPassedError: When the dimension fails the threshold.
|
||||
"""
|
||||
offenders: list[str] = []
|
||||
missing_items: list[str] = []
|
||||
|
||||
def _check(results: EvalResults) -> None:
|
||||
for item in results.items:
|
||||
found_applicable = False
|
||||
for score in item.scores:
|
||||
if evaluator is not None and score.name != evaluator:
|
||||
continue
|
||||
if not score.dimensions:
|
||||
continue
|
||||
for rs in score.dimensions:
|
||||
if rs.id != dimension_id:
|
||||
continue
|
||||
if not rs.applicable:
|
||||
continue
|
||||
found_applicable = True
|
||||
if rs.score is None or rs.score < min_score:
|
||||
offenders.append(
|
||||
f"{item.item_id}/{score.name}/{dimension_id}="
|
||||
f"{rs.score if rs.score is not None else 'None'}"
|
||||
)
|
||||
if require_applicable and not found_applicable:
|
||||
missing_items.append(item.item_id)
|
||||
for sub in results.sub_results.values():
|
||||
_check(sub)
|
||||
|
||||
_check(self)
|
||||
problems: list[str] = []
|
||||
if offenders:
|
||||
problems.append(
|
||||
f"{len(offenders)} dimension score(s) for '{dimension_id}' below {min_score}: "
|
||||
f"{', '.join(offenders[:5])}" + (f" (+{len(offenders) - 5} more)" if len(offenders) > 5 else "")
|
||||
)
|
||||
if missing_items:
|
||||
problems.append(
|
||||
f"Dimension '{dimension_id}' not applicable on {len(missing_items)} item(s): "
|
||||
f"{', '.join(missing_items[:5])}"
|
||||
)
|
||||
if problems:
|
||||
raise EvalNotPassedError(msg or "; ".join(problems))
|
||||
|
||||
def assert_no_failed_items(self, msg: str | None = None) -> None:
|
||||
"""Assert no item ended in ``fail`` or ``error`` status.
|
||||
|
||||
Includes any sub-results from workflow evaluations.
|
||||
|
||||
Args:
|
||||
msg: Optional custom failure message.
|
||||
|
||||
Raises:
|
||||
EvalNotPassedError: When any item failed or errored.
|
||||
"""
|
||||
bad: list[str] = []
|
||||
|
||||
def _check(results: EvalResults) -> None:
|
||||
for item in results.items:
|
||||
if item.is_failed or item.is_error:
|
||||
bad.append(f"{item.item_id}:{item.status}")
|
||||
for sub in results.sub_results.values():
|
||||
_check(sub)
|
||||
|
||||
_check(self)
|
||||
if bad:
|
||||
detail = msg or (
|
||||
f"{len(bad)} item(s) failed or errored: {', '.join(bad[:5])}"
|
||||
+ (f" (+{len(bad) - 5} more)" if len(bad) > 5 else "")
|
||||
)
|
||||
raise EvalNotPassedError(detail)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Generated rubric evaluators
|
||||
|
||||
|
||||
@experimental(feature_id=ExperimentalFeature.EVALS)
|
||||
@dataclass(frozen=True)
|
||||
class RubricScore:
|
||||
"""A single dimension's score from a rubric-based evaluator run.
|
||||
|
||||
Rubric evaluators emit one ``RubricScore`` per dimension per item.
|
||||
Attached to :class:`EvalScoreResult` as a typed view of the raw
|
||||
``properties.rubric_scores`` payload returned by providers such as
|
||||
Foundry's generated rubric evaluators.
|
||||
|
||||
Attributes:
|
||||
id: Dimension id (matches the rubric definition).
|
||||
score: Numeric score, or ``None`` when the dimension was marked
|
||||
non-applicable for this item.
|
||||
applicable: Whether the dimension applied to this item.
|
||||
weight: Dimension weight (mirrors the rubric definition).
|
||||
reason: Short rationale produced by the evaluator.
|
||||
"""
|
||||
|
||||
id: str
|
||||
score: int | None
|
||||
applicable: bool
|
||||
weight: int
|
||||
reason: str
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
@@ -14,12 +14,13 @@ import logging
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from .._agents import Agent
|
||||
from .._agents import Agent, SupportsAgentRun
|
||||
from .._clients import SupportsWebSearchTool
|
||||
from .._compaction import CompactionProvider, ContextWindowCompactionStrategy, ToolResultCompactionStrategy
|
||||
from .._feature_stage import ExperimentalFeature, experimental
|
||||
from .._sessions import ContextProvider, HistoryProvider, InMemoryHistoryProvider
|
||||
from .._skills import SkillsProvider
|
||||
from ._background_agents import BackgroundAgentsProvider
|
||||
from ._memory import MemoryContextProvider, MemoryStore
|
||||
from ._mode import AgentModeProvider
|
||||
from ._todo import TodoProvider
|
||||
@@ -103,6 +104,8 @@ def _assemble_context_providers(
|
||||
memory_store: MemoryStore | None,
|
||||
skills_provider: SkillsProvider | None,
|
||||
skills_paths: Sequence[str] | None,
|
||||
background_agents: Sequence[SupportsAgentRun] | None,
|
||||
background_agents_instructions: str | None,
|
||||
extra_context_providers: Sequence[ContextProvider] | None,
|
||||
) -> list[ContextProvider]:
|
||||
"""Assemble the ordered list of context providers."""
|
||||
@@ -130,6 +133,10 @@ def _assemble_context_providers(
|
||||
if skills_paths:
|
||||
providers.append(SkillsProvider.from_paths(*skills_paths))
|
||||
|
||||
# Background agents are opt-in: only added when agents are provided.
|
||||
if background_agents:
|
||||
providers.append(BackgroundAgentsProvider(background_agents, instructions=background_agents_instructions))
|
||||
|
||||
# Append any user-supplied additional providers.
|
||||
if extra_context_providers:
|
||||
providers.extend(extra_context_providers)
|
||||
@@ -165,6 +172,8 @@ def create_harness_agent(
|
||||
memory_store: MemoryStore | None = None,
|
||||
skills_provider: SkillsProvider | None = None,
|
||||
skills_paths: Sequence[str] | None = None,
|
||||
background_agents: Sequence[SupportsAgentRun] | None = None,
|
||||
background_agents_instructions: str | None = None,
|
||||
disable_web_search: bool = False,
|
||||
otel_provider_name: str | None = None,
|
||||
context_providers: Sequence[ContextProvider] | None = None,
|
||||
@@ -182,6 +191,7 @@ def create_harness_agent(
|
||||
- **AgentModeProvider** — plan/execute mode tracking
|
||||
- **MemoryContextProvider** — file-based durable memory (when ``memory_store`` provided)
|
||||
- **SkillsProvider** — skill discovery and progressive loading
|
||||
- **BackgroundAgentsProvider** — delegate work to background sub-agents
|
||||
- **OpenTelemetry** — observability via ``AgentTelemetryLayer``
|
||||
|
||||
Each feature can be disabled or customized via keyword arguments.
|
||||
@@ -253,6 +263,13 @@ def create_harness_agent(
|
||||
skills_paths: Paths for file-based skill discovery (looks for SKILL.md files).
|
||||
Can be combined with ``skills_provider``. When neither ``skills_provider``
|
||||
nor ``skills_paths`` is provided, no SkillsProvider is added.
|
||||
background_agents: Collection of agents available for background task delegation.
|
||||
When provided, a ``BackgroundAgentsProvider`` is automatically included,
|
||||
enabling the agent to start, monitor, and retrieve results from background tasks.
|
||||
Each agent must have a non-empty, unique name (case-insensitive).
|
||||
background_agents_instructions: Optional instruction override for the
|
||||
``BackgroundAgentsProvider``. May include ``{background_agents}`` placeholder
|
||||
which will be replaced with the agent listing.
|
||||
disable_web_search: When True, skip automatic web search tool inclusion.
|
||||
When False (default), the web search tool is automatically added if the
|
||||
client implements SupportsWebSearchTool. A warning is logged if the client
|
||||
@@ -302,6 +319,8 @@ def create_harness_agent(
|
||||
memory_store=memory_store,
|
||||
skills_provider=skills_provider,
|
||||
skills_paths=skills_paths,
|
||||
background_agents=background_agents,
|
||||
background_agents_instructions=background_agents_instructions,
|
||||
extra_context_providers=context_providers,
|
||||
)
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ import logging
|
||||
import re
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Callable, Collection, Coroutine, Sequence
|
||||
from collections.abc import Callable, Collection, Coroutine, Mapping, Sequence
|
||||
from contextlib import AsyncExitStack, _AsyncGeneratorContextManager # type: ignore
|
||||
from datetime import timedelta
|
||||
from functools import partial
|
||||
@@ -142,6 +142,13 @@ def _inject_otel_into_mcp_meta(meta: dict[str, Any] | None = None) -> dict[str,
|
||||
return meta
|
||||
|
||||
|
||||
def _url_origin(url: Any) -> tuple[str, str, int | None]:
|
||||
port = url.port
|
||||
if port is None:
|
||||
port = 443 if url.scheme == "https" else 80 if url.scheme == "http" else None
|
||||
return (url.scheme, url.host or "", port)
|
||||
|
||||
|
||||
def streamable_http_client(*args: Any, **kwargs: Any) -> _AsyncGeneratorContextManager[Any, None]:
|
||||
"""Lazily import the MCP streamable HTTP transport."""
|
||||
try:
|
||||
@@ -255,6 +262,7 @@ class MCPTool:
|
||||
self._exit_stack = AsyncExitStack()
|
||||
self._lifecycle_lock = asyncio.Lock()
|
||||
self._lifecycle_request_lock = asyncio.Lock()
|
||||
self._function_load_lock = asyncio.Lock()
|
||||
self._lifecycle_queue: asyncio.Queue[tuple[str, bool, bool, asyncio.Future[None]]] | None = None
|
||||
self._lifecycle_owner_task: asyncio.Task[None] | None = None
|
||||
self.session = session
|
||||
@@ -655,6 +663,11 @@ class MCPTool:
|
||||
raise
|
||||
except asyncio.CancelledError:
|
||||
logger.warning("Could not cleanly close MCP exit stack because the lifecycle owner task was cancelled.")
|
||||
except Exception as e:
|
||||
if type(e).__name__ == "ExceptionGroup":
|
||||
logger.warning("Could not cleanly close MCP exit stack due to cleanup error group. Error: %s", e)
|
||||
else:
|
||||
raise
|
||||
|
||||
async def _close_and_check_cancelled(self, ex: BaseException) -> bool:
|
||||
"""Close the exit stack and return True if *ex* is a genuine task cancellation.
|
||||
@@ -1018,6 +1031,10 @@ class MCPTool:
|
||||
Raises:
|
||||
ToolExecutionException: If the MCP server is not connected.
|
||||
"""
|
||||
async with self._function_load_lock:
|
||||
await self._load_prompts_locked()
|
||||
|
||||
async def _load_prompts_locked(self) -> None:
|
||||
from anyio import ClosedResourceError
|
||||
from mcp import types
|
||||
|
||||
@@ -1100,6 +1117,10 @@ class MCPTool:
|
||||
Raises:
|
||||
ToolExecutionException: If the MCP server is not connected.
|
||||
"""
|
||||
async with self._function_load_lock:
|
||||
await self._load_tools_locked()
|
||||
|
||||
async def _load_tools_locked(self) -> None:
|
||||
from anyio import ClosedResourceError
|
||||
from mcp import types
|
||||
|
||||
@@ -1109,7 +1130,7 @@ class MCPTool:
|
||||
|
||||
# Track existing function names to prevent duplicates
|
||||
existing_names = {func.name for func in self._functions}
|
||||
self._tool_call_meta_by_name.clear()
|
||||
tool_call_meta_by_name: dict[str, dict[str, Any]] = {}
|
||||
|
||||
params: types.PaginatedRequestParams | None = None
|
||||
while True:
|
||||
@@ -1145,7 +1166,7 @@ class MCPTool:
|
||||
|
||||
for tool in tool_list.tools:
|
||||
if tool.meta is not None:
|
||||
self._tool_call_meta_by_name[tool.name] = dict(tool.meta)
|
||||
tool_call_meta_by_name[tool.name] = dict(tool.meta)
|
||||
|
||||
normalized_name = _normalize_mcp_name(tool.name)
|
||||
local_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix)
|
||||
@@ -1194,6 +1215,8 @@ class MCPTool:
|
||||
break
|
||||
params = types.PaginatedRequestParams(cursor=tool_list.nextCursor)
|
||||
|
||||
self._tool_call_meta_by_name = tool_call_meta_by_name
|
||||
|
||||
async def _close_on_owner(self) -> None:
|
||||
# Cancel any pending reload tasks before tearing down the session.
|
||||
tasks = list(self._pending_reload_tasks)
|
||||
@@ -1276,7 +1299,11 @@ class MCPTool:
|
||||
tool_name: The name of the tool to call.
|
||||
|
||||
Keyword Args:
|
||||
kwargs: Arguments to pass to the tool.
|
||||
_meta: Optional ``dict[str, Any]`` of MCP request metadata. This reserved key is passed as the
|
||||
``meta`` parameter of the underlying ``session.call_tool`` call rather than as a tool argument.
|
||||
User-supplied keys override metadata from ``tools/list``; OpenTelemetry propagation fills in
|
||||
non-conflicting keys.
|
||||
kwargs: Remaining arguments to pass to the tool.
|
||||
|
||||
Returns:
|
||||
A list of Content items representing the tool output. The default
|
||||
@@ -1294,6 +1321,19 @@ class MCPTool:
|
||||
raise ToolExecutionException(
|
||||
"Tools are not loaded for this server, please set load_tools=True in the constructor."
|
||||
)
|
||||
|
||||
raw_user_meta: object | None = kwargs.get("_meta")
|
||||
user_meta: dict[str, Any] | None = None
|
||||
if raw_user_meta is not None and not isinstance(raw_user_meta, dict):
|
||||
raise ToolExecutionException("MCP tool metadata provided via _meta must be a dict.")
|
||||
if isinstance(raw_user_meta, dict):
|
||||
raw_user_meta_dict = cast(Mapping[object, object], raw_user_meta)
|
||||
user_meta = {}
|
||||
for key, value in raw_user_meta_dict.items():
|
||||
if not isinstance(key, str):
|
||||
raise ToolExecutionException("MCP tool metadata provided via _meta must use string keys.")
|
||||
user_meta[key] = value
|
||||
|
||||
# Filter out framework kwargs that cannot be serialized by the MCP SDK.
|
||||
# These are internal objects passed through the function invocation pipeline
|
||||
# that should not be forwarded to external MCP servers.
|
||||
@@ -1313,12 +1353,16 @@ class MCPTool:
|
||||
"conversation_id",
|
||||
"options",
|
||||
"response_format",
|
||||
"_meta",
|
||||
}
|
||||
}
|
||||
|
||||
# Some MCP proxies require their tools/list metadata to be echoed on tools/call.
|
||||
tool_meta = self._tool_call_meta_by_name.get(tool_name)
|
||||
meta = _inject_otel_into_mcp_meta(dict(tool_meta) if tool_meta is not None else None)
|
||||
request_meta = dict(tool_meta) if tool_meta is not None else None
|
||||
if user_meta is not None:
|
||||
request_meta = {**(request_meta or {}), **user_meta}
|
||||
meta = _inject_otel_into_mcp_meta(request_meta)
|
||||
|
||||
parser = self.parse_tool_results or self._parse_tool_result_from_mcp
|
||||
# Try the operation, reconnecting once if the connection is closed
|
||||
@@ -1336,28 +1380,33 @@ class MCPTool:
|
||||
return parser(result)
|
||||
except ToolExecutionException:
|
||||
raise
|
||||
except ClosedResourceError as cl_ex:
|
||||
except (ClosedResourceError, McpError) as call_ex:
|
||||
is_session_terminated = (
|
||||
isinstance(call_ex, McpError) and "session terminated" in call_ex.error.message.lower()
|
||||
)
|
||||
is_connection_lost = isinstance(call_ex, ClosedResourceError) or is_session_terminated
|
||||
if not is_connection_lost:
|
||||
error_message = call_ex.error.message if isinstance(call_ex, McpError) else str(call_ex)
|
||||
raise ToolExecutionException(error_message, inner_exception=call_ex) from call_ex
|
||||
|
||||
if attempt == 0:
|
||||
# First attempt failed, try reconnecting
|
||||
logger.info("MCP connection closed unexpectedly. Reconnecting...")
|
||||
# First attempt failed, try reconnecting.
|
||||
logger.info("MCP connection closed or terminated unexpectedly. Reconnecting...")
|
||||
try:
|
||||
await self.connect(reset=True)
|
||||
continue # Retry the operation
|
||||
continue
|
||||
except Exception as reconn_ex:
|
||||
raise ToolExecutionException(
|
||||
"Failed to reconnect to MCP server.",
|
||||
inner_exception=reconn_ex,
|
||||
) from reconn_ex
|
||||
else:
|
||||
# Second attempt also failed, give up
|
||||
logger.error(f"MCP connection closed unexpectedly after reconnection: {cl_ex}")
|
||||
raise ToolExecutionException(
|
||||
f"Failed to call tool '{tool_name}' - connection lost.",
|
||||
inner_exception=cl_ex,
|
||||
) from cl_ex
|
||||
except McpError as mcp_exc:
|
||||
error_message = mcp_exc.error.message
|
||||
raise ToolExecutionException(error_message, inner_exception=mcp_exc) from mcp_exc
|
||||
|
||||
# Second attempt also failed, give up.
|
||||
logger.error("MCP connection closed unexpectedly after reconnection: %s", call_ex)
|
||||
raise ToolExecutionException(
|
||||
f"Failed to call tool '{tool_name}' - connection lost.",
|
||||
inner_exception=call_ex,
|
||||
) from call_ex
|
||||
except Exception as ex:
|
||||
raise ToolExecutionException(f"Failed to call tool '{tool_name}'.", inner_exception=ex) from ex
|
||||
raise ToolExecutionException(f"Failed to call tool '{tool_name}' after retries.")
|
||||
@@ -1718,10 +1767,11 @@ class MCPStreamableHTTPTool(MCPTool):
|
||||
Returns:
|
||||
An async context manager for the streamable HTTP client transport.
|
||||
"""
|
||||
from httpx import AsyncClient, Request, Timeout
|
||||
from httpx import URL, AsyncClient, Request, Timeout
|
||||
|
||||
http_client = self._httpx_client
|
||||
if self._header_provider is not None:
|
||||
target_origin = _url_origin(URL(self.url))
|
||||
if http_client is None:
|
||||
http_client = AsyncClient(
|
||||
follow_redirects=True,
|
||||
@@ -1732,6 +1782,8 @@ class MCPStreamableHTTPTool(MCPTool):
|
||||
if not hasattr(self, "_inject_headers_hook"):
|
||||
|
||||
async def _inject_headers(request: Request) -> None: # noqa: RUF029
|
||||
if _url_origin(request.url) != target_origin:
|
||||
return
|
||||
headers = _mcp_call_headers.get({})
|
||||
for key, value in headers.items():
|
||||
request.headers[key] = value
|
||||
|
||||
@@ -36,11 +36,10 @@ if TYPE_CHECKING:
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._agents import SupportsAgentRun
|
||||
from ._clients import SupportsChatGetResponse
|
||||
from ._compaction import CompactionStrategy, TokenizerProtocol
|
||||
from ._sessions import AgentSession
|
||||
from ._tools import FunctionTool, ToolTypes
|
||||
from ._types import ChatOptions, ChatResponse, ChatResponseUpdate
|
||||
from ._types import ChatOptions
|
||||
|
||||
ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel)
|
||||
|
||||
|
||||
@@ -7,6 +7,8 @@ import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Mapping, MutableMapping
|
||||
from dataclasses import asdict, is_dataclass
|
||||
from datetime import date, datetime
|
||||
from typing import Any, ClassVar, Protocol, TypeVar, runtime_checkable
|
||||
|
||||
logger = logging.getLogger("agent_framework")
|
||||
@@ -614,3 +616,46 @@ class SerializationMixin:
|
||||
# Fallback and default
|
||||
# Convert class name to snake_case
|
||||
return _CAMEL_TO_SNAKE_PATTERN.sub("_", cls.__name__).lower()
|
||||
|
||||
|
||||
def make_json_safe(obj: Any) -> Any:
|
||||
"""Recursively convert an object to a JSON-serializable form.
|
||||
|
||||
Handles dataclasses, Pydantic models, objects with ``to_dict``/``dict``/``__dict__``,
|
||||
datetimes, lists, dicts, and primitives. Falls back to ``str()`` for any remaining
|
||||
non-serializable value so that ``json.dumps`` never raises a ``TypeError``.
|
||||
|
||||
Args:
|
||||
obj: Object to make JSON safe.
|
||||
|
||||
Returns:
|
||||
A JSON-serializable version of the object.
|
||||
"""
|
||||
if obj is None or isinstance(obj, (str, int, float, bool)):
|
||||
return obj
|
||||
if isinstance(obj, (datetime, date)):
|
||||
return obj.isoformat()
|
||||
if is_dataclass(obj) and not isinstance(obj, type):
|
||||
return make_json_safe(asdict(obj)) # type: ignore[arg-type]
|
||||
if callable(getattr(obj, "model_dump", None)):
|
||||
try:
|
||||
return make_json_safe(obj.model_dump()) # type: ignore[no-any-return]
|
||||
except TypeError:
|
||||
pass
|
||||
if callable(getattr(obj, "to_dict", None)):
|
||||
try:
|
||||
return make_json_safe(obj.to_dict()) # type: ignore[no-any-return]
|
||||
except TypeError:
|
||||
pass
|
||||
if callable(getattr(obj, "dict", None)):
|
||||
try:
|
||||
return make_json_safe(obj.dict()) # type: ignore[no-any-return]
|
||||
except TypeError:
|
||||
pass
|
||||
if isinstance(obj, dict):
|
||||
return {str(key): make_json_safe(value) for key, value in obj.items()} # type: ignore[misc]
|
||||
if isinstance(obj, (list, tuple)):
|
||||
return [make_json_safe(item) for item in obj] # type: ignore[misc]
|
||||
if hasattr(obj, "__dict__"):
|
||||
return {key: make_json_safe(value) for key, value in vars(obj).items()} # type: ignore[misc]
|
||||
return str(obj)
|
||||
|
||||
@@ -1973,11 +1973,97 @@ def _coalesce_text_content(contents: list[Content], type_str: Literal["text", "t
|
||||
contents.extend(coalesced_contents)
|
||||
|
||||
|
||||
def _content_items_text(items: Any) -> str | None:
|
||||
"""Return concatenated text when a content item list only contains text."""
|
||||
if not isinstance(items, list):
|
||||
return None
|
||||
text_parts: list[str] = []
|
||||
content_items = cast(list[object], items)
|
||||
for item in content_items:
|
||||
if not isinstance(item, Content) or item.type != "text":
|
||||
return None
|
||||
text_parts.append(item.text or "")
|
||||
return "".join(text_parts)
|
||||
|
||||
|
||||
def _merge_content_item_lists(existing: Any, incoming: Any) -> Any:
|
||||
"""Merge streamed nested content lists, replacing deltas with a later full value when present."""
|
||||
if incoming is None:
|
||||
return existing
|
||||
if existing is None:
|
||||
return deepcopy(incoming)
|
||||
|
||||
existing_text = _content_items_text(existing)
|
||||
incoming_text = _content_items_text(incoming)
|
||||
if existing_text is not None and incoming_text is not None:
|
||||
if incoming_text.startswith(existing_text):
|
||||
return deepcopy(incoming)
|
||||
if existing_text.startswith(incoming_text):
|
||||
return existing
|
||||
|
||||
existing_items = cast(list[Content], existing)
|
||||
merged = deepcopy(existing_items[0])
|
||||
merged.text = existing_text + incoming_text
|
||||
return [merged]
|
||||
|
||||
if isinstance(existing, list) and isinstance(incoming, list):
|
||||
existing_list = cast(list[object], existing)
|
||||
incoming_list = cast(list[object], incoming)
|
||||
return [*existing_list, *deepcopy(incoming_list)]
|
||||
return deepcopy(incoming)
|
||||
|
||||
|
||||
def _merge_code_interpreter_content(existing: Content, incoming: Content) -> None:
|
||||
"""Merge two code interpreter content items for the same logical call."""
|
||||
existing.inputs = _merge_content_item_lists(existing.inputs, incoming.inputs)
|
||||
existing.outputs = _merge_content_item_lists(existing.outputs, incoming.outputs)
|
||||
existing.annotations = _combine_annotations(existing.annotations, incoming.annotations)
|
||||
existing.additional_properties = {**existing.additional_properties, **incoming.additional_properties}
|
||||
existing.raw_representation = _combine_raw_representations(existing.raw_representation, incoming.raw_representation)
|
||||
|
||||
|
||||
def _code_interpreter_key(content: Content) -> tuple[str, str] | None:
|
||||
"""Return the aggregation key for code interpreter call/result content."""
|
||||
if content.type not in {"code_interpreter_tool_call", "code_interpreter_tool_result"}:
|
||||
return None
|
||||
call_id = content.call_id or content.additional_properties.get("item_id")
|
||||
if not isinstance(call_id, str) or not call_id:
|
||||
return None
|
||||
return content.type, call_id
|
||||
|
||||
|
||||
def _coalesce_code_interpreter_content(contents: list[Content]) -> None:
|
||||
"""Coalesce streaming code interpreter chunks by call id."""
|
||||
if not contents:
|
||||
return
|
||||
|
||||
coalesced_contents: list[Content] = []
|
||||
seen: dict[tuple[str, str], Content] = {}
|
||||
for content in contents:
|
||||
key = _code_interpreter_key(content)
|
||||
if key is None:
|
||||
coalesced_contents.append(content)
|
||||
continue
|
||||
|
||||
existing = seen.get(key)
|
||||
if existing is None:
|
||||
copied = deepcopy(content)
|
||||
seen[key] = copied
|
||||
coalesced_contents.append(copied)
|
||||
continue
|
||||
|
||||
_merge_code_interpreter_content(existing, content)
|
||||
|
||||
contents.clear()
|
||||
contents.extend(coalesced_contents)
|
||||
|
||||
|
||||
def _finalize_response(response: ChatResponse | AgentResponse) -> None:
|
||||
"""Finalizes the response by performing any necessary post-processing."""
|
||||
for msg in response.messages:
|
||||
_coalesce_text_content(msg.contents, "text")
|
||||
_coalesce_text_content(msg.contents, "text_reasoning")
|
||||
_coalesce_code_interpreter_content(msg.contents)
|
||||
|
||||
|
||||
# region ContinuationToken
|
||||
|
||||
@@ -12,6 +12,7 @@ from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast, overload
|
||||
|
||||
from .._agents import BaseAgent
|
||||
from .._serialization import make_json_safe
|
||||
from .._sessions import (
|
||||
AgentSession,
|
||||
ContextProvider,
|
||||
@@ -62,7 +63,7 @@ class WorkflowAgent(BaseAgent):
|
||||
data: Any
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {"request_id": self.request_id, "data": self.data}
|
||||
return {"request_id": self.request_id, "data": make_json_safe(self.data)}
|
||||
|
||||
def to_json(self) -> str:
|
||||
return json.dumps(self.to_dict())
|
||||
|
||||
@@ -47,6 +47,7 @@ from copy import deepcopy
|
||||
from typing import Any, Generic, Literal, TypeVar, overload
|
||||
|
||||
from .._feature_stage import ExperimentalFeature, experimental
|
||||
from .._serialization import make_json_safe
|
||||
from .._types import AgentResponse, AgentResponseUpdate, ResponseStream
|
||||
from ..observability import OtelAttr, capture_exception, create_workflow_span
|
||||
from ._checkpoint import CheckpointStorage, WorkflowCheckpoint
|
||||
@@ -1515,7 +1516,7 @@ class FunctionalWorkflowAgent:
|
||||
function_call = Content.from_function_call(
|
||||
call_id=request_id,
|
||||
name=self.REQUEST_INFO_FUNCTION_NAME,
|
||||
arguments={"request_id": request_id, "data": event.data},
|
||||
arguments={"request_id": request_id, "data": make_json_safe(event.data)},
|
||||
)
|
||||
return Content.from_function_approval_request(
|
||||
id=request_id,
|
||||
|
||||
@@ -34,6 +34,7 @@ _IMPORTS: dict[str, tuple[str, str]] = {
|
||||
"FoundryLocalChatOptions": ("agent_framework_foundry_local", "agent-framework-foundry-local"),
|
||||
"FoundryLocalClient": ("agent_framework_foundry_local", "agent-framework-foundry-local"),
|
||||
"FoundryLocalSettings": ("agent_framework_foundry_local", "agent-framework-foundry-local"),
|
||||
"GeneratedEvaluatorRef": ("agent_framework_foundry", "agent-framework-foundry"),
|
||||
"RawAnthropicFoundryClient": ("agent_framework_anthropic", "agent-framework-anthropic"),
|
||||
"RawFoundryAgent": ("agent_framework_foundry", "agent-framework-foundry"),
|
||||
"RawFoundryAgentChatClient": ("agent_framework_foundry", "agent-framework-foundry"),
|
||||
|
||||
@@ -20,6 +20,7 @@ from agent_framework_foundry import (
|
||||
FoundryEmbeddingSettings,
|
||||
FoundryEvals,
|
||||
FoundryMemoryProvider,
|
||||
GeneratedEvaluatorRef,
|
||||
RawFoundryAgent,
|
||||
RawFoundryAgentChatClient,
|
||||
RawFoundryChatClient,
|
||||
@@ -52,6 +53,7 @@ __all__ = [
|
||||
"FoundryLocalClient",
|
||||
"FoundryLocalSettings",
|
||||
"FoundryMemoryProvider",
|
||||
"GeneratedEvaluatorRef",
|
||||
"RawAnthropicFoundryClient",
|
||||
"RawFoundryAgent",
|
||||
"RawFoundryAgentChatClient",
|
||||
|
||||
@@ -394,3 +394,94 @@ def test_create_harness_agent_logs_warning_when_no_web_search(caplog: pytest.Log
|
||||
max_output_tokens=16_384,
|
||||
)
|
||||
assert any("SupportsWebSearchTool" in msg for msg in caplog.messages)
|
||||
|
||||
|
||||
# --- Background Agents Tests ---
|
||||
|
||||
|
||||
class _FakeBackgroundAgent:
|
||||
"""Minimal agent stub satisfying SupportsAgentRun for background agents tests."""
|
||||
|
||||
def __init__(self, name: str, description: str | None = None):
|
||||
self.id = f"agent-{name}"
|
||||
self.name = name
|
||||
self.description = description
|
||||
|
||||
def create_session(self, *, session_id: str | None = None) -> AgentSession:
|
||||
return AgentSession(session_id=session_id)
|
||||
|
||||
def get_session(self, service_session_id: str, *, session_id: str | None = None) -> AgentSession:
|
||||
return AgentSession(service_session_id=service_session_id, session_id=session_id)
|
||||
|
||||
async def run(self, messages: Any = None, *, stream: bool = False, session: Any = None, **kwargs: Any) -> Any:
|
||||
from agent_framework import AgentResponse
|
||||
|
||||
return AgentResponse(messages=[], response_id="fake-bg-response")
|
||||
|
||||
|
||||
def test_create_harness_agent_no_background_agents_by_default() -> None:
|
||||
"""No BackgroundAgentsProvider should be included when background_agents is not provided."""
|
||||
from agent_framework._harness._background_agents import BackgroundAgentsProvider
|
||||
|
||||
agent = create_harness_agent(
|
||||
client=_FakeChatClient(), # type: ignore[arg-type]
|
||||
max_context_window_tokens=128_000,
|
||||
max_output_tokens=16_384,
|
||||
disable_web_search=True,
|
||||
)
|
||||
providers = agent.context_providers or []
|
||||
assert not any(isinstance(p, BackgroundAgentsProvider) for p in providers)
|
||||
|
||||
|
||||
def test_create_harness_agent_adds_background_agents_provider() -> None:
|
||||
"""BackgroundAgentsProvider should be included when background_agents are provided."""
|
||||
from agent_framework._harness._background_agents import BackgroundAgentsProvider
|
||||
|
||||
bg_agent = _FakeBackgroundAgent("WebSearcher", "Searches the web")
|
||||
agent = create_harness_agent(
|
||||
client=_FakeChatClient(), # type: ignore[arg-type]
|
||||
max_context_window_tokens=128_000,
|
||||
max_output_tokens=16_384,
|
||||
disable_web_search=True,
|
||||
background_agents=[bg_agent],
|
||||
)
|
||||
providers = agent.context_providers or []
|
||||
bg_providers = [p for p in providers if isinstance(p, BackgroundAgentsProvider)]
|
||||
assert len(bg_providers) == 1
|
||||
|
||||
|
||||
def test_create_harness_agent_background_agents_custom_instructions() -> None:
|
||||
"""Custom instructions should be passed to BackgroundAgentsProvider."""
|
||||
from agent_framework._harness._background_agents import BackgroundAgentsProvider
|
||||
|
||||
custom_instructions = "## Custom\n\nUse agents wisely.\n\n{background_agents}"
|
||||
bg_agent = _FakeBackgroundAgent("Helper", "A helper agent")
|
||||
agent = create_harness_agent(
|
||||
client=_FakeChatClient(), # type: ignore[arg-type]
|
||||
max_context_window_tokens=128_000,
|
||||
max_output_tokens=16_384,
|
||||
disable_web_search=True,
|
||||
background_agents=[bg_agent],
|
||||
background_agents_instructions=custom_instructions,
|
||||
)
|
||||
providers = agent.context_providers or []
|
||||
bg_providers = [p for p in providers if isinstance(p, BackgroundAgentsProvider)]
|
||||
assert len(bg_providers) == 1
|
||||
# Verify the custom instructions were used (placeholder replaced with agent list).
|
||||
assert "Custom" in bg_providers[0]._instructions
|
||||
assert "Helper" in bg_providers[0]._instructions
|
||||
|
||||
|
||||
def test_create_harness_agent_empty_background_agents_list() -> None:
|
||||
"""An empty background_agents list should NOT add a BackgroundAgentsProvider."""
|
||||
from agent_framework._harness._background_agents import BackgroundAgentsProvider
|
||||
|
||||
agent = create_harness_agent(
|
||||
client=_FakeChatClient(), # type: ignore[arg-type]
|
||||
max_context_window_tokens=128_000,
|
||||
max_output_tokens=16_384,
|
||||
disable_web_search=True,
|
||||
background_agents=[],
|
||||
)
|
||||
providers = agent.context_providers or []
|
||||
assert not any(isinstance(p, BackgroundAgentsProvider) for p in providers)
|
||||
|
||||
@@ -11,8 +11,13 @@ import pytest
|
||||
from agent_framework._evaluation import (
|
||||
CheckResult,
|
||||
EvalItem,
|
||||
EvalItemResult,
|
||||
EvalNotPassedError,
|
||||
EvalResults,
|
||||
EvalScoreResult,
|
||||
ExpectedToolCall,
|
||||
LocalEvaluator,
|
||||
RubricScore,
|
||||
_coerce_result,
|
||||
evaluator,
|
||||
keyword_check,
|
||||
@@ -1010,19 +1015,101 @@ class TestAllPassedSubResults:
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# r5 review: _build_overall_item with empty outputs
|
||||
# Rubric assertions (EvalResults.assert_*)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildOverallItemEmpty:
|
||||
"""Test _build_overall_item returns None for empty workflow outputs."""
|
||||
def _rubric_results(*scores_per_item: list[EvalScoreResult]) -> EvalResults:
|
||||
items = [
|
||||
EvalItemResult(item_id=f"item-{i}", status="pass", scores=scores) for i, scores in enumerate(scores_per_item)
|
||||
]
|
||||
return EvalResults(
|
||||
provider="test",
|
||||
eval_id="ev1",
|
||||
run_id="run1",
|
||||
result_counts={"passed": len(items), "failed": 0, "errored": 0, "total": len(items)},
|
||||
items=items,
|
||||
)
|
||||
|
||||
def test_returns_none_for_empty_outputs(self):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from agent_framework._evaluation import _build_overall_item
|
||||
class TestRubricAssertions:
|
||||
"""Tests for EvalResults.assert_dimension_score_at_least."""
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.get_outputs.return_value = []
|
||||
item = _build_overall_item("Hello", mock_result)
|
||||
assert item is None
|
||||
def test_dimension_at_or_above_threshold_passes(self) -> None:
|
||||
results = _rubric_results(
|
||||
[
|
||||
EvalScoreResult(
|
||||
name="policy",
|
||||
score=0.9,
|
||||
dimensions=[RubricScore(id="clarity", score=4, applicable=True, weight=1, reason="")],
|
||||
)
|
||||
],
|
||||
)
|
||||
# Should not raise.
|
||||
results.assert_dimension_score_at_least("clarity", 3)
|
||||
|
||||
def test_dimension_below_threshold_raises(self) -> None:
|
||||
results = _rubric_results(
|
||||
[
|
||||
EvalScoreResult(
|
||||
name="policy",
|
||||
score=0.5,
|
||||
dimensions=[RubricScore(id="clarity", score=2, applicable=True, weight=1, reason="")],
|
||||
)
|
||||
],
|
||||
)
|
||||
with pytest.raises(EvalNotPassedError):
|
||||
results.assert_dimension_score_at_least("clarity", 3)
|
||||
|
||||
def test_non_applicable_skipped_by_default(self) -> None:
|
||||
results = _rubric_results(
|
||||
[
|
||||
EvalScoreResult(
|
||||
name="policy",
|
||||
score=1.0,
|
||||
dimensions=[RubricScore(id="clarity", score=None, applicable=False, weight=1, reason="n/a")],
|
||||
)
|
||||
],
|
||||
)
|
||||
# No applicable scores; default behaviour is to skip silently.
|
||||
results.assert_dimension_score_at_least("clarity", 3)
|
||||
|
||||
def test_require_applicable_raises_when_dimension_absent(self) -> None:
|
||||
results = _rubric_results(
|
||||
[EvalScoreResult(name="policy", score=1.0, dimensions=[])],
|
||||
)
|
||||
with pytest.raises(EvalNotPassedError, match="not applicable"):
|
||||
results.assert_dimension_score_at_least("clarity", 3, require_applicable=True)
|
||||
|
||||
def test_require_applicable_raises_when_filtered_evaluator_missing(self) -> None:
|
||||
# Regression: previously the (not evaluator or found_any) guard caused
|
||||
# this case to silently pass even with require_applicable=True.
|
||||
results = _rubric_results(
|
||||
[
|
||||
EvalScoreResult(
|
||||
name="other",
|
||||
score=0.9,
|
||||
dimensions=[RubricScore(id="clarity", score=4, applicable=True, weight=1, reason="")],
|
||||
)
|
||||
],
|
||||
)
|
||||
with pytest.raises(EvalNotPassedError, match="not applicable"):
|
||||
results.assert_dimension_score_at_least("clarity", 3, evaluator="policy", require_applicable=True)
|
||||
|
||||
def test_evaluator_filter_isolates_offenders(self) -> None:
|
||||
results = _rubric_results(
|
||||
[
|
||||
EvalScoreResult(
|
||||
name="other",
|
||||
score=0.1,
|
||||
dimensions=[RubricScore(id="clarity", score=1, applicable=True, weight=1, reason="")],
|
||||
),
|
||||
EvalScoreResult(
|
||||
name="policy",
|
||||
score=0.9,
|
||||
dimensions=[RubricScore(id="clarity", score=4, applicable=True, weight=1, reason="")],
|
||||
),
|
||||
],
|
||||
)
|
||||
# The low-scoring "other" evaluator is filtered out; "policy" passes.
|
||||
results.assert_dimension_score_at_least("clarity", 3, evaluator="policy")
|
||||
|
||||
@@ -1161,6 +1161,43 @@ async def test_local_mcp_server_function_execution_error():
|
||||
await func.invoke(param="test_value")
|
||||
|
||||
|
||||
async def test_mcp_tool_reconnects_after_session_terminated_error():
|
||||
"""Session termination errors should reconnect once and retry the tool call."""
|
||||
|
||||
class TestServer(MCPTool):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.connect_count = 0
|
||||
self.sessions: list[Any] = []
|
||||
|
||||
async def connect(self, *, reset: bool = False) -> None:
|
||||
self.connect_count += 1
|
||||
self.session = Mock(spec=ClientSession)
|
||||
self.sessions.append(self.session)
|
||||
if self.connect_count == 1:
|
||||
self.session.call_tool = AsyncMock(
|
||||
side_effect=McpError(types.ErrorData(code=-32000, message="Session terminated"))
|
||||
)
|
||||
else:
|
||||
self.session.call_tool = AsyncMock(
|
||||
return_value=types.CallToolResult(content=[types.TextContent(type="text", text="recovered")])
|
||||
)
|
||||
self.is_connected = True
|
||||
|
||||
def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
|
||||
return None
|
||||
|
||||
server = TestServer(name="test_server")
|
||||
await server.connect()
|
||||
|
||||
result = await server.call_tool("test_tool", param="test_value")
|
||||
|
||||
assert _mcp_result_to_text(result) == "recovered"
|
||||
assert server.connect_count == 2
|
||||
assert server.sessions[0].call_tool.await_count == 1
|
||||
assert server.sessions[1].call_tool.await_count == 1
|
||||
|
||||
|
||||
async def test_mcp_tool_call_tool_raises_on_is_error():
|
||||
"""Test that call_tool raises ToolExecutionException when MCP returns isError=True."""
|
||||
|
||||
@@ -3260,6 +3297,68 @@ async def test_load_prompts_pagination_with_duplicates():
|
||||
assert [f.name for f in tool._functions] == ["prompt_1", "prompt_2"]
|
||||
|
||||
|
||||
async def test_load_tools_concurrent_reload_does_not_duplicate_tools_and_preserves_meta():
|
||||
"""Concurrent tool reloads should not duplicate functions or lose tools/list metadata."""
|
||||
tool = MCPTool(name="test_tool")
|
||||
mock_session = AsyncMock()
|
||||
tool.session = mock_session
|
||||
tool.load_tools_flag = True
|
||||
|
||||
page = Mock()
|
||||
page.tools = [
|
||||
types.Tool(
|
||||
name="tool_1",
|
||||
description="First tool",
|
||||
inputSchema={"type": "object", "properties": {"param": {"type": "string"}}},
|
||||
_meta={"echo": "tool_1"},
|
||||
),
|
||||
]
|
||||
page.nextCursor = None
|
||||
|
||||
async def mock_list_tools(params: Any = None) -> Any:
|
||||
assert params is None
|
||||
await asyncio.sleep(0)
|
||||
return page
|
||||
|
||||
mock_session.list_tools = AsyncMock(side_effect=mock_list_tools)
|
||||
|
||||
await asyncio.wait_for(asyncio.gather(tool.load_tools(), tool.load_tools()), timeout=1)
|
||||
|
||||
assert mock_session.list_tools.call_count == 2
|
||||
assert [f.name for f in tool._functions] == ["tool_1"]
|
||||
assert tool._tool_call_meta_by_name == {"tool_1": {"echo": "tool_1"}}
|
||||
|
||||
|
||||
async def test_load_prompts_concurrent_reload_does_not_duplicate_prompts():
|
||||
"""Concurrent prompt reloads should not duplicate functions."""
|
||||
tool = MCPTool(name="test_tool")
|
||||
mock_session = AsyncMock()
|
||||
tool.session = mock_session
|
||||
tool.load_prompts_flag = True
|
||||
|
||||
page = Mock()
|
||||
page.prompts = [
|
||||
types.Prompt(
|
||||
name="prompt_1",
|
||||
description="First prompt",
|
||||
arguments=[types.PromptArgument(name="arg1", description="Arg 1", required=True)],
|
||||
),
|
||||
]
|
||||
page.nextCursor = None
|
||||
|
||||
async def mock_list_prompts(params: Any = None) -> Any:
|
||||
assert params is None
|
||||
await asyncio.sleep(0)
|
||||
return page
|
||||
|
||||
mock_session.list_prompts = AsyncMock(side_effect=mock_list_prompts)
|
||||
|
||||
await asyncio.wait_for(asyncio.gather(tool.load_prompts(), tool.load_prompts()), timeout=1)
|
||||
|
||||
assert mock_session.list_prompts.call_count == 2
|
||||
assert [f.name for f in tool._functions] == ["prompt_1"]
|
||||
|
||||
|
||||
async def test_load_tools_pagination_exception_handling():
|
||||
"""Test that load_tools handles exceptions during pagination gracefully."""
|
||||
from unittest.mock import AsyncMock
|
||||
@@ -3891,6 +3990,31 @@ async def test_mcp_tool_safe_close_handles_cancelled_error():
|
||||
mock_exit_stack.aclose.assert_called_once()
|
||||
|
||||
|
||||
async def test_mcp_tool_safe_close_handles_cleanup_exception_group():
|
||||
"""Cleanup task groups should not hide the original connect failure."""
|
||||
import builtins
|
||||
from contextlib import AsyncExitStack
|
||||
|
||||
exception_group_type = getattr(builtins, "ExceptionGroup", None)
|
||||
if exception_group_type is None:
|
||||
pytest.skip("ExceptionGroup is not available on this Python version")
|
||||
|
||||
tool = MCPStreamableHTTPTool(
|
||||
name="test",
|
||||
url="http://example.com/mcp",
|
||||
load_tools=False,
|
||||
load_prompts=False,
|
||||
)
|
||||
|
||||
mock_exit_stack = AsyncMock(spec=AsyncExitStack)
|
||||
mock_exit_stack.aclose = AsyncMock(side_effect=exception_group_type("cleanup failed", [RuntimeError("reader")]))
|
||||
tool._exit_stack = mock_exit_stack
|
||||
|
||||
await tool._safe_close_exit_stack()
|
||||
|
||||
mock_exit_stack.aclose.assert_called_once()
|
||||
|
||||
|
||||
async def test_connect_sets_logging_level_when_logger_level_is_set():
|
||||
"""Test that connect() sets the MCP server logging level when the logger level is not NOTSET."""
|
||||
|
||||
@@ -4389,6 +4513,52 @@ async def test_mcp_tool_call_tool_forwards_tool_list_meta():
|
||||
assert server.session.call_tool.call_args.kwargs["meta"] == tool_meta
|
||||
|
||||
|
||||
async def test_mcp_tool_call_tool_user_meta_merges_with_tool_list_meta():
|
||||
"""User-provided _meta should be sent as MCP request metadata, not tool arguments."""
|
||||
from opentelemetry import trace
|
||||
|
||||
tool_meta = {"from_tool": "tool-value", "shared": "tool-value"}
|
||||
user_meta = {"from_user": "user-value", "shared": "user-value"}
|
||||
|
||||
class TestServer(MCPTool):
|
||||
async def connect(self) -> None:
|
||||
self.session = Mock(spec=ClientSession)
|
||||
self.session.list_tools = AsyncMock(
|
||||
return_value=types.ListToolsResult(
|
||||
tools=[
|
||||
types.Tool(
|
||||
name="test_tool",
|
||||
description="Test tool",
|
||||
inputSchema={"type": "object", "properties": {"param": {"type": "string"}}},
|
||||
_meta=tool_meta,
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
self.session.call_tool = AsyncMock(
|
||||
return_value=types.CallToolResult(content=[types.TextContent(type="text", text="result")])
|
||||
)
|
||||
|
||||
def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
|
||||
return None
|
||||
|
||||
server = TestServer(name="test_server")
|
||||
async with server:
|
||||
await server.load_tools()
|
||||
|
||||
with trace.use_span(trace.NonRecordingSpan(trace.INVALID_SPAN_CONTEXT)):
|
||||
await server.call_tool("test_tool", param="test_value", _meta=user_meta)
|
||||
|
||||
call_kwargs = server.session.call_tool.call_args.kwargs
|
||||
assert call_kwargs["arguments"] == {"param": "test_value"}
|
||||
assert call_kwargs["meta"] == {
|
||||
"from_tool": "tool-value",
|
||||
"from_user": "user-value",
|
||||
"shared": "user-value",
|
||||
}
|
||||
assert user_meta == {"from_user": "user-value", "shared": "user-value"}
|
||||
|
||||
|
||||
async def test_mcp_streamable_http_tool_hook_not_duplicated_on_repeated_get_mcp_client():
|
||||
"""Test that calling get_mcp_client multiple times does not accumulate duplicate hooks."""
|
||||
tool = MCPStreamableHTTPTool(
|
||||
@@ -4641,6 +4811,42 @@ async def test_mcp_streamable_http_tool_header_provider_with_httpx_event_hook():
|
||||
await tool._httpx_client.aclose()
|
||||
|
||||
|
||||
async def test_mcp_streamable_http_tool_header_provider_skips_cross_origin_redirect():
|
||||
"""The request hook must not re-add caller headers after a cross-origin redirect."""
|
||||
import httpx
|
||||
|
||||
from agent_framework._mcp import _mcp_call_headers
|
||||
|
||||
tool = MCPStreamableHTTPTool(
|
||||
name="test",
|
||||
url="http://example.com/mcp",
|
||||
header_provider=lambda kw: {"Authorization": f"Bearer {kw.get('token', '')}"},
|
||||
)
|
||||
|
||||
try:
|
||||
with patch("agent_framework._mcp.streamable_http_client"):
|
||||
tool.get_mcp_client()
|
||||
|
||||
assert tool._httpx_client is not None
|
||||
hooks = tool._httpx_client.event_hooks.get("request", [])
|
||||
assert len(hooks) == 1
|
||||
|
||||
token = _mcp_call_headers.set({"Authorization": "Bearer secret"})
|
||||
try:
|
||||
same_origin = httpx.Request("POST", "http://example.com/redirected")
|
||||
await hooks[0](same_origin)
|
||||
assert same_origin.headers.get("Authorization") == "Bearer secret"
|
||||
|
||||
cross_origin = httpx.Request("POST", "http://attacker.example/capture")
|
||||
await hooks[0](cross_origin)
|
||||
assert "Authorization" not in cross_origin.headers
|
||||
finally:
|
||||
_mcp_call_headers.reset(token)
|
||||
finally:
|
||||
if getattr(tool, "_httpx_client", None) is not None:
|
||||
await tool._httpx_client.aclose()
|
||||
|
||||
|
||||
async def test_mcp_streamable_http_tool_header_provider_with_user_httpx_client():
|
||||
"""Test that header_provider works when the user provides their own httpx client."""
|
||||
import httpx
|
||||
|
||||
@@ -1691,6 +1691,65 @@ def test_to_otel_part_function_call():
|
||||
}
|
||||
|
||||
|
||||
def test_to_otel_part_function_call_reuses_prepared_arguments():
|
||||
"""Test _to_otel_part does not re-serialize function-call arguments in the observability hot path."""
|
||||
from agent_framework import Content
|
||||
from agent_framework.observability import _to_otel_part
|
||||
|
||||
arguments = {"payload": object()}
|
||||
content = Content(type="function_call", call_id="call_789", name="handoff", arguments=arguments)
|
||||
result = _to_otel_part(content)
|
||||
|
||||
assert result is not None
|
||||
assert result["arguments"] is arguments
|
||||
|
||||
|
||||
def test_make_json_safe_non_callable_method_attribute():
|
||||
"""Test make_json_safe handles objects where model_dump/to_dict/dict are non-callable attributes."""
|
||||
from agent_framework._serialization import make_json_safe
|
||||
|
||||
class ObjWithNonCallableModelDump:
|
||||
model_dump = 42 # not callable
|
||||
|
||||
obj = ObjWithNonCallableModelDump()
|
||||
result = make_json_safe(obj)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_make_json_safe_callable_method_type_error_falls_through():
|
||||
"""Test make_json_safe falls through when serializer-like methods require arguments."""
|
||||
from agent_framework._serialization import make_json_safe
|
||||
|
||||
class ObjWithRequiredArgModelDump:
|
||||
def __init__(self) -> None:
|
||||
self.value = "fallback"
|
||||
|
||||
def model_dump(self, required: str) -> dict[str, str]:
|
||||
return {"required": required}
|
||||
|
||||
obj = ObjWithRequiredArgModelDump()
|
||||
result = make_json_safe(obj)
|
||||
assert result == {"value": "fallback"}
|
||||
|
||||
|
||||
def test_make_json_safe_dict_with_non_string_keys():
|
||||
"""Test make_json_safe converts non-primitive dict keys to strings."""
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from agent_framework._serialization import make_json_safe
|
||||
|
||||
dt_key = datetime(2024, 1, 1)
|
||||
obj = {dt_key: "value", 42: "num_value", "str_key": "normal"}
|
||||
result = make_json_safe(obj)
|
||||
# json.dumps must not raise TypeError
|
||||
serialized = json.dumps(result)
|
||||
parsed = json.loads(serialized)
|
||||
assert parsed[str(dt_key)] == "value"
|
||||
assert parsed["42"] == "num_value"
|
||||
assert parsed["str_key"] == "normal"
|
||||
|
||||
|
||||
def test_to_otel_part_function_result():
|
||||
"""Test _to_otel_part with function_result content."""
|
||||
from agent_framework import Content
|
||||
@@ -3019,6 +3078,49 @@ async def test_system_instructions_preserves_non_ascii_characters(span_exporter:
|
||||
assert [msg.get("role") for msg in input_messages] == ["user"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enable_sensitive_data", [True], indirect=True)
|
||||
def test_capture_messages_with_prepared_request_info_function_call_arguments(span_exporter: InMemorySpanExporter):
|
||||
"""Test _capture_messages handles request-info function-call arguments prepared at Content creation."""
|
||||
import dataclasses
|
||||
import json
|
||||
|
||||
from opentelemetry import trace
|
||||
|
||||
from agent_framework import WorkflowAgent
|
||||
|
||||
@dataclasses.dataclass
|
||||
class HandoffRequest:
|
||||
target_agent: str
|
||||
reason: str
|
||||
|
||||
arguments = WorkflowAgent.RequestInfoFunctionArgs(
|
||||
request_id="call_dc",
|
||||
data=HandoffRequest(target_agent="helper", reason="overflow"),
|
||||
).to_dict()
|
||||
msg = Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content(
|
||||
type="function_call",
|
||||
call_id="call_dc",
|
||||
name="request_info",
|
||||
arguments=arguments,
|
||||
)
|
||||
],
|
||||
)
|
||||
span_exporter.clear()
|
||||
tracer = trace.get_tracer("test")
|
||||
with tracer.start_as_current_span("test_span") as span:
|
||||
_capture_messages(span=span, provider_name="test_provider", messages=[msg])
|
||||
|
||||
spans = span_exporter.get_finished_spans()
|
||||
span = spans[0]
|
||||
input_messages = json.loads(span.attributes[OtelAttr.INPUT_MESSAGES])
|
||||
tool_part = input_messages[0]["parts"][0]
|
||||
assert tool_part["type"] == "tool_call"
|
||||
assert tool_part["arguments"]["data"] == {"target_agent": "helper", "reason": "overflow"}
|
||||
|
||||
|
||||
def test_capture_messages_keeps_framework_instructions_out_of_logs_and_span_messages(
|
||||
span_exporter: InMemorySpanExporter,
|
||||
):
|
||||
|
||||
@@ -307,6 +307,63 @@ class TestHistoryProviderBase:
|
||||
assert provider.stored[0].text == "hello"
|
||||
assert provider.stored[1].text == "hi"
|
||||
|
||||
async def test_after_run_stores_coalesced_code_interpreter_chunks(self) -> None:
|
||||
from agent_framework import AgentResponse, AgentResponseUpdate, Content
|
||||
|
||||
provider = ConcreteHistoryProvider("mem", store_inputs=False)
|
||||
updates = [
|
||||
AgentResponseUpdate(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_code_interpreter_tool_result(
|
||||
call_id="ci_123",
|
||||
outputs=[],
|
||||
)
|
||||
],
|
||||
),
|
||||
AgentResponseUpdate(
|
||||
contents=[
|
||||
Content.from_code_interpreter_tool_call(
|
||||
call_id="ci_123",
|
||||
inputs=[Content.from_text(text="import")],
|
||||
additional_properties={"sequence_number": 1},
|
||||
)
|
||||
],
|
||||
),
|
||||
AgentResponseUpdate(
|
||||
contents=[
|
||||
Content.from_code_interpreter_tool_call(
|
||||
call_id="ci_123",
|
||||
inputs=[Content.from_text(text=" pandas")],
|
||||
additional_properties={"sequence_number": 2},
|
||||
)
|
||||
],
|
||||
),
|
||||
AgentResponseUpdate(
|
||||
contents=[
|
||||
Content.from_code_interpreter_tool_call(
|
||||
call_id="ci_123",
|
||||
inputs=[Content.from_text(text="import pandas as pd")],
|
||||
additional_properties={"sequence_number": 3},
|
||||
)
|
||||
],
|
||||
),
|
||||
]
|
||||
ctx = SessionContext(session_id="s1", input_messages=[Message(role="user", contents=["make a sheet"])])
|
||||
ctx._response = AgentResponse.from_updates(updates)
|
||||
|
||||
await provider.after_run(agent=None, session=AgentSession(), context=ctx, state={}) # type: ignore[arg-type]
|
||||
|
||||
assert len(provider.stored) == 1
|
||||
stored_contents = provider.stored[0].contents
|
||||
calls = [content for content in stored_contents if content.type == "code_interpreter_tool_call"]
|
||||
results = [content for content in stored_contents if content.type == "code_interpreter_tool_result"]
|
||||
assert len(calls) == 1
|
||||
assert len(results) == 1
|
||||
assert calls[0].inputs is not None
|
||||
assert len(calls[0].inputs) == 1
|
||||
assert calls[0].inputs[0].text == "import pandas as pd"
|
||||
|
||||
async def test_after_run_skips_inputs_when_disabled(self) -> None:
|
||||
from agent_framework import AgentResponse
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
@@ -1642,6 +1643,37 @@ class TestFunctionalWorkflowAgentHITL:
|
||||
break
|
||||
assert approval_found, "expected FunctionApprovalRequestContent in agent response"
|
||||
|
||||
async def test_request_info_dataclass_arguments_are_serialized_for_agent(self):
|
||||
@dataclass
|
||||
class HandoffRequest:
|
||||
target_agent: str
|
||||
reason: str
|
||||
|
||||
@workflow
|
||||
async def wf(x: str, ctx: RunContext) -> str:
|
||||
answer = await ctx.request_info(
|
||||
HandoffRequest(target_agent=x, reason="overflow"),
|
||||
response_type=str,
|
||||
request_id="rid-1",
|
||||
)
|
||||
return f"got:{answer}"
|
||||
|
||||
agent = wf.as_agent()
|
||||
response = await agent.run("helper")
|
||||
|
||||
function_call_arguments = None
|
||||
for message in response.messages:
|
||||
for content in message.contents:
|
||||
if getattr(content, "type", None) == "function_approval_request" and content.function_call is not None:
|
||||
function_call_arguments = content.function_call.arguments
|
||||
break
|
||||
|
||||
assert function_call_arguments == {
|
||||
"request_id": "rid-1",
|
||||
"data": {"target_agent": "helper", "reason": "overflow"},
|
||||
}
|
||||
assert json.loads(json.dumps(function_call_arguments)) == function_call_arguments
|
||||
|
||||
async def test_resume_via_agent_responses_kwarg(self):
|
||||
@workflow
|
||||
async def wf(x: str, ctx: RunContext) -> str:
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Awaitable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal, overload
|
||||
|
||||
import pytest
|
||||
@@ -23,6 +25,7 @@ from agent_framework import (
|
||||
WorkflowAgent,
|
||||
WorkflowBuilder,
|
||||
WorkflowContext,
|
||||
WorkflowEvent,
|
||||
executor,
|
||||
handler,
|
||||
response_handler,
|
||||
@@ -292,6 +295,33 @@ class TestWorkflowAgent:
|
||||
pending_requests = await workflow._runner_context.get_pending_request_info_events()
|
||||
assert len(pending_requests) == 0
|
||||
|
||||
def test_request_info_dataclass_arguments_are_serialized_when_content_is_created(self) -> None:
|
||||
"""Test WorkflowAgent prepares request_info arguments before observability captures messages."""
|
||||
|
||||
@dataclass
|
||||
class HandoffRequest:
|
||||
target_agent: str
|
||||
reason: str
|
||||
|
||||
executor = SimpleExecutor(id="executor1", response_text="Response")
|
||||
workflow = WorkflowBuilder(start_executor=executor).build()
|
||||
agent = WorkflowAgent(workflow=workflow, name="Request Test Agent")
|
||||
event = WorkflowEvent.request_info(
|
||||
request_id="request_123",
|
||||
source_executor_id="executor1",
|
||||
request_data=HandoffRequest(target_agent="helper", reason="overflow"),
|
||||
response_type=str,
|
||||
)
|
||||
|
||||
function_call, approval_request = agent._process_request_info_event(event) # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
assert function_call.arguments == {
|
||||
"request_id": "request_123",
|
||||
"data": {"target_agent": "helper", "reason": "overflow"},
|
||||
}
|
||||
assert approval_request.function_call is function_call
|
||||
assert json.loads(json.dumps(function_call.arguments)) == function_call.arguments
|
||||
|
||||
def test_workflow_as_agent_method(self) -> None:
|
||||
"""Test that Workflow.as_agent() creates a properly configured WorkflowAgent."""
|
||||
# Create a simple workflow
|
||||
|
||||
@@ -29,7 +29,7 @@ dependencies = [
|
||||
]
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"types-PyYaml==6.0.12.20250915"
|
||||
"types-PyYaml==6.0.12.20260518"
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
|
||||
@@ -375,13 +375,15 @@ class DevServer:
|
||||
logger.info("Starting Agent Framework Server")
|
||||
await self._ensure_executor()
|
||||
await self._ensure_openai_executor() # Initialize OpenAI executor
|
||||
yield
|
||||
# Shutdown
|
||||
logger.info("Shutting down Agent Framework Server")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Shutdown
|
||||
logger.info("Shutting down Agent Framework Server")
|
||||
|
||||
# Cleanup entity resources (e.g., close credentials, clients)
|
||||
if self.executor:
|
||||
await self._cleanup_entities()
|
||||
# Cleanup entity resources (e.g., close credentials, clients)
|
||||
if self.executor:
|
||||
await self._cleanup_entities()
|
||||
|
||||
app = FastAPI(
|
||||
title="Agent Framework Server",
|
||||
|
||||
@@ -30,7 +30,7 @@ dependencies = [
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"types-python-dateutil==2.9.0.20260402",
|
||||
"types-python-dateutil==2.9.0.20260518",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
|
||||
@@ -12,6 +12,7 @@ from ._embedding_client import (
|
||||
)
|
||||
from ._foundry_evals import (
|
||||
FoundryEvals,
|
||||
GeneratedEvaluatorRef,
|
||||
evaluate_foundry_target,
|
||||
evaluate_traces,
|
||||
)
|
||||
@@ -33,6 +34,7 @@ __all__ = [
|
||||
"FoundryEmbeddingSettings",
|
||||
"FoundryEvals",
|
||||
"FoundryMemoryProvider",
|
||||
"GeneratedEvaluatorRef",
|
||||
"RawFoundryAgent",
|
||||
"RawFoundryAgentChatClient",
|
||||
"RawFoundryChatClient",
|
||||
|
||||
@@ -57,8 +57,6 @@ if TYPE_CHECKING:
|
||||
from agent_framework import (
|
||||
Agent,
|
||||
AgentRunInputs,
|
||||
ChatAndFunctionMiddlewareTypes,
|
||||
ContextProvider,
|
||||
MiddlewareTypes,
|
||||
ToolTypes,
|
||||
)
|
||||
|
||||
@@ -28,8 +28,9 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from collections.abc import Iterable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from agent_framework._evaluation import (
|
||||
AgentEvalConverter,
|
||||
@@ -39,6 +40,7 @@ from agent_framework._evaluation import (
|
||||
EvalItemResult,
|
||||
EvalResults,
|
||||
EvalScoreResult,
|
||||
RubricScore,
|
||||
)
|
||||
from agent_framework._feature_stage import ExperimentalFeature, experimental
|
||||
from openai import AsyncOpenAI
|
||||
@@ -51,6 +53,54 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# region Generated rubric evaluator references
|
||||
|
||||
|
||||
@experimental(feature_id=ExperimentalFeature.EVALS)
|
||||
@dataclass(frozen=True)
|
||||
class GeneratedEvaluatorRef:
|
||||
"""A reference to a rubric evaluator that already exists in Foundry.
|
||||
|
||||
Pass instances of this class to :class:`FoundryEvals` to score items
|
||||
with a pre-existing rubric evaluator (manually authored or
|
||||
auto-generated through the Foundry portal). agent-framework is a
|
||||
consumer here: it does not create or modify the evaluator definition;
|
||||
it only references the persisted version by name.
|
||||
|
||||
Pinning ``version`` is strongly recommended so evaluation runs are
|
||||
reproducible. ``version=None`` resolves to whichever version is
|
||||
current at execution time; :class:`FoundryEvals` emits a warning when
|
||||
a versionless reference is used. CI gates should always pass a
|
||||
concrete version.
|
||||
|
||||
Attributes:
|
||||
name: Evaluator name as stored in the Foundry project (for
|
||||
example ``"reservation-policy-rubric"``). Distinct from
|
||||
built-in evaluators such as ``"builtin.relevance"``.
|
||||
version: Pinned evaluator version. ``None`` means "latest" —
|
||||
this is discouraged for CI/repro and :class:`FoundryEvals`
|
||||
will emit a warning when used.
|
||||
display_name: Optional human-readable name used in result
|
||||
summaries. Defaults to ``name`` when unset.
|
||||
"""
|
||||
|
||||
name: str
|
||||
version: str | None = None
|
||||
display_name: str | None = None
|
||||
|
||||
@classmethod
|
||||
def latest(cls, name: str, *, display_name: str | None = None) -> GeneratedEvaluatorRef:
|
||||
"""Construct a versionless reference (resolves to the latest version at run time).
|
||||
|
||||
Discouraged for reproducible runs. Prefer the constructor with
|
||||
an explicit ``version`` so CI and replay evaluations stay stable
|
||||
when the evaluator is updated in Foundry.
|
||||
"""
|
||||
return cls(name=name, version=None, display_name=display_name)
|
||||
|
||||
|
||||
# endregion
|
||||
# Agent evaluators that accept query/response as conversation arrays.
|
||||
# Maintained manually — check https://learn.microsoft.com/en-us/azure/ai-studio/how-to/develop/evaluate-sdk
|
||||
# for the latest evaluator list. These are the evaluators that need conversation-format input.
|
||||
@@ -166,7 +216,7 @@ def _resolve_evaluator(name: str) -> str:
|
||||
|
||||
|
||||
def _build_testing_criteria(
|
||||
evaluators: Sequence[str],
|
||||
evaluators: Sequence[str | GeneratedEvaluatorRef],
|
||||
model: str,
|
||||
*,
|
||||
include_data_mapping: bool = False,
|
||||
@@ -175,7 +225,9 @@ def _build_testing_criteria(
|
||||
"""Build ``testing_criteria`` for ``evals.create()``.
|
||||
|
||||
Args:
|
||||
evaluators: Evaluator names.
|
||||
evaluators: Evaluator names (built-in shorts / fully-qualified
|
||||
``builtin.*`` names) or :class:`GeneratedEvaluatorRef`
|
||||
instances for generated rubric evaluators.
|
||||
model: Model deployment for the LLM judge.
|
||||
include_data_mapping: Whether to include field-level data mapping
|
||||
(required for the JSONL data source, not needed for response-based).
|
||||
@@ -183,7 +235,38 @@ def _build_testing_criteria(
|
||||
definitions.
|
||||
"""
|
||||
criteria: list[dict[str, Any]] = []
|
||||
for name in evaluators:
|
||||
for entry_spec in evaluators:
|
||||
if isinstance(entry_spec, GeneratedEvaluatorRef):
|
||||
short = entry_spec.display_name or entry_spec.name
|
||||
ref_entry: dict[str, Any] = {
|
||||
"type": "azure_ai_evaluator",
|
||||
"name": short,
|
||||
"evaluator_name": entry_spec.name,
|
||||
"initialization_parameters": {"deployment_name": model},
|
||||
}
|
||||
if entry_spec.version is not None:
|
||||
ref_entry["evaluator_version"] = entry_spec.version
|
||||
else:
|
||||
logger.warning(
|
||||
"GeneratedEvaluatorRef '%s' has no pinned version; the eval run "
|
||||
"will resolve to whichever version is current at execution time. "
|
||||
"Pin the version for reproducible runs.",
|
||||
entry_spec.name,
|
||||
)
|
||||
if include_data_mapping:
|
||||
# Rubric evaluators accept conversation arrays like agent
|
||||
# evaluators, plus tool_definitions when items are tool-aware.
|
||||
ref_mapping: dict[str, str] = {
|
||||
"query": "{{item.query_messages}}",
|
||||
"response": "{{item.response_messages}}",
|
||||
}
|
||||
if include_tool_definitions:
|
||||
ref_mapping["tool_definitions"] = "{{item.tool_definitions}}"
|
||||
ref_entry["data_mapping"] = ref_mapping
|
||||
criteria.append(ref_entry)
|
||||
continue
|
||||
|
||||
name = entry_spec
|
||||
qualified = _resolve_evaluator(name)
|
||||
short = name if not name.startswith("builtin.") else name.split(".")[-1]
|
||||
|
||||
@@ -247,9 +330,9 @@ def _build_item_schema(
|
||||
|
||||
|
||||
def _resolve_default_evaluators(
|
||||
evaluators: Sequence[str] | None,
|
||||
evaluators: Sequence[str | GeneratedEvaluatorRef] | None,
|
||||
items: Sequence[EvalItem | dict[str, Any]] | None = None,
|
||||
) -> list[str]:
|
||||
) -> list[str | GeneratedEvaluatorRef]:
|
||||
"""Resolve evaluators, applying defaults when ``None``.
|
||||
|
||||
Defaults to relevance + coherence + task_adherence. Automatically adds
|
||||
@@ -258,7 +341,7 @@ def _resolve_default_evaluators(
|
||||
if evaluators is not None:
|
||||
return list(evaluators)
|
||||
|
||||
result = list(_DEFAULT_EVALUATORS)
|
||||
result: list[str | GeneratedEvaluatorRef] = list(_DEFAULT_EVALUATORS)
|
||||
if items is not None:
|
||||
has_tools = any((item.tools if isinstance(item, EvalItem) else item.get("tool_definitions")) for item in items)
|
||||
if has_tools:
|
||||
@@ -267,14 +350,24 @@ def _resolve_default_evaluators(
|
||||
|
||||
|
||||
def _filter_tool_evaluators(
|
||||
evaluators: list[str],
|
||||
evaluators: list[str | GeneratedEvaluatorRef],
|
||||
items: Sequence[EvalItem | dict[str, Any]],
|
||||
) -> list[str]:
|
||||
"""Remove tool evaluators if no items have tool definitions."""
|
||||
) -> list[str | GeneratedEvaluatorRef]:
|
||||
"""Remove tool evaluators if no items have tool definitions.
|
||||
|
||||
Generated rubric evaluators are tool-aware but not tool-required; they
|
||||
are preserved regardless of whether items carry tool definitions.
|
||||
"""
|
||||
has_tools = any((item.tools if isinstance(item, EvalItem) else item.get("tool_definitions")) for item in items)
|
||||
if has_tools:
|
||||
return evaluators
|
||||
filtered = [e for e in evaluators if _resolve_evaluator(e) not in _TOOL_EVALUATORS]
|
||||
|
||||
def _is_tool_only(spec: str | GeneratedEvaluatorRef) -> bool:
|
||||
if isinstance(spec, GeneratedEvaluatorRef):
|
||||
return False
|
||||
return _resolve_evaluator(spec) in _TOOL_EVALUATORS
|
||||
|
||||
filtered = [e for e in evaluators if not _is_tool_only(e)]
|
||||
if not filtered:
|
||||
raise ValueError(
|
||||
f"All requested evaluators {evaluators} require tool definitions, "
|
||||
@@ -282,7 +375,7 @@ def _filter_tool_evaluators(
|
||||
"or choose evaluators that do not require tools."
|
||||
)
|
||||
if len(filtered) < len(evaluators):
|
||||
removed = [e for e in evaluators if _resolve_evaluator(e) in _TOOL_EVALUATORS]
|
||||
removed = [e for e in evaluators if _is_tool_only(e)]
|
||||
logger.info("Removed tool evaluators %s (no items have tools)", removed)
|
||||
return filtered
|
||||
|
||||
@@ -354,6 +447,114 @@ def _extract_per_evaluator(run: RunRetrieveResponse) -> dict[str, dict[str, int]
|
||||
return per_eval
|
||||
|
||||
|
||||
_RUBRIC_DIMENSION_KEYS: tuple[str, ...] = ("dimension_scores", "rubric_scores")
|
||||
"""Property keys that may carry per-dimension rubric breakdowns.
|
||||
|
||||
The published Foundry rubric-evaluator output format uses
|
||||
``properties.dimension_scores`` (see the Microsoft Learn "Rubric
|
||||
evaluators" reference). Earlier preview builds and some SDK shapes
|
||||
used ``rubric_scores``; we accept both for defensive forward/backward
|
||||
compatibility.
|
||||
"""
|
||||
|
||||
|
||||
def _parse_dimension_entries(raw: Any) -> list[RubricScore]:
|
||||
"""Parse a raw list-like payload into ``RubricScore`` instances.
|
||||
|
||||
Returns an empty list when ``raw`` is falsy, not iterable, or
|
||||
contains no well-formed entries.
|
||||
"""
|
||||
if not raw:
|
||||
return []
|
||||
try:
|
||||
raw_iter: Iterable[Any] = iter(raw)
|
||||
except TypeError:
|
||||
return []
|
||||
|
||||
parsed: list[RubricScore] = []
|
||||
for raw_entry in raw_iter:
|
||||
entry: Any = raw_entry
|
||||
try:
|
||||
rid: Any
|
||||
score_val: Any
|
||||
applicable: Any
|
||||
weight: Any
|
||||
reason: Any
|
||||
if isinstance(entry, dict):
|
||||
entry_any = cast("dict[str, Any]", entry)
|
||||
rid = entry_any.get("id")
|
||||
score_val = entry_any.get("score")
|
||||
applicable = entry_any.get("applicable")
|
||||
weight = entry_any.get("weight")
|
||||
reason = entry_any.get("reason", "")
|
||||
else:
|
||||
rid = getattr(entry, "id", None)
|
||||
score_val = getattr(entry, "score", None)
|
||||
applicable = getattr(entry, "applicable", None)
|
||||
weight = getattr(entry, "weight", None)
|
||||
reason = getattr(entry, "reason", "") or ""
|
||||
if rid is None or weight is None or applicable is None:
|
||||
continue
|
||||
parsed.append(
|
||||
RubricScore(
|
||||
id=str(rid),
|
||||
score=int(score_val) if isinstance(score_val, (int, float)) else None,
|
||||
applicable=bool(applicable),
|
||||
weight=int(weight),
|
||||
reason=str(reason) if reason is not None else "",
|
||||
)
|
||||
)
|
||||
except (TypeError, ValueError):
|
||||
logger.debug("Skipping malformed rubric dimension entry: %s", cast("Any", entry), exc_info=True)
|
||||
return parsed
|
||||
|
||||
|
||||
def _extract_rubric_scores(sample: Any) -> list[RubricScore] | None:
|
||||
"""Extract typed ``RubricScore`` instances from an evaluator's raw sample payload.
|
||||
|
||||
Foundry rubric evaluators include a per-dimension breakdown under
|
||||
``properties.dimension_scores`` on each result (preview builds used
|
||||
``rubric_scores``; both keys are accepted, with the canonical
|
||||
``dimension_scores`` taking priority). The exact location may
|
||||
vary across SDK versions, so this helper accepts a few shapes:
|
||||
|
||||
* The SDK ``sample`` object exposes
|
||||
``properties.dimension_scores`` / ``properties.rubric_scores``.
|
||||
* The ``sample`` is a dict containing the same under
|
||||
``properties.<key>``.
|
||||
* The ``sample`` is a dict with ``dimension_scores`` /
|
||||
``rubric_scores`` at the top level.
|
||||
|
||||
Returns ``None`` when no rubric scores are present (i.e. the
|
||||
evaluator was not a rubric evaluator).
|
||||
"""
|
||||
if sample is None:
|
||||
return None
|
||||
|
||||
containers: list[Any] = []
|
||||
properties: Any = getattr(sample, "properties", None)
|
||||
if properties is not None:
|
||||
containers.append(properties)
|
||||
if isinstance(sample, dict):
|
||||
sample_any = cast("dict[str, Any]", sample)
|
||||
props_dict: Any = sample_any.get("properties")
|
||||
if props_dict is not None and props_dict is not properties:
|
||||
containers.append(props_dict)
|
||||
containers.append(sample_any)
|
||||
|
||||
for container in containers:
|
||||
for key in _RUBRIC_DIMENSION_KEYS:
|
||||
raw: Any = None
|
||||
if isinstance(container, dict):
|
||||
raw = cast("dict[str, Any]", container).get(key)
|
||||
elif hasattr(container, key):
|
||||
raw = getattr(container, key, None)
|
||||
parsed = _parse_dimension_entries(raw)
|
||||
if parsed:
|
||||
return parsed
|
||||
return None
|
||||
|
||||
|
||||
async def _fetch_output_items(
|
||||
client: AsyncOpenAI,
|
||||
eval_id: str,
|
||||
@@ -377,12 +578,15 @@ async def _fetch_output_items(
|
||||
# Extract per-evaluator scores
|
||||
scores: list[EvalScoreResult] = []
|
||||
for r in oi.results or []:
|
||||
sample = r.sample
|
||||
dimensions = _extract_rubric_scores(sample)
|
||||
scores.append(
|
||||
EvalScoreResult(
|
||||
name=r.name,
|
||||
score=r.score,
|
||||
passed=r.passed,
|
||||
sample=r.sample,
|
||||
sample=sample,
|
||||
dimensions=dimensions,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -394,15 +598,18 @@ async def _fetch_output_items(
|
||||
output_text: str | None = None
|
||||
response_id: str | None = None
|
||||
|
||||
sample = oi.sample
|
||||
if sample is not None: # pyright: ignore[reportUnnecessaryComparison]
|
||||
err = sample.error
|
||||
if err is not None and (err.code or err.message): # pyright: ignore[reportUnnecessaryComparison]
|
||||
# mypy infers oi.sample as dict[str, object] | None, but the
|
||||
# OpenAI SDK actually returns a typed Sample model. Cast to Any so
|
||||
# both type checkers accept the attribute access pattern.
|
||||
oi_sample: Any = oi.sample
|
||||
if oi_sample is not None:
|
||||
err = oi_sample.error
|
||||
if err is not None and (err.code or err.message):
|
||||
error_code = err.code or None
|
||||
error_message = err.message or None
|
||||
|
||||
usage = sample.usage
|
||||
if usage is not None and usage.total_tokens: # pyright: ignore[reportUnnecessaryComparison]
|
||||
usage = oi_sample.usage
|
||||
if usage is not None and usage.total_tokens:
|
||||
token_usage = {
|
||||
"prompt_tokens": usage.prompt_tokens,
|
||||
"completion_tokens": usage.completion_tokens,
|
||||
@@ -411,13 +618,13 @@ async def _fetch_output_items(
|
||||
}
|
||||
|
||||
# Extract input/output text
|
||||
if sample.input:
|
||||
parts = [si.content for si in sample.input if si.role == "user"]
|
||||
if oi_sample.input:
|
||||
parts = [si.content for si in oi_sample.input if si.role == "user"]
|
||||
if parts:
|
||||
input_text = " ".join(parts)
|
||||
|
||||
if sample.output:
|
||||
parts = [so.content or "" for so in sample.output if so.role == "assistant"]
|
||||
if oi_sample.output:
|
||||
parts = [so.content or "" for so in oi_sample.output if so.role == "assistant"]
|
||||
if parts:
|
||||
output_text = " ".join(parts)
|
||||
|
||||
@@ -472,7 +679,7 @@ async def _evaluate_via_responses_impl(
|
||||
*,
|
||||
client: AsyncOpenAI,
|
||||
response_ids: Sequence[str],
|
||||
evaluators: list[str],
|
||||
evaluators: list[str | GeneratedEvaluatorRef],
|
||||
model: str,
|
||||
eval_name: str,
|
||||
poll_interval: float,
|
||||
@@ -573,8 +780,11 @@ class FoundryEvals:
|
||||
(from ``azure.ai.projects.aio``). Provide this or *client*.
|
||||
model: Model deployment name for the evaluator LLM judge.
|
||||
Resolved from ``client.model`` when omitted.
|
||||
evaluators: Evaluator names (e.g. ``["relevance", "tool_call_accuracy"]``).
|
||||
When ``None`` (default), uses smart defaults based on item data.
|
||||
evaluators: Evaluator specifications. Entries may be built-in
|
||||
short names (e.g. ``"relevance"``), fully-qualified
|
||||
``"builtin.*"`` names, or :class:`GeneratedEvaluatorRef`
|
||||
instances for previously generated rubric evaluators. When
|
||||
``None`` (default), uses smart defaults based on item data.
|
||||
conversation_split: How to split multi-turn conversations into
|
||||
query/response halves. Defaults to ``LAST_TURN``. Pass a
|
||||
``ConversationSplit`` enum value or a custom callable — see
|
||||
@@ -623,7 +833,7 @@ class FoundryEvals:
|
||||
client: FoundryChatClient | None = None,
|
||||
project_client: AIProjectClient | None = None,
|
||||
model: str | None = None,
|
||||
evaluators: Sequence[str] | None = None,
|
||||
evaluators: Sequence[str | GeneratedEvaluatorRef] | None = None,
|
||||
conversation_split: ConversationSplitter = ConversationSplit.LAST_TURN,
|
||||
poll_interval: float = 5.0,
|
||||
timeout: float = 180.0,
|
||||
@@ -642,7 +852,9 @@ class FoundryEvals:
|
||||
"Model is required. Pass model= explicitly or use a FoundryChatClient that has a model configured."
|
||||
)
|
||||
self._model = resolved_model
|
||||
self._evaluators = list(evaluators) if evaluators is not None else None
|
||||
self._evaluators: list[str | GeneratedEvaluatorRef] | None = (
|
||||
list(evaluators) if evaluators is not None else None
|
||||
)
|
||||
self._conversation_split = conversation_split
|
||||
self._poll_interval = poll_interval
|
||||
self._timeout = timeout
|
||||
@@ -678,7 +890,7 @@ class FoundryEvals:
|
||||
async def _evaluate_via_dataset(
|
||||
self,
|
||||
items: Sequence[EvalItem],
|
||||
evaluators: list[str],
|
||||
evaluators: list[str | GeneratedEvaluatorRef],
|
||||
eval_name: str,
|
||||
) -> EvalResults:
|
||||
"""Evaluate using JSONL dataset upload path."""
|
||||
|
||||
@@ -25,16 +25,25 @@ from agent_framework._evaluation import (
|
||||
from agent_framework._workflows._workflow import WorkflowRunResult
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from agent_framework_foundry import GeneratedEvaluatorRef
|
||||
from agent_framework_foundry._foundry_evals import (
|
||||
_AGENT_EVALUATORS,
|
||||
_BUILTIN_EVALUATORS,
|
||||
_TOOL_EVALUATORS,
|
||||
FoundryEvals,
|
||||
_build_item_schema,
|
||||
_build_testing_criteria,
|
||||
_extract_per_evaluator,
|
||||
_extract_result_counts,
|
||||
_extract_rubric_scores,
|
||||
_fetch_output_items,
|
||||
_filter_tool_evaluators,
|
||||
_poll_eval_run,
|
||||
_resolve_default_evaluators,
|
||||
_resolve_evaluator,
|
||||
_resolve_openai_client,
|
||||
evaluate_foundry_target,
|
||||
evaluate_traces,
|
||||
)
|
||||
|
||||
|
||||
@@ -806,6 +815,67 @@ class TestBuildTestingCriteria:
|
||||
for c in criteria:
|
||||
assert "tool_definitions" in c["data_mapping"], f"{c['name']} missing tool_definitions"
|
||||
|
||||
def test_generated_evaluator_ref_pinned_version(self) -> None:
|
||||
|
||||
ref = GeneratedEvaluatorRef(name="my-rubric", version="1")
|
||||
criteria = _build_testing_criteria([ref], "gpt-4o", include_data_mapping=True)
|
||||
|
||||
assert len(criteria) == 1
|
||||
c = criteria[0]
|
||||
assert c["type"] == "azure_ai_evaluator"
|
||||
assert c["evaluator_name"] == "my-rubric"
|
||||
assert c["evaluator_version"] == "1"
|
||||
assert c["name"] == "my-rubric"
|
||||
assert c["initialization_parameters"] == {"deployment_name": "gpt-4o"}
|
||||
assert c["data_mapping"] == {
|
||||
"query": "{{item.query_messages}}",
|
||||
"response": "{{item.response_messages}}",
|
||||
}
|
||||
|
||||
def test_generated_evaluator_ref_display_name_used_as_short(self) -> None:
|
||||
|
||||
ref = GeneratedEvaluatorRef(name="my-rubric", version="2", display_name="My Rubric")
|
||||
criteria = _build_testing_criteria([ref], "gpt-4o")
|
||||
|
||||
assert criteria[0]["name"] == "My Rubric"
|
||||
assert criteria[0]["evaluator_name"] == "my-rubric"
|
||||
|
||||
def test_generated_evaluator_ref_tool_definitions_added(self) -> None:
|
||||
|
||||
ref = GeneratedEvaluatorRef(name="my-rubric", version="1")
|
||||
criteria = _build_testing_criteria(
|
||||
[ref],
|
||||
"gpt-4o",
|
||||
include_data_mapping=True,
|
||||
include_tool_definitions=True,
|
||||
)
|
||||
|
||||
assert criteria[0]["data_mapping"]["tool_definitions"] == "{{item.tool_definitions}}"
|
||||
|
||||
def test_generated_evaluator_ref_unpinned_warns(self, caplog: pytest.LogCaptureFixture) -> None:
|
||||
import logging
|
||||
|
||||
ref = GeneratedEvaluatorRef.latest("my-rubric")
|
||||
with caplog.at_level(logging.WARNING, logger="agent_framework_foundry._foundry_evals"):
|
||||
criteria = _build_testing_criteria([ref], "gpt-4o")
|
||||
|
||||
assert "evaluator_version" not in criteria[0]
|
||||
assert any("no pinned version" in r.message for r in caplog.records)
|
||||
|
||||
def test_generated_evaluator_ref_mixed_with_builtins(self) -> None:
|
||||
|
||||
ref = GeneratedEvaluatorRef(name="my-rubric", version="1")
|
||||
criteria = _build_testing_criteria(
|
||||
["relevance", ref, "task_adherence"],
|
||||
"gpt-4o",
|
||||
include_data_mapping=True,
|
||||
)
|
||||
|
||||
assert [c["name"] for c in criteria] == ["relevance", "my-rubric", "task_adherence"]
|
||||
assert criteria[0]["evaluator_name"] == "builtin.relevance"
|
||||
assert criteria[1]["evaluator_name"] == "my-rubric"
|
||||
assert criteria[2]["evaluator_name"] == "builtin.task_adherence"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_item_schema
|
||||
@@ -1263,6 +1333,29 @@ class TestFilterToolEvaluators:
|
||||
items,
|
||||
)
|
||||
|
||||
def test_preserves_generated_ref_when_no_tools(self) -> None:
|
||||
|
||||
ref = GeneratedEvaluatorRef(name="rubric", version="1")
|
||||
items = [
|
||||
EvalItem(conversation=[Message("user", ["q"]), Message("assistant", ["r"])]),
|
||||
]
|
||||
result = _filter_tool_evaluators(
|
||||
["relevance", ref, "tool_call_accuracy"],
|
||||
items,
|
||||
)
|
||||
assert "relevance" in result
|
||||
assert ref in result
|
||||
assert "tool_call_accuracy" not in result
|
||||
|
||||
def test_generated_ref_alone_does_not_raise(self) -> None:
|
||||
|
||||
ref = GeneratedEvaluatorRef(name="rubric", version="1")
|
||||
items = [
|
||||
EvalItem(conversation=[Message("user", ["q"]), Message("assistant", ["r"])]),
|
||||
]
|
||||
result = _filter_tool_evaluators([ref], items)
|
||||
assert result == [ref]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EvalResults
|
||||
@@ -2267,7 +2360,6 @@ class TestEvalResultsWithItems:
|
||||
|
||||
class TestFetchOutputItems:
|
||||
async def test_fetches_and_converts_output_items(self) -> None:
|
||||
from agent_framework_foundry._foundry_evals import _fetch_output_items
|
||||
|
||||
# Build mock output items matching the OpenAI SDK schema
|
||||
mock_result = MagicMock()
|
||||
@@ -2329,7 +2421,6 @@ class TestFetchOutputItems:
|
||||
assert item.error_code is None
|
||||
|
||||
async def test_handles_errored_item(self) -> None:
|
||||
from agent_framework_foundry._foundry_evals import _fetch_output_items
|
||||
|
||||
mock_error = MagicMock()
|
||||
mock_error.code = "QueryExtractionError"
|
||||
@@ -2361,7 +2452,6 @@ class TestFetchOutputItems:
|
||||
assert len(item.scores) == 0
|
||||
|
||||
async def test_handles_api_failure_gracefully(self) -> None:
|
||||
from agent_framework_foundry._foundry_evals import _fetch_output_items
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.evals.runs.output_items.list = AsyncMock(side_effect=TypeError("API error"))
|
||||
@@ -2369,6 +2459,166 @@ class TestFetchOutputItems:
|
||||
items = await _fetch_output_items(mock_client, "eval_1", "run_1")
|
||||
assert items == []
|
||||
|
||||
async def test_extracts_rubric_scores_from_dict_sample(self) -> None:
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.name = "my-rubric"
|
||||
mock_result.score = 0.85
|
||||
mock_result.passed = True
|
||||
mock_result.sample = {
|
||||
"properties": {
|
||||
"rubric_scores": [
|
||||
{"id": "policy", "score": 4, "applicable": True, "weight": 1, "reason": "ok"},
|
||||
{"id": "safety", "score": None, "applicable": False, "weight": 1, "reason": "n/a"},
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
mock_oi = MagicMock()
|
||||
mock_oi.id = "oi_1"
|
||||
mock_oi.status = "pass"
|
||||
mock_oi.results = [mock_result]
|
||||
mock_oi.sample = None
|
||||
mock_oi.datasource_item = {}
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.evals.runs.output_items.list = AsyncMock(return_value=_AsyncPage([mock_oi]))
|
||||
|
||||
items = await _fetch_output_items(mock_client, "eval_1", "run_1")
|
||||
|
||||
assert len(items) == 1
|
||||
scores = items[0].scores
|
||||
assert len(scores) == 1
|
||||
assert scores[0].dimensions is not None
|
||||
assert len(scores[0].dimensions) == 2
|
||||
policy = next(d for d in scores[0].dimensions if d.id == "policy")
|
||||
assert policy.score == 4
|
||||
assert policy.applicable is True
|
||||
assert policy.weight == 1
|
||||
assert policy.reason == "ok"
|
||||
safety = next(d for d in scores[0].dimensions if d.id == "safety")
|
||||
assert safety.score is None
|
||||
assert safety.applicable is False
|
||||
|
||||
async def test_no_rubric_scores_when_absent(self) -> None:
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.name = "relevance"
|
||||
mock_result.score = 0.85
|
||||
mock_result.passed = True
|
||||
mock_result.sample = None
|
||||
|
||||
mock_oi = MagicMock()
|
||||
mock_oi.id = "oi_2"
|
||||
mock_oi.status = "pass"
|
||||
mock_oi.results = [mock_result]
|
||||
mock_oi.sample = None
|
||||
mock_oi.datasource_item = {}
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.evals.runs.output_items.list = AsyncMock(return_value=_AsyncPage([mock_oi]))
|
||||
|
||||
items = await _fetch_output_items(mock_client, "eval_1", "run_1")
|
||||
|
||||
assert items[0].scores[0].dimensions is None
|
||||
|
||||
|
||||
class TestExtractRubricScores:
|
||||
def test_handles_attribute_style_properties(self) -> None:
|
||||
|
||||
rs = MagicMock()
|
||||
rs.id = "policy"
|
||||
rs.score = 5
|
||||
rs.applicable = True
|
||||
rs.weight = 2
|
||||
rs.reason = "ok"
|
||||
|
||||
sample = MagicMock()
|
||||
sample.properties = MagicMock()
|
||||
sample.properties.rubric_scores = [rs]
|
||||
|
||||
result = _extract_rubric_scores(sample)
|
||||
assert result is not None
|
||||
assert result[0].id == "policy"
|
||||
assert result[0].score == 5
|
||||
assert result[0].weight == 2
|
||||
|
||||
def test_top_level_rubric_scores_in_dict(self) -> None:
|
||||
|
||||
sample = {"rubric_scores": [{"id": "a", "score": 3, "applicable": True, "weight": 1, "reason": "r"}]}
|
||||
result = _extract_rubric_scores(sample)
|
||||
assert result is not None
|
||||
assert result[0].id == "a"
|
||||
|
||||
def test_returns_none_when_missing(self) -> None:
|
||||
|
||||
assert _extract_rubric_scores(None) is None
|
||||
assert _extract_rubric_scores({}) is None
|
||||
assert _extract_rubric_scores({"properties": {}}) is None
|
||||
|
||||
def test_skips_malformed_entries(self) -> None:
|
||||
|
||||
sample = {
|
||||
"properties": {
|
||||
"rubric_scores": [
|
||||
{"id": "good", "score": 3, "applicable": True, "weight": 1, "reason": "ok"},
|
||||
{"id": "bad-no-weight", "score": 2, "applicable": True, "reason": "x"},
|
||||
]
|
||||
}
|
||||
}
|
||||
result = _extract_rubric_scores(sample)
|
||||
assert result is not None
|
||||
assert len(result) == 1
|
||||
assert result[0].id == "good"
|
||||
|
||||
def test_canonical_dimension_scores_key_from_docs(self) -> None:
|
||||
"""Per the Microsoft Learn docs, runtime output uses ``properties.dimension_scores``."""
|
||||
|
||||
sample = {
|
||||
"properties": {
|
||||
"dimension_scores": [
|
||||
{
|
||||
"id": "intent_recognition",
|
||||
"score": 5,
|
||||
"applicable": True,
|
||||
"weight": 9,
|
||||
"reason": "Identified correctly.",
|
||||
},
|
||||
{
|
||||
"id": "general_quality",
|
||||
"score": 4,
|
||||
"applicable": True,
|
||||
"weight": 5,
|
||||
"reason": "Strong overall.",
|
||||
},
|
||||
]
|
||||
}
|
||||
}
|
||||
result = _extract_rubric_scores(sample)
|
||||
assert result is not None
|
||||
assert [r.id for r in result] == ["intent_recognition", "general_quality"]
|
||||
assert [r.score for r in result] == [5, 4]
|
||||
assert [r.weight for r in result] == [9, 5]
|
||||
|
||||
def test_dimension_scores_via_attribute(self) -> None:
|
||||
"""Canonical key also resolves when properties exposes ``dimension_scores`` as an attr."""
|
||||
|
||||
rs = MagicMock()
|
||||
rs.id = "policy_enforcement"
|
||||
rs.score = 1
|
||||
rs.applicable = True
|
||||
rs.weight = 5
|
||||
rs.reason = "violated"
|
||||
|
||||
sample = MagicMock()
|
||||
sample.properties = MagicMock(spec=["dimension_scores"])
|
||||
sample.properties.dimension_scores = [rs]
|
||||
|
||||
result = _extract_rubric_scores(sample)
|
||||
assert result is not None
|
||||
assert result[0].id == "policy_enforcement"
|
||||
assert result[0].score == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _poll_eval_run — timeout / failed / canceled paths
|
||||
@@ -2378,7 +2628,6 @@ class TestFetchOutputItems:
|
||||
class TestPollEvalRun:
|
||||
async def test_timeout_returns_timeout_status(self) -> None:
|
||||
"""Poll timeout returns EvalResults with status='timeout'."""
|
||||
from agent_framework_foundry._foundry_evals import _poll_eval_run
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_pending = MagicMock()
|
||||
@@ -2392,7 +2641,6 @@ class TestPollEvalRun:
|
||||
|
||||
async def test_failed_run_returns_error(self) -> None:
|
||||
"""Failed run returns EvalResults with error message."""
|
||||
from agent_framework_foundry._foundry_evals import _poll_eval_run
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_failed = MagicMock()
|
||||
@@ -2410,7 +2658,6 @@ class TestPollEvalRun:
|
||||
|
||||
async def test_canceled_run_returns_canceled_status(self) -> None:
|
||||
"""Canceled run returns EvalResults with status='canceled'."""
|
||||
from agent_framework_foundry._foundry_evals import _poll_eval_run
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_canceled = MagicMock()
|
||||
@@ -2435,7 +2682,6 @@ class TestPollEvalRun:
|
||||
class TestEvaluateTraces:
|
||||
async def test_raises_without_required_args(self) -> None:
|
||||
"""Raises ValueError when no response_ids, trace_ids, or agent_id given."""
|
||||
from agent_framework_foundry._foundry_evals import evaluate_traces
|
||||
|
||||
mock_client = MagicMock()
|
||||
with pytest.raises(ValueError, match="Provide at least one of"):
|
||||
@@ -2446,7 +2692,6 @@ class TestEvaluateTraces:
|
||||
|
||||
async def test_response_ids_path(self) -> None:
|
||||
"""evaluate_traces with response_ids uses the responses API path."""
|
||||
from agent_framework_foundry._foundry_evals import evaluate_traces
|
||||
|
||||
mock_client = MagicMock()
|
||||
|
||||
@@ -2494,7 +2739,6 @@ class TestEvaluateTraces:
|
||||
|
||||
async def test_trace_ids_path(self) -> None:
|
||||
"""evaluate_traces with trace_ids builds azure_ai_traces data source."""
|
||||
from agent_framework_foundry._foundry_evals import evaluate_traces
|
||||
|
||||
mock_client = MagicMock()
|
||||
|
||||
@@ -2534,7 +2778,6 @@ class TestEvaluateTraces:
|
||||
class TestEvaluateFoundryTarget:
|
||||
async def test_happy_path(self) -> None:
|
||||
"""evaluate_foundry_target creates eval + run and polls to completion."""
|
||||
from agent_framework_foundry._foundry_evals import evaluate_foundry_target
|
||||
|
||||
mock_client = MagicMock()
|
||||
|
||||
@@ -2670,13 +2913,11 @@ class TestEvaluatorSetConsistency:
|
||||
"""Verify that _AGENT_EVALUATORS and _TOOL_EVALUATORS are subsets of _BUILTIN_EVALUATORS."""
|
||||
|
||||
def test_agent_evaluators_subset(self):
|
||||
from agent_framework_foundry._foundry_evals import _AGENT_EVALUATORS, _BUILTIN_EVALUATORS
|
||||
|
||||
diff = _AGENT_EVALUATORS - set(_BUILTIN_EVALUATORS.values())
|
||||
assert not diff, f"_AGENT_EVALUATORS has names not in _BUILTIN_EVALUATORS: {diff}"
|
||||
|
||||
def test_tool_evaluators_subset(self):
|
||||
from agent_framework_foundry._foundry_evals import _BUILTIN_EVALUATORS, _TOOL_EVALUATORS
|
||||
|
||||
diff = _TOOL_EVALUATORS - set(_BUILTIN_EVALUATORS.values())
|
||||
assert not diff, f"_TOOL_EVALUATORS has names not in _BUILTIN_EVALUATORS: {diff}"
|
||||
@@ -2690,7 +2931,6 @@ class TestEvaluatorSetConsistency:
|
||||
class TestEvaluateTracesAgentId:
|
||||
async def test_agent_id_only_path(self) -> None:
|
||||
"""evaluate_traces with agent_id only builds azure_ai_traces data source."""
|
||||
from agent_framework_foundry._foundry_evals import evaluate_traces
|
||||
|
||||
mock_client = MagicMock()
|
||||
|
||||
@@ -2748,7 +2988,6 @@ class TestFilterToolEvaluatorsRaises:
|
||||
class TestEvaluateFoundryTargetValidation:
|
||||
async def test_target_without_type_raises(self) -> None:
|
||||
"""target dict without 'type' key raises ValueError."""
|
||||
from agent_framework_foundry._foundry_evals import evaluate_foundry_target
|
||||
|
||||
mock_client = MagicMock()
|
||||
with pytest.raises(ValueError, match="'type' key"):
|
||||
|
||||
@@ -57,19 +57,19 @@ math = [
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"uv==0.11.6",
|
||||
"ruff==0.15.8",
|
||||
"uv==0.11.17",
|
||||
"ruff==0.15.15",
|
||||
"pytest==9.0.3",
|
||||
"mypy==1.20.0",
|
||||
"pyright==1.1.408",
|
||||
#tasks
|
||||
"poethepoet==0.42.1",
|
||||
"poethepoet==0.46.0",
|
||||
"rich>=13.7.1,<15.0.0",
|
||||
"tomli==2.4.1",
|
||||
"tomli-w==1.2.0",
|
||||
# tau2 from source (not available on PyPI)
|
||||
"tau2@ git+https://github.com/sierra-research/tau2-bench@5ba9e3e56db57c5e4114bf7f901291f09b2c5619",
|
||||
"prek==0.3.9",
|
||||
"prek==0.4.3",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
# Mistral Package (agent-framework-mistral)
|
||||
|
||||
Integration with Mistral AI for embedding generation.
|
||||
|
||||
## Main Classes
|
||||
|
||||
- **`MistralEmbeddingClient`** - Embedding client for Mistral AI models
|
||||
- **`MistralEmbeddingOptions`** - Options TypedDict for Mistral-specific embedding parameters
|
||||
- **`MistralEmbeddingSettings`** - TypedDict settings for Mistral configuration
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
from agent_framework_mistral import MistralEmbeddingClient
|
||||
|
||||
# Requires MISTRAL_API_KEY environment variable (or pass api_key= directly)
|
||||
client = MistralEmbeddingClient(model="mistral-embed")
|
||||
result = await client.get_embeddings(["Hello, world!"])
|
||||
print(result[0].vector)
|
||||
```
|
||||
|
||||
## Import Path
|
||||
|
||||
```python
|
||||
from agent_framework_mistral import MistralEmbeddingClient
|
||||
```
|
||||
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) Microsoft Corporation.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE
|
||||
@@ -0,0 +1,42 @@
|
||||
# Get Started with Microsoft Agent Framework Mistral AI
|
||||
|
||||
Please install this package:
|
||||
|
||||
```bash
|
||||
pip install agent-framework-mistral --pre
|
||||
```
|
||||
|
||||
and see the [README](https://github.com/microsoft/agent-framework/tree/main/python/README.md) for more information.
|
||||
|
||||
## Embedding Client
|
||||
|
||||
The `MistralEmbeddingClient` provides embedding generation using Mistral AI models.
|
||||
|
||||
### Quick Start
|
||||
|
||||
```python
|
||||
from agent_framework_mistral import MistralEmbeddingClient
|
||||
|
||||
# Using environment variables (MISTRAL_API_KEY, MISTRAL_EMBEDDING_MODEL)
|
||||
client = MistralEmbeddingClient()
|
||||
|
||||
# Or passing parameters directly
|
||||
client = MistralEmbeddingClient(
|
||||
model="mistral-embed",
|
||||
api_key="your-api-key",
|
||||
)
|
||||
|
||||
# Generate embeddings
|
||||
result = await client.get_embeddings(["Hello, world!", "How are you?"])
|
||||
for embedding in result:
|
||||
print(f"Dimensions: {embedding.dimensions}")
|
||||
print(f"Vector: {embedding.vector[:5]}...")
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
| Environment Variable | Description |
|
||||
|---|---|
|
||||
| `MISTRAL_API_KEY` | Your Mistral AI API key |
|
||||
| `MISTRAL_EMBEDDING_MODEL` | Embedding model name (e.g., `mistral-embed`) |
|
||||
| `MISTRAL_SERVER_URL` | Optional server URL override |
|
||||
@@ -0,0 +1,17 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import importlib.metadata
|
||||
|
||||
from ._embedding_client import MistralEmbeddingClient, MistralEmbeddingOptions, MistralEmbeddingSettings
|
||||
|
||||
try:
|
||||
__version__ = importlib.metadata.version(__name__)
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
__version__ = "0.0.0" # Fallback for development mode
|
||||
|
||||
__all__ = [
|
||||
"MistralEmbeddingClient",
|
||||
"MistralEmbeddingOptions",
|
||||
"MistralEmbeddingSettings",
|
||||
"__version__",
|
||||
]
|
||||
@@ -0,0 +1,250 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, ClassVar, Generic, TypedDict
|
||||
|
||||
from agent_framework import (
|
||||
BaseEmbeddingClient,
|
||||
Embedding,
|
||||
EmbeddingGenerationOptions,
|
||||
GeneratedEmbeddings,
|
||||
UsageDetails,
|
||||
load_settings,
|
||||
)
|
||||
from agent_framework._settings import SecretString
|
||||
from agent_framework.observability import EmbeddingTelemetryLayer
|
||||
from mistralai.client import Mistral
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
from typing import TypeVar # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypeVar # type: ignore # pragma: no cover
|
||||
|
||||
|
||||
logger = logging.getLogger("agent_framework.mistral")
|
||||
|
||||
|
||||
class MistralEmbeddingOptions(EmbeddingGenerationOptions, total=False):
|
||||
"""Mistral AI-specific embedding options.
|
||||
|
||||
Extends EmbeddingGenerationOptions with Mistral-specific fields.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
from agent_framework_mistral import MistralEmbeddingOptions
|
||||
|
||||
options: MistralEmbeddingOptions = {
|
||||
"model": "mistral-embed",
|
||||
"dimensions": 1024,
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
MistralEmbeddingOptionsT = TypeVar(
|
||||
"MistralEmbeddingOptionsT",
|
||||
bound=TypedDict, # type: ignore[valid-type]
|
||||
default="MistralEmbeddingOptions",
|
||||
covariant=True,
|
||||
)
|
||||
|
||||
|
||||
class MistralEmbeddingSettings(TypedDict, total=False):
|
||||
"""Mistral AI embedding settings.
|
||||
|
||||
Fields:
|
||||
api_key: Mistral API key. Resolved from ``MISTRAL_API_KEY``.
|
||||
embedding_model: Embedding model name. Resolved from ``MISTRAL_EMBEDDING_MODEL``.
|
||||
server_url: Optional server URL override. Resolved from ``MISTRAL_SERVER_URL``.
|
||||
"""
|
||||
|
||||
api_key: str | None
|
||||
embedding_model: str | None
|
||||
server_url: str | None
|
||||
|
||||
|
||||
class RawMistralEmbeddingClient(
|
||||
BaseEmbeddingClient[str, list[float], MistralEmbeddingOptionsT],
|
||||
Generic[MistralEmbeddingOptionsT],
|
||||
):
|
||||
"""Raw Mistral AI embedding client without telemetry.
|
||||
|
||||
Keyword Args:
|
||||
model: The Mistral embedding model (e.g. "mistral-embed").
|
||||
Can also be set via environment variable ``MISTRAL_EMBEDDING_MODEL``.
|
||||
api_key: Mistral API key. Defaults to ``MISTRAL_API_KEY`` environment variable.
|
||||
server_url: Optional server URL override. Defaults to ``MISTRAL_SERVER_URL``
|
||||
environment variable, or the Mistral default.
|
||||
client: Optional pre-configured ``Mistral`` client instance.
|
||||
additional_properties: Additional properties stored on the client instance.
|
||||
env_file_path: Path to ``.env`` file for settings.
|
||||
env_file_encoding: Encoding for ``.env`` file.
|
||||
"""
|
||||
|
||||
INJECTABLE: ClassVar[set[str]] = {"client"}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model: str | None = None,
|
||||
api_key: str | SecretString | None = None,
|
||||
server_url: str | None = None,
|
||||
client: Mistral | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize a raw Mistral AI embedding client."""
|
||||
mistral_settings = load_settings(
|
||||
MistralEmbeddingSettings,
|
||||
env_prefix="MISTRAL_",
|
||||
required_fields=["embedding_model", "api_key"],
|
||||
api_key=str(api_key) if isinstance(api_key, SecretString) else api_key,
|
||||
embedding_model=model,
|
||||
server_url=server_url,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
|
||||
self.model: str = mistral_settings["embedding_model"] # type: ignore[assignment]
|
||||
resolved_api_key: str = mistral_settings["api_key"] # type: ignore[assignment]
|
||||
resolved_server_url = mistral_settings.get("server_url")
|
||||
|
||||
if client is not None:
|
||||
self.client = client
|
||||
else:
|
||||
client_kwargs: dict[str, Any] = {"api_key": resolved_api_key}
|
||||
if resolved_server_url:
|
||||
client_kwargs["server_url"] = resolved_server_url
|
||||
self.client = Mistral(**client_kwargs)
|
||||
|
||||
self.server_url = resolved_server_url
|
||||
super().__init__(additional_properties=additional_properties)
|
||||
|
||||
def service_url(self) -> str:
|
||||
"""Get the URL of the service."""
|
||||
return self.server_url or "https://api.mistral.ai"
|
||||
|
||||
async def get_embeddings(
|
||||
self,
|
||||
values: Sequence[str],
|
||||
*,
|
||||
options: MistralEmbeddingOptionsT | None = None,
|
||||
) -> GeneratedEmbeddings[list[float], MistralEmbeddingOptionsT]:
|
||||
"""Call the Mistral AI embeddings API.
|
||||
|
||||
Args:
|
||||
values: The text values to generate embeddings for.
|
||||
options: Optional embedding generation options.
|
||||
|
||||
Returns:
|
||||
Generated embeddings with usage metadata.
|
||||
|
||||
Raises:
|
||||
ValueError: If model is not provided or values is empty.
|
||||
"""
|
||||
if not values:
|
||||
return GeneratedEmbeddings([], options=options)
|
||||
|
||||
opts: dict[str, Any] = options or {} # type: ignore
|
||||
model = opts.get("model") or self.model
|
||||
if not model:
|
||||
raise ValueError("model is required")
|
||||
|
||||
kwargs: dict[str, Any] = {"model": model, "inputs": list(values)}
|
||||
if "dimensions" in opts:
|
||||
kwargs["output_dimension"] = opts["dimensions"]
|
||||
|
||||
response = await self.client.embeddings.create_async(**kwargs)
|
||||
|
||||
embeddings: list[Embedding[list[float]]] = []
|
||||
if response and response.data:
|
||||
items = sorted(response.data, key=lambda d: d.index if d.index is not None else 0)
|
||||
for item in items:
|
||||
vector = list(item.embedding) if item.embedding else []
|
||||
embeddings.append(
|
||||
Embedding(
|
||||
vector=vector,
|
||||
dimensions=len(vector),
|
||||
model=response.model or model,
|
||||
)
|
||||
)
|
||||
|
||||
usage_dict: UsageDetails | None = None
|
||||
if response and response.usage:
|
||||
usage_dict = {
|
||||
"input_token_count": response.usage.prompt_tokens,
|
||||
"total_token_count": response.usage.total_tokens,
|
||||
}
|
||||
|
||||
return GeneratedEmbeddings(embeddings, options=options, usage=usage_dict)
|
||||
|
||||
|
||||
class MistralEmbeddingClient(
|
||||
EmbeddingTelemetryLayer[str, list[float], MistralEmbeddingOptionsT],
|
||||
RawMistralEmbeddingClient[MistralEmbeddingOptionsT],
|
||||
Generic[MistralEmbeddingOptionsT],
|
||||
):
|
||||
"""Mistral AI embedding client with telemetry support.
|
||||
|
||||
Keyword Args:
|
||||
model: The Mistral embedding model (e.g. "mistral-embed").
|
||||
Can also be set via environment variable ``MISTRAL_EMBEDDING_MODEL``.
|
||||
api_key: Mistral API key. Defaults to ``MISTRAL_API_KEY`` environment variable.
|
||||
server_url: Optional server URL override. Defaults to ``MISTRAL_SERVER_URL``
|
||||
environment variable, or the Mistral default.
|
||||
client: Optional pre-configured ``Mistral`` client instance.
|
||||
otel_provider_name: Optional telemetry provider name override.
|
||||
env_file_path: Path to ``.env`` file for settings.
|
||||
env_file_encoding: Encoding for ``.env`` file.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
from agent_framework_mistral import MistralEmbeddingClient
|
||||
|
||||
# Using environment variables
|
||||
# Set MISTRAL_API_KEY=your-key
|
||||
# Set MISTRAL_EMBEDDING_MODEL=mistral-embed
|
||||
client = MistralEmbeddingClient()
|
||||
|
||||
# Or passing parameters directly
|
||||
client = MistralEmbeddingClient(
|
||||
model="mistral-embed",
|
||||
api_key="your-api-key",
|
||||
)
|
||||
|
||||
# Generate embeddings
|
||||
result = await client.get_embeddings(["Hello, world!"])
|
||||
print(result[0].vector)
|
||||
"""
|
||||
|
||||
OTEL_PROVIDER_NAME: ClassVar[str] = "mistralai"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model: str | None = None,
|
||||
api_key: str | SecretString | None = None,
|
||||
server_url: str | None = None,
|
||||
client: Mistral | None = None,
|
||||
otel_provider_name: str | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize a Mistral AI embedding client."""
|
||||
super().__init__(
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
server_url=server_url,
|
||||
client=client,
|
||||
additional_properties=additional_properties,
|
||||
otel_provider_name=otel_provider_name,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,105 @@
|
||||
[project]
|
||||
name = "agent-framework-mistral"
|
||||
description = "Mistral AI integration for Microsoft Agent Framework."
|
||||
authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}]
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
version = "1.0.0a260505"
|
||||
license-files = ["LICENSE"]
|
||||
urls.homepage = "https://learn.microsoft.com/en-us/agent-framework/"
|
||||
urls.source = "https://github.com/microsoft/agent-framework/tree/main/python"
|
||||
urls.release_notes = "https://github.com/microsoft/agent-framework/releases?q=tag%3Apython-1&expanded=true"
|
||||
urls.issues = "https://github.com/microsoft/agent-framework/issues"
|
||||
classifiers = [
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Framework :: Pydantic :: 2",
|
||||
"Typing :: Typed",
|
||||
]
|
||||
dependencies = [
|
||||
"agent-framework-core>=1.1.0,<2",
|
||||
"mistralai>=2.0.0,<3",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
prerelease = "if-necessary-or-explicit"
|
||||
environments = [
|
||||
"sys_platform == 'darwin'",
|
||||
"sys_platform == 'linux'",
|
||||
"sys_platform == 'win32'"
|
||||
]
|
||||
|
||||
[tool.uv-dynamic-versioning]
|
||||
fallback-version = "0.0.0"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = 'tests'
|
||||
addopts = "-ra -q -r fEX"
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
filterwarnings = []
|
||||
markers = [
|
||||
"integration: marks tests as integration tests that require external services",
|
||||
]
|
||||
timeout = 120
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 120
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "N", "W"]
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = [
|
||||
"**/__init__.py"
|
||||
]
|
||||
|
||||
[tool.pyright]
|
||||
extends = "../../pyproject.toml"
|
||||
include = ["agent_framework_mistral"]
|
||||
exclude = ['tests']
|
||||
|
||||
[tool.mypy]
|
||||
plugins = ['pydantic.mypy']
|
||||
strict = true
|
||||
python_version = "3.10"
|
||||
ignore_missing_imports = true
|
||||
disallow_untyped_defs = true
|
||||
no_implicit_optional = true
|
||||
check_untyped_defs = true
|
||||
warn_return_any = true
|
||||
show_error_codes = true
|
||||
warn_unused_ignores = false
|
||||
disallow_incomplete_defs = true
|
||||
disallow_untyped_decorators = true
|
||||
disallow_any_unimported = true
|
||||
|
||||
[tool.bandit]
|
||||
targets = ["agent_framework_mistral"]
|
||||
exclude_dirs = ["tests"]
|
||||
|
||||
[tool.poe]
|
||||
executor.type = "uv"
|
||||
include = "../../shared_tasks.toml"
|
||||
|
||||
[tool.poe.tasks.mypy]
|
||||
help = "Run MyPy for this package."
|
||||
cmd = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_mistral"
|
||||
|
||||
[tool.poe.tasks.test]
|
||||
help = "Run the default unit test suite for this package."
|
||||
cmd = 'pytest -m "not integration" --cov=agent_framework_mistral --cov-report=term-missing:skip-covered tests'
|
||||
|
||||
[tool.uv.build-backend]
|
||||
module-name = "agent_framework_mistral"
|
||||
module-root = ""
|
||||
|
||||
[build-system]
|
||||
requires = ["uv_build>=0.8.2,<0.9.0"]
|
||||
build-backend = "uv_build"
|
||||
@@ -0,0 +1,15 @@
|
||||
# Mistral AI Embedding Examples
|
||||
|
||||
This folder contains examples demonstrating how to use Mistral AI embedding models with the Agent Framework.
|
||||
|
||||
## Examples
|
||||
|
||||
| File | Description |
|
||||
|------|-------------|
|
||||
| [`mistral_embeddings.py`](mistral_embeddings.py) | Basic embedding generation with the Mistral AI embedding client. |
|
||||
|
||||
## Environment Variables
|
||||
|
||||
- `MISTRAL_API_KEY`: Your Mistral AI API key
|
||||
- `MISTRAL_EMBEDDING_MODEL`: Embedding model name (e.g., `mistral-embed`)
|
||||
- `MISTRAL_SERVER_URL` (optional): Server URL override for custom deployments
|
||||
@@ -0,0 +1 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
@@ -0,0 +1,77 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Shows how to generate embeddings using the Mistral AI embedding client.
|
||||
|
||||
Requires ``MISTRAL_API_KEY`` and ``MISTRAL_EMBEDDING_MODEL`` environment variables.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from agent_framework_mistral import MistralEmbeddingClient
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
async def basic_embedding_example() -> None:
|
||||
"""Generate embeddings for a list of texts."""
|
||||
print("=== Basic Embedding Generation ===")
|
||||
|
||||
# 1. Create the embedding client (uses MISTRAL_API_KEY and MISTRAL_EMBEDDING_MODEL env vars).
|
||||
client = MistralEmbeddingClient()
|
||||
|
||||
# 2. Generate embeddings for multiple texts.
|
||||
texts = ["Hello, world!", "How are you?", "Agent Framework with Mistral AI"]
|
||||
result = await client.get_embeddings(texts)
|
||||
|
||||
# 3. Print results.
|
||||
print(f"Generated {len(result)} embeddings")
|
||||
for i, embedding in enumerate(result):
|
||||
print(f" Text {i + 1}: dimensions={embedding.dimensions}, vector={embedding.vector[:5]}...")
|
||||
|
||||
if result.usage:
|
||||
print(
|
||||
f" Usage: {result.usage['input_token_count']} input tokens, "
|
||||
f"{result.usage['total_token_count']} total tokens"
|
||||
)
|
||||
|
||||
|
||||
async def embedding_with_options_example() -> None:
|
||||
"""Generate embeddings with custom dimensions."""
|
||||
print("\n=== Embedding with Custom Dimensions ===")
|
||||
|
||||
from agent_framework_mistral import MistralEmbeddingOptions
|
||||
|
||||
client = MistralEmbeddingClient()
|
||||
|
||||
# Request a specific output dimension (model must support it).
|
||||
options: MistralEmbeddingOptions = {"dimensions": 256}
|
||||
result = await client.get_embeddings(["Dimensionality reduction example"], options=options)
|
||||
|
||||
print(f" Dimensions: {result[0].dimensions}")
|
||||
print(f" Vector (first 5): {result[0].vector[:5]}...")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Run embedding examples."""
|
||||
await basic_embedding_example()
|
||||
await embedding_with_options_example()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
"""
|
||||
Sample output:
|
||||
=== Basic Embedding Generation ===
|
||||
Generated 3 embeddings
|
||||
Text 1: dimensions=1024, vector=[0.0123, -0.0456, 0.0789, -0.0012, 0.0345]...
|
||||
Text 2: dimensions=1024, vector=[0.0234, -0.0567, 0.0891, -0.0023, 0.0456]...
|
||||
Text 3: dimensions=1024, vector=[0.0345, -0.0678, 0.0912, -0.0034, 0.0567]...
|
||||
Usage: 15 input tokens, 15 total tokens
|
||||
|
||||
=== Embedding with Custom Dimensions ===
|
||||
Dimensions: 256
|
||||
Vector (first 5): [0.0456, -0.0789, 0.0123, -0.0456, 0.0789]...
|
||||
"""
|
||||
@@ -0,0 +1,267 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from agent_framework import Embedding, GeneratedEmbeddings
|
||||
|
||||
from agent_framework_mistral import MistralEmbeddingClient, MistralEmbeddingOptions
|
||||
|
||||
# region: Unit Tests
|
||||
|
||||
|
||||
def test_mistral_embedding_construction(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test construction with environment variables."""
|
||||
monkeypatch.setenv("MISTRAL_EMBEDDING_MODEL", "mistral-embed")
|
||||
monkeypatch.setenv("MISTRAL_API_KEY", "test-key")
|
||||
with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls:
|
||||
mock_cls.return_value = MagicMock()
|
||||
client = MistralEmbeddingClient()
|
||||
assert client.model == "mistral-embed"
|
||||
|
||||
|
||||
def test_mistral_embedding_construction_with_params() -> None:
|
||||
"""Test construction with explicit parameters."""
|
||||
with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls:
|
||||
mock_cls.return_value = MagicMock()
|
||||
client = MistralEmbeddingClient(
|
||||
model="mistral-embed",
|
||||
api_key="test-key",
|
||||
)
|
||||
assert client.model == "mistral-embed"
|
||||
mock_cls.assert_called_once_with(api_key="test-key")
|
||||
|
||||
|
||||
def test_mistral_embedding_construction_with_server_url() -> None:
|
||||
"""Test construction with custom server URL."""
|
||||
with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls:
|
||||
mock_cls.return_value = MagicMock()
|
||||
client = MistralEmbeddingClient(
|
||||
model="mistral-embed",
|
||||
api_key="test-key",
|
||||
server_url="https://custom.mistral.ai",
|
||||
)
|
||||
assert client.model == "mistral-embed"
|
||||
assert client.server_url == "https://custom.mistral.ai"
|
||||
mock_cls.assert_called_once_with(
|
||||
api_key="test-key",
|
||||
server_url="https://custom.mistral.ai",
|
||||
)
|
||||
|
||||
|
||||
def test_mistral_embedding_construction_with_client() -> None:
|
||||
"""Test construction with a pre-configured client."""
|
||||
mock_client = MagicMock()
|
||||
with patch("agent_framework_mistral._embedding_client.Mistral"):
|
||||
client = MistralEmbeddingClient(
|
||||
model="mistral-embed",
|
||||
api_key="test-key",
|
||||
client=mock_client,
|
||||
)
|
||||
assert client.client is mock_client
|
||||
|
||||
|
||||
def test_mistral_embedding_construction_missing_model_raises(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test that missing model raises an error."""
|
||||
monkeypatch.delenv("MISTRAL_EMBEDDING_MODEL", raising=False)
|
||||
monkeypatch.setenv("MISTRAL_API_KEY", "test-key")
|
||||
from agent_framework.exceptions import SettingNotFoundError
|
||||
|
||||
with pytest.raises(SettingNotFoundError):
|
||||
MistralEmbeddingClient()
|
||||
|
||||
|
||||
def test_mistral_embedding_construction_missing_api_key_raises(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Test that missing API key raises an error."""
|
||||
monkeypatch.delenv("MISTRAL_API_KEY", raising=False)
|
||||
monkeypatch.setenv("MISTRAL_EMBEDDING_MODEL", "mistral-embed")
|
||||
from agent_framework.exceptions import SettingNotFoundError
|
||||
|
||||
with pytest.raises(SettingNotFoundError):
|
||||
MistralEmbeddingClient()
|
||||
|
||||
|
||||
def test_mistral_embedding_service_url() -> None:
|
||||
"""Test service_url returns the correct URL."""
|
||||
with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls:
|
||||
mock_cls.return_value = MagicMock()
|
||||
client = MistralEmbeddingClient(
|
||||
model="mistral-embed",
|
||||
api_key="test-key",
|
||||
)
|
||||
assert client.service_url() == "https://api.mistral.ai"
|
||||
|
||||
|
||||
def test_mistral_embedding_service_url_custom() -> None:
|
||||
"""Test service_url returns custom URL when set."""
|
||||
with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls:
|
||||
mock_cls.return_value = MagicMock()
|
||||
client = MistralEmbeddingClient(
|
||||
model="mistral-embed",
|
||||
api_key="test-key",
|
||||
server_url="https://custom.mistral.ai",
|
||||
)
|
||||
assert client.service_url() == "https://custom.mistral.ai"
|
||||
|
||||
|
||||
async def test_mistral_embedding_get_embeddings() -> None:
|
||||
"""Test generating embeddings via the Mistral API."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = [
|
||||
MagicMock(embedding=[0.1, 0.2, 0.3], index=0, object="embedding"),
|
||||
MagicMock(embedding=[0.4, 0.5, 0.6], index=1, object="embedding"),
|
||||
]
|
||||
mock_response.model = "mistral-embed"
|
||||
mock_response.usage = MagicMock(prompt_tokens=10, total_tokens=10)
|
||||
|
||||
with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls:
|
||||
mock_client = MagicMock()
|
||||
mock_client.embeddings = MagicMock()
|
||||
mock_client.embeddings.create_async = AsyncMock(return_value=mock_response)
|
||||
mock_cls.return_value = mock_client
|
||||
|
||||
client = MistralEmbeddingClient(model="mistral-embed", api_key="test-key")
|
||||
result = await client.get_embeddings(["hello", "world"])
|
||||
|
||||
assert isinstance(result, GeneratedEmbeddings)
|
||||
assert len(result) == 2
|
||||
assert result[0].vector == [0.1, 0.2, 0.3]
|
||||
assert result[1].vector == [0.4, 0.5, 0.6]
|
||||
assert result[0].model == "mistral-embed"
|
||||
assert result.usage == {"input_token_count": 10, "total_token_count": 10}
|
||||
|
||||
mock_client.embeddings.create_async.assert_called_once_with(
|
||||
model="mistral-embed",
|
||||
inputs=["hello", "world"],
|
||||
)
|
||||
|
||||
|
||||
async def test_mistral_embedding_get_embeddings_empty_input() -> None:
|
||||
"""Test generating embeddings with empty input."""
|
||||
with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls:
|
||||
mock_client = MagicMock()
|
||||
mock_cls.return_value = mock_client
|
||||
|
||||
client = MistralEmbeddingClient(model="mistral-embed", api_key="test-key")
|
||||
result = await client.get_embeddings([])
|
||||
|
||||
assert isinstance(result, GeneratedEmbeddings)
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
async def test_mistral_embedding_get_embeddings_with_dimensions() -> None:
|
||||
"""Test generating embeddings with custom dimensions option."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = [
|
||||
MagicMock(embedding=[0.1, 0.2], index=0, object="embedding"),
|
||||
]
|
||||
mock_response.model = "mistral-embed"
|
||||
mock_response.usage = MagicMock(prompt_tokens=5, total_tokens=5)
|
||||
|
||||
with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls:
|
||||
mock_client = MagicMock()
|
||||
mock_client.embeddings = MagicMock()
|
||||
mock_client.embeddings.create_async = AsyncMock(return_value=mock_response)
|
||||
mock_cls.return_value = mock_client
|
||||
|
||||
client = MistralEmbeddingClient(model="mistral-embed", api_key="test-key")
|
||||
options: MistralEmbeddingOptions = {"dimensions": 512}
|
||||
result = await client.get_embeddings(["hello"], options=options)
|
||||
|
||||
assert len(result) == 1
|
||||
mock_client.embeddings.create_async.assert_called_once_with(
|
||||
model="mistral-embed",
|
||||
inputs=["hello"],
|
||||
output_dimension=512,
|
||||
)
|
||||
|
||||
|
||||
async def test_mistral_embedding_get_embeddings_no_model_raises() -> None:
|
||||
"""Test that missing model at call time raises ValueError."""
|
||||
with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls:
|
||||
mock_client = MagicMock()
|
||||
mock_cls.return_value = mock_client
|
||||
|
||||
client = MistralEmbeddingClient(model="mistral-embed", api_key="test-key")
|
||||
client.model = None # type: ignore[assignment]
|
||||
|
||||
with pytest.raises(ValueError, match="model is required"):
|
||||
await client.get_embeddings(["hello"])
|
||||
|
||||
|
||||
async def test_mistral_embedding_get_embeddings_model_override() -> None:
|
||||
"""Test that model can be overridden via options."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = [
|
||||
MagicMock(embedding=[0.1, 0.2, 0.3], index=0, object="embedding"),
|
||||
]
|
||||
mock_response.model = "custom-embed"
|
||||
mock_response.usage = MagicMock(prompt_tokens=5, total_tokens=5)
|
||||
|
||||
with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls:
|
||||
mock_client = MagicMock()
|
||||
mock_client.embeddings = MagicMock()
|
||||
mock_client.embeddings.create_async = AsyncMock(return_value=mock_response)
|
||||
mock_cls.return_value = mock_client
|
||||
|
||||
client = MistralEmbeddingClient(model="mistral-embed", api_key="test-key")
|
||||
options: MistralEmbeddingOptions = {"model": "custom-embed"}
|
||||
result = await client.get_embeddings(["hello"], options=options)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].model == "custom-embed"
|
||||
mock_client.embeddings.create_async.assert_called_once_with(
|
||||
model="custom-embed",
|
||||
inputs=["hello"],
|
||||
)
|
||||
|
||||
|
||||
async def test_mistral_embedding_get_embeddings_no_usage() -> None:
|
||||
"""Test handling response without usage information."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = [
|
||||
MagicMock(embedding=[0.1, 0.2, 0.3], index=0, object="embedding"),
|
||||
]
|
||||
mock_response.model = "mistral-embed"
|
||||
mock_response.usage = None
|
||||
|
||||
with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls:
|
||||
mock_client = MagicMock()
|
||||
mock_client.embeddings = MagicMock()
|
||||
mock_client.embeddings.create_async = AsyncMock(return_value=mock_response)
|
||||
mock_cls.return_value = mock_client
|
||||
|
||||
client = MistralEmbeddingClient(model="mistral-embed", api_key="test-key")
|
||||
result = await client.get_embeddings(["hello"])
|
||||
|
||||
assert len(result) == 1
|
||||
assert result.usage is None
|
||||
|
||||
|
||||
# region: Integration Tests
|
||||
|
||||
skip_if_mistral_embedding_integration_tests_disabled = pytest.mark.skipif(
|
||||
os.getenv("MISTRAL_EMBEDDING_MODEL", "") in ("", "test-model") or os.getenv("MISTRAL_API_KEY", "") == "",
|
||||
reason="No real Mistral embedding model or API key provided; skipping integration tests.",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.flaky
|
||||
@pytest.mark.integration
|
||||
@skip_if_mistral_embedding_integration_tests_disabled
|
||||
async def test_mistral_embedding_integration() -> None:
|
||||
"""Integration test for Mistral AI embedding client."""
|
||||
client = MistralEmbeddingClient()
|
||||
result = await client.get_embeddings(["Hello, world!", "How are you?"])
|
||||
|
||||
assert isinstance(result, GeneratedEmbeddings)
|
||||
assert len(result) == 2
|
||||
for embedding in result:
|
||||
assert isinstance(embedding, Embedding)
|
||||
assert isinstance(embedding.vector, list)
|
||||
assert len(embedding.vector) > 0
|
||||
assert all(isinstance(v, float) for v in embedding.vector)
|
||||
assert result.usage is not None
|
||||
assert result.usage["input_token_count"] is not None
|
||||
assert result.usage["input_token_count"] > 0
|
||||
Reference in New Issue
Block a user