mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
dd3d085539
* Include reasoning messages in MESSAGES_SNAPSHOT (#4843) FlowState now tracks reasoning messages emitted during a run. _emit_text_reasoning() persists reasoning (including encrypted_value) into flow.reasoning_messages, and _build_messages_snapshot() appends them to the final MESSAGES_SNAPSHOT event. Changes: - Add reasoning_messages field to FlowState - Update _emit_text_reasoning() to accept optional flow parameter - Include reasoning_messages in _build_messages_snapshot() - Add 'reasoning' to ALLOWED_AGUI_ROLES so normalize_agui_role() preserves the role through snapshot round-trips - Skip reasoning messages in agui_messages_to_agent_framework() since they are UI-only state and should not be forwarded to LLM providers - Add regression tests for snapshot emission, encrypted value preservation, and multi-turn round-trip with reasoning Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Python: Include reasoning messages in MESSAGES_SNAPSHOT events Fixes #4843 * Fix PR review feedback for reasoning persistence (#4843) - Accumulate reasoning text per message_id (append deltas) instead of storing only the current chunk, matching flow.accumulated_text pattern - Use camelCase encryptedValue in snapshot JSON to match AG-UI protocol conventions (toolCallId, encryptedValue) - Normalize snake_case encrypted_value to encryptedValue in agui_messages_to_snapshot_format for input compatibility - Update normalize_agui_role docstring to include reasoning role - Add tests for incremental reasoning accumulation and key normalization Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Address review feedback for #4843: Python: agent-framework-ag-ui: include reasoning messages in MESSAGES_SNAPSHOT --------- Co-authored-by: Copilot <copilot@github.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
530 lines
15 KiB
Python
530 lines
15 KiB
Python
# Copyright (c) Microsoft. All rights reserved.
|
|
|
|
"""Tests for utilities."""
|
|
|
|
from dataclasses import dataclass
|
|
from datetime import date, datetime
|
|
|
|
from agent_framework_ag_ui._utils import (
|
|
generate_event_id,
|
|
make_json_safe,
|
|
merge_state,
|
|
)
|
|
|
|
|
|
def test_generate_event_id():
|
|
"""Test event ID generation."""
|
|
id1 = generate_event_id()
|
|
id2 = generate_event_id()
|
|
|
|
assert id1 != id2
|
|
assert isinstance(id1, str)
|
|
assert len(id1) > 0
|
|
|
|
|
|
def test_merge_state():
|
|
"""Test state merging."""
|
|
current: dict[str, int] = {"a": 1, "b": 2}
|
|
update: dict[str, int] = {"b": 3, "c": 4}
|
|
|
|
result = merge_state(current, update)
|
|
|
|
assert result["a"] == 1
|
|
assert result["b"] == 3
|
|
assert result["c"] == 4
|
|
|
|
|
|
def test_merge_state_empty_update():
|
|
"""Test merging with empty update."""
|
|
current: dict[str, int] = {"x": 10, "y": 20}
|
|
update: dict[str, int] = {}
|
|
|
|
result = merge_state(current, update)
|
|
|
|
assert result == current
|
|
assert result is not current
|
|
|
|
|
|
def test_merge_state_empty_current():
|
|
"""Test merging with empty current state."""
|
|
current: dict[str, int] = {}
|
|
update: dict[str, int] = {"a": 1, "b": 2}
|
|
|
|
result = merge_state(current, update)
|
|
|
|
assert result == update
|
|
|
|
|
|
def test_merge_state_deep_copy():
|
|
"""Test that merge_state creates a deep copy preventing mutation of original."""
|
|
current: dict[str, dict[str, object]] = {"recipe": {"name": "Cake", "ingredients": ["flour", "sugar"]}}
|
|
update: dict[str, str] = {"other": "value"}
|
|
|
|
result = merge_state(current, update)
|
|
|
|
result["recipe"]["ingredients"].append("eggs")
|
|
|
|
assert "eggs" not in current["recipe"]["ingredients"]
|
|
assert current["recipe"]["ingredients"] == ["flour", "sugar"]
|
|
assert result["recipe"]["ingredients"] == ["flour", "sugar", "eggs"]
|
|
|
|
|
|
def test_make_json_safe_basic():
|
|
"""Test JSON serialization of basic types."""
|
|
assert make_json_safe("text") == "text"
|
|
assert make_json_safe(123) == 123
|
|
assert make_json_safe(None) is None
|
|
assert make_json_safe(3.14) == 3.14
|
|
assert make_json_safe(True) is True
|
|
assert make_json_safe(False) is False
|
|
|
|
|
|
def test_make_json_safe_datetime():
|
|
"""Test datetime serialization."""
|
|
dt = datetime(2025, 10, 30, 12, 30, 45)
|
|
result = make_json_safe(dt)
|
|
assert result == "2025-10-30T12:30:45"
|
|
|
|
|
|
def test_make_json_safe_date():
|
|
"""Test date serialization."""
|
|
d = date(2025, 10, 30)
|
|
result = make_json_safe(d)
|
|
assert result == "2025-10-30"
|
|
|
|
|
|
@dataclass
|
|
class SampleDataclass:
|
|
"""Sample dataclass for testing."""
|
|
|
|
name: str
|
|
value: int
|
|
|
|
|
|
def test_make_json_safe_dataclass():
|
|
"""Test dataclass serialization."""
|
|
obj = SampleDataclass(name="test", value=42)
|
|
result = make_json_safe(obj)
|
|
assert result == {"name": "test", "value": 42}
|
|
|
|
|
|
class ModelDumpObject:
|
|
"""Object with model_dump method."""
|
|
|
|
def model_dump(self):
|
|
return {"type": "model", "data": "dump"}
|
|
|
|
|
|
def test_make_json_safe_model_dump():
|
|
"""Test object with model_dump method."""
|
|
obj = ModelDumpObject()
|
|
result = make_json_safe(obj)
|
|
assert result == {"type": "model", "data": "dump"}
|
|
|
|
|
|
class ToDictObject:
|
|
"""Object with to_dict method (like SerializationMixin)."""
|
|
|
|
def to_dict(self):
|
|
return {"type": "serialization_mixin", "method": "to_dict"}
|
|
|
|
|
|
def test_make_json_safe_to_dict():
|
|
"""Test object with to_dict method (SerializationMixin pattern)."""
|
|
obj = ToDictObject()
|
|
result = make_json_safe(obj)
|
|
assert result == {"type": "serialization_mixin", "method": "to_dict"}
|
|
|
|
|
|
class DictObject:
|
|
"""Object with dict method."""
|
|
|
|
def dict(self):
|
|
return {"type": "dict", "method": "call"}
|
|
|
|
|
|
def test_make_json_safe_dict_method():
|
|
"""Test object with dict method."""
|
|
obj = DictObject()
|
|
result = make_json_safe(obj)
|
|
assert result == {"type": "dict", "method": "call"}
|
|
|
|
|
|
class CustomObject:
|
|
"""Custom object with __dict__."""
|
|
|
|
def __init__(self):
|
|
self.field1 = "value1"
|
|
self.field2 = 123
|
|
|
|
|
|
def test_make_json_safe_dict_attribute():
|
|
"""Test object with __dict__ attribute."""
|
|
obj = CustomObject()
|
|
result = make_json_safe(obj)
|
|
assert result == {"field1": "value1", "field2": 123}
|
|
|
|
|
|
def test_make_json_safe_list():
|
|
"""Test list serialization."""
|
|
lst = [1, "text", None, {"key": "value"}]
|
|
result = make_json_safe(lst)
|
|
assert result == [1, "text", None, {"key": "value"}]
|
|
|
|
|
|
def test_make_json_safe_tuple():
|
|
"""Test tuple serialization."""
|
|
tpl = (1, 2, 3)
|
|
result = make_json_safe(tpl)
|
|
assert result == [1, 2, 3]
|
|
|
|
|
|
def test_make_json_safe_dict():
|
|
"""Test dict serialization."""
|
|
d = {"a": 1, "b": {"c": 2}}
|
|
result = make_json_safe(d)
|
|
assert result == {"a": 1, "b": {"c": 2}}
|
|
|
|
|
|
def test_make_json_safe_nested():
|
|
"""Test nested structure serialization."""
|
|
obj = {
|
|
"datetime": datetime(2025, 10, 30),
|
|
"list": [1, 2, CustomObject()],
|
|
"nested": {"value": SampleDataclass(name="nested", value=99)},
|
|
}
|
|
result = make_json_safe(obj)
|
|
|
|
assert result["datetime"] == "2025-10-30T00:00:00"
|
|
assert result["list"][0] == 1
|
|
assert result["list"][2] == {"field1": "value1", "field2": 123}
|
|
assert result["nested"]["value"] == {"name": "nested", "value": 99}
|
|
|
|
|
|
class UnserializableObject:
|
|
"""Object that can't be serialized by standard methods."""
|
|
|
|
def __init__(self):
|
|
# Add attribute to trigger __dict__ fallback path
|
|
pass
|
|
|
|
|
|
def test_make_json_safe_fallback():
|
|
"""Test fallback to dict for objects with __dict__."""
|
|
obj = UnserializableObject()
|
|
result = make_json_safe(obj)
|
|
# Objects with __dict__ return their __dict__ dict
|
|
assert isinstance(result, dict)
|
|
|
|
|
|
def test_make_json_safe_dataclass_with_nested_to_dict_object():
|
|
"""Test dataclass containing a to_dict object (like HandoffAgentUserRequest with AgentResponse).
|
|
|
|
This test verifies the fix for the AG-UI JSON serialization error when
|
|
HandoffAgentUserRequest (a dataclass) contains an AgentResponse (SerializationMixin).
|
|
"""
|
|
|
|
class NestedToDictObject:
|
|
"""Simulates SerializationMixin objects like AgentResponse."""
|
|
|
|
def __init__(self, contents: list[str]):
|
|
self.contents = contents
|
|
|
|
def to_dict(self):
|
|
return {"type": "response", "contents": self.contents}
|
|
|
|
@dataclass
|
|
class ContainerDataclass:
|
|
"""Simulates HandoffAgentUserRequest dataclass."""
|
|
|
|
response: NestedToDictObject
|
|
|
|
obj = ContainerDataclass(response=NestedToDictObject(contents=["hello", "world"]))
|
|
result = make_json_safe(obj)
|
|
|
|
# Verify the nested to_dict object was properly serialized
|
|
assert result == {"response": {"type": "response", "contents": ["hello", "world"]}}
|
|
|
|
# Verify the result is actually JSON serializable
|
|
import json
|
|
|
|
json_str = json.dumps(result)
|
|
assert json_str is not None
|
|
|
|
|
|
def test_convert_tools_to_agui_format_with_tool():
|
|
"""Test converting FunctionTool to AG-UI format."""
|
|
from agent_framework import tool
|
|
|
|
from agent_framework_ag_ui._utils import convert_tools_to_agui_format
|
|
|
|
@tool
|
|
def test_func(param: str, count: int = 5) -> str:
|
|
"""Test function."""
|
|
return f"{param} {count}"
|
|
|
|
result = convert_tools_to_agui_format([test_func])
|
|
|
|
assert result is not None
|
|
assert len(result) == 1
|
|
assert result[0]["name"] == "test_func"
|
|
assert result[0]["description"] == "Test function."
|
|
assert "parameters" in result[0]
|
|
assert "properties" in result[0]["parameters"]
|
|
|
|
|
|
def test_convert_tools_to_agui_format_with_callable():
|
|
"""Test converting plain callable to AG-UI format."""
|
|
from agent_framework_ag_ui._utils import convert_tools_to_agui_format
|
|
|
|
def plain_func(x: int) -> int:
|
|
"""A plain function."""
|
|
return x * 2
|
|
|
|
result = convert_tools_to_agui_format([plain_func])
|
|
|
|
assert result is not None
|
|
assert len(result) == 1
|
|
assert result[0]["name"] == "plain_func"
|
|
assert result[0]["description"] == "A plain function."
|
|
assert "parameters" in result[0]
|
|
|
|
|
|
def test_convert_tools_to_agui_format_with_dict():
|
|
"""Test converting dict tool to AG-UI format."""
|
|
from agent_framework_ag_ui._utils import convert_tools_to_agui_format
|
|
|
|
tool_dict = {
|
|
"name": "custom_tool",
|
|
"description": "Custom tool",
|
|
"parameters": {"type": "object"},
|
|
}
|
|
|
|
result = convert_tools_to_agui_format([tool_dict])
|
|
|
|
assert result is not None
|
|
assert len(result) == 1
|
|
assert result[0] == tool_dict
|
|
|
|
|
|
def test_convert_tools_to_agui_format_with_none():
|
|
"""Test converting None tools."""
|
|
from agent_framework_ag_ui._utils import convert_tools_to_agui_format
|
|
|
|
result = convert_tools_to_agui_format(None)
|
|
|
|
assert result is None
|
|
|
|
|
|
def test_convert_tools_to_agui_format_with_single_tool():
|
|
"""Test converting single tool (not in list)."""
|
|
from agent_framework import tool
|
|
|
|
from agent_framework_ag_ui._utils import convert_tools_to_agui_format
|
|
|
|
@tool
|
|
def single_tool(arg: str) -> str:
|
|
"""Single tool."""
|
|
return arg
|
|
|
|
result = convert_tools_to_agui_format(single_tool)
|
|
|
|
assert result is not None
|
|
assert len(result) == 1
|
|
assert result[0]["name"] == "single_tool"
|
|
|
|
|
|
def test_convert_tools_to_agui_format_with_multiple_tools():
|
|
"""Test converting multiple tools."""
|
|
from agent_framework import tool
|
|
|
|
from agent_framework_ag_ui._utils import convert_tools_to_agui_format
|
|
|
|
@tool
|
|
def tool1(x: int) -> int:
|
|
"""Tool 1."""
|
|
return x
|
|
|
|
@tool
|
|
def tool2(y: str) -> str:
|
|
"""Tool 2."""
|
|
return y
|
|
|
|
result = convert_tools_to_agui_format([tool1, tool2])
|
|
|
|
assert result is not None
|
|
assert len(result) == 2
|
|
assert result[0]["name"] == "tool1"
|
|
assert result[1]["name"] == "tool2"
|
|
|
|
|
|
# Additional tests for utils coverage
|
|
|
|
|
|
def test_safe_json_parse_with_dict():
|
|
"""Test safe_json_parse with dict input."""
|
|
from agent_framework_ag_ui._utils import safe_json_parse
|
|
|
|
input_dict = {"key": "value"}
|
|
result = safe_json_parse(input_dict)
|
|
assert result == input_dict
|
|
|
|
|
|
def test_safe_json_parse_with_json_string():
|
|
"""Test safe_json_parse with JSON string."""
|
|
from agent_framework_ag_ui._utils import safe_json_parse
|
|
|
|
result = safe_json_parse('{"key": "value"}')
|
|
assert result == {"key": "value"}
|
|
|
|
|
|
def test_safe_json_parse_with_invalid_json():
|
|
"""Test safe_json_parse with invalid JSON."""
|
|
from agent_framework_ag_ui._utils import safe_json_parse
|
|
|
|
result = safe_json_parse("not json")
|
|
assert result is None
|
|
|
|
|
|
def test_safe_json_parse_with_non_dict_json():
|
|
"""Test safe_json_parse with JSON that parses to non-dict."""
|
|
from agent_framework_ag_ui._utils import safe_json_parse
|
|
|
|
result = safe_json_parse("[1, 2, 3]")
|
|
assert result is None
|
|
|
|
|
|
def test_safe_json_parse_with_none():
|
|
"""Test safe_json_parse with None input."""
|
|
from agent_framework_ag_ui._utils import safe_json_parse
|
|
|
|
result = safe_json_parse(None)
|
|
assert result is None
|
|
|
|
|
|
def test_get_role_value_with_enum():
|
|
"""Test get_role_value with enum role."""
|
|
from agent_framework import Content, Message
|
|
|
|
from agent_framework_ag_ui._utils import get_role_value
|
|
|
|
message = Message(role="user", contents=[Content.from_text("test")])
|
|
result = get_role_value(message)
|
|
assert result == "user"
|
|
|
|
|
|
def test_get_role_value_with_string():
|
|
"""Test get_role_value with string role."""
|
|
from agent_framework_ag_ui._utils import get_role_value
|
|
|
|
class MockMessage:
|
|
role = "assistant"
|
|
|
|
result = get_role_value(MockMessage())
|
|
assert result == "assistant"
|
|
|
|
|
|
def test_get_role_value_with_none():
|
|
"""Test get_role_value with no role."""
|
|
from agent_framework_ag_ui._utils import get_role_value
|
|
|
|
class MockMessage:
|
|
pass
|
|
|
|
result = get_role_value(MockMessage())
|
|
assert result == ""
|
|
|
|
|
|
def test_normalize_agui_role_developer():
|
|
"""Test normalize_agui_role maps developer to system."""
|
|
from agent_framework_ag_ui._utils import normalize_agui_role
|
|
|
|
assert normalize_agui_role("developer") == "system"
|
|
|
|
|
|
def test_normalize_agui_role_valid():
|
|
"""Test normalize_agui_role with valid roles."""
|
|
from agent_framework_ag_ui._utils import normalize_agui_role
|
|
|
|
assert normalize_agui_role("user") == "user"
|
|
assert normalize_agui_role("assistant") == "assistant"
|
|
assert normalize_agui_role("system") == "system"
|
|
assert normalize_agui_role("tool") == "tool"
|
|
assert normalize_agui_role("reasoning") == "reasoning"
|
|
|
|
|
|
def test_normalize_agui_role_invalid():
|
|
"""Test normalize_agui_role with invalid role defaults to user."""
|
|
from agent_framework_ag_ui._utils import normalize_agui_role
|
|
|
|
assert normalize_agui_role("invalid") == "user"
|
|
assert normalize_agui_role(123) == "user"
|
|
|
|
|
|
def test_extract_state_from_tool_args():
|
|
"""Test extract_state_from_tool_args."""
|
|
from agent_framework_ag_ui._utils import extract_state_from_tool_args
|
|
|
|
# Specific key
|
|
assert extract_state_from_tool_args({"key": "value"}, "key") == "value"
|
|
|
|
# Wildcard
|
|
args = {"a": 1, "b": 2}
|
|
assert extract_state_from_tool_args(args, "*") == args
|
|
|
|
# Missing key
|
|
assert extract_state_from_tool_args({"other": "value"}, "key") is None
|
|
|
|
# None args
|
|
assert extract_state_from_tool_args(None, "key") is None
|
|
|
|
|
|
def test_convert_agui_tools_to_agent_framework():
|
|
"""Test convert_agui_tools_to_agent_framework."""
|
|
from agent_framework_ag_ui._utils import convert_agui_tools_to_agent_framework
|
|
|
|
agui_tools = [
|
|
{
|
|
"name": "test_tool",
|
|
"description": "A test tool",
|
|
"parameters": {"type": "object", "properties": {"arg": {"type": "string"}}},
|
|
}
|
|
]
|
|
|
|
result = convert_agui_tools_to_agent_framework(agui_tools)
|
|
|
|
assert result is not None
|
|
assert len(result) == 1
|
|
assert result[0].name == "test_tool"
|
|
assert result[0].description == "A test tool"
|
|
assert result[0].declaration_only is True
|
|
|
|
|
|
def test_convert_agui_tools_to_agent_framework_none():
|
|
"""Test convert_agui_tools_to_agent_framework with None."""
|
|
from agent_framework_ag_ui._utils import convert_agui_tools_to_agent_framework
|
|
|
|
result = convert_agui_tools_to_agent_framework(None)
|
|
assert result is None
|
|
|
|
|
|
def test_convert_agui_tools_to_agent_framework_empty():
|
|
"""Test convert_agui_tools_to_agent_framework with empty list."""
|
|
from agent_framework_ag_ui._utils import convert_agui_tools_to_agent_framework
|
|
|
|
result = convert_agui_tools_to_agent_framework([])
|
|
assert result is None
|
|
|
|
|
|
def test_make_json_safe_unconvertible():
|
|
"""Test make_json_safe with object that has no standard conversion."""
|
|
|
|
class NoConversion:
|
|
__slots__ = () # No __dict__
|
|
|
|
from agent_framework_ag_ui._utils import make_json_safe
|
|
|
|
result = make_json_safe(NoConversion())
|
|
# Falls back to str()
|
|
assert isinstance(result, str)
|