mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Fix structured_output propagation in ClaudeAgent (#4137)
* Fix structured_output propagation in ClaudeAgent Capture structured_output from ResultMessage in _get_stream() and propagate it to AgentResponse.value via a custom finalizer. Previously structured_output was silently discarded, making output_format unusable. Fixes #4095 * Address review feedback: use value parameter instead of private properties - Extend AgentResponse.from_updates() to accept optional value parameter - Remove structured_output yield from _get_stream() - Update _finalize_response() to pass value via public API - Update streaming test to use get_final_response() * Fix mypy errors: add value parameter to from_updates overloads Add value parameter to both @overload signatures of AgentResponse.from_updates() so mypy recognizes the argument. --------- Co-authored-by: Amit Mukherjee <amimukherjee@microsoft.com> Co-authored-by: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
69eabcd1fc
commit
11628c3166
@@ -618,12 +618,24 @@ class ClaudeAgent(BaseAgent, Generic[OptionsT]):
|
||||
"""
|
||||
response = ResponseStream(
|
||||
self._get_stream(messages, session=session, options=options, **kwargs),
|
||||
finalizer=AgentResponse.from_updates,
|
||||
finalizer=self._finalize_response,
|
||||
)
|
||||
if stream:
|
||||
return response
|
||||
return response.get_final_response()
|
||||
|
||||
def _finalize_response(self, updates: Sequence[AgentResponseUpdate]) -> AgentResponse[Any]:
|
||||
"""Build AgentResponse and propagate structured_output as value.
|
||||
|
||||
Args:
|
||||
updates: The collected stream updates.
|
||||
|
||||
Returns:
|
||||
An AgentResponse with structured_output set as value if present.
|
||||
"""
|
||||
structured_output = getattr(self, "_structured_output", None)
|
||||
return AgentResponse.from_updates(updates, value=structured_output)
|
||||
|
||||
async def _get_stream(
|
||||
self,
|
||||
messages: AgentRunInputs | None = None,
|
||||
@@ -647,6 +659,7 @@ class ClaudeAgent(BaseAgent, Generic[OptionsT]):
|
||||
await self._apply_runtime_options(dict(options) if options else None)
|
||||
|
||||
session_id: str | None = None
|
||||
structured_output: Any = None
|
||||
|
||||
await self._client.query(prompt)
|
||||
async for message in self._client.receive_response():
|
||||
@@ -700,7 +713,11 @@ class ClaudeAgent(BaseAgent, Generic[OptionsT]):
|
||||
error_msg = message.result or "Unknown error from Claude API"
|
||||
raise AgentException(f"Claude API error: {error_msg}")
|
||||
session_id = message.session_id
|
||||
structured_output = message.structured_output
|
||||
|
||||
# Update session with session ID
|
||||
if session_id:
|
||||
session.service_session_id = session_id
|
||||
|
||||
# Store structured output for the finalizer
|
||||
self._structured_output = structured_output
|
||||
|
||||
@@ -785,3 +785,163 @@ class TestApplyRuntimeOptions:
|
||||
await agent._apply_runtime_options(None) # type: ignore[reportPrivateUsage]
|
||||
mock_client.set_model.assert_not_called()
|
||||
mock_client.set_permission_mode.assert_not_called()
|
||||
|
||||
|
||||
# region Test ClaudeAgent Structured Output
|
||||
|
||||
|
||||
class TestClaudeAgentStructuredOutput:
|
||||
"""Tests for ClaudeAgent structured output propagation."""
|
||||
|
||||
@staticmethod
|
||||
async def _create_async_generator(items: list[Any]) -> Any:
|
||||
"""Helper to create async generator from list."""
|
||||
for item in items:
|
||||
yield item
|
||||
|
||||
def _create_mock_client(self, messages: list[Any]) -> MagicMock:
|
||||
"""Create a mock ClaudeSDKClient that yields given messages."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.connect = AsyncMock()
|
||||
mock_client.disconnect = AsyncMock()
|
||||
mock_client.query = AsyncMock()
|
||||
mock_client.set_model = AsyncMock()
|
||||
mock_client.set_permission_mode = AsyncMock()
|
||||
mock_client.receive_response = MagicMock(return_value=self._create_async_generator(messages))
|
||||
return mock_client
|
||||
|
||||
async def test_structured_output_propagated_to_response(self) -> None:
|
||||
"""Test that structured_output from ResultMessage is propagated to response.value."""
|
||||
from claude_agent_sdk import AssistantMessage, ResultMessage, TextBlock
|
||||
from claude_agent_sdk.types import StreamEvent
|
||||
|
||||
structured_data = {"name": "Alice", "age": 30}
|
||||
messages = [
|
||||
StreamEvent(
|
||||
event={
|
||||
"type": "content_block_delta",
|
||||
"delta": {"type": "text_delta", "text": '{"name": "Alice", "age": 30}'},
|
||||
},
|
||||
uuid="event-1",
|
||||
session_id="session-123",
|
||||
),
|
||||
AssistantMessage(
|
||||
content=[TextBlock(text='{"name": "Alice", "age": 30}')],
|
||||
model="claude-sonnet",
|
||||
),
|
||||
ResultMessage(
|
||||
subtype="success",
|
||||
duration_ms=100,
|
||||
duration_api_ms=50,
|
||||
is_error=False,
|
||||
num_turns=1,
|
||||
session_id="session-123",
|
||||
structured_output=structured_data,
|
||||
),
|
||||
]
|
||||
mock_client = self._create_mock_client(messages)
|
||||
|
||||
with patch("agent_framework_claude._agent.ClaudeSDKClient", return_value=mock_client):
|
||||
agent = ClaudeAgent()
|
||||
response = await agent.run("Return structured data")
|
||||
assert response.value == structured_data
|
||||
|
||||
async def test_structured_output_none_when_not_present(self) -> None:
|
||||
"""Test that response.value is None when structured_output is not present."""
|
||||
from claude_agent_sdk import AssistantMessage, ResultMessage, TextBlock
|
||||
from claude_agent_sdk.types import StreamEvent
|
||||
|
||||
messages = [
|
||||
StreamEvent(
|
||||
event={
|
||||
"type": "content_block_delta",
|
||||
"delta": {"type": "text_delta", "text": "Hello!"},
|
||||
},
|
||||
uuid="event-1",
|
||||
session_id="session-123",
|
||||
),
|
||||
AssistantMessage(
|
||||
content=[TextBlock(text="Hello!")],
|
||||
model="claude-sonnet",
|
||||
),
|
||||
ResultMessage(
|
||||
subtype="success",
|
||||
duration_ms=100,
|
||||
duration_api_ms=50,
|
||||
is_error=False,
|
||||
num_turns=1,
|
||||
session_id="session-123",
|
||||
),
|
||||
]
|
||||
mock_client = self._create_mock_client(messages)
|
||||
|
||||
with patch("agent_framework_claude._agent.ClaudeSDKClient", return_value=mock_client):
|
||||
agent = ClaudeAgent()
|
||||
response = await agent.run("Hello")
|
||||
assert response.value is None
|
||||
|
||||
async def test_structured_output_with_streaming(self) -> None:
|
||||
"""Test that structured_output is available via get_final_response after streaming."""
|
||||
from claude_agent_sdk import AssistantMessage, ResultMessage, TextBlock
|
||||
from claude_agent_sdk.types import StreamEvent
|
||||
|
||||
structured_data = {"key": "value"}
|
||||
messages = [
|
||||
StreamEvent(
|
||||
event={
|
||||
"type": "content_block_delta",
|
||||
"delta": {"type": "text_delta", "text": '{"key": "value"}'},
|
||||
},
|
||||
uuid="event-1",
|
||||
session_id="session-123",
|
||||
),
|
||||
AssistantMessage(
|
||||
content=[TextBlock(text='{"key": "value"}')],
|
||||
model="claude-sonnet",
|
||||
),
|
||||
ResultMessage(
|
||||
subtype="success",
|
||||
duration_ms=100,
|
||||
duration_api_ms=50,
|
||||
is_error=False,
|
||||
num_turns=1,
|
||||
session_id="session-123",
|
||||
structured_output=structured_data,
|
||||
),
|
||||
]
|
||||
mock_client = self._create_mock_client(messages)
|
||||
|
||||
with patch("agent_framework_claude._agent.ClaudeSDKClient", return_value=mock_client):
|
||||
agent = ClaudeAgent()
|
||||
stream = agent.run("Return structured data", stream=True)
|
||||
# Consume the stream
|
||||
async for _ in stream:
|
||||
pass
|
||||
# Structured output should be available via get_final_response
|
||||
response = await stream.get_final_response()
|
||||
assert response.value == structured_data
|
||||
|
||||
async def test_structured_output_with_error_does_not_propagate(self) -> None:
|
||||
"""Test that structured_output is not propagated when ResultMessage is an error."""
|
||||
from agent_framework.exceptions import AgentException
|
||||
from claude_agent_sdk import ResultMessage
|
||||
|
||||
messages = [
|
||||
ResultMessage(
|
||||
subtype="error",
|
||||
duration_ms=100,
|
||||
duration_api_ms=50,
|
||||
is_error=True,
|
||||
num_turns=0,
|
||||
session_id="error-session",
|
||||
result="Something went wrong",
|
||||
structured_output={"some": "data"},
|
||||
),
|
||||
]
|
||||
mock_client = self._create_mock_client(messages)
|
||||
|
||||
with patch("agent_framework_claude._agent.ClaudeSDKClient", return_value=mock_client):
|
||||
agent = ClaudeAgent()
|
||||
with pytest.raises(AgentException) as exc_info:
|
||||
await agent.run("Hello")
|
||||
assert "Something went wrong" in str(exc_info.value)
|
||||
|
||||
@@ -2256,6 +2256,7 @@ class AgentResponse(SerializationMixin, Generic[ResponseModelT]):
|
||||
updates: Sequence[AgentResponseUpdate],
|
||||
*,
|
||||
output_format_type: type[ResponseModelBoundT],
|
||||
value: Any | None = None,
|
||||
) -> AgentResponse[ResponseModelBoundT]: ...
|
||||
|
||||
@overload
|
||||
@@ -2265,6 +2266,7 @@ class AgentResponse(SerializationMixin, Generic[ResponseModelT]):
|
||||
updates: Sequence[AgentResponseUpdate],
|
||||
*,
|
||||
output_format_type: None = None,
|
||||
value: Any | None = None,
|
||||
) -> AgentResponse[Any]: ...
|
||||
|
||||
@classmethod
|
||||
@@ -2273,6 +2275,7 @@ class AgentResponse(SerializationMixin, Generic[ResponseModelT]):
|
||||
updates: Sequence[AgentResponseUpdate],
|
||||
*,
|
||||
output_format_type: type[BaseModel] | None = None,
|
||||
value: Any | None = None,
|
||||
) -> AgentResponseT:
|
||||
"""Joins multiple updates into a single AgentResponse.
|
||||
|
||||
@@ -2281,8 +2284,9 @@ class AgentResponse(SerializationMixin, Generic[ResponseModelT]):
|
||||
|
||||
Keyword Args:
|
||||
output_format_type: Optional Pydantic model type to parse the response text into structured data.
|
||||
value: Optional pre-parsed structured output value to set directly on the response.
|
||||
"""
|
||||
msg = cls(messages=[], response_format=output_format_type)
|
||||
msg = cls(messages=[], response_format=output_format_type, value=value)
|
||||
for update in updates:
|
||||
_process_update(msg, update)
|
||||
_finalize_response(msg)
|
||||
|
||||
Reference in New Issue
Block a user