mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Fix core observability unsafe serialization of function-call arguments containing dataclass/framework objects (#6026)
* fix: safely serialize function-call arguments in core observability Apply make_json_safe() to content.arguments in _to_otel_part() before building the otel message dict, so that dataclass/framework payloads (e.g. workflow request_info events) do not cause a TypeError when _capture_messages() calls json.dumps(). Lift make_json_safe() into agent_framework._serialization (no new external deps — dataclasses/datetime only) so the core observability path can use it without a dependency on the ag-ui adapter. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix(core): safely serialize workflow request_info payloads in observability (#5733) - Add make_json_safe() helper to recursively convert non-serializable objects - Use make_json_safe() in _to_otel_part() for function_call arguments - Fix CustomPayload test class to use @dataclass (resolves B903 lint error) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix(serialization): guard callability and normalize dict keys in make_json_safe (#5733) - Use callable(getattr(obj, method, None)) instead of hasattr() so that non-callable attributes named model_dump/to_dict/dict do not raise TypeError at runtime. - Wrap each call in try/except TypeError to handle callables with mandatory arguments gracefully. - Convert dict keys to str() so that non-string keys (e.g. datetime, int) cannot cause json.dumps to raise TypeError. - Add regression tests for both scenarios. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Address observability serialization review feedback --------- Co-authored-by: Copilot <copilot@github.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
03e14ca187
commit
f36096ce1a
@@ -7,6 +7,8 @@ import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Mapping, MutableMapping
|
||||
from dataclasses import asdict, is_dataclass
|
||||
from datetime import date, datetime
|
||||
from typing import Any, ClassVar, Protocol, TypeVar, runtime_checkable
|
||||
|
||||
logger = logging.getLogger("agent_framework")
|
||||
@@ -614,3 +616,46 @@ class SerializationMixin:
|
||||
# Fallback and default
|
||||
# Convert class name to snake_case
|
||||
return _CAMEL_TO_SNAKE_PATTERN.sub("_", cls.__name__).lower()
|
||||
|
||||
|
||||
def make_json_safe(obj: Any) -> Any:
|
||||
"""Recursively convert an object to a JSON-serializable form.
|
||||
|
||||
Handles dataclasses, Pydantic models, objects with ``to_dict``/``dict``/``__dict__``,
|
||||
datetimes, lists, dicts, and primitives. Falls back to ``str()`` for any remaining
|
||||
non-serializable value so that ``json.dumps`` never raises a ``TypeError``.
|
||||
|
||||
Args:
|
||||
obj: Object to make JSON safe.
|
||||
|
||||
Returns:
|
||||
A JSON-serializable version of the object.
|
||||
"""
|
||||
if obj is None or isinstance(obj, (str, int, float, bool)):
|
||||
return obj
|
||||
if isinstance(obj, (datetime, date)):
|
||||
return obj.isoformat()
|
||||
if is_dataclass(obj) and not isinstance(obj, type):
|
||||
return make_json_safe(asdict(obj)) # type: ignore[arg-type]
|
||||
if callable(getattr(obj, "model_dump", None)):
|
||||
try:
|
||||
return make_json_safe(obj.model_dump()) # type: ignore[no-any-return]
|
||||
except TypeError:
|
||||
pass
|
||||
if callable(getattr(obj, "to_dict", None)):
|
||||
try:
|
||||
return make_json_safe(obj.to_dict()) # type: ignore[no-any-return]
|
||||
except TypeError:
|
||||
pass
|
||||
if callable(getattr(obj, "dict", None)):
|
||||
try:
|
||||
return make_json_safe(obj.dict()) # type: ignore[no-any-return]
|
||||
except TypeError:
|
||||
pass
|
||||
if isinstance(obj, dict):
|
||||
return {str(key): make_json_safe(value) for key, value in obj.items()} # type: ignore[misc]
|
||||
if isinstance(obj, (list, tuple)):
|
||||
return [make_json_safe(item) for item in obj] # type: ignore[misc]
|
||||
if hasattr(obj, "__dict__"):
|
||||
return {key: make_json_safe(value) for key, value in vars(obj).items()} # type: ignore[misc]
|
||||
return str(obj)
|
||||
|
||||
@@ -12,6 +12,7 @@ from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast, overload
|
||||
|
||||
from .._agents import BaseAgent
|
||||
from .._serialization import make_json_safe
|
||||
from .._sessions import (
|
||||
AgentSession,
|
||||
ContextProvider,
|
||||
@@ -61,7 +62,7 @@ class WorkflowAgent(BaseAgent):
|
||||
data: Any
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {"request_id": self.request_id, "data": self.data}
|
||||
return {"request_id": self.request_id, "data": make_json_safe(self.data)}
|
||||
|
||||
def to_json(self) -> str:
|
||||
return json.dumps(self.to_dict())
|
||||
|
||||
@@ -47,6 +47,7 @@ from copy import deepcopy
|
||||
from typing import Any, Generic, Literal, TypeVar, overload
|
||||
|
||||
from .._feature_stage import ExperimentalFeature, experimental
|
||||
from .._serialization import make_json_safe
|
||||
from .._types import AgentResponse, AgentResponseUpdate, ResponseStream
|
||||
from ..observability import OtelAttr, capture_exception, create_workflow_span
|
||||
from ._checkpoint import CheckpointStorage, WorkflowCheckpoint
|
||||
@@ -1515,7 +1516,7 @@ class FunctionalWorkflowAgent:
|
||||
function_call = Content.from_function_call(
|
||||
call_id=request_id,
|
||||
name=self.REQUEST_INFO_FUNCTION_NAME,
|
||||
arguments={"request_id": request_id, "data": event.data},
|
||||
arguments={"request_id": request_id, "data": make_json_safe(event.data)},
|
||||
)
|
||||
return Content.from_function_approval_request(
|
||||
id=request_id,
|
||||
|
||||
@@ -1691,6 +1691,65 @@ def test_to_otel_part_function_call():
|
||||
}
|
||||
|
||||
|
||||
def test_to_otel_part_function_call_reuses_prepared_arguments():
|
||||
"""Test _to_otel_part does not re-serialize function-call arguments in the observability hot path."""
|
||||
from agent_framework import Content
|
||||
from agent_framework.observability import _to_otel_part
|
||||
|
||||
arguments = {"payload": object()}
|
||||
content = Content(type="function_call", call_id="call_789", name="handoff", arguments=arguments)
|
||||
result = _to_otel_part(content)
|
||||
|
||||
assert result is not None
|
||||
assert result["arguments"] is arguments
|
||||
|
||||
|
||||
def test_make_json_safe_non_callable_method_attribute():
|
||||
"""Test make_json_safe handles objects where model_dump/to_dict/dict are non-callable attributes."""
|
||||
from agent_framework._serialization import make_json_safe
|
||||
|
||||
class ObjWithNonCallableModelDump:
|
||||
model_dump = 42 # not callable
|
||||
|
||||
obj = ObjWithNonCallableModelDump()
|
||||
result = make_json_safe(obj)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_make_json_safe_callable_method_type_error_falls_through():
|
||||
"""Test make_json_safe falls through when serializer-like methods require arguments."""
|
||||
from agent_framework._serialization import make_json_safe
|
||||
|
||||
class ObjWithRequiredArgModelDump:
|
||||
def __init__(self) -> None:
|
||||
self.value = "fallback"
|
||||
|
||||
def model_dump(self, required: str) -> dict[str, str]:
|
||||
return {"required": required}
|
||||
|
||||
obj = ObjWithRequiredArgModelDump()
|
||||
result = make_json_safe(obj)
|
||||
assert result == {"value": "fallback"}
|
||||
|
||||
|
||||
def test_make_json_safe_dict_with_non_string_keys():
|
||||
"""Test make_json_safe converts non-primitive dict keys to strings."""
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from agent_framework._serialization import make_json_safe
|
||||
|
||||
dt_key = datetime(2024, 1, 1)
|
||||
obj = {dt_key: "value", 42: "num_value", "str_key": "normal"}
|
||||
result = make_json_safe(obj)
|
||||
# json.dumps must not raise TypeError
|
||||
serialized = json.dumps(result)
|
||||
parsed = json.loads(serialized)
|
||||
assert parsed[str(dt_key)] == "value"
|
||||
assert parsed["42"] == "num_value"
|
||||
assert parsed["str_key"] == "normal"
|
||||
|
||||
|
||||
def test_to_otel_part_function_result():
|
||||
"""Test _to_otel_part with function_result content."""
|
||||
from agent_framework import Content
|
||||
@@ -3019,6 +3078,49 @@ async def test_system_instructions_preserves_non_ascii_characters(span_exporter:
|
||||
assert [msg.get("role") for msg in input_messages] == ["user"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enable_sensitive_data", [True], indirect=True)
|
||||
def test_capture_messages_with_prepared_request_info_function_call_arguments(span_exporter: InMemorySpanExporter):
|
||||
"""Test _capture_messages handles request-info function-call arguments prepared at Content creation."""
|
||||
import dataclasses
|
||||
import json
|
||||
|
||||
from opentelemetry import trace
|
||||
|
||||
from agent_framework import WorkflowAgent
|
||||
|
||||
@dataclasses.dataclass
|
||||
class HandoffRequest:
|
||||
target_agent: str
|
||||
reason: str
|
||||
|
||||
arguments = WorkflowAgent.RequestInfoFunctionArgs(
|
||||
request_id="call_dc",
|
||||
data=HandoffRequest(target_agent="helper", reason="overflow"),
|
||||
).to_dict()
|
||||
msg = Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content(
|
||||
type="function_call",
|
||||
call_id="call_dc",
|
||||
name="request_info",
|
||||
arguments=arguments,
|
||||
)
|
||||
],
|
||||
)
|
||||
span_exporter.clear()
|
||||
tracer = trace.get_tracer("test")
|
||||
with tracer.start_as_current_span("test_span") as span:
|
||||
_capture_messages(span=span, provider_name="test_provider", messages=[msg])
|
||||
|
||||
spans = span_exporter.get_finished_spans()
|
||||
span = spans[0]
|
||||
input_messages = json.loads(span.attributes[OtelAttr.INPUT_MESSAGES])
|
||||
tool_part = input_messages[0]["parts"][0]
|
||||
assert tool_part["type"] == "tool_call"
|
||||
assert tool_part["arguments"]["data"] == {"target_agent": "helper", "reason": "overflow"}
|
||||
|
||||
|
||||
def test_capture_messages_keeps_framework_instructions_out_of_logs_and_span_messages(
|
||||
span_exporter: InMemorySpanExporter,
|
||||
):
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
@@ -1642,6 +1643,37 @@ class TestFunctionalWorkflowAgentHITL:
|
||||
break
|
||||
assert approval_found, "expected FunctionApprovalRequestContent in agent response"
|
||||
|
||||
async def test_request_info_dataclass_arguments_are_serialized_for_agent(self):
|
||||
@dataclass
|
||||
class HandoffRequest:
|
||||
target_agent: str
|
||||
reason: str
|
||||
|
||||
@workflow
|
||||
async def wf(x: str, ctx: RunContext) -> str:
|
||||
answer = await ctx.request_info(
|
||||
HandoffRequest(target_agent=x, reason="overflow"),
|
||||
response_type=str,
|
||||
request_id="rid-1",
|
||||
)
|
||||
return f"got:{answer}"
|
||||
|
||||
agent = wf.as_agent()
|
||||
response = await agent.run("helper")
|
||||
|
||||
function_call_arguments = None
|
||||
for message in response.messages:
|
||||
for content in message.contents:
|
||||
if getattr(content, "type", None) == "function_approval_request" and content.function_call is not None:
|
||||
function_call_arguments = content.function_call.arguments
|
||||
break
|
||||
|
||||
assert function_call_arguments == {
|
||||
"request_id": "rid-1",
|
||||
"data": {"target_agent": "helper", "reason": "overflow"},
|
||||
}
|
||||
assert json.loads(json.dumps(function_call_arguments)) == function_call_arguments
|
||||
|
||||
async def test_resume_via_agent_responses_kwarg(self):
|
||||
@workflow
|
||||
async def wf(x: str, ctx: RunContext) -> str:
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Awaitable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal, overload
|
||||
|
||||
import pytest
|
||||
@@ -23,6 +25,7 @@ from agent_framework import (
|
||||
WorkflowAgent,
|
||||
WorkflowBuilder,
|
||||
WorkflowContext,
|
||||
WorkflowEvent,
|
||||
executor,
|
||||
handler,
|
||||
response_handler,
|
||||
@@ -293,6 +296,33 @@ class TestWorkflowAgent:
|
||||
# Verify cleanup - pending requests should be cleared after function response handling
|
||||
assert len(agent.pending_requests) == 0
|
||||
|
||||
def test_request_info_dataclass_arguments_are_serialized_when_content_is_created(self) -> None:
|
||||
"""Test WorkflowAgent prepares request_info arguments before observability captures messages."""
|
||||
|
||||
@dataclass
|
||||
class HandoffRequest:
|
||||
target_agent: str
|
||||
reason: str
|
||||
|
||||
executor = SimpleExecutor(id="executor1", response_text="Response")
|
||||
workflow = WorkflowBuilder(start_executor=executor).build()
|
||||
agent = WorkflowAgent(workflow=workflow, name="Request Test Agent")
|
||||
event = WorkflowEvent.request_info(
|
||||
request_id="request_123",
|
||||
source_executor_id="executor1",
|
||||
request_data=HandoffRequest(target_agent="helper", reason="overflow"),
|
||||
response_type=str,
|
||||
)
|
||||
|
||||
function_call, approval_request = agent._process_request_info_event(event) # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
assert function_call.arguments == {
|
||||
"request_id": "request_123",
|
||||
"data": {"target_agent": "helper", "reason": "overflow"},
|
||||
}
|
||||
assert approval_request.function_call is function_call
|
||||
assert json.loads(json.dumps(function_call.arguments)) == function_call.arguments
|
||||
|
||||
def test_workflow_as_agent_method(self) -> None:
|
||||
"""Test that Workflow.as_agent() creates a properly configured WorkflowAgent."""
|
||||
# Create a simple workflow
|
||||
|
||||
Reference in New Issue
Block a user