mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Fix executor_completed event with non-copyable raw_representation in mixed workflows (#4493)
* Python: Fix `executor_completed` event with non-copyable raw_representation in mixed workflows Fixes #4455 * fix(#4455): use class-level sets for deepcopy field exclusion - SerializationMixin.__deepcopy__: check type(self).DEFAULT_EXCLUDE instead of hardcoding 'raw_representation' - Content.__deepcopy__: add _SHALLOW_COPY_FIELDS class variable and check against it instead of hardcoding - Fix tautological assertion in test (was always True) - Add second excluded field to test to verify DEFAULT_EXCLUDE is respected generically Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Decouple __deepcopy__ from DEFAULT_EXCLUDE in SerializationMixin (#4455) Introduce _SHALLOW_COPY_FIELDS class variable in SerializationMixin to separate deep-copy semantics from serialization semantics. Previously, __deepcopy__ used DEFAULT_EXCLUDE to decide which fields to shallow-copy, conflating 'not serialized' with 'not safe to deep-copy'. A field added to DEFAULT_EXCLUDE purely for serialization (e.g. additional_properties) would be silently shared between original and copy. - Add _SHALLOW_COPY_FIELDS (default {'raw_representation'}) to SerializationMixin, matching the pattern already used by Content - Update __deepcopy__ to read from _SHALLOW_COPY_FIELDS instead of DEFAULT_EXCLUDE - Add test verifying DEFAULT_EXCLUDE fields are deep-copied unless also in _SHALLOW_COPY_FIELDS - Add test for Content._SHALLOW_COPY_FIELDS identity preservation - Add test for ChatResponse deep-copying additional_properties Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Add test for _SHALLOW_COPY_FIELDS and DEFAULT_EXCLUDE independence Add test_deepcopy_shallow_copy_fields_override_default_exclude to verify that a field in both DEFAULT_EXCLUDE and _SHALLOW_COPY_FIELDS is shallow-copied (controlled by _SHALLOW_COPY_FIELDS), while a field in DEFAULT_EXCLUDE only is still deep-copied. This addresses review comment #11 ensuring the two class variables control independent concerns. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Remove unnecessary local variable in __deepcopy__ Inline cls._SHALLOW_COPY_FIELDS directly in the loop check instead of assigning to a local variable first, per review feedback. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Apply pre-commit auto-fixes --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
a3bfad4791
commit
e35f530f2e
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
@@ -263,6 +264,25 @@ class SerializationMixin:
|
||||
|
||||
DEFAULT_EXCLUDE: ClassVar[set[str]] = set()
|
||||
INJECTABLE: ClassVar[set[str]] = set()
|
||||
_SHALLOW_COPY_FIELDS: ClassVar[set[str]] = {"raw_representation"}
|
||||
|
||||
def __deepcopy__(self, memo: dict[int, Any]) -> SerializationMixin:
|
||||
"""Create a deep copy, preserving ``_SHALLOW_COPY_FIELDS`` by reference.
|
||||
|
||||
Fields listed in ``_SHALLOW_COPY_FIELDS`` may contain LLM SDK objects
|
||||
(e.g., proto/gRPC responses) that are not safe to deep-copy. They are
|
||||
kept as shallow references in the copy; all other attributes are
|
||||
deep-copied normally.
|
||||
"""
|
||||
cls = type(self)
|
||||
result = cls.__new__(cls)
|
||||
memo[id(self)] = result
|
||||
for k, v in self.__dict__.items():
|
||||
if k in cls._SHALLOW_COPY_FIELDS:
|
||||
object.__setattr__(result, k, v)
|
||||
else:
|
||||
object.__setattr__(result, k, copy.deepcopy(v, memo))
|
||||
return result
|
||||
|
||||
def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]:
|
||||
"""Convert the instance and any nested objects to a dictionary.
|
||||
|
||||
@@ -445,6 +445,8 @@ class Content:
|
||||
`Content.from_uri()`, etc. to create instances.
|
||||
"""
|
||||
|
||||
_SHALLOW_COPY_FIELDS: ClassVar[set[str]] = {"raw_representation"}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
type: ContentType,
|
||||
@@ -546,6 +548,23 @@ class Content:
|
||||
self.approved = approved
|
||||
self.consent_link = consent_link
|
||||
|
||||
def __deepcopy__(self, memo: dict[int, Any]) -> Content:
|
||||
"""Create a deep copy, preserving ``_SHALLOW_COPY_FIELDS`` by reference.
|
||||
|
||||
Fields listed in ``_SHALLOW_COPY_FIELDS`` may contain LLM SDK objects
|
||||
(e.g., proto/gRPC responses) that are not safe to deep-copy.
|
||||
"""
|
||||
cls = type(self)
|
||||
result = cls.__new__(cls)
|
||||
memo[id(self)] = result
|
||||
shallow = cls._SHALLOW_COPY_FIELDS
|
||||
for k, v in self.__dict__.items():
|
||||
if k in shallow:
|
||||
object.__setattr__(result, k, v)
|
||||
else:
|
||||
object.__setattr__(result, k, deepcopy(v, memo))
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_text(
|
||||
cls: type[ContentT],
|
||||
|
||||
@@ -427,3 +427,103 @@ class TestSerializationMixin:
|
||||
|
||||
assert obj.options["existing"] == "value"
|
||||
assert obj.options["injected"] == "option"
|
||||
|
||||
def test_deepcopy_preserves_shallow_copy_fields_by_reference(self):
|
||||
"""Test that deepcopy keeps _SHALLOW_COPY_FIELDS fields as shallow references."""
|
||||
import copy
|
||||
|
||||
class NonCopyable:
|
||||
def __deepcopy__(self, memo):
|
||||
raise TypeError("cannot deepcopy")
|
||||
|
||||
class TestClass(SerializationMixin):
|
||||
_SHALLOW_COPY_FIELDS = {"raw_representation", "other_opaque"}
|
||||
|
||||
def __init__(self, items: list, raw_representation: Any = None, other_opaque: Any = None):
|
||||
self.items = items
|
||||
self.raw_representation = raw_representation
|
||||
self.other_opaque = other_opaque
|
||||
|
||||
raw = NonCopyable()
|
||||
opaque = NonCopyable()
|
||||
original_items = ["a", "b"]
|
||||
obj = TestClass(items=original_items, raw_representation=raw, other_opaque=opaque)
|
||||
cloned = copy.deepcopy(obj)
|
||||
|
||||
# _SHALLOW_COPY_FIELDS fields should be the same object (shallow copy)
|
||||
assert cloned.raw_representation is raw
|
||||
assert cloned.other_opaque is opaque
|
||||
# Normal attributes should be independent copies
|
||||
assert cloned.items is not original_items
|
||||
assert cloned.items == ["a", "b"]
|
||||
|
||||
def test_deepcopy_deep_copies_non_shallow_copy_fields(self):
|
||||
"""Test that deepcopy fully copies fields not in _SHALLOW_COPY_FIELDS."""
|
||||
import copy
|
||||
|
||||
class TestClass(SerializationMixin):
|
||||
_SHALLOW_COPY_FIELDS = {"raw_representation"}
|
||||
|
||||
def __init__(self, items: list, raw_representation: Any = None):
|
||||
self.items = items
|
||||
self.raw_representation = raw_representation
|
||||
|
||||
original_list = ["a", "b"]
|
||||
obj = TestClass(items=original_list, raw_representation="raw")
|
||||
cloned = copy.deepcopy(obj)
|
||||
|
||||
# list should be a new object
|
||||
assert cloned.items is not original_list
|
||||
assert cloned.items == ["a", "b"]
|
||||
# raw_representation should be the same object
|
||||
assert cloned.raw_representation is obj.raw_representation
|
||||
|
||||
def test_deepcopy_deep_copies_default_exclude_fields(self):
|
||||
"""Test that DEFAULT_EXCLUDE fields are deep-copied unless also in _SHALLOW_COPY_FIELDS."""
|
||||
import copy
|
||||
|
||||
class TestClass(SerializationMixin):
|
||||
DEFAULT_EXCLUDE = {"additional_properties"}
|
||||
|
||||
def __init__(self, items: list, additional_properties: dict | None = None):
|
||||
self.items = items
|
||||
self.additional_properties = additional_properties or {}
|
||||
|
||||
original_props = {"key": "value"}
|
||||
obj = TestClass(items=["a"], additional_properties=original_props)
|
||||
cloned = copy.deepcopy(obj)
|
||||
|
||||
# DEFAULT_EXCLUDE field should be deep-copied (independent copy)
|
||||
assert cloned.additional_properties is not original_props
|
||||
assert cloned.additional_properties == {"key": "value"}
|
||||
|
||||
def test_deepcopy_shallow_copy_fields_override_default_exclude(self):
|
||||
"""Test that _SHALLOW_COPY_FIELDS controls deepcopy independently of DEFAULT_EXCLUDE."""
|
||||
import copy
|
||||
|
||||
class NonCopyable:
|
||||
def __deepcopy__(self, memo):
|
||||
raise TypeError("cannot deepcopy")
|
||||
|
||||
class TestClass(SerializationMixin):
|
||||
DEFAULT_EXCLUDE = {"opaque", "additional_properties"}
|
||||
_SHALLOW_COPY_FIELDS = {"opaque"}
|
||||
|
||||
def __init__(self, items: list, opaque: Any = None, additional_properties: dict | None = None):
|
||||
self.items = items
|
||||
self.opaque = opaque
|
||||
self.additional_properties = additional_properties or {}
|
||||
|
||||
opaque = NonCopyable()
|
||||
original_props = {"key": "value"}
|
||||
obj = TestClass(items=["a"], opaque=opaque, additional_properties=original_props)
|
||||
cloned = copy.deepcopy(obj)
|
||||
|
||||
# Field in both DEFAULT_EXCLUDE and _SHALLOW_COPY_FIELDS: shallow-copied
|
||||
assert cloned.opaque is opaque
|
||||
# Field in DEFAULT_EXCLUDE only: deep-copied
|
||||
assert cloned.additional_properties is not original_props
|
||||
assert cloned.additional_properties == {"key": "value"}
|
||||
# Normal field: deep-copied
|
||||
assert cloned.items is not obj.items
|
||||
assert cloned.items == ["a"]
|
||||
|
||||
@@ -1860,6 +1860,170 @@ def test_agent_run_response_update_all_content_types():
|
||||
assert update_str.role == "user"
|
||||
|
||||
|
||||
# region DeepCopy
|
||||
|
||||
|
||||
class _NonCopyableRaw:
|
||||
"""Simulates an LLM SDK response object that cannot be deep-copied (e.g., proto/gRPC)."""
|
||||
|
||||
def __deepcopy__(self, memo: dict) -> Any:
|
||||
raise TypeError("Cannot deepcopy this object")
|
||||
|
||||
|
||||
def test_content_deepcopy_preserves_raw_representation():
|
||||
"""Test that deepcopy of Content keeps raw_representation by reference."""
|
||||
import copy
|
||||
|
||||
raw = _NonCopyableRaw()
|
||||
content = Content.from_text("hello", raw_representation=raw)
|
||||
|
||||
cloned = copy.deepcopy(content)
|
||||
|
||||
assert cloned.text == "hello"
|
||||
assert cloned.raw_representation is raw
|
||||
assert cloned.additional_properties is not content.additional_properties
|
||||
|
||||
|
||||
def test_message_deepcopy_preserves_raw_representation():
|
||||
"""Test that deepcopy of Message keeps raw_representation by reference."""
|
||||
import copy
|
||||
|
||||
raw = _NonCopyableRaw()
|
||||
msg = Message("assistant", ["hello"], raw_representation=raw)
|
||||
|
||||
cloned = copy.deepcopy(msg)
|
||||
|
||||
assert cloned.text == "hello"
|
||||
assert cloned.raw_representation is raw
|
||||
assert cloned.contents is not msg.contents
|
||||
|
||||
|
||||
def test_agent_response_deepcopy_preserves_raw_representation():
|
||||
"""Test that deepcopy of AgentResponse keeps raw_representation by reference."""
|
||||
import copy
|
||||
|
||||
raw = _NonCopyableRaw()
|
||||
response = AgentResponse(
|
||||
messages=[Message("assistant", ["test"])],
|
||||
raw_representation=raw,
|
||||
)
|
||||
|
||||
cloned = copy.deepcopy(response)
|
||||
|
||||
assert cloned.text == "test"
|
||||
assert cloned.raw_representation is raw
|
||||
assert cloned.messages is not response.messages
|
||||
|
||||
|
||||
def test_chat_response_deepcopy_preserves_raw_representation():
|
||||
"""Test that deepcopy of ChatResponse keeps raw_representation by reference."""
|
||||
import copy
|
||||
|
||||
raw = _NonCopyableRaw()
|
||||
response = ChatResponse(
|
||||
messages=[Message("assistant", ["test"])],
|
||||
raw_representation=raw,
|
||||
)
|
||||
|
||||
cloned = copy.deepcopy(response)
|
||||
|
||||
assert cloned.text == "test"
|
||||
assert cloned.raw_representation is raw
|
||||
assert cloned.messages is not response.messages
|
||||
|
||||
|
||||
def test_chat_response_update_deepcopy_preserves_raw_representation():
|
||||
"""Test that deepcopy of ChatResponseUpdate keeps raw_representation by reference."""
|
||||
import copy
|
||||
|
||||
raw = _NonCopyableRaw()
|
||||
update = ChatResponseUpdate(
|
||||
contents=[Content.from_text("hello")],
|
||||
role="assistant",
|
||||
raw_representation=raw,
|
||||
)
|
||||
|
||||
cloned = copy.deepcopy(update)
|
||||
|
||||
assert cloned.text == "hello"
|
||||
assert cloned.raw_representation is raw
|
||||
assert cloned.contents is not update.contents
|
||||
|
||||
|
||||
def test_agent_response_update_deepcopy_preserves_raw_representation():
|
||||
"""Test that deepcopy of AgentResponseUpdate keeps raw_representation by reference."""
|
||||
import copy
|
||||
|
||||
raw = _NonCopyableRaw()
|
||||
update = AgentResponseUpdate(
|
||||
contents=[Content.from_text("hello")],
|
||||
role="assistant",
|
||||
raw_representation=raw,
|
||||
)
|
||||
|
||||
cloned = copy.deepcopy(update)
|
||||
|
||||
assert cloned.text == "hello"
|
||||
assert cloned.raw_representation is raw
|
||||
assert cloned.contents is not update.contents
|
||||
|
||||
|
||||
def test_nested_deepcopy_preserves_raw_representation():
|
||||
"""Test that deepcopy of an AgentResponse with nested Message raw_representations works."""
|
||||
import copy
|
||||
|
||||
raw_msg = _NonCopyableRaw()
|
||||
raw_response = _NonCopyableRaw()
|
||||
response = AgentResponse(
|
||||
messages=[Message("assistant", ["hello"], raw_representation=raw_msg)],
|
||||
raw_representation=raw_response,
|
||||
)
|
||||
|
||||
cloned = copy.deepcopy(response)
|
||||
|
||||
assert cloned.raw_representation is raw_response
|
||||
assert cloned.messages[0].raw_representation is raw_msg
|
||||
assert cloned.messages is not response.messages
|
||||
assert cloned.text == "hello"
|
||||
|
||||
|
||||
def test_content_deepcopy_shallow_copy_fields_identity():
|
||||
"""Test that Content._SHALLOW_COPY_FIELDS fields are identity-preserved while others are deep-copied."""
|
||||
import copy
|
||||
|
||||
raw = _NonCopyableRaw()
|
||||
content = Content.from_text("hello", raw_representation=raw)
|
||||
content.additional_properties["key"] = "value"
|
||||
|
||||
cloned = copy.deepcopy(content)
|
||||
|
||||
# _SHALLOW_COPY_FIELDS (raw_representation) should be same object
|
||||
assert cloned.raw_representation is raw
|
||||
# Non-shallow fields should be independent deep copies
|
||||
assert cloned.additional_properties is not content.additional_properties
|
||||
assert cloned.additional_properties == {"key": "value"}
|
||||
|
||||
|
||||
def test_chat_response_deepcopy_deep_copies_additional_properties():
|
||||
"""Test that ChatResponse deepcopy deep-copies additional_properties despite it being in DEFAULT_EXCLUDE."""
|
||||
import copy
|
||||
|
||||
response = ChatResponse(
|
||||
messages=[Message("assistant", ["test"])],
|
||||
additional_properties={"key": [1, 2, 3]},
|
||||
)
|
||||
|
||||
cloned = copy.deepcopy(response)
|
||||
|
||||
# additional_properties is in DEFAULT_EXCLUDE for serialization but not in _SHALLOW_COPY_FIELDS,
|
||||
# so it should be deep-copied (independent copy)
|
||||
assert cloned.additional_properties is not response.additional_properties
|
||||
assert cloned.additional_properties == {"key": [1, 2, 3]}
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region Serialization
|
||||
|
||||
|
||||
|
||||
@@ -383,3 +383,60 @@ async def test_agent_executor_run_with_messages_kwarg_does_not_raise() -> None:
|
||||
result = await workflow.run("hello", messages=["stale"])
|
||||
assert result is not None
|
||||
assert agent.call_count == 1
|
||||
|
||||
|
||||
class _NonCopyableRaw:
|
||||
"""Simulates an LLM SDK response object that cannot be deep-copied (e.g., proto/gRPC)."""
|
||||
|
||||
def __deepcopy__(self, memo: dict) -> Any:
|
||||
raise TypeError("Cannot deepcopy this object")
|
||||
|
||||
|
||||
class _AgentWithRawRepr(BaseAgent):
|
||||
"""Agent that returns responses with a non-copyable raw_representation."""
|
||||
|
||||
def __init__(self, raw: Any, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
self._raw = raw
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | list[str] | list[Message] | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
|
||||
async def _run() -> AgentResponse:
|
||||
return AgentResponse(
|
||||
messages=[Message("assistant", [f"reply from {self.name}"])],
|
||||
raw_representation=self._raw,
|
||||
)
|
||||
|
||||
return _run()
|
||||
|
||||
|
||||
async def test_agent_executor_workflow_with_non_copyable_raw_representation() -> None:
|
||||
"""Workflow should complete when AgentResponse contains a raw_representation that cannot be deep-copied."""
|
||||
raw = _NonCopyableRaw()
|
||||
|
||||
agent_a = _AgentWithRawRepr(raw=raw, id="a", name="AgentA")
|
||||
agent_b = _CountingAgent(id="b", name="AgentB")
|
||||
|
||||
exec_a = AgentExecutor(agent_a, id="exec_a")
|
||||
exec_b = AgentExecutor(agent_b, id="exec_b")
|
||||
|
||||
workflow = SequentialBuilder(participants=[exec_a, exec_b]).build()
|
||||
events = await workflow.run("hello")
|
||||
|
||||
completed = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_completed"]
|
||||
completed_a = [e for e in completed if e.executor_id == "exec_a"]
|
||||
|
||||
assert len(completed_a) == 1
|
||||
assert completed_a[0].data is not None
|
||||
|
||||
# The yielded AgentResponse should preserve its raw_representation reference
|
||||
agent_responses = [d for d in completed_a[0].data if isinstance(d, AgentResponse)]
|
||||
assert len(agent_responses) > 0
|
||||
assert agent_responses[0].text == "reply from AgentA"
|
||||
assert agent_responses[0].raw_representation is raw
|
||||
|
||||
Reference in New Issue
Block a user