mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: fix Foundry handoff argument serialization (#5861)
This commit is contained in:
committed by
GitHub
Unverified
parent
578416a379
commit
cf91819625
@@ -9,8 +9,9 @@ import logging
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
from collections.abc import AsyncIterable, AsyncIterator, Generator, Mapping, Sequence
|
||||
from collections.abc import AsyncIterable, AsyncIterator, Generator, Sequence
|
||||
from contextlib import suppress
|
||||
from dataclasses import asdict, is_dataclass
|
||||
from pathlib import Path
|
||||
from contextlib import AbstractAsyncContextManager, AsyncExitStack, suppress
|
||||
from typing import Protocol, cast
|
||||
@@ -1505,11 +1506,20 @@ def _convert_message_content(content: MessageContent) -> Content:
|
||||
# region Output Item Conversion
|
||||
|
||||
|
||||
def _arguments_to_str(arguments: str | Mapping[str, Any] | None) -> str:
|
||||
def _argument_json_default(value: Any) -> Any:
|
||||
if is_dataclass(value) and not isinstance(value, type):
|
||||
return asdict(value)
|
||||
to_dict = getattr(value, "to_dict", None)
|
||||
if callable(to_dict):
|
||||
return to_dict()
|
||||
raise TypeError(f"Object of type {type(value).__name__} is not JSON serializable")
|
||||
|
||||
|
||||
def _arguments_to_str(arguments: Any | None) -> str:
|
||||
"""Convert arguments to a JSON string.
|
||||
|
||||
Args:
|
||||
arguments: The arguments to convert, can be a string, mapping, or None.
|
||||
arguments: The arguments to convert, can be a string, JSON-like object, or None.
|
||||
|
||||
Returns:
|
||||
The arguments as a JSON string.
|
||||
@@ -1518,7 +1528,7 @@ def _arguments_to_str(arguments: str | Mapping[str, Any] | None) -> str:
|
||||
return ""
|
||||
if isinstance(arguments, str):
|
||||
return arguments
|
||||
return json.dumps(arguments)
|
||||
return json.dumps(arguments, default=_argument_json_default)
|
||||
|
||||
|
||||
async def _to_outputs(
|
||||
|
||||
@@ -12,6 +12,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from dataclasses import dataclass
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import httpx
|
||||
@@ -405,6 +406,36 @@ class TestStreaming:
|
||||
assert len(args_done) == 1
|
||||
assert args_done[0]["data"]["arguments"] == '{"q": "hello"}'
|
||||
|
||||
async def test_function_call_streaming_serializes_dataclass_arguments(self) -> None:
|
||||
@dataclass
|
||||
class HandoffLikeRequest:
|
||||
agent_response: AgentResponse
|
||||
|
||||
request = HandoffLikeRequest(
|
||||
agent_response=AgentResponse(
|
||||
messages=[Message(role="assistant", contents=[Content.from_text("Need more details")])]
|
||||
)
|
||||
)
|
||||
agent = _make_agent(
|
||||
stream_updates=[
|
||||
AgentResponseUpdate(
|
||||
contents=[Content.from_function_call("call_1", "handoff_to_refund", arguments=request)],
|
||||
role="assistant",
|
||||
),
|
||||
]
|
||||
)
|
||||
server = _make_server(agent)
|
||||
resp = await _post(server, stream=True)
|
||||
|
||||
assert resp.status_code == 200
|
||||
events = _parse_sse_events(resp.text)
|
||||
args_done = [e for e in events if e["event"] == "response.function_call_arguments.done"]
|
||||
assert len(args_done) == 1
|
||||
|
||||
payload = json.loads(args_done[0]["data"]["arguments"])
|
||||
assert payload["agent_response"]["type"] == "agent_response"
|
||||
assert payload["agent_response"]["messages"][0]["contents"][0]["text"] == "Need more details"
|
||||
|
||||
async def test_alternating_text_and_function_call(self) -> None:
|
||||
agent = _make_agent(
|
||||
stream_updates=[
|
||||
|
||||
Reference in New Issue
Block a user