Merge branch 'main' into local-branch-fix-workflow-as-agent-pending-request-handling

This commit is contained in:
Tao Chen
2026-06-01 22:15:14 -07:00
Unverified
116 changed files with 9188 additions and 4649 deletions
@@ -2,6 +2,7 @@
import base64
import logging
import uuid
from asyncio import CancelledError
from collections.abc import Mapping
from functools import partial
@@ -181,9 +182,18 @@ class A2AExecutor(AgentExecutor):
"""Run the agent in streaming mode and publish updates to the task updater."""
response_stream = self._agent.run(query, session=session, stream=True, **self._run_kwargs)
streamed_artifact_ids: set[str] = set()
# Generate a stable artifact ID for the entire stream so all chunks share the same ID.
# This ensures clients can coalesce streaming tokens into a single artifact/message
# per the A2A spec (TaskArtifactUpdateEvent with append=True on same artifactId).
default_artifact_id = str(uuid.uuid4())
await (
response_stream.with_transform_hook(
partial(self.handle_events, updater=updater, streamed_artifact_ids=streamed_artifact_ids)
partial(
self.handle_events,
updater=updater,
streamed_artifact_ids=streamed_artifact_ids,
default_artifact_id=default_artifact_id,
)
)
).get_final_response()
@@ -199,7 +209,11 @@ class A2AExecutor(AgentExecutor):
await self.handle_events(message, updater)
async def handle_events(
self, item: Message | AgentResponseUpdate, updater: TaskUpdater, streamed_artifact_ids: set[str] | None = None
self,
item: Message | AgentResponseUpdate,
updater: TaskUpdater,
streamed_artifact_ids: set[str] | None = None,
default_artifact_id: str | None = None,
) -> None:
"""Convert agent response items (Messages or Updates) to A2A protocol events.
@@ -213,7 +227,10 @@ class A2AExecutor(AgentExecutor):
item: The agent response item (Message or AgentResponseUpdate) to process.
updater: The task updater to publish events to.
streamed_artifact_ids: A set of artifact IDs that have already been streamed.
Used to prevent duplicate updates for the same artifact.
Used to track which artifacts need append=True on subsequent chunks.
default_artifact_id: A stable artifact ID to use when the item does not provide one.
This ensures all streaming chunks for a single response share the same artifact ID,
allowing clients to coalesce them into a single message.
Example:
.. code-block:: python
@@ -224,6 +241,7 @@ class A2AExecutor(AgentExecutor):
item: Message | AgentResponseUpdate,
updater: TaskUpdater,
streamed_artifact_ids: set[str] | None = None,
default_artifact_id: str | None = None,
) -> None:
# Custom logic to transform item contents
if item.role == "assistant" and item.contents:
@@ -260,19 +278,20 @@ class A2AExecutor(AgentExecutor):
if parts:
if isinstance(item, AgentResponseUpdate):
# Resolve artifact ID: use item's message_id if available, otherwise fall back
# to the stable default_artifact_id so all streaming chunks share the same ID.
artifact_id = item.message_id or default_artifact_id
# For streaming updates, we send TaskArtifactUpdateEvent via add_artifact
await updater.add_artifact(
parts=parts,
artifact_id=item.message_id,
artifact_id=artifact_id,
metadata=metadata,
append=(
True
if streamed_artifact_ids is not None and item.message_id in (streamed_artifact_ids or set())
else None
True if streamed_artifact_ids is not None and artifact_id in streamed_artifact_ids else None
),
)
if item.message_id and streamed_artifact_ids is not None:
streamed_artifact_ids.add(item.message_id)
if artifact_id and streamed_artifact_ids is not None:
streamed_artifact_ids.add(artifact_id)
else:
# For final messages, we send TaskStatusUpdateEvent with 'working' state
await updater.update_status(
@@ -492,6 +492,7 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
contents=contents,
role="assistant" if msg.role == A2ARole.ROLE_AGENT else "user",
response_id=msg.message_id or str(uuid.uuid4()),
message_id=msg.message_id,
additional_properties={"a2a_metadata": metadata} if metadata else None,
raw_representation=msg,
)
@@ -732,6 +733,7 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
contents=contents,
role="assistant" if message.role == A2ARole.ROLE_AGENT else "user",
response_id=update_event.task_id,
message_id=message.message_id,
additional_properties={"a2a_metadata": merged_metadata} if merged_metadata else None,
raw_representation=update_event,
)
+5 -2
View File
@@ -420,6 +420,7 @@ async def test_run_streaming_with_message_response(a2a_agent: A2AAgent, mock_a2a
assert content.text == "Streaming response from agent!"
assert updates[0].response_id == "msg-stream-123"
assert updates[0].message_id == "msg-stream-123"
assert mock_a2a_client.call_count == 1
@@ -1422,7 +1423,7 @@ async def test_streaming_status_update_event_yields_content(
status=TaskStatus(
state=TaskState.TASK_STATE_COMPLETED,
message=A2AMessage(
message_id=str(uuid4()),
message_id="msg-status-done",
role=A2ARole.ROLE_AGENT,
parts=[Part(text="Done")],
),
@@ -1437,6 +1438,7 @@ async def test_streaming_status_update_event_yields_content(
assert len(updates) == 1
assert updates[0].text == "Done"
assert updates[0].role == "assistant"
assert updates[0].message_id == "msg-status-done"
assert updates[0].raw_representation == update_event
@@ -1449,7 +1451,7 @@ async def test_streaming_input_required_emits_content(a2a_agent: A2AAgent, mock_
status=TaskStatus(
state=TaskState.TASK_STATE_INPUT_REQUIRED,
message=A2AMessage(
message_id=str(uuid4()),
message_id="msg-input-req",
role=A2ARole.ROLE_AGENT,
parts=[Part(text="What is your name?")],
),
@@ -1463,6 +1465,7 @@ async def test_streaming_input_required_emits_content(a2a_agent: A2AAgent, mock_
assert len(updates) == 1
assert updates[0].text == "What is your name?"
assert updates[0].message_id == "msg-input-req"
@mark.asyncio
@@ -803,6 +803,15 @@ class RawAnthropicClient(
}
a_content.append(mcp_result)
case "text_reasoning":
if content.text is None:
if (
content.protected_data
and a_content
and a_content[-1].get("type") == "thinking"
and "signature" not in a_content[-1]
):
a_content[-1]["signature"] = content.protected_data
continue
thinking_block: dict[str, Any] = {"type": "thinking", "thinking": content.text}
if content.protected_data:
thinking_block["signature"] = content.protected_data
@@ -485,6 +485,48 @@ def test_prepare_message_for_anthropic_text_reasoning_with_signature(
assert result["content"][0]["signature"] == "sig_abc123"
def test_prepare_message_for_anthropic_attaches_signature_only_reasoning(
mock_anthropic_client: MagicMock,
) -> None:
client = create_test_anthropic_client(mock_anthropic_client)
message = Message(
role="assistant",
contents=[
Content.from_text_reasoning(text="Let me think about this..."),
Content.from_text_reasoning(text=None, protected_data="sig_abc123"),
],
)
result = client._prepare_message_for_anthropic(message)
assert result["content"] == [
{"type": "thinking", "thinking": "Let me think about this...", "signature": "sig_abc123"}
]
def test_prepare_message_for_anthropic_skips_orphan_signature_only_reasoning(
mock_anthropic_client: MagicMock,
) -> None:
client = create_test_anthropic_client(mock_anthropic_client)
message = Message(
role="assistant",
contents=[
Content.from_text_reasoning(text=None, protected_data="sig_abc123"),
Content.from_function_call(
call_id="call_123",
name="get_weather",
arguments={"location": "San Francisco"},
),
],
)
result = client._prepare_message_for_anthropic(message)
assert len(result["content"]) == 1
assert result["content"][0]["type"] == "tool_use"
assert result["content"][0]["id"] == "call_123"
def test_prepare_message_for_anthropic_mcp_server_tool_call(
mock_anthropic_client: MagicMock,
) -> None:
@@ -365,6 +365,14 @@ def _start_function_app(sample_path: Path, port: int) -> subprocess.Popen[Any]:
# use the task hub name to separate orchestration state.
env["TASKHUB_NAME"] = f"test{uuid.uuid4().hex[:8]}"
# The Azure Functions Python worker's dependency isolation mechanism crashes
# on Python 3.13 with a SIGSEGV in the protobuf C extension (google._upb).
# Disabling isolation lets the worker load dependencies from the app's own
# environment, which avoids the crash.
# See: https://github.com/Azure/azure-functions-python-worker/issues/1797
if sys.version_info >= (3, 13):
env.setdefault("PYTHON_ISOLATE_WORKER_DEPENDENCIES", "0")
# On Windows, use CREATE_NEW_PROCESS_GROUP to allow proper termination
# shell=True only on Windows to handle PATH resolution
if sys.platform == "win32":
@@ -375,8 +383,15 @@ def _start_function_app(sample_path: Path, port: int) -> subprocess.Popen[Any]:
shell=True,
env=env,
)
# On Unix, don't use shell=True to avoid shell wrapper issues
return subprocess.Popen(["func", "start", "--port", str(port)], cwd=str(sample_path), env=env)
# On Unix, use start_new_session=True to isolate the process group from the
# pytest-xdist worker. Without this, signals (e.g. from test-timeout) can
# propagate to the func host and vice-versa, potentially killing the worker.
return subprocess.Popen(
["func", "start", "--port", str(port)],
cwd=str(sample_path),
env=env,
start_new_session=True,
)
def _wait_for_function_app_ready(func_process: subprocess.Popen[Any], port: int, max_wait: int = 60) -> None:
@@ -533,18 +548,33 @@ def function_app_for_test(request: pytest.FixtureRequest) -> Iterator[dict[str,
_load_and_validate_env(sample_path)
max_attempts = 3
# The overall budget MUST be shorter than the pytest-timeout value
# (--timeout=120 by default) so that the fixture finishes cleanly instead
# of being killed by os._exit() which crashes the xdist worker.
overall_budget = 100 # seconds leaves headroom below the 120 s test timeout
last_error: Exception | None = None
func_process: subprocess.Popen[Any] | None = None
base_url = ""
port = 0
overall_start = time.monotonic()
attempts_made = 0
for _ in range(max_attempts):
remaining = overall_budget - (time.monotonic() - overall_start)
if remaining < 10:
# Not enough time for another attempt; bail out.
break
attempts_made += 1
port = _find_available_port()
base_url = _build_base_url(port)
func_process = _start_function_app(sample_path, port)
try:
_wait_for_function_app_ready(func_process, port)
# Cap each attempt's wait to the remaining budget minus a small
# buffer for cleanup.
per_attempt_wait = min(60, int(remaining) - 5)
_wait_for_function_app_ready(func_process, port, max_wait=max(per_attempt_wait, 10))
last_error = None
break
except FunctionAppStartupError as exc:
@@ -553,7 +583,8 @@ def function_app_for_test(request: pytest.FixtureRequest) -> Iterator[dict[str,
func_process = None
if func_process is None:
error_message = f"Function app failed to start after {max_attempts} attempt(s)."
elapsed = int(time.monotonic() - overall_start)
error_message = f"Function app failed to start after {attempts_made} attempt(s) ({elapsed}s elapsed)."
if last_error is not None:
error_message += f" Last error: {last_error}"
pytest.fail(error_message)
@@ -4,6 +4,7 @@
from __future__ import annotations
import asyncio
import copy
import json
import logging
import sys
@@ -36,6 +37,7 @@ from agent_framework.observability import ChatTelemetryLayer
from boto3.session import Session as Boto3Session
from botocore.client import BaseClient
from botocore.config import Config as BotoConfig
from botocore.exceptions import ClientError
from pydantic import BaseModel
if sys.version_info >= (3, 13):
@@ -115,13 +117,20 @@ class BedrockChatOptions(ChatOptions[ResponseModelT], Generic[ResponseModelT], t
translates to ``toolConfig.tools``.
tool_choice: How the model should use tools,
translates to ``toolConfig.toolChoice``.
response_format: Structured output format. Accepts a Pydantic BaseModel
subclass or an OpenAI-style dict schema
(``{"json_schema": {"name": ..., "schema": ...}}``).
When provided, the Converse API request includes
``outputConfig.textFormat`` with the schema serialized as a JSON
string. ``ChatResponse.value`` will be populated with the parsed
model instance. Only supported on models that support
``outputConfig.textFormat``. Unsupported models raise a ValueError.
# Options not supported in Bedrock Converse API:
seed: Not supported.
frequency_penalty: Not supported.
presence_penalty: Not supported.
allow_multiple_tool_calls: Not supported (models handle parallel calls automatically).
response_format: Not directly supported (use model-specific prompting).
user: Not supported.
store: Not supported.
logit_bias: Not supported.
@@ -161,9 +170,6 @@ class BedrockChatOptions(ChatOptions[ResponseModelT], Generic[ResponseModelT], t
allow_multiple_tool_calls: None # type: ignore[misc]
"""Not supported. Bedrock models handle parallel tool calls automatically."""
response_format: None # type: ignore[misc]
"""Not directly supported. Use model-specific prompting for JSON output."""
user: None # type: ignore[misc]
"""Not supported in Bedrock Converse API."""
@@ -324,10 +330,28 @@ class BedrockChatClient(
return Boto3Session(**session_kwargs)
def _invoke_converse(self, request: Mapping[str, Any]) -> dict[str, Any]:
response = self._bedrock_client.converse(**request)
if not isinstance(response, Mapping):
raise ChatClientInvalidResponseException("Bedrock converse response must be a mapping.")
return response
try:
response = self._bedrock_client.converse(**request)
if not isinstance(response, Mapping):
raise ChatClientInvalidResponseException("Bedrock converse response must be a mapping.")
return response
except ClientError as e:
error_details = e.response.get("Error", {})
error_code = error_details.get("Code", "")
error_message = error_details.get("Message", "")
# "outputConfig" in error_message catches cases where Bedrock explicitly
# rejects the outputConfig field (unsupported model). Other ValidationExceptions
# (e.g. malformed schema shape, invalid property values) will not mention
# "outputConfig" and will bubble up as raw ClientError without being misdiagnosed.
if error_code == "ValidationException" and (
"outputconfig" in error_message.lower() or "outputconfig" in str(e).lower()
):
raise ValueError(
f"Model '{self.model}' does not support structured output via outputConfig.textFormat. "
"Check the model's Bedrock Converse outputConfig/textFormat support. "
f"AWS error Code: {error_code}. AWS error Message: {error_message}"
) from e
raise
@override
def _inner_get_response(
@@ -344,7 +368,7 @@ class BedrockChatClient(
# Streaming mode - simulate streaming by yielding a single update
async def _stream() -> AsyncIterable[ChatResponseUpdate]:
response = await asyncio.to_thread(self._invoke_converse, request)
parsed_response = self._process_converse_response(response)
parsed_response = self._process_converse_response(response, options)
contents = list(parsed_response.messages[0].contents if parsed_response.messages else [])
if parsed_response.usage_details:
contents.append(Content.from_usage(usage_details=parsed_response.usage_details)) # type: ignore[arg-type]
@@ -360,12 +384,12 @@ class BedrockChatClient(
raw_representation=parsed_response.raw_representation,
)
return self._build_response_stream(_stream())
return self._build_response_stream(_stream(), response_format=options.get("response_format"))
# Non-streaming mode
async def _get_response() -> ChatResponse:
raw_response = await asyncio.to_thread(self._invoke_converse, request)
return self._process_converse_response(raw_response)
return self._process_converse_response(raw_response, options)
return _get_response()
@@ -430,6 +454,9 @@ class BedrockChatClient(
if tool_config:
run_options["toolConfig"] = tool_config
if output_config := self._prepare_output_config(options.get("response_format")):
run_options["outputConfig"] = output_config
return run_options
def _prepare_bedrock_messages(
@@ -628,7 +655,9 @@ class BedrockChatClient(
def _generate_tool_call_id() -> str:
return f"tool-call-{uuid4().hex}"
def _process_converse_response(self, response: dict[str, Any]) -> ChatResponse:
def _process_converse_response(
self, response: dict[str, Any], options: Mapping[str, Any] | None = None
) -> ChatResponse:
"""Convert Bedrock Converse API response to ChatResponse."""
output = response.get("output") or {}
message = output.get("message") or {}
@@ -646,6 +675,7 @@ class BedrockChatClient(
usage_details=usage_details,
model=model,
finish_reason=finish_reason,
response_format=options.get("response_format") if options else None,
raw_representation=response,
)
@@ -728,6 +758,108 @@ class BedrockChatClient(
return None
return FINISH_REASON_MAP.get(reason.lower())
def _prepare_output_config(self, response_format: Any | None) -> dict[str, Any] | None:
"""Convert response_format into the AWS Bedrock outputConfig wire format.
Args:
response_format: A Pydantic model class or a dict schema, or None.
Returns:
A dict for the Converse API ``outputConfig`` parameter, or None if
response_format is not set.
"""
if response_format is None:
return None
if isinstance(response_format, Mapping):
if "json_schema" in response_format:
# Shape A — OpenAI-style wrapper
json_schema_config = response_format["json_schema"]
schema_src = json_schema_config.get("schema", {})
name = json_schema_config.get("name", "output_schema")
elif "schema" in response_format:
# Shape B — inner shape directly {"name": ..., "schema": ...}
schema_src = response_format["schema"]
name = response_format.get("name", "output_schema")
else:
# Shape C — assume entire dict is the raw schema
logger.warning(
"response_format dict has no 'json_schema' or 'schema' key; "
"treating entire dict as raw JSON schema."
)
schema_src = dict(response_format)
name = "output_schema"
if isinstance(schema_src, str):
schema_src = json.loads(schema_src)
schema = copy.deepcopy(schema_src)
else:
if not isinstance(response_format, type) or not issubclass(response_format, BaseModel):
raise TypeError(
"response_format must be None, a dict JSON schema, "
"or a Pydantic BaseModel subclass."
)
# response_format is a Pydantic model class
schema = response_format.model_json_schema()
name = response_format.__name__
self._set_additional_properties_false(schema)
json_schema: dict[str, Any] = {
"name": name,
"schema": json.dumps(schema),
}
description = getattr(response_format, "__doc__", None) if not isinstance(response_format, Mapping) else None
if description and isinstance(description, str) and description.strip():
json_schema["description"] = description.strip()
return {
"textFormat": {
"type": "json_schema",
"structure": {
"jsonSchema": json_schema
},
}
}
def _set_additional_properties_false(self, schema: dict[str, Any]) -> None:
"""Recursively set additionalProperties: false on all object types in a JSON schema.
AWS requires strict schema enforcement. This mirrors the approach used by
AnthropicChatClient._prepare_response_format().
Args:
schema: The JSON schema dict to modify in-place.
"""
visited: set[int] = set()
def walk(node: Any) -> None:
if isinstance(node, dict):
node_id = id(node)
if node_id in visited:
return
visited.add(node_id)
if node.get("type") == "object" or (
"properties" in node and "type" not in node
):
existing = node.get("additionalProperties")
if existing is None or existing is True:
node["additionalProperties"] = False
for value in node.values():
if isinstance(value, (dict, list)):
walk(value)
elif isinstance(node, list):
node_id = id(node)
if node_id in visited:
return
visited.add(node_id)
for item in node:
if isinstance(item, (dict, list)):
walk(item)
walk(schema)
def service_url(self) -> str:
"""Returns the service URL for the Bedrock runtime in the configured AWS region.
@@ -0,0 +1,382 @@
# Copyright (c) Microsoft. All rights reserved.
from __future__ import annotations
import copy
import json
from typing import Any
from unittest.mock import patch
import pytest
from agent_framework import Content, Message
from botocore.exceptions import ClientError
from pydantic import BaseModel
from agent_framework_bedrock import BedrockChatClient
# region Test models
class WeatherReport(BaseModel):
city: str
temperature: float
summary: str
class NestedAddress(BaseModel):
street: str
city: str
zip_code: str
class Person(BaseModel):
name: str
age: int
address: NestedAddress
# endregion
# region Helpers
class _StubBedrockRuntime:
"""Stub that records calls and returns a canned response."""
def __init__(self, response_text: str = "Bedrock says hi") -> None:
self.calls: list[dict[str, Any]] = []
self._response_text = response_text
def converse(self, **kwargs: Any) -> dict[str, Any]:
self.calls.append(kwargs)
return {
"modelId": kwargs["modelId"],
"responseId": "resp-structured",
"usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30},
"output": {
"completionReason": "end_turn",
"message": {
"id": "msg-structured",
"role": "assistant",
"content": [{"text": self._response_text}],
},
},
}
def _make_client(response_text: str = "Bedrock says hi") -> tuple[BedrockChatClient, _StubBedrockRuntime]:
stub = _StubBedrockRuntime(response_text)
client = BedrockChatClient(
model="us.anthropic.claude-haiku-4-5-v1:0",
region="us-east-1",
client=stub,
)
return client, stub
def _user_messages() -> list[Message]:
return [Message(role="user", contents=[Content.from_text(text="Give me a weather report")])]
# endregion
# region Tests
def test_prepare_output_config_correct_wire_shape() -> None:
"""_prepare_output_config(WeatherReport) must produce the correct
textFormat → structure → jsonSchema shape with type: 'json_schema'."""
client, _ = _make_client()
output_config = client._prepare_output_config(WeatherReport)
assert output_config is not None
text_format = output_config["textFormat"]
assert text_format["type"] == "json_schema"
assert "structure" in text_format
json_schema = text_format["structure"]["jsonSchema"]
assert json_schema["name"] == "WeatherReport"
assert "schema" in json_schema
def test_prepare_output_config_schema_is_json_string() -> None:
"""The schema value inside jsonSchema must be a JSON string, not a dict."""
client, _ = _make_client()
output_config = client._prepare_output_config(WeatherReport)
assert output_config is not None
schema_value = output_config["textFormat"]["structure"]["jsonSchema"]["schema"]
assert isinstance(schema_value, str), f"Expected str, got {type(schema_value)}"
# Verify it's valid JSON
parsed = json.loads(schema_value)
assert isinstance(parsed, dict)
assert parsed["type"] == "object"
def test_additional_properties_false_set_recursively() -> None:
"""additionalProperties: false must be set on all nested object types."""
client, _ = _make_client()
output_config = client._prepare_output_config(Person)
assert output_config is not None
schema_str = output_config["textFormat"]["structure"]["jsonSchema"]["schema"]
schema = json.loads(schema_str)
# Top-level object
assert schema.get("additionalProperties") is False
# Check $defs for NestedAddress
defs = schema.get("$defs", {})
assert "NestedAddress" in defs, "Expected NestedAddress to be present in $defs"
assert defs["NestedAddress"].get("additionalProperties") is False, (
"Expected additionalProperties=False on nested NestedAddress schema"
)
def test_no_output_config_when_response_format_none() -> None:
"""When response_format is None, no outputConfig key should appear in the request."""
client, stub = _make_client()
messages = _user_messages()
request = client._prepare_options(messages, {"max_tokens": 100})
assert "outputConfig" not in request, (
f"outputConfig should not be present when response_format is None, got: {request.get('outputConfig')}"
)
async def test_chat_response_value_populated() -> None:
"""After a mocked response with response_format, .value should be a populated Pydantic model."""
json_response = json.dumps({"city": "Seattle", "temperature": 72.5, "summary": "Sunny and warm"})
client, stub = _make_client(response_text=json_response)
messages = _user_messages()
response = await client.get_response(
messages=messages,
options={"max_tokens": 100, "response_format": WeatherReport},
)
assert response.text == json_response
assert response.value is not None
assert isinstance(response.value, WeatherReport)
assert response.value.city == "Seattle"
assert response.value.temperature == 72.5
assert response.value.summary == "Sunny and warm"
# Verify outputConfig was sent to the API
assert len(stub.calls) == 1
api_request = stub.calls[0]
assert "outputConfig" in api_request
assert api_request["outputConfig"]["textFormat"]["type"] == "json_schema"
def test_dict_schema_response_format() -> None:
"""_prepare_output_config should work when response_format is a dict, not just a Pydantic class."""
client, _ = _make_client()
dict_schema = {
"json_schema": {
"name": "weather_output",
"schema": {
"type": "object",
"properties": {
"city": {"type": "string"},
"temp": {"type": "number"},
},
},
}
}
output_config = client._prepare_output_config(dict_schema)
assert output_config is not None
json_schema = output_config["textFormat"]["structure"]["jsonSchema"]
assert json_schema["name"] == "weather_output"
schema_parsed = json.loads(json_schema["schema"])
assert schema_parsed["type"] == "object"
assert "city" in schema_parsed["properties"]
def test_prepare_output_config_none_returns_none() -> None:
"""_prepare_output_config(None) must return None."""
client, _ = _make_client()
result = client._prepare_output_config(None)
assert result is None
async def test_chat_response_value_populated_streaming() -> None:
"""In streaming mode, .value should also be populated on the final response."""
json_response = json.dumps({"city": "Portland", "temperature": 68.0, "summary": "Cloudy"})
client, stub = _make_client(response_text=json_response)
messages = _user_messages()
stream = client.get_response(
messages=messages,
stream=True,
options={"max_tokens": 100, "response_format": WeatherReport},
)
# Consume stream and get final response
async for _ in stream:
pass
response = await stream.get_final_response()
assert response.value is not None
assert isinstance(response.value, WeatherReport)
assert response.value.city == "Portland"
# Verify outputConfig was sent
assert len(stub.calls) == 1
assert "outputConfig" in stub.calls[0]
async def test_unsupported_model_validation_exception() -> None:
"""When a model doesn't support outputConfig, a clear error should be raised."""
class _FailingStubBedrockRuntime:
def converse(self, **kwargs: Any) -> dict[str, Any]:
# Simulate botocore ClientError for ValidationException
error_response = {"Error": {"Code": "ValidationException", "Message": "Invalid field outputConfig"}}
raise ClientError(error_response, "Converse")
client = BedrockChatClient(
model="us.anthropic.claude-v2",
region="us-east-1",
client=_FailingStubBedrockRuntime(),
)
with pytest.raises(ValueError) as exc:
await client.get_response(
messages=_user_messages(),
options={"response_format": WeatherReport},
)
assert "does not support structured output via outputConfig.textFormat" in str(exc.value)
assert "Check the model's Bedrock Converse outputConfig/textFormat support." in str(exc.value)
def test_invalid_response_format_type_raises() -> None:
"""Non-dict, non-BaseModel response_format should raise TypeError."""
client, _ = _make_client()
with pytest.raises(TypeError, match="Pydantic BaseModel subclass"):
client._prepare_output_config("not_a_valid_format")
def test_mapping_response_format_accepted() -> None:
"""A non-dict Mapping response_format must be accepted and produce
correct outputConfig, not raise TypeError."""
from collections.abc import MutableMapping
class _WrappedMapping(MutableMapping):
def __init__(self, data):
self._data = dict(data)
def __getitem__(self, key):
return self._data[key]
def __setitem__(self, key, value):
self._data[key] = value
def __delitem__(self, key):
del self._data[key]
def __iter__(self):
return iter(self._data)
def __len__(self):
return len(self._data)
client, _ = _make_client()
mapping_format = _WrappedMapping({
"json_schema": {
"name": "test_output",
"schema": {
"type": "object",
"properties": {"result": {"type": "string"}},
},
}
})
output_config = client._prepare_output_config(mapping_format)
assert output_config is not None
json_schema = output_config["textFormat"]["structure"]["jsonSchema"]
assert json_schema["name"] == "test_output"
schema = json.loads(json_schema["schema"])
assert schema.get("additionalProperties") is False
def test_shape_b_dict_schema_wire_format() -> None:
"""Dict response_format in Shape B (inner shape directly) should
produce correct outputConfig."""
client, _ = _make_client()
response_format = {
"name": "weather_output",
"schema": {
"type": "object",
"properties": {
"city": {"type": "string"},
"temperature": {"type": "number"},
},
},
}
output_config = client._prepare_output_config(response_format)
assert output_config is not None
text_format = output_config["textFormat"]
assert text_format["type"] == "json_schema"
json_schema = text_format["structure"]["jsonSchema"]
assert json_schema["name"] == "weather_output"
schema = json.loads(json_schema["schema"])
assert schema.get("additionalProperties") is False
def test_dict_schema_not_mutated() -> None:
"""Caller's dict schema must not be mutated by _prepare_output_config."""
client, _ = _make_client()
original_schema = {
"json_schema": {
"name": "test",
"schema": {
"type": "object",
"properties": {"a": {"type": "string"}},
},
}
}
snapshot = copy.deepcopy(original_schema)
client._prepare_output_config(original_schema)
assert original_schema == snapshot, "Original dict schema was mutated"
async def test_non_outputconfig_validation_exception_propagates() -> None:
"""ValidationException unrelated to outputConfig must propagate
as raw ClientError, not be caught and reclassified."""
client, _ = _make_client()
error_response = {
"Error": {
"Code": "ValidationException",
"Message": "Invalid message format",
}
}
with (
patch.object(
client,
"_bedrock_client",
**{"converse.side_effect": ClientError(error_response, "Converse")},
),
pytest.raises(ClientError),
):
await client.get_response(
messages=_user_messages(),
options={"max_tokens": 100},
)
# endregion
@@ -71,6 +71,7 @@ from ._evaluation import (
Evaluator,
ExpectedToolCall,
LocalEvaluator,
RubricScore,
evaluate_agent,
evaluate_workflow,
evaluator,
@@ -460,6 +461,7 @@ __all__ = [
"ResponseStream",
"Role",
"RoleLiteral",
"RubricScore",
"RunContext",
"Runner",
"RunnerContext",
@@ -311,12 +311,15 @@ class EvalScoreResult:
score: Numeric score from the evaluator.
passed: Whether the item passed this evaluator's threshold.
sample: Optional raw evaluator output (rationale, metadata).
dimensions: Per-dimension scores when this evaluator is a rubric
evaluator. ``None`` for non-rubric (e.g. built-in) evaluators.
"""
name: str
score: float
passed: bool | None = None
sample: dict[str, Any] | None = None
dimensions: list[RubricScore] | None = None
@experimental(feature_id=ExperimentalFeature.EVALS)
@@ -496,6 +499,179 @@ class EvalResults:
detail += f" Errored items: {', '.join(summaries)}."
raise EvalNotPassedError(detail)
def assert_score_at_least(
self,
min_score: float,
*,
evaluator: str | None = None,
msg: str | None = None,
) -> None:
"""Assert every item's score (optionally filtered by evaluator) is ``>= min_score``.
Designed for CI gates on generated rubric evaluators (e.g.
``results.assert_score_at_least(0.80)``). Includes any
sub-results from workflow evaluations.
Args:
min_score: Minimum acceptable score (inclusive).
evaluator: When set, only check scores from the evaluator
whose ``EvalScoreResult.name`` matches.
msg: Optional custom failure message.
Raises:
EvalNotPassedError: When any matching score is below the threshold.
"""
offenders: list[str] = []
def _check(results: EvalResults) -> None:
for item in results.items:
for score in item.scores:
if evaluator is not None and score.name != evaluator:
continue
if score.score < min_score:
offenders.append(f"{item.item_id}/{score.name}={score.score:.3f}")
for sub in results.sub_results.values():
_check(sub)
_check(self)
if offenders:
detail = msg or (
f"{len(offenders)} score(s) below threshold {min_score}"
f"{' for ' + evaluator if evaluator else ''}: {', '.join(offenders[:5])}"
+ (f" (+{len(offenders) - 5} more)" if len(offenders) > 5 else "")
)
raise EvalNotPassedError(detail)
def assert_dimension_score_at_least(
self,
dimension_id: str,
min_score: float,
*,
evaluator: str | None = None,
require_applicable: bool = False,
msg: str | None = None,
) -> None:
"""Assert every item's score for a rubric *dimension* is ``>= min_score``.
Walks ``EvalScoreResult.dimensions`` looking for the named
dimension across all items (and sub-results). Non-applicable
dimensions are skipped by default; pass
``require_applicable=True`` to fail when no applicable score is
produced.
Args:
dimension_id: Dimension id (matches the rubric definition).
min_score: Minimum acceptable dimension score (inclusive).
evaluator: When set, only consider scores from the evaluator
whose ``EvalScoreResult.name`` matches.
require_applicable: When ``True``, missing or non-applicable
dimension scores raise. Defaults to ``False`` (skip).
msg: Optional custom failure message.
Raises:
EvalNotPassedError: When the dimension fails the threshold.
"""
offenders: list[str] = []
missing_items: list[str] = []
def _check(results: EvalResults) -> None:
for item in results.items:
found_applicable = False
for score in item.scores:
if evaluator is not None and score.name != evaluator:
continue
if not score.dimensions:
continue
for rs in score.dimensions:
if rs.id != dimension_id:
continue
if not rs.applicable:
continue
found_applicable = True
if rs.score is None or rs.score < min_score:
offenders.append(
f"{item.item_id}/{score.name}/{dimension_id}="
f"{rs.score if rs.score is not None else 'None'}"
)
if require_applicable and not found_applicable:
missing_items.append(item.item_id)
for sub in results.sub_results.values():
_check(sub)
_check(self)
problems: list[str] = []
if offenders:
problems.append(
f"{len(offenders)} dimension score(s) for '{dimension_id}' below {min_score}: "
f"{', '.join(offenders[:5])}" + (f" (+{len(offenders) - 5} more)" if len(offenders) > 5 else "")
)
if missing_items:
problems.append(
f"Dimension '{dimension_id}' not applicable on {len(missing_items)} item(s): "
f"{', '.join(missing_items[:5])}"
)
if problems:
raise EvalNotPassedError(msg or "; ".join(problems))
def assert_no_failed_items(self, msg: str | None = None) -> None:
"""Assert no item ended in ``fail`` or ``error`` status.
Includes any sub-results from workflow evaluations.
Args:
msg: Optional custom failure message.
Raises:
EvalNotPassedError: When any item failed or errored.
"""
bad: list[str] = []
def _check(results: EvalResults) -> None:
for item in results.items:
if item.is_failed or item.is_error:
bad.append(f"{item.item_id}:{item.status}")
for sub in results.sub_results.values():
_check(sub)
_check(self)
if bad:
detail = msg or (
f"{len(bad)} item(s) failed or errored: {', '.join(bad[:5])}"
+ (f" (+{len(bad) - 5} more)" if len(bad) > 5 else "")
)
raise EvalNotPassedError(detail)
# endregion
# region Generated rubric evaluators
@experimental(feature_id=ExperimentalFeature.EVALS)
@dataclass(frozen=True)
class RubricScore:
"""A single dimension's score from a rubric-based evaluator run.
Rubric evaluators emit one ``RubricScore`` per dimension per item.
Attached to :class:`EvalScoreResult` as a typed view of the raw
``properties.rubric_scores`` payload returned by providers such as
Foundry's generated rubric evaluators.
Attributes:
id: Dimension id (matches the rubric definition).
score: Numeric score, or ``None`` when the dimension was marked
non-applicable for this item.
applicable: Whether the dimension applied to this item.
weight: Dimension weight (mirrors the rubric definition).
reason: Short rationale produced by the evaluator.
"""
id: str
score: int | None
applicable: bool
weight: int
reason: str
# endregion
@@ -14,12 +14,13 @@ import logging
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any
from .._agents import Agent
from .._agents import Agent, SupportsAgentRun
from .._clients import SupportsWebSearchTool
from .._compaction import CompactionProvider, ContextWindowCompactionStrategy, ToolResultCompactionStrategy
from .._feature_stage import ExperimentalFeature, experimental
from .._sessions import ContextProvider, HistoryProvider, InMemoryHistoryProvider
from .._skills import SkillsProvider
from ._background_agents import BackgroundAgentsProvider
from ._memory import MemoryContextProvider, MemoryStore
from ._mode import AgentModeProvider
from ._todo import TodoProvider
@@ -103,6 +104,8 @@ def _assemble_context_providers(
memory_store: MemoryStore | None,
skills_provider: SkillsProvider | None,
skills_paths: Sequence[str] | None,
background_agents: Sequence[SupportsAgentRun] | None,
background_agents_instructions: str | None,
extra_context_providers: Sequence[ContextProvider] | None,
) -> list[ContextProvider]:
"""Assemble the ordered list of context providers."""
@@ -130,6 +133,10 @@ def _assemble_context_providers(
if skills_paths:
providers.append(SkillsProvider.from_paths(*skills_paths))
# Background agents are opt-in: only added when agents are provided.
if background_agents:
providers.append(BackgroundAgentsProvider(background_agents, instructions=background_agents_instructions))
# Append any user-supplied additional providers.
if extra_context_providers:
providers.extend(extra_context_providers)
@@ -165,6 +172,8 @@ def create_harness_agent(
memory_store: MemoryStore | None = None,
skills_provider: SkillsProvider | None = None,
skills_paths: Sequence[str] | None = None,
background_agents: Sequence[SupportsAgentRun] | None = None,
background_agents_instructions: str | None = None,
disable_web_search: bool = False,
otel_provider_name: str | None = None,
context_providers: Sequence[ContextProvider] | None = None,
@@ -182,6 +191,7 @@ def create_harness_agent(
- **AgentModeProvider** — plan/execute mode tracking
- **MemoryContextProvider** — file-based durable memory (when ``memory_store`` provided)
- **SkillsProvider** — skill discovery and progressive loading
- **BackgroundAgentsProvider** — delegate work to background sub-agents
- **OpenTelemetry** — observability via ``AgentTelemetryLayer``
Each feature can be disabled or customized via keyword arguments.
@@ -253,6 +263,13 @@ def create_harness_agent(
skills_paths: Paths for file-based skill discovery (looks for SKILL.md files).
Can be combined with ``skills_provider``. When neither ``skills_provider``
nor ``skills_paths`` is provided, no SkillsProvider is added.
background_agents: Collection of agents available for background task delegation.
When provided, a ``BackgroundAgentsProvider`` is automatically included,
enabling the agent to start, monitor, and retrieve results from background tasks.
Each agent must have a non-empty, unique name (case-insensitive).
background_agents_instructions: Optional instruction override for the
``BackgroundAgentsProvider``. May include ``{background_agents}`` placeholder
which will be replaced with the agent listing.
disable_web_search: When True, skip automatic web search tool inclusion.
When False (default), the web search tool is automatically added if the
client implements SupportsWebSearchTool. A warning is logged if the client
@@ -302,6 +319,8 @@ def create_harness_agent(
memory_store=memory_store,
skills_provider=skills_provider,
skills_paths=skills_paths,
background_agents=background_agents,
background_agents_instructions=background_agents_instructions,
extra_context_providers=context_providers,
)
+72 -20
View File
@@ -10,7 +10,7 @@ import logging
import re
import sys
from abc import abstractmethod
from collections.abc import Callable, Collection, Coroutine, Sequence
from collections.abc import Callable, Collection, Coroutine, Mapping, Sequence
from contextlib import AsyncExitStack, _AsyncGeneratorContextManager # type: ignore
from datetime import timedelta
from functools import partial
@@ -142,6 +142,13 @@ def _inject_otel_into_mcp_meta(meta: dict[str, Any] | None = None) -> dict[str,
return meta
def _url_origin(url: Any) -> tuple[str, str, int | None]:
port = url.port
if port is None:
port = 443 if url.scheme == "https" else 80 if url.scheme == "http" else None
return (url.scheme, url.host or "", port)
def streamable_http_client(*args: Any, **kwargs: Any) -> _AsyncGeneratorContextManager[Any, None]:
"""Lazily import the MCP streamable HTTP transport."""
try:
@@ -255,6 +262,7 @@ class MCPTool:
self._exit_stack = AsyncExitStack()
self._lifecycle_lock = asyncio.Lock()
self._lifecycle_request_lock = asyncio.Lock()
self._function_load_lock = asyncio.Lock()
self._lifecycle_queue: asyncio.Queue[tuple[str, bool, bool, asyncio.Future[None]]] | None = None
self._lifecycle_owner_task: asyncio.Task[None] | None = None
self.session = session
@@ -655,6 +663,11 @@ class MCPTool:
raise
except asyncio.CancelledError:
logger.warning("Could not cleanly close MCP exit stack because the lifecycle owner task was cancelled.")
except Exception as e:
if type(e).__name__ == "ExceptionGroup":
logger.warning("Could not cleanly close MCP exit stack due to cleanup error group. Error: %s", e)
else:
raise
async def _close_and_check_cancelled(self, ex: BaseException) -> bool:
"""Close the exit stack and return True if *ex* is a genuine task cancellation.
@@ -1018,6 +1031,10 @@ class MCPTool:
Raises:
ToolExecutionException: If the MCP server is not connected.
"""
async with self._function_load_lock:
await self._load_prompts_locked()
async def _load_prompts_locked(self) -> None:
from anyio import ClosedResourceError
from mcp import types
@@ -1100,6 +1117,10 @@ class MCPTool:
Raises:
ToolExecutionException: If the MCP server is not connected.
"""
async with self._function_load_lock:
await self._load_tools_locked()
async def _load_tools_locked(self) -> None:
from anyio import ClosedResourceError
from mcp import types
@@ -1109,7 +1130,7 @@ class MCPTool:
# Track existing function names to prevent duplicates
existing_names = {func.name for func in self._functions}
self._tool_call_meta_by_name.clear()
tool_call_meta_by_name: dict[str, dict[str, Any]] = {}
params: types.PaginatedRequestParams | None = None
while True:
@@ -1145,7 +1166,7 @@ class MCPTool:
for tool in tool_list.tools:
if tool.meta is not None:
self._tool_call_meta_by_name[tool.name] = dict(tool.meta)
tool_call_meta_by_name[tool.name] = dict(tool.meta)
normalized_name = _normalize_mcp_name(tool.name)
local_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix)
@@ -1194,6 +1215,8 @@ class MCPTool:
break
params = types.PaginatedRequestParams(cursor=tool_list.nextCursor)
self._tool_call_meta_by_name = tool_call_meta_by_name
async def _close_on_owner(self) -> None:
# Cancel any pending reload tasks before tearing down the session.
tasks = list(self._pending_reload_tasks)
@@ -1276,7 +1299,11 @@ class MCPTool:
tool_name: The name of the tool to call.
Keyword Args:
kwargs: Arguments to pass to the tool.
_meta: Optional ``dict[str, Any]`` of MCP request metadata. This reserved key is passed as the
``meta`` parameter of the underlying ``session.call_tool`` call rather than as a tool argument.
User-supplied keys override metadata from ``tools/list``; OpenTelemetry propagation fills in
non-conflicting keys.
kwargs: Remaining arguments to pass to the tool.
Returns:
A list of Content items representing the tool output. The default
@@ -1294,6 +1321,19 @@ class MCPTool:
raise ToolExecutionException(
"Tools are not loaded for this server, please set load_tools=True in the constructor."
)
raw_user_meta: object | None = kwargs.get("_meta")
user_meta: dict[str, Any] | None = None
if raw_user_meta is not None and not isinstance(raw_user_meta, dict):
raise ToolExecutionException("MCP tool metadata provided via _meta must be a dict.")
if isinstance(raw_user_meta, dict):
raw_user_meta_dict = cast(Mapping[object, object], raw_user_meta)
user_meta = {}
for key, value in raw_user_meta_dict.items():
if not isinstance(key, str):
raise ToolExecutionException("MCP tool metadata provided via _meta must use string keys.")
user_meta[key] = value
# Filter out framework kwargs that cannot be serialized by the MCP SDK.
# These are internal objects passed through the function invocation pipeline
# that should not be forwarded to external MCP servers.
@@ -1313,12 +1353,16 @@ class MCPTool:
"conversation_id",
"options",
"response_format",
"_meta",
}
}
# Some MCP proxies require their tools/list metadata to be echoed on tools/call.
tool_meta = self._tool_call_meta_by_name.get(tool_name)
meta = _inject_otel_into_mcp_meta(dict(tool_meta) if tool_meta is not None else None)
request_meta = dict(tool_meta) if tool_meta is not None else None
if user_meta is not None:
request_meta = {**(request_meta or {}), **user_meta}
meta = _inject_otel_into_mcp_meta(request_meta)
parser = self.parse_tool_results or self._parse_tool_result_from_mcp
# Try the operation, reconnecting once if the connection is closed
@@ -1336,28 +1380,33 @@ class MCPTool:
return parser(result)
except ToolExecutionException:
raise
except ClosedResourceError as cl_ex:
except (ClosedResourceError, McpError) as call_ex:
is_session_terminated = (
isinstance(call_ex, McpError) and "session terminated" in call_ex.error.message.lower()
)
is_connection_lost = isinstance(call_ex, ClosedResourceError) or is_session_terminated
if not is_connection_lost:
error_message = call_ex.error.message if isinstance(call_ex, McpError) else str(call_ex)
raise ToolExecutionException(error_message, inner_exception=call_ex) from call_ex
if attempt == 0:
# First attempt failed, try reconnecting
logger.info("MCP connection closed unexpectedly. Reconnecting...")
# First attempt failed, try reconnecting.
logger.info("MCP connection closed or terminated unexpectedly. Reconnecting...")
try:
await self.connect(reset=True)
continue # Retry the operation
continue
except Exception as reconn_ex:
raise ToolExecutionException(
"Failed to reconnect to MCP server.",
inner_exception=reconn_ex,
) from reconn_ex
else:
# Second attempt also failed, give up
logger.error(f"MCP connection closed unexpectedly after reconnection: {cl_ex}")
raise ToolExecutionException(
f"Failed to call tool '{tool_name}' - connection lost.",
inner_exception=cl_ex,
) from cl_ex
except McpError as mcp_exc:
error_message = mcp_exc.error.message
raise ToolExecutionException(error_message, inner_exception=mcp_exc) from mcp_exc
# Second attempt also failed, give up.
logger.error("MCP connection closed unexpectedly after reconnection: %s", call_ex)
raise ToolExecutionException(
f"Failed to call tool '{tool_name}' - connection lost.",
inner_exception=call_ex,
) from call_ex
except Exception as ex:
raise ToolExecutionException(f"Failed to call tool '{tool_name}'.", inner_exception=ex) from ex
raise ToolExecutionException(f"Failed to call tool '{tool_name}' after retries.")
@@ -1718,10 +1767,11 @@ class MCPStreamableHTTPTool(MCPTool):
Returns:
An async context manager for the streamable HTTP client transport.
"""
from httpx import AsyncClient, Request, Timeout
from httpx import URL, AsyncClient, Request, Timeout
http_client = self._httpx_client
if self._header_provider is not None:
target_origin = _url_origin(URL(self.url))
if http_client is None:
http_client = AsyncClient(
follow_redirects=True,
@@ -1732,6 +1782,8 @@ class MCPStreamableHTTPTool(MCPTool):
if not hasattr(self, "_inject_headers_hook"):
async def _inject_headers(request: Request) -> None: # noqa: RUF029
if _url_origin(request.url) != target_origin:
return
headers = _mcp_call_headers.get({})
for key, value in headers.items():
request.headers[key] = value
@@ -36,11 +36,10 @@ if TYPE_CHECKING:
from pydantic import BaseModel
from ._agents import SupportsAgentRun
from ._clients import SupportsChatGetResponse
from ._compaction import CompactionStrategy, TokenizerProtocol
from ._sessions import AgentSession
from ._tools import FunctionTool, ToolTypes
from ._types import ChatOptions, ChatResponse, ChatResponseUpdate
from ._types import ChatOptions
ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel)
@@ -7,6 +7,8 @@ import json
import logging
import re
from collections.abc import Mapping, MutableMapping
from dataclasses import asdict, is_dataclass
from datetime import date, datetime
from typing import Any, ClassVar, Protocol, TypeVar, runtime_checkable
logger = logging.getLogger("agent_framework")
@@ -614,3 +616,46 @@ class SerializationMixin:
# Fallback and default
# Convert class name to snake_case
return _CAMEL_TO_SNAKE_PATTERN.sub("_", cls.__name__).lower()
def make_json_safe(obj: Any) -> Any:
"""Recursively convert an object to a JSON-serializable form.
Handles dataclasses, Pydantic models, objects with ``to_dict``/``dict``/``__dict__``,
datetimes, lists, dicts, and primitives. Falls back to ``str()`` for any remaining
non-serializable value so that ``json.dumps`` never raises a ``TypeError``.
Args:
obj: Object to make JSON safe.
Returns:
A JSON-serializable version of the object.
"""
if obj is None or isinstance(obj, (str, int, float, bool)):
return obj
if isinstance(obj, (datetime, date)):
return obj.isoformat()
if is_dataclass(obj) and not isinstance(obj, type):
return make_json_safe(asdict(obj)) # type: ignore[arg-type]
if callable(getattr(obj, "model_dump", None)):
try:
return make_json_safe(obj.model_dump()) # type: ignore[no-any-return]
except TypeError:
pass
if callable(getattr(obj, "to_dict", None)):
try:
return make_json_safe(obj.to_dict()) # type: ignore[no-any-return]
except TypeError:
pass
if callable(getattr(obj, "dict", None)):
try:
return make_json_safe(obj.dict()) # type: ignore[no-any-return]
except TypeError:
pass
if isinstance(obj, dict):
return {str(key): make_json_safe(value) for key, value in obj.items()} # type: ignore[misc]
if isinstance(obj, (list, tuple)):
return [make_json_safe(item) for item in obj] # type: ignore[misc]
if hasattr(obj, "__dict__"):
return {key: make_json_safe(value) for key, value in vars(obj).items()} # type: ignore[misc]
return str(obj)
@@ -1973,11 +1973,97 @@ def _coalesce_text_content(contents: list[Content], type_str: Literal["text", "t
contents.extend(coalesced_contents)
def _content_items_text(items: Any) -> str | None:
"""Return concatenated text when a content item list only contains text."""
if not isinstance(items, list):
return None
text_parts: list[str] = []
content_items = cast(list[object], items)
for item in content_items:
if not isinstance(item, Content) or item.type != "text":
return None
text_parts.append(item.text or "")
return "".join(text_parts)
def _merge_content_item_lists(existing: Any, incoming: Any) -> Any:
"""Merge streamed nested content lists, replacing deltas with a later full value when present."""
if incoming is None:
return existing
if existing is None:
return deepcopy(incoming)
existing_text = _content_items_text(existing)
incoming_text = _content_items_text(incoming)
if existing_text is not None and incoming_text is not None:
if incoming_text.startswith(existing_text):
return deepcopy(incoming)
if existing_text.startswith(incoming_text):
return existing
existing_items = cast(list[Content], existing)
merged = deepcopy(existing_items[0])
merged.text = existing_text + incoming_text
return [merged]
if isinstance(existing, list) and isinstance(incoming, list):
existing_list = cast(list[object], existing)
incoming_list = cast(list[object], incoming)
return [*existing_list, *deepcopy(incoming_list)]
return deepcopy(incoming)
def _merge_code_interpreter_content(existing: Content, incoming: Content) -> None:
"""Merge two code interpreter content items for the same logical call."""
existing.inputs = _merge_content_item_lists(existing.inputs, incoming.inputs)
existing.outputs = _merge_content_item_lists(existing.outputs, incoming.outputs)
existing.annotations = _combine_annotations(existing.annotations, incoming.annotations)
existing.additional_properties = {**existing.additional_properties, **incoming.additional_properties}
existing.raw_representation = _combine_raw_representations(existing.raw_representation, incoming.raw_representation)
def _code_interpreter_key(content: Content) -> tuple[str, str] | None:
"""Return the aggregation key for code interpreter call/result content."""
if content.type not in {"code_interpreter_tool_call", "code_interpreter_tool_result"}:
return None
call_id = content.call_id or content.additional_properties.get("item_id")
if not isinstance(call_id, str) or not call_id:
return None
return content.type, call_id
def _coalesce_code_interpreter_content(contents: list[Content]) -> None:
"""Coalesce streaming code interpreter chunks by call id."""
if not contents:
return
coalesced_contents: list[Content] = []
seen: dict[tuple[str, str], Content] = {}
for content in contents:
key = _code_interpreter_key(content)
if key is None:
coalesced_contents.append(content)
continue
existing = seen.get(key)
if existing is None:
copied = deepcopy(content)
seen[key] = copied
coalesced_contents.append(copied)
continue
_merge_code_interpreter_content(existing, content)
contents.clear()
contents.extend(coalesced_contents)
def _finalize_response(response: ChatResponse | AgentResponse) -> None:
"""Finalizes the response by performing any necessary post-processing."""
for msg in response.messages:
_coalesce_text_content(msg.contents, "text")
_coalesce_text_content(msg.contents, "text_reasoning")
_coalesce_code_interpreter_content(msg.contents)
# region ContinuationToken
@@ -12,6 +12,7 @@ from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast, overload
from .._agents import BaseAgent
from .._serialization import make_json_safe
from .._sessions import (
AgentSession,
ContextProvider,
@@ -62,7 +63,7 @@ class WorkflowAgent(BaseAgent):
data: Any
def to_dict(self) -> dict[str, Any]:
return {"request_id": self.request_id, "data": self.data}
return {"request_id": self.request_id, "data": make_json_safe(self.data)}
def to_json(self) -> str:
return json.dumps(self.to_dict())
@@ -47,6 +47,7 @@ from copy import deepcopy
from typing import Any, Generic, Literal, TypeVar, overload
from .._feature_stage import ExperimentalFeature, experimental
from .._serialization import make_json_safe
from .._types import AgentResponse, AgentResponseUpdate, ResponseStream
from ..observability import OtelAttr, capture_exception, create_workflow_span
from ._checkpoint import CheckpointStorage, WorkflowCheckpoint
@@ -1515,7 +1516,7 @@ class FunctionalWorkflowAgent:
function_call = Content.from_function_call(
call_id=request_id,
name=self.REQUEST_INFO_FUNCTION_NAME,
arguments={"request_id": request_id, "data": event.data},
arguments={"request_id": request_id, "data": make_json_safe(event.data)},
)
return Content.from_function_approval_request(
id=request_id,
@@ -34,6 +34,7 @@ _IMPORTS: dict[str, tuple[str, str]] = {
"FoundryLocalChatOptions": ("agent_framework_foundry_local", "agent-framework-foundry-local"),
"FoundryLocalClient": ("agent_framework_foundry_local", "agent-framework-foundry-local"),
"FoundryLocalSettings": ("agent_framework_foundry_local", "agent-framework-foundry-local"),
"GeneratedEvaluatorRef": ("agent_framework_foundry", "agent-framework-foundry"),
"RawAnthropicFoundryClient": ("agent_framework_anthropic", "agent-framework-anthropic"),
"RawFoundryAgent": ("agent_framework_foundry", "agent-framework-foundry"),
"RawFoundryAgentChatClient": ("agent_framework_foundry", "agent-framework-foundry"),
@@ -20,6 +20,7 @@ from agent_framework_foundry import (
FoundryEmbeddingSettings,
FoundryEvals,
FoundryMemoryProvider,
GeneratedEvaluatorRef,
RawFoundryAgent,
RawFoundryAgentChatClient,
RawFoundryChatClient,
@@ -52,6 +53,7 @@ __all__ = [
"FoundryLocalClient",
"FoundryLocalSettings",
"FoundryMemoryProvider",
"GeneratedEvaluatorRef",
"RawAnthropicFoundryClient",
"RawFoundryAgent",
"RawFoundryAgentChatClient",
@@ -394,3 +394,94 @@ def test_create_harness_agent_logs_warning_when_no_web_search(caplog: pytest.Log
max_output_tokens=16_384,
)
assert any("SupportsWebSearchTool" in msg for msg in caplog.messages)
# --- Background Agents Tests ---
class _FakeBackgroundAgent:
"""Minimal agent stub satisfying SupportsAgentRun for background agents tests."""
def __init__(self, name: str, description: str | None = None):
self.id = f"agent-{name}"
self.name = name
self.description = description
def create_session(self, *, session_id: str | None = None) -> AgentSession:
return AgentSession(session_id=session_id)
def get_session(self, service_session_id: str, *, session_id: str | None = None) -> AgentSession:
return AgentSession(service_session_id=service_session_id, session_id=session_id)
async def run(self, messages: Any = None, *, stream: bool = False, session: Any = None, **kwargs: Any) -> Any:
from agent_framework import AgentResponse
return AgentResponse(messages=[], response_id="fake-bg-response")
def test_create_harness_agent_no_background_agents_by_default() -> None:
"""No BackgroundAgentsProvider should be included when background_agents is not provided."""
from agent_framework._harness._background_agents import BackgroundAgentsProvider
agent = create_harness_agent(
client=_FakeChatClient(), # type: ignore[arg-type]
max_context_window_tokens=128_000,
max_output_tokens=16_384,
disable_web_search=True,
)
providers = agent.context_providers or []
assert not any(isinstance(p, BackgroundAgentsProvider) for p in providers)
def test_create_harness_agent_adds_background_agents_provider() -> None:
"""BackgroundAgentsProvider should be included when background_agents are provided."""
from agent_framework._harness._background_agents import BackgroundAgentsProvider
bg_agent = _FakeBackgroundAgent("WebSearcher", "Searches the web")
agent = create_harness_agent(
client=_FakeChatClient(), # type: ignore[arg-type]
max_context_window_tokens=128_000,
max_output_tokens=16_384,
disable_web_search=True,
background_agents=[bg_agent],
)
providers = agent.context_providers or []
bg_providers = [p for p in providers if isinstance(p, BackgroundAgentsProvider)]
assert len(bg_providers) == 1
def test_create_harness_agent_background_agents_custom_instructions() -> None:
"""Custom instructions should be passed to BackgroundAgentsProvider."""
from agent_framework._harness._background_agents import BackgroundAgentsProvider
custom_instructions = "## Custom\n\nUse agents wisely.\n\n{background_agents}"
bg_agent = _FakeBackgroundAgent("Helper", "A helper agent")
agent = create_harness_agent(
client=_FakeChatClient(), # type: ignore[arg-type]
max_context_window_tokens=128_000,
max_output_tokens=16_384,
disable_web_search=True,
background_agents=[bg_agent],
background_agents_instructions=custom_instructions,
)
providers = agent.context_providers or []
bg_providers = [p for p in providers if isinstance(p, BackgroundAgentsProvider)]
assert len(bg_providers) == 1
# Verify the custom instructions were used (placeholder replaced with agent list).
assert "Custom" in bg_providers[0]._instructions
assert "Helper" in bg_providers[0]._instructions
def test_create_harness_agent_empty_background_agents_list() -> None:
"""An empty background_agents list should NOT add a BackgroundAgentsProvider."""
from agent_framework._harness._background_agents import BackgroundAgentsProvider
agent = create_harness_agent(
client=_FakeChatClient(), # type: ignore[arg-type]
max_context_window_tokens=128_000,
max_output_tokens=16_384,
disable_web_search=True,
background_agents=[],
)
providers = agent.context_providers or []
assert not any(isinstance(p, BackgroundAgentsProvider) for p in providers)
@@ -11,8 +11,13 @@ import pytest
from agent_framework._evaluation import (
CheckResult,
EvalItem,
EvalItemResult,
EvalNotPassedError,
EvalResults,
EvalScoreResult,
ExpectedToolCall,
LocalEvaluator,
RubricScore,
_coerce_result,
evaluator,
keyword_check,
@@ -1010,19 +1015,101 @@ class TestAllPassedSubResults:
# ---------------------------------------------------------------------------
# r5 review: _build_overall_item with empty outputs
# Rubric assertions (EvalResults.assert_*)
# ---------------------------------------------------------------------------
class TestBuildOverallItemEmpty:
"""Test _build_overall_item returns None for empty workflow outputs."""
def _rubric_results(*scores_per_item: list[EvalScoreResult]) -> EvalResults:
items = [
EvalItemResult(item_id=f"item-{i}", status="pass", scores=scores) for i, scores in enumerate(scores_per_item)
]
return EvalResults(
provider="test",
eval_id="ev1",
run_id="run1",
result_counts={"passed": len(items), "failed": 0, "errored": 0, "total": len(items)},
items=items,
)
def test_returns_none_for_empty_outputs(self):
from unittest.mock import MagicMock
from agent_framework._evaluation import _build_overall_item
class TestRubricAssertions:
"""Tests for EvalResults.assert_dimension_score_at_least."""
mock_result = MagicMock()
mock_result.get_outputs.return_value = []
item = _build_overall_item("Hello", mock_result)
assert item is None
def test_dimension_at_or_above_threshold_passes(self) -> None:
results = _rubric_results(
[
EvalScoreResult(
name="policy",
score=0.9,
dimensions=[RubricScore(id="clarity", score=4, applicable=True, weight=1, reason="")],
)
],
)
# Should not raise.
results.assert_dimension_score_at_least("clarity", 3)
def test_dimension_below_threshold_raises(self) -> None:
results = _rubric_results(
[
EvalScoreResult(
name="policy",
score=0.5,
dimensions=[RubricScore(id="clarity", score=2, applicable=True, weight=1, reason="")],
)
],
)
with pytest.raises(EvalNotPassedError):
results.assert_dimension_score_at_least("clarity", 3)
def test_non_applicable_skipped_by_default(self) -> None:
results = _rubric_results(
[
EvalScoreResult(
name="policy",
score=1.0,
dimensions=[RubricScore(id="clarity", score=None, applicable=False, weight=1, reason="n/a")],
)
],
)
# No applicable scores; default behaviour is to skip silently.
results.assert_dimension_score_at_least("clarity", 3)
def test_require_applicable_raises_when_dimension_absent(self) -> None:
results = _rubric_results(
[EvalScoreResult(name="policy", score=1.0, dimensions=[])],
)
with pytest.raises(EvalNotPassedError, match="not applicable"):
results.assert_dimension_score_at_least("clarity", 3, require_applicable=True)
def test_require_applicable_raises_when_filtered_evaluator_missing(self) -> None:
# Regression: previously the (not evaluator or found_any) guard caused
# this case to silently pass even with require_applicable=True.
results = _rubric_results(
[
EvalScoreResult(
name="other",
score=0.9,
dimensions=[RubricScore(id="clarity", score=4, applicable=True, weight=1, reason="")],
)
],
)
with pytest.raises(EvalNotPassedError, match="not applicable"):
results.assert_dimension_score_at_least("clarity", 3, evaluator="policy", require_applicable=True)
def test_evaluator_filter_isolates_offenders(self) -> None:
results = _rubric_results(
[
EvalScoreResult(
name="other",
score=0.1,
dimensions=[RubricScore(id="clarity", score=1, applicable=True, weight=1, reason="")],
),
EvalScoreResult(
name="policy",
score=0.9,
dimensions=[RubricScore(id="clarity", score=4, applicable=True, weight=1, reason="")],
),
],
)
# The low-scoring "other" evaluator is filtered out; "policy" passes.
results.assert_dimension_score_at_least("clarity", 3, evaluator="policy")
+206
View File
@@ -1161,6 +1161,43 @@ async def test_local_mcp_server_function_execution_error():
await func.invoke(param="test_value")
async def test_mcp_tool_reconnects_after_session_terminated_error():
"""Session termination errors should reconnect once and retry the tool call."""
class TestServer(MCPTool):
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.connect_count = 0
self.sessions: list[Any] = []
async def connect(self, *, reset: bool = False) -> None:
self.connect_count += 1
self.session = Mock(spec=ClientSession)
self.sessions.append(self.session)
if self.connect_count == 1:
self.session.call_tool = AsyncMock(
side_effect=McpError(types.ErrorData(code=-32000, message="Session terminated"))
)
else:
self.session.call_tool = AsyncMock(
return_value=types.CallToolResult(content=[types.TextContent(type="text", text="recovered")])
)
self.is_connected = True
def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
return None
server = TestServer(name="test_server")
await server.connect()
result = await server.call_tool("test_tool", param="test_value")
assert _mcp_result_to_text(result) == "recovered"
assert server.connect_count == 2
assert server.sessions[0].call_tool.await_count == 1
assert server.sessions[1].call_tool.await_count == 1
async def test_mcp_tool_call_tool_raises_on_is_error():
"""Test that call_tool raises ToolExecutionException when MCP returns isError=True."""
@@ -3260,6 +3297,68 @@ async def test_load_prompts_pagination_with_duplicates():
assert [f.name for f in tool._functions] == ["prompt_1", "prompt_2"]
async def test_load_tools_concurrent_reload_does_not_duplicate_tools_and_preserves_meta():
"""Concurrent tool reloads should not duplicate functions or lose tools/list metadata."""
tool = MCPTool(name="test_tool")
mock_session = AsyncMock()
tool.session = mock_session
tool.load_tools_flag = True
page = Mock()
page.tools = [
types.Tool(
name="tool_1",
description="First tool",
inputSchema={"type": "object", "properties": {"param": {"type": "string"}}},
_meta={"echo": "tool_1"},
),
]
page.nextCursor = None
async def mock_list_tools(params: Any = None) -> Any:
assert params is None
await asyncio.sleep(0)
return page
mock_session.list_tools = AsyncMock(side_effect=mock_list_tools)
await asyncio.wait_for(asyncio.gather(tool.load_tools(), tool.load_tools()), timeout=1)
assert mock_session.list_tools.call_count == 2
assert [f.name for f in tool._functions] == ["tool_1"]
assert tool._tool_call_meta_by_name == {"tool_1": {"echo": "tool_1"}}
async def test_load_prompts_concurrent_reload_does_not_duplicate_prompts():
"""Concurrent prompt reloads should not duplicate functions."""
tool = MCPTool(name="test_tool")
mock_session = AsyncMock()
tool.session = mock_session
tool.load_prompts_flag = True
page = Mock()
page.prompts = [
types.Prompt(
name="prompt_1",
description="First prompt",
arguments=[types.PromptArgument(name="arg1", description="Arg 1", required=True)],
),
]
page.nextCursor = None
async def mock_list_prompts(params: Any = None) -> Any:
assert params is None
await asyncio.sleep(0)
return page
mock_session.list_prompts = AsyncMock(side_effect=mock_list_prompts)
await asyncio.wait_for(asyncio.gather(tool.load_prompts(), tool.load_prompts()), timeout=1)
assert mock_session.list_prompts.call_count == 2
assert [f.name for f in tool._functions] == ["prompt_1"]
async def test_load_tools_pagination_exception_handling():
"""Test that load_tools handles exceptions during pagination gracefully."""
from unittest.mock import AsyncMock
@@ -3891,6 +3990,31 @@ async def test_mcp_tool_safe_close_handles_cancelled_error():
mock_exit_stack.aclose.assert_called_once()
async def test_mcp_tool_safe_close_handles_cleanup_exception_group():
"""Cleanup task groups should not hide the original connect failure."""
import builtins
from contextlib import AsyncExitStack
exception_group_type = getattr(builtins, "ExceptionGroup", None)
if exception_group_type is None:
pytest.skip("ExceptionGroup is not available on this Python version")
tool = MCPStreamableHTTPTool(
name="test",
url="http://example.com/mcp",
load_tools=False,
load_prompts=False,
)
mock_exit_stack = AsyncMock(spec=AsyncExitStack)
mock_exit_stack.aclose = AsyncMock(side_effect=exception_group_type("cleanup failed", [RuntimeError("reader")]))
tool._exit_stack = mock_exit_stack
await tool._safe_close_exit_stack()
mock_exit_stack.aclose.assert_called_once()
async def test_connect_sets_logging_level_when_logger_level_is_set():
"""Test that connect() sets the MCP server logging level when the logger level is not NOTSET."""
@@ -4389,6 +4513,52 @@ async def test_mcp_tool_call_tool_forwards_tool_list_meta():
assert server.session.call_tool.call_args.kwargs["meta"] == tool_meta
async def test_mcp_tool_call_tool_user_meta_merges_with_tool_list_meta():
"""User-provided _meta should be sent as MCP request metadata, not tool arguments."""
from opentelemetry import trace
tool_meta = {"from_tool": "tool-value", "shared": "tool-value"}
user_meta = {"from_user": "user-value", "shared": "user-value"}
class TestServer(MCPTool):
async def connect(self) -> None:
self.session = Mock(spec=ClientSession)
self.session.list_tools = AsyncMock(
return_value=types.ListToolsResult(
tools=[
types.Tool(
name="test_tool",
description="Test tool",
inputSchema={"type": "object", "properties": {"param": {"type": "string"}}},
_meta=tool_meta,
)
]
)
)
self.session.call_tool = AsyncMock(
return_value=types.CallToolResult(content=[types.TextContent(type="text", text="result")])
)
def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
return None
server = TestServer(name="test_server")
async with server:
await server.load_tools()
with trace.use_span(trace.NonRecordingSpan(trace.INVALID_SPAN_CONTEXT)):
await server.call_tool("test_tool", param="test_value", _meta=user_meta)
call_kwargs = server.session.call_tool.call_args.kwargs
assert call_kwargs["arguments"] == {"param": "test_value"}
assert call_kwargs["meta"] == {
"from_tool": "tool-value",
"from_user": "user-value",
"shared": "user-value",
}
assert user_meta == {"from_user": "user-value", "shared": "user-value"}
async def test_mcp_streamable_http_tool_hook_not_duplicated_on_repeated_get_mcp_client():
"""Test that calling get_mcp_client multiple times does not accumulate duplicate hooks."""
tool = MCPStreamableHTTPTool(
@@ -4641,6 +4811,42 @@ async def test_mcp_streamable_http_tool_header_provider_with_httpx_event_hook():
await tool._httpx_client.aclose()
async def test_mcp_streamable_http_tool_header_provider_skips_cross_origin_redirect():
"""The request hook must not re-add caller headers after a cross-origin redirect."""
import httpx
from agent_framework._mcp import _mcp_call_headers
tool = MCPStreamableHTTPTool(
name="test",
url="http://example.com/mcp",
header_provider=lambda kw: {"Authorization": f"Bearer {kw.get('token', '')}"},
)
try:
with patch("agent_framework._mcp.streamable_http_client"):
tool.get_mcp_client()
assert tool._httpx_client is not None
hooks = tool._httpx_client.event_hooks.get("request", [])
assert len(hooks) == 1
token = _mcp_call_headers.set({"Authorization": "Bearer secret"})
try:
same_origin = httpx.Request("POST", "http://example.com/redirected")
await hooks[0](same_origin)
assert same_origin.headers.get("Authorization") == "Bearer secret"
cross_origin = httpx.Request("POST", "http://attacker.example/capture")
await hooks[0](cross_origin)
assert "Authorization" not in cross_origin.headers
finally:
_mcp_call_headers.reset(token)
finally:
if getattr(tool, "_httpx_client", None) is not None:
await tool._httpx_client.aclose()
async def test_mcp_streamable_http_tool_header_provider_with_user_httpx_client():
"""Test that header_provider works when the user provides their own httpx client."""
import httpx
@@ -1691,6 +1691,65 @@ def test_to_otel_part_function_call():
}
def test_to_otel_part_function_call_reuses_prepared_arguments():
"""Test _to_otel_part does not re-serialize function-call arguments in the observability hot path."""
from agent_framework import Content
from agent_framework.observability import _to_otel_part
arguments = {"payload": object()}
content = Content(type="function_call", call_id="call_789", name="handoff", arguments=arguments)
result = _to_otel_part(content)
assert result is not None
assert result["arguments"] is arguments
def test_make_json_safe_non_callable_method_attribute():
"""Test make_json_safe handles objects where model_dump/to_dict/dict are non-callable attributes."""
from agent_framework._serialization import make_json_safe
class ObjWithNonCallableModelDump:
model_dump = 42 # not callable
obj = ObjWithNonCallableModelDump()
result = make_json_safe(obj)
assert result == {}
def test_make_json_safe_callable_method_type_error_falls_through():
"""Test make_json_safe falls through when serializer-like methods require arguments."""
from agent_framework._serialization import make_json_safe
class ObjWithRequiredArgModelDump:
def __init__(self) -> None:
self.value = "fallback"
def model_dump(self, required: str) -> dict[str, str]:
return {"required": required}
obj = ObjWithRequiredArgModelDump()
result = make_json_safe(obj)
assert result == {"value": "fallback"}
def test_make_json_safe_dict_with_non_string_keys():
"""Test make_json_safe converts non-primitive dict keys to strings."""
import json
from datetime import datetime
from agent_framework._serialization import make_json_safe
dt_key = datetime(2024, 1, 1)
obj = {dt_key: "value", 42: "num_value", "str_key": "normal"}
result = make_json_safe(obj)
# json.dumps must not raise TypeError
serialized = json.dumps(result)
parsed = json.loads(serialized)
assert parsed[str(dt_key)] == "value"
assert parsed["42"] == "num_value"
assert parsed["str_key"] == "normal"
def test_to_otel_part_function_result():
"""Test _to_otel_part with function_result content."""
from agent_framework import Content
@@ -3019,6 +3078,49 @@ async def test_system_instructions_preserves_non_ascii_characters(span_exporter:
assert [msg.get("role") for msg in input_messages] == ["user"]
@pytest.mark.parametrize("enable_sensitive_data", [True], indirect=True)
def test_capture_messages_with_prepared_request_info_function_call_arguments(span_exporter: InMemorySpanExporter):
"""Test _capture_messages handles request-info function-call arguments prepared at Content creation."""
import dataclasses
import json
from opentelemetry import trace
from agent_framework import WorkflowAgent
@dataclasses.dataclass
class HandoffRequest:
target_agent: str
reason: str
arguments = WorkflowAgent.RequestInfoFunctionArgs(
request_id="call_dc",
data=HandoffRequest(target_agent="helper", reason="overflow"),
).to_dict()
msg = Message(
role="assistant",
contents=[
Content(
type="function_call",
call_id="call_dc",
name="request_info",
arguments=arguments,
)
],
)
span_exporter.clear()
tracer = trace.get_tracer("test")
with tracer.start_as_current_span("test_span") as span:
_capture_messages(span=span, provider_name="test_provider", messages=[msg])
spans = span_exporter.get_finished_spans()
span = spans[0]
input_messages = json.loads(span.attributes[OtelAttr.INPUT_MESSAGES])
tool_part = input_messages[0]["parts"][0]
assert tool_part["type"] == "tool_call"
assert tool_part["arguments"]["data"] == {"target_agent": "helper", "reason": "overflow"}
def test_capture_messages_keeps_framework_instructions_out_of_logs_and_span_messages(
span_exporter: InMemorySpanExporter,
):
@@ -307,6 +307,63 @@ class TestHistoryProviderBase:
assert provider.stored[0].text == "hello"
assert provider.stored[1].text == "hi"
async def test_after_run_stores_coalesced_code_interpreter_chunks(self) -> None:
from agent_framework import AgentResponse, AgentResponseUpdate, Content
provider = ConcreteHistoryProvider("mem", store_inputs=False)
updates = [
AgentResponseUpdate(
role="assistant",
contents=[
Content.from_code_interpreter_tool_result(
call_id="ci_123",
outputs=[],
)
],
),
AgentResponseUpdate(
contents=[
Content.from_code_interpreter_tool_call(
call_id="ci_123",
inputs=[Content.from_text(text="import")],
additional_properties={"sequence_number": 1},
)
],
),
AgentResponseUpdate(
contents=[
Content.from_code_interpreter_tool_call(
call_id="ci_123",
inputs=[Content.from_text(text=" pandas")],
additional_properties={"sequence_number": 2},
)
],
),
AgentResponseUpdate(
contents=[
Content.from_code_interpreter_tool_call(
call_id="ci_123",
inputs=[Content.from_text(text="import pandas as pd")],
additional_properties={"sequence_number": 3},
)
],
),
]
ctx = SessionContext(session_id="s1", input_messages=[Message(role="user", contents=["make a sheet"])])
ctx._response = AgentResponse.from_updates(updates)
await provider.after_run(agent=None, session=AgentSession(), context=ctx, state={}) # type: ignore[arg-type]
assert len(provider.stored) == 1
stored_contents = provider.stored[0].contents
calls = [content for content in stored_contents if content.type == "code_interpreter_tool_call"]
results = [content for content in stored_contents if content.type == "code_interpreter_tool_result"]
assert len(calls) == 1
assert len(results) == 1
assert calls[0].inputs is not None
assert len(calls[0].inputs) == 1
assert calls[0].inputs[0].text == "import pandas as pd"
async def test_after_run_skips_inputs_when_disabled(self) -> None:
from agent_framework import AgentResponse
@@ -5,6 +5,7 @@
from __future__ import annotations
import asyncio
import json
import logging
from collections.abc import Iterator
from contextlib import contextmanager
@@ -1642,6 +1643,37 @@ class TestFunctionalWorkflowAgentHITL:
break
assert approval_found, "expected FunctionApprovalRequestContent in agent response"
async def test_request_info_dataclass_arguments_are_serialized_for_agent(self):
@dataclass
class HandoffRequest:
target_agent: str
reason: str
@workflow
async def wf(x: str, ctx: RunContext) -> str:
answer = await ctx.request_info(
HandoffRequest(target_agent=x, reason="overflow"),
response_type=str,
request_id="rid-1",
)
return f"got:{answer}"
agent = wf.as_agent()
response = await agent.run("helper")
function_call_arguments = None
for message in response.messages:
for content in message.contents:
if getattr(content, "type", None) == "function_approval_request" and content.function_call is not None:
function_call_arguments = content.function_call.arguments
break
assert function_call_arguments == {
"request_id": "rid-1",
"data": {"target_agent": "helper", "reason": "overflow"},
}
assert json.loads(json.dumps(function_call_arguments)) == function_call_arguments
async def test_resume_via_agent_responses_kwarg(self):
@workflow
async def wf(x: str, ctx: RunContext) -> str:
@@ -1,7 +1,9 @@
# Copyright (c) Microsoft. All rights reserved.
import json
import uuid
from collections.abc import Awaitable, Sequence
from dataclasses import dataclass
from typing import Any, Literal, overload
import pytest
@@ -23,6 +25,7 @@ from agent_framework import (
WorkflowAgent,
WorkflowBuilder,
WorkflowContext,
WorkflowEvent,
executor,
handler,
response_handler,
@@ -292,6 +295,33 @@ class TestWorkflowAgent:
pending_requests = await workflow._runner_context.get_pending_request_info_events()
assert len(pending_requests) == 0
def test_request_info_dataclass_arguments_are_serialized_when_content_is_created(self) -> None:
"""Test WorkflowAgent prepares request_info arguments before observability captures messages."""
@dataclass
class HandoffRequest:
target_agent: str
reason: str
executor = SimpleExecutor(id="executor1", response_text="Response")
workflow = WorkflowBuilder(start_executor=executor).build()
agent = WorkflowAgent(workflow=workflow, name="Request Test Agent")
event = WorkflowEvent.request_info(
request_id="request_123",
source_executor_id="executor1",
request_data=HandoffRequest(target_agent="helper", reason="overflow"),
response_type=str,
)
function_call, approval_request = agent._process_request_info_event(event) # pyright: ignore[reportPrivateUsage]
assert function_call.arguments == {
"request_id": "request_123",
"data": {"target_agent": "helper", "reason": "overflow"},
}
assert approval_request.function_call is function_call
assert json.loads(json.dumps(function_call.arguments)) == function_call.arguments
def test_workflow_as_agent_method(self) -> None:
"""Test that Workflow.as_agent() creates a properly configured WorkflowAgent."""
# Create a simple workflow
+1 -1
View File
@@ -29,7 +29,7 @@ dependencies = [
]
[dependency-groups]
dev = [
"types-PyYaml==6.0.12.20250915"
"types-PyYaml==6.0.12.20260518"
]
[tool.uv]
@@ -375,13 +375,15 @@ class DevServer:
logger.info("Starting Agent Framework Server")
await self._ensure_executor()
await self._ensure_openai_executor() # Initialize OpenAI executor
yield
# Shutdown
logger.info("Shutting down Agent Framework Server")
try:
yield
finally:
# Shutdown
logger.info("Shutting down Agent Framework Server")
# Cleanup entity resources (e.g., close credentials, clients)
if self.executor:
await self._cleanup_entities()
# Cleanup entity resources (e.g., close credentials, clients)
if self.executor:
await self._cleanup_entities()
app = FastAPI(
title="Agent Framework Server",
+1 -1
View File
@@ -30,7 +30,7 @@ dependencies = [
[dependency-groups]
dev = [
"types-python-dateutil==2.9.0.20260402",
"types-python-dateutil==2.9.0.20260518",
]
[tool.uv]
@@ -12,6 +12,7 @@ from ._embedding_client import (
)
from ._foundry_evals import (
FoundryEvals,
GeneratedEvaluatorRef,
evaluate_foundry_target,
evaluate_traces,
)
@@ -33,6 +34,7 @@ __all__ = [
"FoundryEmbeddingSettings",
"FoundryEvals",
"FoundryMemoryProvider",
"GeneratedEvaluatorRef",
"RawFoundryAgent",
"RawFoundryAgentChatClient",
"RawFoundryChatClient",
@@ -57,8 +57,6 @@ if TYPE_CHECKING:
from agent_framework import (
Agent,
AgentRunInputs,
ChatAndFunctionMiddlewareTypes,
ContextProvider,
MiddlewareTypes,
ToolTypes,
)
@@ -28,8 +28,9 @@ from __future__ import annotations
import asyncio
import logging
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any
from collections.abc import Iterable, Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, cast
from agent_framework._evaluation import (
AgentEvalConverter,
@@ -39,6 +40,7 @@ from agent_framework._evaluation import (
EvalItemResult,
EvalResults,
EvalScoreResult,
RubricScore,
)
from agent_framework._feature_stage import ExperimentalFeature, experimental
from openai import AsyncOpenAI
@@ -51,6 +53,54 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# region Generated rubric evaluator references
@experimental(feature_id=ExperimentalFeature.EVALS)
@dataclass(frozen=True)
class GeneratedEvaluatorRef:
"""A reference to a rubric evaluator that already exists in Foundry.
Pass instances of this class to :class:`FoundryEvals` to score items
with a pre-existing rubric evaluator (manually authored or
auto-generated through the Foundry portal). agent-framework is a
consumer here: it does not create or modify the evaluator definition;
it only references the persisted version by name.
Pinning ``version`` is strongly recommended so evaluation runs are
reproducible. ``version=None`` resolves to whichever version is
current at execution time; :class:`FoundryEvals` emits a warning when
a versionless reference is used. CI gates should always pass a
concrete version.
Attributes:
name: Evaluator name as stored in the Foundry project (for
example ``"reservation-policy-rubric"``). Distinct from
built-in evaluators such as ``"builtin.relevance"``.
version: Pinned evaluator version. ``None`` means "latest"
this is discouraged for CI/repro and :class:`FoundryEvals`
will emit a warning when used.
display_name: Optional human-readable name used in result
summaries. Defaults to ``name`` when unset.
"""
name: str
version: str | None = None
display_name: str | None = None
@classmethod
def latest(cls, name: str, *, display_name: str | None = None) -> GeneratedEvaluatorRef:
"""Construct a versionless reference (resolves to the latest version at run time).
Discouraged for reproducible runs. Prefer the constructor with
an explicit ``version`` so CI and replay evaluations stay stable
when the evaluator is updated in Foundry.
"""
return cls(name=name, version=None, display_name=display_name)
# endregion
# Agent evaluators that accept query/response as conversation arrays.
# Maintained manually — check https://learn.microsoft.com/en-us/azure/ai-studio/how-to/develop/evaluate-sdk
# for the latest evaluator list. These are the evaluators that need conversation-format input.
@@ -166,7 +216,7 @@ def _resolve_evaluator(name: str) -> str:
def _build_testing_criteria(
evaluators: Sequence[str],
evaluators: Sequence[str | GeneratedEvaluatorRef],
model: str,
*,
include_data_mapping: bool = False,
@@ -175,7 +225,9 @@ def _build_testing_criteria(
"""Build ``testing_criteria`` for ``evals.create()``.
Args:
evaluators: Evaluator names.
evaluators: Evaluator names (built-in shorts / fully-qualified
``builtin.*`` names) or :class:`GeneratedEvaluatorRef`
instances for generated rubric evaluators.
model: Model deployment for the LLM judge.
include_data_mapping: Whether to include field-level data mapping
(required for the JSONL data source, not needed for response-based).
@@ -183,7 +235,38 @@ def _build_testing_criteria(
definitions.
"""
criteria: list[dict[str, Any]] = []
for name in evaluators:
for entry_spec in evaluators:
if isinstance(entry_spec, GeneratedEvaluatorRef):
short = entry_spec.display_name or entry_spec.name
ref_entry: dict[str, Any] = {
"type": "azure_ai_evaluator",
"name": short,
"evaluator_name": entry_spec.name,
"initialization_parameters": {"deployment_name": model},
}
if entry_spec.version is not None:
ref_entry["evaluator_version"] = entry_spec.version
else:
logger.warning(
"GeneratedEvaluatorRef '%s' has no pinned version; the eval run "
"will resolve to whichever version is current at execution time. "
"Pin the version for reproducible runs.",
entry_spec.name,
)
if include_data_mapping:
# Rubric evaluators accept conversation arrays like agent
# evaluators, plus tool_definitions when items are tool-aware.
ref_mapping: dict[str, str] = {
"query": "{{item.query_messages}}",
"response": "{{item.response_messages}}",
}
if include_tool_definitions:
ref_mapping["tool_definitions"] = "{{item.tool_definitions}}"
ref_entry["data_mapping"] = ref_mapping
criteria.append(ref_entry)
continue
name = entry_spec
qualified = _resolve_evaluator(name)
short = name if not name.startswith("builtin.") else name.split(".")[-1]
@@ -247,9 +330,9 @@ def _build_item_schema(
def _resolve_default_evaluators(
evaluators: Sequence[str] | None,
evaluators: Sequence[str | GeneratedEvaluatorRef] | None,
items: Sequence[EvalItem | dict[str, Any]] | None = None,
) -> list[str]:
) -> list[str | GeneratedEvaluatorRef]:
"""Resolve evaluators, applying defaults when ``None``.
Defaults to relevance + coherence + task_adherence. Automatically adds
@@ -258,7 +341,7 @@ def _resolve_default_evaluators(
if evaluators is not None:
return list(evaluators)
result = list(_DEFAULT_EVALUATORS)
result: list[str | GeneratedEvaluatorRef] = list(_DEFAULT_EVALUATORS)
if items is not None:
has_tools = any((item.tools if isinstance(item, EvalItem) else item.get("tool_definitions")) for item in items)
if has_tools:
@@ -267,14 +350,24 @@ def _resolve_default_evaluators(
def _filter_tool_evaluators(
evaluators: list[str],
evaluators: list[str | GeneratedEvaluatorRef],
items: Sequence[EvalItem | dict[str, Any]],
) -> list[str]:
"""Remove tool evaluators if no items have tool definitions."""
) -> list[str | GeneratedEvaluatorRef]:
"""Remove tool evaluators if no items have tool definitions.
Generated rubric evaluators are tool-aware but not tool-required; they
are preserved regardless of whether items carry tool definitions.
"""
has_tools = any((item.tools if isinstance(item, EvalItem) else item.get("tool_definitions")) for item in items)
if has_tools:
return evaluators
filtered = [e for e in evaluators if _resolve_evaluator(e) not in _TOOL_EVALUATORS]
def _is_tool_only(spec: str | GeneratedEvaluatorRef) -> bool:
if isinstance(spec, GeneratedEvaluatorRef):
return False
return _resolve_evaluator(spec) in _TOOL_EVALUATORS
filtered = [e for e in evaluators if not _is_tool_only(e)]
if not filtered:
raise ValueError(
f"All requested evaluators {evaluators} require tool definitions, "
@@ -282,7 +375,7 @@ def _filter_tool_evaluators(
"or choose evaluators that do not require tools."
)
if len(filtered) < len(evaluators):
removed = [e for e in evaluators if _resolve_evaluator(e) in _TOOL_EVALUATORS]
removed = [e for e in evaluators if _is_tool_only(e)]
logger.info("Removed tool evaluators %s (no items have tools)", removed)
return filtered
@@ -354,6 +447,114 @@ def _extract_per_evaluator(run: RunRetrieveResponse) -> dict[str, dict[str, int]
return per_eval
_RUBRIC_DIMENSION_KEYS: tuple[str, ...] = ("dimension_scores", "rubric_scores")
"""Property keys that may carry per-dimension rubric breakdowns.
The published Foundry rubric-evaluator output format uses
``properties.dimension_scores`` (see the Microsoft Learn "Rubric
evaluators" reference). Earlier preview builds and some SDK shapes
used ``rubric_scores``; we accept both for defensive forward/backward
compatibility.
"""
def _parse_dimension_entries(raw: Any) -> list[RubricScore]:
"""Parse a raw list-like payload into ``RubricScore`` instances.
Returns an empty list when ``raw`` is falsy, not iterable, or
contains no well-formed entries.
"""
if not raw:
return []
try:
raw_iter: Iterable[Any] = iter(raw)
except TypeError:
return []
parsed: list[RubricScore] = []
for raw_entry in raw_iter:
entry: Any = raw_entry
try:
rid: Any
score_val: Any
applicable: Any
weight: Any
reason: Any
if isinstance(entry, dict):
entry_any = cast("dict[str, Any]", entry)
rid = entry_any.get("id")
score_val = entry_any.get("score")
applicable = entry_any.get("applicable")
weight = entry_any.get("weight")
reason = entry_any.get("reason", "")
else:
rid = getattr(entry, "id", None)
score_val = getattr(entry, "score", None)
applicable = getattr(entry, "applicable", None)
weight = getattr(entry, "weight", None)
reason = getattr(entry, "reason", "") or ""
if rid is None or weight is None or applicable is None:
continue
parsed.append(
RubricScore(
id=str(rid),
score=int(score_val) if isinstance(score_val, (int, float)) else None,
applicable=bool(applicable),
weight=int(weight),
reason=str(reason) if reason is not None else "",
)
)
except (TypeError, ValueError):
logger.debug("Skipping malformed rubric dimension entry: %s", cast("Any", entry), exc_info=True)
return parsed
def _extract_rubric_scores(sample: Any) -> list[RubricScore] | None:
"""Extract typed ``RubricScore`` instances from an evaluator's raw sample payload.
Foundry rubric evaluators include a per-dimension breakdown under
``properties.dimension_scores`` on each result (preview builds used
``rubric_scores``; both keys are accepted, with the canonical
``dimension_scores`` taking priority). The exact location may
vary across SDK versions, so this helper accepts a few shapes:
* The SDK ``sample`` object exposes
``properties.dimension_scores`` / ``properties.rubric_scores``.
* The ``sample`` is a dict containing the same under
``properties.<key>``.
* The ``sample`` is a dict with ``dimension_scores`` /
``rubric_scores`` at the top level.
Returns ``None`` when no rubric scores are present (i.e. the
evaluator was not a rubric evaluator).
"""
if sample is None:
return None
containers: list[Any] = []
properties: Any = getattr(sample, "properties", None)
if properties is not None:
containers.append(properties)
if isinstance(sample, dict):
sample_any = cast("dict[str, Any]", sample)
props_dict: Any = sample_any.get("properties")
if props_dict is not None and props_dict is not properties:
containers.append(props_dict)
containers.append(sample_any)
for container in containers:
for key in _RUBRIC_DIMENSION_KEYS:
raw: Any = None
if isinstance(container, dict):
raw = cast("dict[str, Any]", container).get(key)
elif hasattr(container, key):
raw = getattr(container, key, None)
parsed = _parse_dimension_entries(raw)
if parsed:
return parsed
return None
async def _fetch_output_items(
client: AsyncOpenAI,
eval_id: str,
@@ -377,12 +578,15 @@ async def _fetch_output_items(
# Extract per-evaluator scores
scores: list[EvalScoreResult] = []
for r in oi.results or []:
sample = r.sample
dimensions = _extract_rubric_scores(sample)
scores.append(
EvalScoreResult(
name=r.name,
score=r.score,
passed=r.passed,
sample=r.sample,
sample=sample,
dimensions=dimensions,
)
)
@@ -394,15 +598,18 @@ async def _fetch_output_items(
output_text: str | None = None
response_id: str | None = None
sample = oi.sample
if sample is not None: # pyright: ignore[reportUnnecessaryComparison]
err = sample.error
if err is not None and (err.code or err.message): # pyright: ignore[reportUnnecessaryComparison]
# mypy infers oi.sample as dict[str, object] | None, but the
# OpenAI SDK actually returns a typed Sample model. Cast to Any so
# both type checkers accept the attribute access pattern.
oi_sample: Any = oi.sample
if oi_sample is not None:
err = oi_sample.error
if err is not None and (err.code or err.message):
error_code = err.code or None
error_message = err.message or None
usage = sample.usage
if usage is not None and usage.total_tokens: # pyright: ignore[reportUnnecessaryComparison]
usage = oi_sample.usage
if usage is not None and usage.total_tokens:
token_usage = {
"prompt_tokens": usage.prompt_tokens,
"completion_tokens": usage.completion_tokens,
@@ -411,13 +618,13 @@ async def _fetch_output_items(
}
# Extract input/output text
if sample.input:
parts = [si.content for si in sample.input if si.role == "user"]
if oi_sample.input:
parts = [si.content for si in oi_sample.input if si.role == "user"]
if parts:
input_text = " ".join(parts)
if sample.output:
parts = [so.content or "" for so in sample.output if so.role == "assistant"]
if oi_sample.output:
parts = [so.content or "" for so in oi_sample.output if so.role == "assistant"]
if parts:
output_text = " ".join(parts)
@@ -472,7 +679,7 @@ async def _evaluate_via_responses_impl(
*,
client: AsyncOpenAI,
response_ids: Sequence[str],
evaluators: list[str],
evaluators: list[str | GeneratedEvaluatorRef],
model: str,
eval_name: str,
poll_interval: float,
@@ -573,8 +780,11 @@ class FoundryEvals:
(from ``azure.ai.projects.aio``). Provide this or *client*.
model: Model deployment name for the evaluator LLM judge.
Resolved from ``client.model`` when omitted.
evaluators: Evaluator names (e.g. ``["relevance", "tool_call_accuracy"]``).
When ``None`` (default), uses smart defaults based on item data.
evaluators: Evaluator specifications. Entries may be built-in
short names (e.g. ``"relevance"``), fully-qualified
``"builtin.*"`` names, or :class:`GeneratedEvaluatorRef`
instances for previously generated rubric evaluators. When
``None`` (default), uses smart defaults based on item data.
conversation_split: How to split multi-turn conversations into
query/response halves. Defaults to ``LAST_TURN``. Pass a
``ConversationSplit`` enum value or a custom callable see
@@ -623,7 +833,7 @@ class FoundryEvals:
client: FoundryChatClient | None = None,
project_client: AIProjectClient | None = None,
model: str | None = None,
evaluators: Sequence[str] | None = None,
evaluators: Sequence[str | GeneratedEvaluatorRef] | None = None,
conversation_split: ConversationSplitter = ConversationSplit.LAST_TURN,
poll_interval: float = 5.0,
timeout: float = 180.0,
@@ -642,7 +852,9 @@ class FoundryEvals:
"Model is required. Pass model= explicitly or use a FoundryChatClient that has a model configured."
)
self._model = resolved_model
self._evaluators = list(evaluators) if evaluators is not None else None
self._evaluators: list[str | GeneratedEvaluatorRef] | None = (
list(evaluators) if evaluators is not None else None
)
self._conversation_split = conversation_split
self._poll_interval = poll_interval
self._timeout = timeout
@@ -678,7 +890,7 @@ class FoundryEvals:
async def _evaluate_via_dataset(
self,
items: Sequence[EvalItem],
evaluators: list[str],
evaluators: list[str | GeneratedEvaluatorRef],
eval_name: str,
) -> EvalResults:
"""Evaluate using JSONL dataset upload path."""
@@ -25,16 +25,25 @@ from agent_framework._evaluation import (
from agent_framework._workflows._workflow import WorkflowRunResult
from openai import AsyncOpenAI
from agent_framework_foundry import GeneratedEvaluatorRef
from agent_framework_foundry._foundry_evals import (
_AGENT_EVALUATORS,
_BUILTIN_EVALUATORS,
_TOOL_EVALUATORS,
FoundryEvals,
_build_item_schema,
_build_testing_criteria,
_extract_per_evaluator,
_extract_result_counts,
_extract_rubric_scores,
_fetch_output_items,
_filter_tool_evaluators,
_poll_eval_run,
_resolve_default_evaluators,
_resolve_evaluator,
_resolve_openai_client,
evaluate_foundry_target,
evaluate_traces,
)
@@ -806,6 +815,67 @@ class TestBuildTestingCriteria:
for c in criteria:
assert "tool_definitions" in c["data_mapping"], f"{c['name']} missing tool_definitions"
def test_generated_evaluator_ref_pinned_version(self) -> None:
ref = GeneratedEvaluatorRef(name="my-rubric", version="1")
criteria = _build_testing_criteria([ref], "gpt-4o", include_data_mapping=True)
assert len(criteria) == 1
c = criteria[0]
assert c["type"] == "azure_ai_evaluator"
assert c["evaluator_name"] == "my-rubric"
assert c["evaluator_version"] == "1"
assert c["name"] == "my-rubric"
assert c["initialization_parameters"] == {"deployment_name": "gpt-4o"}
assert c["data_mapping"] == {
"query": "{{item.query_messages}}",
"response": "{{item.response_messages}}",
}
def test_generated_evaluator_ref_display_name_used_as_short(self) -> None:
ref = GeneratedEvaluatorRef(name="my-rubric", version="2", display_name="My Rubric")
criteria = _build_testing_criteria([ref], "gpt-4o")
assert criteria[0]["name"] == "My Rubric"
assert criteria[0]["evaluator_name"] == "my-rubric"
def test_generated_evaluator_ref_tool_definitions_added(self) -> None:
ref = GeneratedEvaluatorRef(name="my-rubric", version="1")
criteria = _build_testing_criteria(
[ref],
"gpt-4o",
include_data_mapping=True,
include_tool_definitions=True,
)
assert criteria[0]["data_mapping"]["tool_definitions"] == "{{item.tool_definitions}}"
def test_generated_evaluator_ref_unpinned_warns(self, caplog: pytest.LogCaptureFixture) -> None:
import logging
ref = GeneratedEvaluatorRef.latest("my-rubric")
with caplog.at_level(logging.WARNING, logger="agent_framework_foundry._foundry_evals"):
criteria = _build_testing_criteria([ref], "gpt-4o")
assert "evaluator_version" not in criteria[0]
assert any("no pinned version" in r.message for r in caplog.records)
def test_generated_evaluator_ref_mixed_with_builtins(self) -> None:
ref = GeneratedEvaluatorRef(name="my-rubric", version="1")
criteria = _build_testing_criteria(
["relevance", ref, "task_adherence"],
"gpt-4o",
include_data_mapping=True,
)
assert [c["name"] for c in criteria] == ["relevance", "my-rubric", "task_adherence"]
assert criteria[0]["evaluator_name"] == "builtin.relevance"
assert criteria[1]["evaluator_name"] == "my-rubric"
assert criteria[2]["evaluator_name"] == "builtin.task_adherence"
# ---------------------------------------------------------------------------
# _build_item_schema
@@ -1263,6 +1333,29 @@ class TestFilterToolEvaluators:
items,
)
def test_preserves_generated_ref_when_no_tools(self) -> None:
ref = GeneratedEvaluatorRef(name="rubric", version="1")
items = [
EvalItem(conversation=[Message("user", ["q"]), Message("assistant", ["r"])]),
]
result = _filter_tool_evaluators(
["relevance", ref, "tool_call_accuracy"],
items,
)
assert "relevance" in result
assert ref in result
assert "tool_call_accuracy" not in result
def test_generated_ref_alone_does_not_raise(self) -> None:
ref = GeneratedEvaluatorRef(name="rubric", version="1")
items = [
EvalItem(conversation=[Message("user", ["q"]), Message("assistant", ["r"])]),
]
result = _filter_tool_evaluators([ref], items)
assert result == [ref]
# ---------------------------------------------------------------------------
# EvalResults
@@ -2267,7 +2360,6 @@ class TestEvalResultsWithItems:
class TestFetchOutputItems:
async def test_fetches_and_converts_output_items(self) -> None:
from agent_framework_foundry._foundry_evals import _fetch_output_items
# Build mock output items matching the OpenAI SDK schema
mock_result = MagicMock()
@@ -2329,7 +2421,6 @@ class TestFetchOutputItems:
assert item.error_code is None
async def test_handles_errored_item(self) -> None:
from agent_framework_foundry._foundry_evals import _fetch_output_items
mock_error = MagicMock()
mock_error.code = "QueryExtractionError"
@@ -2361,7 +2452,6 @@ class TestFetchOutputItems:
assert len(item.scores) == 0
async def test_handles_api_failure_gracefully(self) -> None:
from agent_framework_foundry._foundry_evals import _fetch_output_items
mock_client = MagicMock()
mock_client.evals.runs.output_items.list = AsyncMock(side_effect=TypeError("API error"))
@@ -2369,6 +2459,166 @@ class TestFetchOutputItems:
items = await _fetch_output_items(mock_client, "eval_1", "run_1")
assert items == []
async def test_extracts_rubric_scores_from_dict_sample(self) -> None:
mock_result = MagicMock()
mock_result.name = "my-rubric"
mock_result.score = 0.85
mock_result.passed = True
mock_result.sample = {
"properties": {
"rubric_scores": [
{"id": "policy", "score": 4, "applicable": True, "weight": 1, "reason": "ok"},
{"id": "safety", "score": None, "applicable": False, "weight": 1, "reason": "n/a"},
]
}
}
mock_oi = MagicMock()
mock_oi.id = "oi_1"
mock_oi.status = "pass"
mock_oi.results = [mock_result]
mock_oi.sample = None
mock_oi.datasource_item = {}
mock_client = MagicMock()
mock_client.evals.runs.output_items.list = AsyncMock(return_value=_AsyncPage([mock_oi]))
items = await _fetch_output_items(mock_client, "eval_1", "run_1")
assert len(items) == 1
scores = items[0].scores
assert len(scores) == 1
assert scores[0].dimensions is not None
assert len(scores[0].dimensions) == 2
policy = next(d for d in scores[0].dimensions if d.id == "policy")
assert policy.score == 4
assert policy.applicable is True
assert policy.weight == 1
assert policy.reason == "ok"
safety = next(d for d in scores[0].dimensions if d.id == "safety")
assert safety.score is None
assert safety.applicable is False
async def test_no_rubric_scores_when_absent(self) -> None:
mock_result = MagicMock()
mock_result.name = "relevance"
mock_result.score = 0.85
mock_result.passed = True
mock_result.sample = None
mock_oi = MagicMock()
mock_oi.id = "oi_2"
mock_oi.status = "pass"
mock_oi.results = [mock_result]
mock_oi.sample = None
mock_oi.datasource_item = {}
mock_client = MagicMock()
mock_client.evals.runs.output_items.list = AsyncMock(return_value=_AsyncPage([mock_oi]))
items = await _fetch_output_items(mock_client, "eval_1", "run_1")
assert items[0].scores[0].dimensions is None
class TestExtractRubricScores:
def test_handles_attribute_style_properties(self) -> None:
rs = MagicMock()
rs.id = "policy"
rs.score = 5
rs.applicable = True
rs.weight = 2
rs.reason = "ok"
sample = MagicMock()
sample.properties = MagicMock()
sample.properties.rubric_scores = [rs]
result = _extract_rubric_scores(sample)
assert result is not None
assert result[0].id == "policy"
assert result[0].score == 5
assert result[0].weight == 2
def test_top_level_rubric_scores_in_dict(self) -> None:
sample = {"rubric_scores": [{"id": "a", "score": 3, "applicable": True, "weight": 1, "reason": "r"}]}
result = _extract_rubric_scores(sample)
assert result is not None
assert result[0].id == "a"
def test_returns_none_when_missing(self) -> None:
assert _extract_rubric_scores(None) is None
assert _extract_rubric_scores({}) is None
assert _extract_rubric_scores({"properties": {}}) is None
def test_skips_malformed_entries(self) -> None:
sample = {
"properties": {
"rubric_scores": [
{"id": "good", "score": 3, "applicable": True, "weight": 1, "reason": "ok"},
{"id": "bad-no-weight", "score": 2, "applicable": True, "reason": "x"},
]
}
}
result = _extract_rubric_scores(sample)
assert result is not None
assert len(result) == 1
assert result[0].id == "good"
def test_canonical_dimension_scores_key_from_docs(self) -> None:
"""Per the Microsoft Learn docs, runtime output uses ``properties.dimension_scores``."""
sample = {
"properties": {
"dimension_scores": [
{
"id": "intent_recognition",
"score": 5,
"applicable": True,
"weight": 9,
"reason": "Identified correctly.",
},
{
"id": "general_quality",
"score": 4,
"applicable": True,
"weight": 5,
"reason": "Strong overall.",
},
]
}
}
result = _extract_rubric_scores(sample)
assert result is not None
assert [r.id for r in result] == ["intent_recognition", "general_quality"]
assert [r.score for r in result] == [5, 4]
assert [r.weight for r in result] == [9, 5]
def test_dimension_scores_via_attribute(self) -> None:
"""Canonical key also resolves when properties exposes ``dimension_scores`` as an attr."""
rs = MagicMock()
rs.id = "policy_enforcement"
rs.score = 1
rs.applicable = True
rs.weight = 5
rs.reason = "violated"
sample = MagicMock()
sample.properties = MagicMock(spec=["dimension_scores"])
sample.properties.dimension_scores = [rs]
result = _extract_rubric_scores(sample)
assert result is not None
assert result[0].id == "policy_enforcement"
assert result[0].score == 1
# ---------------------------------------------------------------------------
# _poll_eval_run — timeout / failed / canceled paths
@@ -2378,7 +2628,6 @@ class TestFetchOutputItems:
class TestPollEvalRun:
async def test_timeout_returns_timeout_status(self) -> None:
"""Poll timeout returns EvalResults with status='timeout'."""
from agent_framework_foundry._foundry_evals import _poll_eval_run
mock_client = MagicMock()
mock_pending = MagicMock()
@@ -2392,7 +2641,6 @@ class TestPollEvalRun:
async def test_failed_run_returns_error(self) -> None:
"""Failed run returns EvalResults with error message."""
from agent_framework_foundry._foundry_evals import _poll_eval_run
mock_client = MagicMock()
mock_failed = MagicMock()
@@ -2410,7 +2658,6 @@ class TestPollEvalRun:
async def test_canceled_run_returns_canceled_status(self) -> None:
"""Canceled run returns EvalResults with status='canceled'."""
from agent_framework_foundry._foundry_evals import _poll_eval_run
mock_client = MagicMock()
mock_canceled = MagicMock()
@@ -2435,7 +2682,6 @@ class TestPollEvalRun:
class TestEvaluateTraces:
async def test_raises_without_required_args(self) -> None:
"""Raises ValueError when no response_ids, trace_ids, or agent_id given."""
from agent_framework_foundry._foundry_evals import evaluate_traces
mock_client = MagicMock()
with pytest.raises(ValueError, match="Provide at least one of"):
@@ -2446,7 +2692,6 @@ class TestEvaluateTraces:
async def test_response_ids_path(self) -> None:
"""evaluate_traces with response_ids uses the responses API path."""
from agent_framework_foundry._foundry_evals import evaluate_traces
mock_client = MagicMock()
@@ -2494,7 +2739,6 @@ class TestEvaluateTraces:
async def test_trace_ids_path(self) -> None:
"""evaluate_traces with trace_ids builds azure_ai_traces data source."""
from agent_framework_foundry._foundry_evals import evaluate_traces
mock_client = MagicMock()
@@ -2534,7 +2778,6 @@ class TestEvaluateTraces:
class TestEvaluateFoundryTarget:
async def test_happy_path(self) -> None:
"""evaluate_foundry_target creates eval + run and polls to completion."""
from agent_framework_foundry._foundry_evals import evaluate_foundry_target
mock_client = MagicMock()
@@ -2670,13 +2913,11 @@ class TestEvaluatorSetConsistency:
"""Verify that _AGENT_EVALUATORS and _TOOL_EVALUATORS are subsets of _BUILTIN_EVALUATORS."""
def test_agent_evaluators_subset(self):
from agent_framework_foundry._foundry_evals import _AGENT_EVALUATORS, _BUILTIN_EVALUATORS
diff = _AGENT_EVALUATORS - set(_BUILTIN_EVALUATORS.values())
assert not diff, f"_AGENT_EVALUATORS has names not in _BUILTIN_EVALUATORS: {diff}"
def test_tool_evaluators_subset(self):
from agent_framework_foundry._foundry_evals import _BUILTIN_EVALUATORS, _TOOL_EVALUATORS
diff = _TOOL_EVALUATORS - set(_BUILTIN_EVALUATORS.values())
assert not diff, f"_TOOL_EVALUATORS has names not in _BUILTIN_EVALUATORS: {diff}"
@@ -2690,7 +2931,6 @@ class TestEvaluatorSetConsistency:
class TestEvaluateTracesAgentId:
async def test_agent_id_only_path(self) -> None:
"""evaluate_traces with agent_id only builds azure_ai_traces data source."""
from agent_framework_foundry._foundry_evals import evaluate_traces
mock_client = MagicMock()
@@ -2748,7 +2988,6 @@ class TestFilterToolEvaluatorsRaises:
class TestEvaluateFoundryTargetValidation:
async def test_target_without_type_raises(self) -> None:
"""target dict without 'type' key raises ValueError."""
from agent_framework_foundry._foundry_evals import evaluate_foundry_target
mock_client = MagicMock()
with pytest.raises(ValueError, match="'type' key"):
+4 -4
View File
@@ -57,19 +57,19 @@ math = [
[dependency-groups]
dev = [
"uv==0.11.6",
"ruff==0.15.8",
"uv==0.11.17",
"ruff==0.15.15",
"pytest==9.0.3",
"mypy==1.20.0",
"pyright==1.1.408",
#tasks
"poethepoet==0.42.1",
"poethepoet==0.46.0",
"rich>=13.7.1,<15.0.0",
"tomli==2.4.1",
"tomli-w==1.2.0",
# tau2 from source (not available on PyPI)
"tau2@ git+https://github.com/sierra-research/tau2-bench@5ba9e3e56db57c5e4114bf7f901291f09b2c5619",
"prek==0.3.9",
"prek==0.4.3",
]
[project.scripts]
+26
View File
@@ -0,0 +1,26 @@
# Mistral Package (agent-framework-mistral)
Integration with Mistral AI for embedding generation.
## Main Classes
- **`MistralEmbeddingClient`** - Embedding client for Mistral AI models
- **`MistralEmbeddingOptions`** - Options TypedDict for Mistral-specific embedding parameters
- **`MistralEmbeddingSettings`** - TypedDict settings for Mistral configuration
## Usage
```python
from agent_framework_mistral import MistralEmbeddingClient
# Requires MISTRAL_API_KEY environment variable (or pass api_key= directly)
client = MistralEmbeddingClient(model="mistral-embed")
result = await client.get_embeddings(["Hello, world!"])
print(result[0].vector)
```
## Import Path
```python
from agent_framework_mistral import MistralEmbeddingClient
```
+21
View File
@@ -0,0 +1,21 @@
MIT License
Copyright (c) Microsoft Corporation.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE
+42
View File
@@ -0,0 +1,42 @@
# Get Started with Microsoft Agent Framework Mistral AI
Please install this package:
```bash
pip install agent-framework-mistral --pre
```
and see the [README](https://github.com/microsoft/agent-framework/tree/main/python/README.md) for more information.
## Embedding Client
The `MistralEmbeddingClient` provides embedding generation using Mistral AI models.
### Quick Start
```python
from agent_framework_mistral import MistralEmbeddingClient
# Using environment variables (MISTRAL_API_KEY, MISTRAL_EMBEDDING_MODEL)
client = MistralEmbeddingClient()
# Or passing parameters directly
client = MistralEmbeddingClient(
model="mistral-embed",
api_key="your-api-key",
)
# Generate embeddings
result = await client.get_embeddings(["Hello, world!", "How are you?"])
for embedding in result:
print(f"Dimensions: {embedding.dimensions}")
print(f"Vector: {embedding.vector[:5]}...")
```
### Configuration
| Environment Variable | Description |
|---|---|
| `MISTRAL_API_KEY` | Your Mistral AI API key |
| `MISTRAL_EMBEDDING_MODEL` | Embedding model name (e.g., `mistral-embed`) |
| `MISTRAL_SERVER_URL` | Optional server URL override |
@@ -0,0 +1,17 @@
# Copyright (c) Microsoft. All rights reserved.
import importlib.metadata
from ._embedding_client import MistralEmbeddingClient, MistralEmbeddingOptions, MistralEmbeddingSettings
try:
__version__ = importlib.metadata.version(__name__)
except importlib.metadata.PackageNotFoundError:
__version__ = "0.0.0" # Fallback for development mode
__all__ = [
"MistralEmbeddingClient",
"MistralEmbeddingOptions",
"MistralEmbeddingSettings",
"__version__",
]
@@ -0,0 +1,250 @@
# Copyright (c) Microsoft. All rights reserved.
from __future__ import annotations
import logging
import sys
from collections.abc import Sequence
from typing import Any, ClassVar, Generic, TypedDict
from agent_framework import (
BaseEmbeddingClient,
Embedding,
EmbeddingGenerationOptions,
GeneratedEmbeddings,
UsageDetails,
load_settings,
)
from agent_framework._settings import SecretString
from agent_framework.observability import EmbeddingTelemetryLayer
from mistralai.client import Mistral
if sys.version_info >= (3, 13):
from typing import TypeVar # type: ignore # pragma: no cover
else:
from typing_extensions import TypeVar # type: ignore # pragma: no cover
logger = logging.getLogger("agent_framework.mistral")
class MistralEmbeddingOptions(EmbeddingGenerationOptions, total=False):
"""Mistral AI-specific embedding options.
Extends EmbeddingGenerationOptions with Mistral-specific fields.
Examples:
.. code-block:: python
from agent_framework_mistral import MistralEmbeddingOptions
options: MistralEmbeddingOptions = {
"model": "mistral-embed",
"dimensions": 1024,
}
"""
MistralEmbeddingOptionsT = TypeVar(
"MistralEmbeddingOptionsT",
bound=TypedDict, # type: ignore[valid-type]
default="MistralEmbeddingOptions",
covariant=True,
)
class MistralEmbeddingSettings(TypedDict, total=False):
"""Mistral AI embedding settings.
Fields:
api_key: Mistral API key. Resolved from ``MISTRAL_API_KEY``.
embedding_model: Embedding model name. Resolved from ``MISTRAL_EMBEDDING_MODEL``.
server_url: Optional server URL override. Resolved from ``MISTRAL_SERVER_URL``.
"""
api_key: str | None
embedding_model: str | None
server_url: str | None
class RawMistralEmbeddingClient(
BaseEmbeddingClient[str, list[float], MistralEmbeddingOptionsT],
Generic[MistralEmbeddingOptionsT],
):
"""Raw Mistral AI embedding client without telemetry.
Keyword Args:
model: The Mistral embedding model (e.g. "mistral-embed").
Can also be set via environment variable ``MISTRAL_EMBEDDING_MODEL``.
api_key: Mistral API key. Defaults to ``MISTRAL_API_KEY`` environment variable.
server_url: Optional server URL override. Defaults to ``MISTRAL_SERVER_URL``
environment variable, or the Mistral default.
client: Optional pre-configured ``Mistral`` client instance.
additional_properties: Additional properties stored on the client instance.
env_file_path: Path to ``.env`` file for settings.
env_file_encoding: Encoding for ``.env`` file.
"""
INJECTABLE: ClassVar[set[str]] = {"client"}
def __init__(
self,
*,
model: str | None = None,
api_key: str | SecretString | None = None,
server_url: str | None = None,
client: Mistral | None = None,
additional_properties: dict[str, Any] | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
) -> None:
"""Initialize a raw Mistral AI embedding client."""
mistral_settings = load_settings(
MistralEmbeddingSettings,
env_prefix="MISTRAL_",
required_fields=["embedding_model", "api_key"],
api_key=str(api_key) if isinstance(api_key, SecretString) else api_key,
embedding_model=model,
server_url=server_url,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
self.model: str = mistral_settings["embedding_model"] # type: ignore[assignment]
resolved_api_key: str = mistral_settings["api_key"] # type: ignore[assignment]
resolved_server_url = mistral_settings.get("server_url")
if client is not None:
self.client = client
else:
client_kwargs: dict[str, Any] = {"api_key": resolved_api_key}
if resolved_server_url:
client_kwargs["server_url"] = resolved_server_url
self.client = Mistral(**client_kwargs)
self.server_url = resolved_server_url
super().__init__(additional_properties=additional_properties)
def service_url(self) -> str:
"""Get the URL of the service."""
return self.server_url or "https://api.mistral.ai"
async def get_embeddings(
self,
values: Sequence[str],
*,
options: MistralEmbeddingOptionsT | None = None,
) -> GeneratedEmbeddings[list[float], MistralEmbeddingOptionsT]:
"""Call the Mistral AI embeddings API.
Args:
values: The text values to generate embeddings for.
options: Optional embedding generation options.
Returns:
Generated embeddings with usage metadata.
Raises:
ValueError: If model is not provided or values is empty.
"""
if not values:
return GeneratedEmbeddings([], options=options)
opts: dict[str, Any] = options or {} # type: ignore
model = opts.get("model") or self.model
if not model:
raise ValueError("model is required")
kwargs: dict[str, Any] = {"model": model, "inputs": list(values)}
if "dimensions" in opts:
kwargs["output_dimension"] = opts["dimensions"]
response = await self.client.embeddings.create_async(**kwargs)
embeddings: list[Embedding[list[float]]] = []
if response and response.data:
items = sorted(response.data, key=lambda d: d.index if d.index is not None else 0)
for item in items:
vector = list(item.embedding) if item.embedding else []
embeddings.append(
Embedding(
vector=vector,
dimensions=len(vector),
model=response.model or model,
)
)
usage_dict: UsageDetails | None = None
if response and response.usage:
usage_dict = {
"input_token_count": response.usage.prompt_tokens,
"total_token_count": response.usage.total_tokens,
}
return GeneratedEmbeddings(embeddings, options=options, usage=usage_dict)
class MistralEmbeddingClient(
EmbeddingTelemetryLayer[str, list[float], MistralEmbeddingOptionsT],
RawMistralEmbeddingClient[MistralEmbeddingOptionsT],
Generic[MistralEmbeddingOptionsT],
):
"""Mistral AI embedding client with telemetry support.
Keyword Args:
model: The Mistral embedding model (e.g. "mistral-embed").
Can also be set via environment variable ``MISTRAL_EMBEDDING_MODEL``.
api_key: Mistral API key. Defaults to ``MISTRAL_API_KEY`` environment variable.
server_url: Optional server URL override. Defaults to ``MISTRAL_SERVER_URL``
environment variable, or the Mistral default.
client: Optional pre-configured ``Mistral`` client instance.
otel_provider_name: Optional telemetry provider name override.
env_file_path: Path to ``.env`` file for settings.
env_file_encoding: Encoding for ``.env`` file.
Examples:
.. code-block:: python
from agent_framework_mistral import MistralEmbeddingClient
# Using environment variables
# Set MISTRAL_API_KEY=your-key
# Set MISTRAL_EMBEDDING_MODEL=mistral-embed
client = MistralEmbeddingClient()
# Or passing parameters directly
client = MistralEmbeddingClient(
model="mistral-embed",
api_key="your-api-key",
)
# Generate embeddings
result = await client.get_embeddings(["Hello, world!"])
print(result[0].vector)
"""
OTEL_PROVIDER_NAME: ClassVar[str] = "mistralai"
def __init__(
self,
*,
model: str | None = None,
api_key: str | SecretString | None = None,
server_url: str | None = None,
client: Mistral | None = None,
otel_provider_name: str | None = None,
additional_properties: dict[str, Any] | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
) -> None:
"""Initialize a Mistral AI embedding client."""
super().__init__(
model=model,
api_key=api_key,
server_url=server_url,
client=client,
additional_properties=additional_properties,
otel_provider_name=otel_provider_name,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
@@ -0,0 +1 @@
+105
View File
@@ -0,0 +1,105 @@
[project]
name = "agent-framework-mistral"
description = "Mistral AI integration for Microsoft Agent Framework."
authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}]
readme = "README.md"
requires-python = ">=3.10"
version = "1.0.0a260505"
license-files = ["LICENSE"]
urls.homepage = "https://learn.microsoft.com/en-us/agent-framework/"
urls.source = "https://github.com/microsoft/agent-framework/tree/main/python"
urls.release_notes = "https://github.com/microsoft/agent-framework/releases?q=tag%3Apython-1&expanded=true"
urls.issues = "https://github.com/microsoft/agent-framework/issues"
classifiers = [
"License :: OSI Approved :: MIT License",
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Framework :: Pydantic :: 2",
"Typing :: Typed",
]
dependencies = [
"agent-framework-core>=1.1.0,<2",
"mistralai>=2.0.0,<3",
]
[tool.uv]
prerelease = "if-necessary-or-explicit"
environments = [
"sys_platform == 'darwin'",
"sys_platform == 'linux'",
"sys_platform == 'win32'"
]
[tool.uv-dynamic-versioning]
fallback-version = "0.0.0"
[tool.pytest.ini_options]
testpaths = 'tests'
addopts = "-ra -q -r fEX"
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"
filterwarnings = []
markers = [
"integration: marks tests as integration tests that require external services",
]
timeout = 120
[tool.ruff]
line-length = 120
[tool.ruff.lint]
select = ["E", "F", "I", "N", "W"]
[tool.coverage.run]
omit = [
"**/__init__.py"
]
[tool.pyright]
extends = "../../pyproject.toml"
include = ["agent_framework_mistral"]
exclude = ['tests']
[tool.mypy]
plugins = ['pydantic.mypy']
strict = true
python_version = "3.10"
ignore_missing_imports = true
disallow_untyped_defs = true
no_implicit_optional = true
check_untyped_defs = true
warn_return_any = true
show_error_codes = true
warn_unused_ignores = false
disallow_incomplete_defs = true
disallow_untyped_decorators = true
disallow_any_unimported = true
[tool.bandit]
targets = ["agent_framework_mistral"]
exclude_dirs = ["tests"]
[tool.poe]
executor.type = "uv"
include = "../../shared_tasks.toml"
[tool.poe.tasks.mypy]
help = "Run MyPy for this package."
cmd = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_mistral"
[tool.poe.tasks.test]
help = "Run the default unit test suite for this package."
cmd = 'pytest -m "not integration" --cov=agent_framework_mistral --cov-report=term-missing:skip-covered tests'
[tool.uv.build-backend]
module-name = "agent_framework_mistral"
module-root = ""
[build-system]
requires = ["uv_build>=0.8.2,<0.9.0"]
build-backend = "uv_build"
+15
View File
@@ -0,0 +1,15 @@
# Mistral AI Embedding Examples
This folder contains examples demonstrating how to use Mistral AI embedding models with the Agent Framework.
## Examples
| File | Description |
|------|-------------|
| [`mistral_embeddings.py`](mistral_embeddings.py) | Basic embedding generation with the Mistral AI embedding client. |
## Environment Variables
- `MISTRAL_API_KEY`: Your Mistral AI API key
- `MISTRAL_EMBEDDING_MODEL`: Embedding model name (e.g., `mistral-embed`)
- `MISTRAL_SERVER_URL` (optional): Server URL override for custom deployments
@@ -0,0 +1 @@
# Copyright (c) Microsoft. All rights reserved.
@@ -0,0 +1,77 @@
# Copyright (c) Microsoft. All rights reserved.
"""Shows how to generate embeddings using the Mistral AI embedding client.
Requires ``MISTRAL_API_KEY`` and ``MISTRAL_EMBEDDING_MODEL`` environment variables.
"""
import asyncio
from dotenv import load_dotenv
from agent_framework_mistral import MistralEmbeddingClient
load_dotenv()
async def basic_embedding_example() -> None:
"""Generate embeddings for a list of texts."""
print("=== Basic Embedding Generation ===")
# 1. Create the embedding client (uses MISTRAL_API_KEY and MISTRAL_EMBEDDING_MODEL env vars).
client = MistralEmbeddingClient()
# 2. Generate embeddings for multiple texts.
texts = ["Hello, world!", "How are you?", "Agent Framework with Mistral AI"]
result = await client.get_embeddings(texts)
# 3. Print results.
print(f"Generated {len(result)} embeddings")
for i, embedding in enumerate(result):
print(f" Text {i + 1}: dimensions={embedding.dimensions}, vector={embedding.vector[:5]}...")
if result.usage:
print(
f" Usage: {result.usage['input_token_count']} input tokens, "
f"{result.usage['total_token_count']} total tokens"
)
async def embedding_with_options_example() -> None:
"""Generate embeddings with custom dimensions."""
print("\n=== Embedding with Custom Dimensions ===")
from agent_framework_mistral import MistralEmbeddingOptions
client = MistralEmbeddingClient()
# Request a specific output dimension (model must support it).
options: MistralEmbeddingOptions = {"dimensions": 256}
result = await client.get_embeddings(["Dimensionality reduction example"], options=options)
print(f" Dimensions: {result[0].dimensions}")
print(f" Vector (first 5): {result[0].vector[:5]}...")
async def main() -> None:
"""Run embedding examples."""
await basic_embedding_example()
await embedding_with_options_example()
if __name__ == "__main__":
asyncio.run(main())
"""
Sample output:
=== Basic Embedding Generation ===
Generated 3 embeddings
Text 1: dimensions=1024, vector=[0.0123, -0.0456, 0.0789, -0.0012, 0.0345]...
Text 2: dimensions=1024, vector=[0.0234, -0.0567, 0.0891, -0.0023, 0.0456]...
Text 3: dimensions=1024, vector=[0.0345, -0.0678, 0.0912, -0.0034, 0.0567]...
Usage: 15 input tokens, 15 total tokens
=== Embedding with Custom Dimensions ===
Dimensions: 256
Vector (first 5): [0.0456, -0.0789, 0.0123, -0.0456, 0.0789]...
"""
@@ -0,0 +1,267 @@
# Copyright (c) Microsoft. All rights reserved.
import os
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agent_framework import Embedding, GeneratedEmbeddings
from agent_framework_mistral import MistralEmbeddingClient, MistralEmbeddingOptions
# region: Unit Tests
def test_mistral_embedding_construction(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test construction with environment variables."""
monkeypatch.setenv("MISTRAL_EMBEDDING_MODEL", "mistral-embed")
monkeypatch.setenv("MISTRAL_API_KEY", "test-key")
with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls:
mock_cls.return_value = MagicMock()
client = MistralEmbeddingClient()
assert client.model == "mistral-embed"
def test_mistral_embedding_construction_with_params() -> None:
"""Test construction with explicit parameters."""
with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls:
mock_cls.return_value = MagicMock()
client = MistralEmbeddingClient(
model="mistral-embed",
api_key="test-key",
)
assert client.model == "mistral-embed"
mock_cls.assert_called_once_with(api_key="test-key")
def test_mistral_embedding_construction_with_server_url() -> None:
"""Test construction with custom server URL."""
with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls:
mock_cls.return_value = MagicMock()
client = MistralEmbeddingClient(
model="mistral-embed",
api_key="test-key",
server_url="https://custom.mistral.ai",
)
assert client.model == "mistral-embed"
assert client.server_url == "https://custom.mistral.ai"
mock_cls.assert_called_once_with(
api_key="test-key",
server_url="https://custom.mistral.ai",
)
def test_mistral_embedding_construction_with_client() -> None:
"""Test construction with a pre-configured client."""
mock_client = MagicMock()
with patch("agent_framework_mistral._embedding_client.Mistral"):
client = MistralEmbeddingClient(
model="mistral-embed",
api_key="test-key",
client=mock_client,
)
assert client.client is mock_client
def test_mistral_embedding_construction_missing_model_raises(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that missing model raises an error."""
monkeypatch.delenv("MISTRAL_EMBEDDING_MODEL", raising=False)
monkeypatch.setenv("MISTRAL_API_KEY", "test-key")
from agent_framework.exceptions import SettingNotFoundError
with pytest.raises(SettingNotFoundError):
MistralEmbeddingClient()
def test_mistral_embedding_construction_missing_api_key_raises(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test that missing API key raises an error."""
monkeypatch.delenv("MISTRAL_API_KEY", raising=False)
monkeypatch.setenv("MISTRAL_EMBEDDING_MODEL", "mistral-embed")
from agent_framework.exceptions import SettingNotFoundError
with pytest.raises(SettingNotFoundError):
MistralEmbeddingClient()
def test_mistral_embedding_service_url() -> None:
"""Test service_url returns the correct URL."""
with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls:
mock_cls.return_value = MagicMock()
client = MistralEmbeddingClient(
model="mistral-embed",
api_key="test-key",
)
assert client.service_url() == "https://api.mistral.ai"
def test_mistral_embedding_service_url_custom() -> None:
"""Test service_url returns custom URL when set."""
with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls:
mock_cls.return_value = MagicMock()
client = MistralEmbeddingClient(
model="mistral-embed",
api_key="test-key",
server_url="https://custom.mistral.ai",
)
assert client.service_url() == "https://custom.mistral.ai"
async def test_mistral_embedding_get_embeddings() -> None:
"""Test generating embeddings via the Mistral API."""
mock_response = MagicMock()
mock_response.data = [
MagicMock(embedding=[0.1, 0.2, 0.3], index=0, object="embedding"),
MagicMock(embedding=[0.4, 0.5, 0.6], index=1, object="embedding"),
]
mock_response.model = "mistral-embed"
mock_response.usage = MagicMock(prompt_tokens=10, total_tokens=10)
with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls:
mock_client = MagicMock()
mock_client.embeddings = MagicMock()
mock_client.embeddings.create_async = AsyncMock(return_value=mock_response)
mock_cls.return_value = mock_client
client = MistralEmbeddingClient(model="mistral-embed", api_key="test-key")
result = await client.get_embeddings(["hello", "world"])
assert isinstance(result, GeneratedEmbeddings)
assert len(result) == 2
assert result[0].vector == [0.1, 0.2, 0.3]
assert result[1].vector == [0.4, 0.5, 0.6]
assert result[0].model == "mistral-embed"
assert result.usage == {"input_token_count": 10, "total_token_count": 10}
mock_client.embeddings.create_async.assert_called_once_with(
model="mistral-embed",
inputs=["hello", "world"],
)
async def test_mistral_embedding_get_embeddings_empty_input() -> None:
"""Test generating embeddings with empty input."""
with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls:
mock_client = MagicMock()
mock_cls.return_value = mock_client
client = MistralEmbeddingClient(model="mistral-embed", api_key="test-key")
result = await client.get_embeddings([])
assert isinstance(result, GeneratedEmbeddings)
assert len(result) == 0
async def test_mistral_embedding_get_embeddings_with_dimensions() -> None:
"""Test generating embeddings with custom dimensions option."""
mock_response = MagicMock()
mock_response.data = [
MagicMock(embedding=[0.1, 0.2], index=0, object="embedding"),
]
mock_response.model = "mistral-embed"
mock_response.usage = MagicMock(prompt_tokens=5, total_tokens=5)
with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls:
mock_client = MagicMock()
mock_client.embeddings = MagicMock()
mock_client.embeddings.create_async = AsyncMock(return_value=mock_response)
mock_cls.return_value = mock_client
client = MistralEmbeddingClient(model="mistral-embed", api_key="test-key")
options: MistralEmbeddingOptions = {"dimensions": 512}
result = await client.get_embeddings(["hello"], options=options)
assert len(result) == 1
mock_client.embeddings.create_async.assert_called_once_with(
model="mistral-embed",
inputs=["hello"],
output_dimension=512,
)
async def test_mistral_embedding_get_embeddings_no_model_raises() -> None:
"""Test that missing model at call time raises ValueError."""
with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls:
mock_client = MagicMock()
mock_cls.return_value = mock_client
client = MistralEmbeddingClient(model="mistral-embed", api_key="test-key")
client.model = None # type: ignore[assignment]
with pytest.raises(ValueError, match="model is required"):
await client.get_embeddings(["hello"])
async def test_mistral_embedding_get_embeddings_model_override() -> None:
"""Test that model can be overridden via options."""
mock_response = MagicMock()
mock_response.data = [
MagicMock(embedding=[0.1, 0.2, 0.3], index=0, object="embedding"),
]
mock_response.model = "custom-embed"
mock_response.usage = MagicMock(prompt_tokens=5, total_tokens=5)
with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls:
mock_client = MagicMock()
mock_client.embeddings = MagicMock()
mock_client.embeddings.create_async = AsyncMock(return_value=mock_response)
mock_cls.return_value = mock_client
client = MistralEmbeddingClient(model="mistral-embed", api_key="test-key")
options: MistralEmbeddingOptions = {"model": "custom-embed"}
result = await client.get_embeddings(["hello"], options=options)
assert len(result) == 1
assert result[0].model == "custom-embed"
mock_client.embeddings.create_async.assert_called_once_with(
model="custom-embed",
inputs=["hello"],
)
async def test_mistral_embedding_get_embeddings_no_usage() -> None:
"""Test handling response without usage information."""
mock_response = MagicMock()
mock_response.data = [
MagicMock(embedding=[0.1, 0.2, 0.3], index=0, object="embedding"),
]
mock_response.model = "mistral-embed"
mock_response.usage = None
with patch("agent_framework_mistral._embedding_client.Mistral") as mock_cls:
mock_client = MagicMock()
mock_client.embeddings = MagicMock()
mock_client.embeddings.create_async = AsyncMock(return_value=mock_response)
mock_cls.return_value = mock_client
client = MistralEmbeddingClient(model="mistral-embed", api_key="test-key")
result = await client.get_embeddings(["hello"])
assert len(result) == 1
assert result.usage is None
# region: Integration Tests
skip_if_mistral_embedding_integration_tests_disabled = pytest.mark.skipif(
os.getenv("MISTRAL_EMBEDDING_MODEL", "") in ("", "test-model") or os.getenv("MISTRAL_API_KEY", "") == "",
reason="No real Mistral embedding model or API key provided; skipping integration tests.",
)
@pytest.mark.flaky
@pytest.mark.integration
@skip_if_mistral_embedding_integration_tests_disabled
async def test_mistral_embedding_integration() -> None:
"""Integration test for Mistral AI embedding client."""
client = MistralEmbeddingClient()
result = await client.get_embeddings(["Hello, world!", "How are you?"])
assert isinstance(result, GeneratedEmbeddings)
assert len(result) == 2
for embedding in result:
assert isinstance(embedding, Embedding)
assert isinstance(embedding.vector, list)
assert len(embedding.vector) > 0
assert all(isinstance(v, float) for v in embedding.vector)
assert result.usage is not None
assert result.usage["input_token_count"] is not None
assert result.usage["input_token_count"] > 0