mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
4268080c20
* Fix magentic manager warning
* Use typing_extensions.Sentinel for _MISSING sentinel value
Replace the bare object() sentinel with typing_extensions.Sentinel per
PEP 661 (now final). Sentinel provides a proper name and repr
('<_MISSING>') and is the idiomatic approach going forward.
Refs #4306
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
* fix: correct Sentinel type annotation for max_stall_count param (#6261)
Use int | Sentinel for max_stall_count parameter type annotation instead
of int with cast(Any, _MISSING) to properly express that the parameter
can hold either an int or the _MISSING sentinel value. This fixes the
pyright reportUnnecessaryComparison errors caused by the types int and
Sentinel having no overlap.
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
* Rename _MISSING sentinel to UNSET in orchestrations
The sentinel is user-visible as a default in public init signatures, so
use UNSET (no leading underscore) instead of the private _MISSING name.
Drop the now-unnecessary reportPrivateUsage ignores on the UNSET imports.
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
---------
Co-authored-by: Copilot <copilot@github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
384 lines
12 KiB
Python
384 lines
12 KiB
Python
# Copyright (c) Microsoft. All rights reserved.
|
|
|
|
from __future__ import annotations
|
|
|
|
import copy
|
|
import json
|
|
from typing import Any
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
from agent_framework import Content, Message
|
|
from botocore.exceptions import ClientError
|
|
from pydantic import BaseModel
|
|
|
|
from agent_framework_bedrock import BedrockChatClient
|
|
|
|
# region Test models
|
|
|
|
|
|
class WeatherReport(BaseModel):
|
|
city: str
|
|
temperature: float
|
|
summary: str
|
|
|
|
|
|
class NestedAddress(BaseModel):
|
|
street: str
|
|
city: str
|
|
zip_code: str
|
|
|
|
|
|
class Person(BaseModel):
|
|
name: str
|
|
age: int
|
|
address: NestedAddress
|
|
|
|
|
|
# endregion
|
|
|
|
|
|
# region Helpers
|
|
|
|
|
|
class _StubBedrockRuntime:
|
|
"""Stub that records calls and returns a canned response."""
|
|
|
|
def __init__(self, response_text: str = "Bedrock says hi") -> None:
|
|
self.calls: list[dict[str, Any]] = []
|
|
self._response_text = response_text
|
|
|
|
def converse(self, **kwargs: Any) -> dict[str, Any]:
|
|
self.calls.append(kwargs)
|
|
return {
|
|
"modelId": kwargs["modelId"],
|
|
"responseId": "resp-structured",
|
|
"usage": {"inputTokens": 10, "outputTokens": 20, "totalTokens": 30},
|
|
"output": {
|
|
"completionReason": "end_turn",
|
|
"message": {
|
|
"id": "msg-structured",
|
|
"role": "assistant",
|
|
"content": [{"text": self._response_text}],
|
|
},
|
|
},
|
|
}
|
|
|
|
|
|
def _make_client(response_text: str = "Bedrock says hi") -> tuple[BedrockChatClient, _StubBedrockRuntime]:
|
|
stub = _StubBedrockRuntime(response_text)
|
|
client = BedrockChatClient(
|
|
model="us.anthropic.claude-haiku-4-5-v1:0",
|
|
region="us-east-1",
|
|
client=stub,
|
|
)
|
|
return client, stub
|
|
|
|
|
|
def _user_messages() -> list[Message]:
|
|
return [Message(role="user", contents=[Content.from_text(text="Give me a weather report")])]
|
|
|
|
|
|
# endregion
|
|
|
|
|
|
# region Tests
|
|
|
|
|
|
def test_prepare_output_config_correct_wire_shape() -> None:
|
|
"""_prepare_output_config(WeatherReport) must produce the correct
|
|
textFormat → structure → jsonSchema shape with type: 'json_schema'."""
|
|
client, _ = _make_client()
|
|
|
|
output_config = client._prepare_output_config(WeatherReport)
|
|
|
|
assert output_config is not None
|
|
text_format = output_config["textFormat"]
|
|
assert text_format["type"] == "json_schema"
|
|
assert "structure" in text_format
|
|
json_schema = text_format["structure"]["jsonSchema"]
|
|
assert json_schema["name"] == "WeatherReport"
|
|
assert "schema" in json_schema
|
|
|
|
|
|
def test_prepare_output_config_schema_is_json_string() -> None:
|
|
"""The schema value inside jsonSchema must be a JSON string, not a dict."""
|
|
client, _ = _make_client()
|
|
|
|
output_config = client._prepare_output_config(WeatherReport)
|
|
|
|
assert output_config is not None
|
|
schema_value = output_config["textFormat"]["structure"]["jsonSchema"]["schema"]
|
|
assert isinstance(schema_value, str), f"Expected str, got {type(schema_value)}"
|
|
# Verify it's valid JSON
|
|
parsed = json.loads(schema_value)
|
|
assert isinstance(parsed, dict)
|
|
assert parsed["type"] == "object"
|
|
|
|
|
|
def test_additional_properties_false_set_recursively() -> None:
|
|
"""additionalProperties: false must be set on all nested object types."""
|
|
client, _ = _make_client()
|
|
|
|
output_config = client._prepare_output_config(Person)
|
|
|
|
assert output_config is not None
|
|
schema_str = output_config["textFormat"]["structure"]["jsonSchema"]["schema"]
|
|
schema = json.loads(schema_str)
|
|
|
|
# Top-level object
|
|
assert schema.get("additionalProperties") is False
|
|
|
|
# Check $defs for NestedAddress
|
|
defs = schema.get("$defs", {})
|
|
assert "NestedAddress" in defs, "Expected NestedAddress to be present in $defs"
|
|
assert defs["NestedAddress"].get("additionalProperties") is False, (
|
|
"Expected additionalProperties=False on nested NestedAddress schema"
|
|
)
|
|
|
|
|
|
def test_no_output_config_when_response_format_none() -> None:
|
|
"""When response_format is None, no outputConfig key should appear in the request."""
|
|
client, stub = _make_client()
|
|
messages = _user_messages()
|
|
|
|
request = client._prepare_options(messages, {"max_tokens": 100})
|
|
|
|
assert "outputConfig" not in request, (
|
|
f"outputConfig should not be present when response_format is None, got: {request.get('outputConfig')}"
|
|
)
|
|
|
|
|
|
async def test_chat_response_value_populated() -> None:
|
|
"""After a mocked response with response_format, .value should be a populated Pydantic model."""
|
|
json_response = json.dumps({"city": "Seattle", "temperature": 72.5, "summary": "Sunny and warm"})
|
|
client, stub = _make_client(response_text=json_response)
|
|
messages = _user_messages()
|
|
|
|
response = await client.get_response(
|
|
messages=messages,
|
|
options={"max_tokens": 100, "response_format": WeatherReport},
|
|
)
|
|
|
|
assert response.text == json_response
|
|
assert response.value is not None
|
|
assert isinstance(response.value, WeatherReport)
|
|
assert response.value.city == "Seattle"
|
|
assert response.value.temperature == 72.5
|
|
assert response.value.summary == "Sunny and warm"
|
|
|
|
# Verify outputConfig was sent to the API
|
|
assert len(stub.calls) == 1
|
|
api_request = stub.calls[0]
|
|
assert "outputConfig" in api_request
|
|
assert api_request["outputConfig"]["textFormat"]["type"] == "json_schema"
|
|
|
|
|
|
def test_dict_schema_response_format() -> None:
|
|
"""_prepare_output_config should work when response_format is a dict, not just a Pydantic class."""
|
|
client, _ = _make_client()
|
|
|
|
dict_schema = {
|
|
"json_schema": {
|
|
"name": "weather_output",
|
|
"schema": {
|
|
"type": "object",
|
|
"properties": {
|
|
"city": {"type": "string"},
|
|
"temp": {"type": "number"},
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
output_config = client._prepare_output_config(dict_schema)
|
|
|
|
assert output_config is not None
|
|
json_schema = output_config["textFormat"]["structure"]["jsonSchema"]
|
|
assert json_schema["name"] == "weather_output"
|
|
schema_parsed = json.loads(json_schema["schema"])
|
|
assert schema_parsed["type"] == "object"
|
|
assert "city" in schema_parsed["properties"]
|
|
|
|
|
|
def test_prepare_output_config_none_returns_none() -> None:
|
|
"""_prepare_output_config(None) must return None."""
|
|
client, _ = _make_client()
|
|
|
|
result = client._prepare_output_config(None)
|
|
|
|
assert result is None
|
|
|
|
|
|
async def test_chat_response_value_populated_streaming() -> None:
|
|
"""In streaming mode, .value should also be populated on the final response."""
|
|
json_response = json.dumps({"city": "Portland", "temperature": 68.0, "summary": "Cloudy"})
|
|
client, stub = _make_client(response_text=json_response)
|
|
messages = _user_messages()
|
|
|
|
stream = client.get_response(
|
|
messages=messages,
|
|
stream=True,
|
|
options={"max_tokens": 100, "response_format": WeatherReport},
|
|
)
|
|
|
|
# Consume stream and get final response
|
|
async for _ in stream:
|
|
pass
|
|
response = await stream.get_final_response()
|
|
|
|
assert response.value is not None
|
|
assert isinstance(response.value, WeatherReport)
|
|
assert response.value.city == "Portland"
|
|
|
|
# Verify outputConfig was sent
|
|
assert len(stub.calls) == 1
|
|
assert "outputConfig" in stub.calls[0]
|
|
|
|
|
|
async def test_unsupported_model_validation_exception() -> None:
|
|
"""When a model doesn't support outputConfig, a clear error should be raised."""
|
|
|
|
class _FailingStubBedrockRuntime:
|
|
def converse(self, **kwargs: Any) -> dict[str, Any]:
|
|
# Simulate botocore ClientError for ValidationException
|
|
error_response = {"Error": {"Code": "ValidationException", "Message": "Invalid field outputConfig"}}
|
|
raise ClientError(error_response, "Converse")
|
|
|
|
client = BedrockChatClient(
|
|
model="us.anthropic.claude-v2",
|
|
region="us-east-1",
|
|
client=_FailingStubBedrockRuntime(),
|
|
)
|
|
|
|
with pytest.raises(ValueError) as exc:
|
|
await client.get_response(
|
|
messages=_user_messages(),
|
|
options={"response_format": WeatherReport},
|
|
)
|
|
|
|
assert "does not support structured output via outputConfig.textFormat" in str(exc.value)
|
|
assert "Check the model's Bedrock Converse outputConfig/textFormat support." in str(exc.value)
|
|
|
|
|
|
def test_invalid_response_format_type_raises() -> None:
|
|
"""Non-dict, non-BaseModel response_format should raise TypeError."""
|
|
client, _ = _make_client()
|
|
with pytest.raises(TypeError, match="Pydantic BaseModel subclass"):
|
|
client._prepare_output_config("not_a_valid_format")
|
|
|
|
|
|
def test_mapping_response_format_accepted() -> None:
|
|
"""A non-dict Mapping response_format must be accepted and produce
|
|
correct outputConfig, not raise TypeError."""
|
|
from collections.abc import MutableMapping
|
|
|
|
class _WrappedMapping(MutableMapping):
|
|
def __init__(self, data):
|
|
self._data = dict(data)
|
|
|
|
def __getitem__(self, key):
|
|
return self._data[key]
|
|
|
|
def __setitem__(self, key, value):
|
|
self._data[key] = value
|
|
|
|
def __delitem__(self, key):
|
|
del self._data[key]
|
|
|
|
def __iter__(self):
|
|
return iter(self._data)
|
|
|
|
def __len__(self):
|
|
return len(self._data)
|
|
|
|
client, _ = _make_client()
|
|
mapping_format = _WrappedMapping({
|
|
"json_schema": {
|
|
"name": "test_output",
|
|
"schema": {
|
|
"type": "object",
|
|
"properties": {"result": {"type": "string"}},
|
|
},
|
|
}
|
|
})
|
|
|
|
output_config = client._prepare_output_config(mapping_format)
|
|
|
|
assert output_config is not None
|
|
json_schema = output_config["textFormat"]["structure"]["jsonSchema"]
|
|
assert json_schema["name"] == "test_output"
|
|
schema = json.loads(json_schema["schema"])
|
|
assert schema.get("additionalProperties") is False
|
|
|
|
|
|
def test_shape_b_dict_schema_wire_format() -> None:
|
|
"""Dict response_format in Shape B (inner shape directly) should
|
|
produce correct outputConfig."""
|
|
client, _ = _make_client()
|
|
|
|
response_format = {
|
|
"name": "weather_output",
|
|
"schema": {
|
|
"type": "object",
|
|
"properties": {
|
|
"city": {"type": "string"},
|
|
"temperature": {"type": "number"},
|
|
},
|
|
},
|
|
}
|
|
|
|
output_config = client._prepare_output_config(response_format)
|
|
|
|
assert output_config is not None
|
|
text_format = output_config["textFormat"]
|
|
assert text_format["type"] == "json_schema"
|
|
json_schema = text_format["structure"]["jsonSchema"]
|
|
assert json_schema["name"] == "weather_output"
|
|
schema = json.loads(json_schema["schema"])
|
|
assert schema.get("additionalProperties") is False
|
|
|
|
|
|
def test_dict_schema_not_mutated() -> None:
|
|
"""Caller's dict schema must not be mutated by _prepare_output_config."""
|
|
client, _ = _make_client()
|
|
original_schema = {
|
|
"json_schema": {
|
|
"name": "test",
|
|
"schema": {
|
|
"type": "object",
|
|
"properties": {"a": {"type": "string"}},
|
|
},
|
|
}
|
|
}
|
|
snapshot = copy.deepcopy(original_schema)
|
|
client._prepare_output_config(original_schema)
|
|
assert original_schema == snapshot, "Original dict schema was mutated"
|
|
|
|
|
|
async def test_non_outputconfig_validation_exception_propagates() -> None:
|
|
"""ValidationException unrelated to outputConfig must propagate
|
|
as raw ClientError, not be caught and reclassified."""
|
|
client, _ = _make_client()
|
|
error_response = {
|
|
"Error": {
|
|
"Code": "ValidationException",
|
|
"Message": "Invalid message format",
|
|
}
|
|
}
|
|
with (
|
|
patch.object(
|
|
client,
|
|
"_bedrock_client",
|
|
**{"converse.side_effect": ClientError(error_response, "Converse")},
|
|
),
|
|
pytest.raises(ClientError),
|
|
):
|
|
await client.get_response(
|
|
messages=_user_messages(),
|
|
options={"max_tokens": 100},
|
|
)
|
|
|
|
|
|
# endregion
|