mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Fix state snapshot to use deepcopy so nested mutations are detected in durable workflow activities (#4518)
* Use deepcopy for state snapshot to detect nested mutations (#4500) Replace dict() shallow copy with copy.deepcopy() when snapshotting workflow state before activity execution. The shallow copy shared references to nested objects (dicts, lists), so in-place mutations by executors were reflected in both the snapshot and live state, producing an empty diff and preventing state updates from propagating to downstream activities. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Python: Fix state snapshot to use deepcopy so nested mutations are detected in durable workflow activities Fixes #4500 * Address PR review: remove report, extract testable helpers (#4500) - Delete REPRODUCTION_REPORT.md (debugging artifact with local paths and raw LLM output) - Extract _create_state_snapshot() and _compute_state_updates() as module-level helpers in _app.py so tests exercise the production code path - Update TestStateSnapshotDiff to import and use production helpers instead of reimplementing snapshot/diff logic locally Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Apply pre-commit auto-fixes * Add regression tests proving shallow copy bug and deep copy isolation (#4500) Add two additional tests to TestStateSnapshotDiff: - test_shallow_copy_would_miss_nested_mutations: reproduces the original bug by demonstrating that dict() (shallow copy) misses nested mutations - test_create_state_snapshot_isolates_nested_objects: verifies the production _create_state_snapshot helper creates a true deep copy These tests ensure a regression back to shallow copy would be caught. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Add integration test exercising full activity code path (#4500) Address PR review comment: add test_executor_activity_detects_nested_state_mutations that captures the actual executor_activity function from _setup_executor_activity and verifies it detects in-place nested mutations. This test would fail if _app.py line 314 regressed from _create_state_snapshot() back to dict(). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Address review feedback for #4518: review comment fixes * Address PR review feedback for state snapshot diff - Inline _compute_state_updates logic at call site to reuse precomputed original_keys/current_keys sets, avoiding redundant set allocations - Fix test docstring to describe behavioral regression instead of hard-coding a specific line number - Use SOURCE_ORCHESTRATOR constant in integration test instead of literal string Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Apply pre-commit auto-fixes * fix: remove unused _compute_state_updates from _app.py (#4518) The function was inlined per review comment, making the module-level helper unused and triggering a pyright reportUnusedFunction error. Move the helper into the test file where it is still needed for unit testing the diffing logic. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
aa2ff672fb
commit
ed2fb3b9dd
@@ -14,6 +14,7 @@ import logging
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import Callable, Mapping
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, TypeVar, cast
|
||||
@@ -58,6 +59,11 @@ EntityHandler = Callable[[df.DurableEntityContext], None]
|
||||
HandlerT = TypeVar("HandlerT", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def _create_state_snapshot(state: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Create a deep copy of the deserialized state for later diffing."""
|
||||
return deepcopy(state)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentMetadata:
|
||||
"""Metadata for a registered agent.
|
||||
@@ -306,7 +312,7 @@ class AgentFunctionApp(DFAppBase):
|
||||
deserialized_state: dict[str, Any] = {
|
||||
str(k): deserialize_value(v) for k, v in shared_state_snapshot.items()
|
||||
}
|
||||
original_snapshot: dict[str, Any] = dict(deserialized_state)
|
||||
original_snapshot = _create_state_snapshot(deserialized_state)
|
||||
shared_state.import_state(deserialized_state)
|
||||
|
||||
if is_hitl_response:
|
||||
@@ -339,9 +345,10 @@ class AgentFunctionApp(DFAppBase):
|
||||
deletes: set[str] = original_keys - current_keys
|
||||
|
||||
# Updates = keys in current that are new or have different values
|
||||
updates = {
|
||||
k: v for k, v in current_state.items() if k not in original_snapshot or original_snapshot[k] != v
|
||||
}
|
||||
updates: dict[str, Any] = {}
|
||||
for key in current_keys:
|
||||
if key not in original_keys or current_state[key] != original_snapshot.get(key):
|
||||
updates[key] = current_state[key]
|
||||
|
||||
# Drain messages and events from runner context
|
||||
sent_messages = await runner_context.drain_messages()
|
||||
|
||||
@@ -26,6 +26,7 @@ from agent_framework_durabletask import (
|
||||
|
||||
from agent_framework_azurefunctions import AgentFunctionApp
|
||||
from agent_framework_azurefunctions._entities import create_agent_entity
|
||||
from agent_framework_azurefunctions._workflow import SOURCE_ORCHESTRATOR
|
||||
|
||||
FuncT = TypeVar("FuncT", bound=Callable[..., Any])
|
||||
|
||||
@@ -1441,5 +1442,286 @@ class TestAgentFunctionAppWorkflow:
|
||||
assert "instance-456" in url
|
||||
|
||||
|
||||
def _compute_state_updates(original_snapshot: dict[str, Any], current_state: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Compute state updates by comparing current state against the original snapshot.
|
||||
|
||||
This mirrors the inlined logic in ``_app.py``'s ``executor_activity.run()``.
|
||||
"""
|
||||
original_keys = set(original_snapshot.keys())
|
||||
current_keys = set(current_state.keys())
|
||||
updates: dict[str, Any] = {}
|
||||
for key in current_keys:
|
||||
if key not in original_keys or current_state[key] != original_snapshot.get(key):
|
||||
updates[key] = current_state[key]
|
||||
return updates
|
||||
|
||||
|
||||
class TestStateSnapshotDiff:
|
||||
"""Test suite for state snapshot diffing in activity execution.
|
||||
|
||||
The activity executor snapshots state before execution and diffs against the
|
||||
post-execution state to determine which keys were updated. These tests exercise
|
||||
the production snapshot helper and the state-update diffing logic to ensure that
|
||||
in-place mutations to nested objects (dicts, lists) are correctly detected as changes.
|
||||
"""
|
||||
|
||||
def test_nested_dict_mutation_detected_in_diff(self) -> None:
|
||||
"""Test that mutating values inside a nested dict appears in the diff."""
|
||||
from agent_framework._workflows._state import State
|
||||
|
||||
from agent_framework_azurefunctions._app import _create_state_snapshot
|
||||
|
||||
deserialized_state: dict[str, Any] = {
|
||||
"Local.config": {"code": "", "enabled": False},
|
||||
"simple_key": "simple_value",
|
||||
}
|
||||
|
||||
original_snapshot = _create_state_snapshot(deserialized_state)
|
||||
|
||||
shared_state = State()
|
||||
shared_state.import_state(deserialized_state)
|
||||
|
||||
config = shared_state.get("Local.config")
|
||||
config["code"] = "SOMECODEXXX"
|
||||
config["enabled"] = True
|
||||
|
||||
shared_state.commit()
|
||||
current_state = shared_state.export_state()
|
||||
|
||||
updates = _compute_state_updates(original_snapshot, current_state)
|
||||
|
||||
assert "Local.config" in updates
|
||||
assert updates["Local.config"]["code"] == "SOMECODEXXX"
|
||||
assert updates["Local.config"]["enabled"] is True
|
||||
|
||||
def test_new_key_in_nested_dict_detected_in_diff(self) -> None:
|
||||
"""Test that adding a key to a nested dict appears in the diff."""
|
||||
from agent_framework._workflows._state import State
|
||||
|
||||
from agent_framework_azurefunctions._app import _create_state_snapshot
|
||||
|
||||
deserialized_state: dict[str, Any] = {
|
||||
"Local.data": {"existing": "value"},
|
||||
}
|
||||
|
||||
original_snapshot = _create_state_snapshot(deserialized_state)
|
||||
|
||||
shared_state = State()
|
||||
shared_state.import_state(deserialized_state)
|
||||
|
||||
data = shared_state.get("Local.data")
|
||||
data["code"] = "NEW_CODE"
|
||||
|
||||
shared_state.commit()
|
||||
current_state = shared_state.export_state()
|
||||
|
||||
updates = _compute_state_updates(original_snapshot, current_state)
|
||||
|
||||
assert "Local.data" in updates
|
||||
assert updates["Local.data"]["code"] == "NEW_CODE"
|
||||
|
||||
def test_nested_list_mutation_detected_in_diff(self) -> None:
|
||||
"""Test that appending to a nested list appears in the diff."""
|
||||
from agent_framework._workflows._state import State
|
||||
|
||||
from agent_framework_azurefunctions._app import _create_state_snapshot
|
||||
|
||||
deserialized_state: dict[str, Any] = {
|
||||
"Local.items": [1, 2, 3],
|
||||
}
|
||||
|
||||
original_snapshot = _create_state_snapshot(deserialized_state)
|
||||
|
||||
shared_state = State()
|
||||
shared_state.import_state(deserialized_state)
|
||||
|
||||
items = shared_state.get("Local.items")
|
||||
items.append(4)
|
||||
|
||||
shared_state.commit()
|
||||
current_state = shared_state.export_state()
|
||||
|
||||
updates = _compute_state_updates(original_snapshot, current_state)
|
||||
|
||||
assert "Local.items" in updates
|
||||
assert updates["Local.items"] == [1, 2, 3, 4]
|
||||
|
||||
def test_new_top_level_key_detected_in_diff(self) -> None:
|
||||
"""Test that setting a new top-level key appears in the diff."""
|
||||
from agent_framework._workflows._state import State
|
||||
|
||||
from agent_framework_azurefunctions._app import _create_state_snapshot
|
||||
|
||||
deserialized_state: dict[str, Any] = {
|
||||
"existing": "value",
|
||||
}
|
||||
|
||||
original_snapshot = _create_state_snapshot(deserialized_state)
|
||||
|
||||
shared_state = State()
|
||||
shared_state.import_state(deserialized_state)
|
||||
|
||||
shared_state.set("Local.code", "SOMECODEXXX")
|
||||
|
||||
shared_state.commit()
|
||||
current_state = shared_state.export_state()
|
||||
|
||||
updates = _compute_state_updates(original_snapshot, current_state)
|
||||
|
||||
assert "Local.code" in updates
|
||||
assert updates["Local.code"] == "SOMECODEXXX"
|
||||
|
||||
def test_unchanged_nested_state_produces_empty_diff(self) -> None:
|
||||
"""Test that unmodified nested state produces no updates."""
|
||||
from agent_framework._workflows._state import State
|
||||
|
||||
from agent_framework_azurefunctions._app import _create_state_snapshot
|
||||
|
||||
deserialized_state: dict[str, Any] = {
|
||||
"Local.config": {"code": "existing", "enabled": True},
|
||||
"simple_key": "simple_value",
|
||||
}
|
||||
|
||||
original_snapshot = _create_state_snapshot(deserialized_state)
|
||||
|
||||
shared_state = State()
|
||||
shared_state.import_state(deserialized_state)
|
||||
|
||||
# No mutations performed
|
||||
shared_state.commit()
|
||||
current_state = shared_state.export_state()
|
||||
|
||||
updates = _compute_state_updates(original_snapshot, current_state)
|
||||
|
||||
assert updates == {}
|
||||
|
||||
def test_shallow_copy_would_miss_nested_mutations(self) -> None:
|
||||
"""Regression test: a shallow copy (dict()) shares nested refs, hiding mutations.
|
||||
|
||||
This reproduces the original bug from #4500 where ``dict(deserialized_state)``
|
||||
was used instead of ``copy.deepcopy()``. With a shallow copy the snapshot and
|
||||
the live state share nested objects, so in-place mutations appear in both and
|
||||
the diff produces an empty update set.
|
||||
"""
|
||||
from agent_framework._workflows._state import State
|
||||
|
||||
deserialized_state: dict[str, Any] = {
|
||||
"Local.config": {"code": "", "enabled": False},
|
||||
}
|
||||
|
||||
# Shallow copy (the OLD, buggy behaviour)
|
||||
shallow_snapshot = dict(deserialized_state)
|
||||
|
||||
shared_state = State()
|
||||
shared_state.import_state(deserialized_state)
|
||||
|
||||
config = shared_state.get("Local.config")
|
||||
config["code"] = "SOMECODEXXX"
|
||||
config["enabled"] = True
|
||||
|
||||
shared_state.commit()
|
||||
current_state = shared_state.export_state()
|
||||
|
||||
# With a shallow copy the mutation leaks into the snapshot → empty diff
|
||||
updates_shallow = _compute_state_updates(shallow_snapshot, current_state)
|
||||
assert updates_shallow == {}, "shallow copy should miss nested mutations (demonstrating the bug)"
|
||||
|
||||
def test_create_state_snapshot_isolates_nested_objects(self) -> None:
|
||||
"""Verify _create_state_snapshot produces a deep copy that is mutation-proof.
|
||||
|
||||
This ensures the production snapshot helper is not equivalent to ``dict()``
|
||||
and will correctly isolate nested objects so that later mutations are detected.
|
||||
"""
|
||||
from agent_framework_azurefunctions._app import _create_state_snapshot
|
||||
|
||||
original: dict[str, Any] = {
|
||||
"nested_dict": {"a": 1},
|
||||
"nested_list": [1, 2, 3],
|
||||
}
|
||||
|
||||
snapshot = _create_state_snapshot(original)
|
||||
|
||||
# Mutate the originals in place
|
||||
original["nested_dict"]["a"] = 999
|
||||
original["nested_list"].append(4)
|
||||
|
||||
# Snapshot must be unaffected
|
||||
assert snapshot["nested_dict"]["a"] == 1
|
||||
assert snapshot["nested_list"] == [1, 2, 3]
|
||||
|
||||
def test_executor_activity_detects_nested_state_mutations(self) -> None:
|
||||
"""Integration test: the full activity wrapper detects nested mutations.
|
||||
|
||||
This exercises the actual executor_activity function registered by
|
||||
_setup_executor_activity to verify the production code path uses
|
||||
_create_state_snapshot (deep copy) rather than dict() (shallow copy).
|
||||
If the implementation regressed to using a shallow copy such as
|
||||
``dict(deserialized_state)``, this test would fail because in-place
|
||||
mutations would leak into the snapshot and produce an empty diff.
|
||||
"""
|
||||
mock_executor = Mock()
|
||||
mock_executor.id = "test-exec"
|
||||
|
||||
async def mutate_nested_state(
|
||||
message: Any,
|
||||
source_executor_ids: Any,
|
||||
state: Any,
|
||||
runner_context: Any,
|
||||
) -> None:
|
||||
config = state.get("Local.config")
|
||||
config["code"] = "MUTATED"
|
||||
config["enabled"] = True
|
||||
state.commit()
|
||||
|
||||
mock_executor.execute = AsyncMock(side_effect=mutate_nested_state)
|
||||
|
||||
mock_workflow = Mock()
|
||||
mock_workflow.executors = {"test-exec": mock_executor}
|
||||
|
||||
# Capture the activity function by making decorators pass-through
|
||||
captured_activity: dict[str, Any] = {}
|
||||
|
||||
def passthrough_function_name(name: str) -> Callable[[FuncT], FuncT]:
|
||||
def decorator(fn: FuncT) -> FuncT:
|
||||
captured_activity["fn"] = fn
|
||||
return fn
|
||||
|
||||
return decorator
|
||||
|
||||
def passthrough_activity_trigger(input_name: str) -> Callable[[FuncT], FuncT]:
|
||||
def decorator(fn: FuncT) -> FuncT:
|
||||
return fn
|
||||
|
||||
return decorator
|
||||
|
||||
with (
|
||||
patch.object(AgentFunctionApp, "function_name", side_effect=passthrough_function_name),
|
||||
patch.object(AgentFunctionApp, "activity_trigger", side_effect=passthrough_activity_trigger),
|
||||
patch.object(AgentFunctionApp, "_setup_workflow_orchestration"),
|
||||
):
|
||||
AgentFunctionApp(workflow=mock_workflow)
|
||||
|
||||
assert "fn" in captured_activity, "activity function was not captured"
|
||||
|
||||
# Call the activity with nested state that the executor will mutate
|
||||
input_data = json.dumps({
|
||||
"message": "test",
|
||||
"shared_state_snapshot": {
|
||||
"Local.config": {"code": "", "enabled": False},
|
||||
},
|
||||
"source_executor_ids": [SOURCE_ORCHESTRATOR],
|
||||
})
|
||||
|
||||
result = json.loads(captured_activity["fn"](input_data))
|
||||
|
||||
# The deep copy snapshot must detect the in-place nested mutations
|
||||
assert "Local.config" in result["shared_state_updates"], (
|
||||
"nested mutation not detected — snapshot may be using shallow copy"
|
||||
)
|
||||
updated_config = result["shared_state_updates"]["Local.config"]
|
||||
assert updated_config["code"] == "MUTATED"
|
||||
assert updated_config["enabled"] is True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "--tb=short"])
|
||||
|
||||
Reference in New Issue
Block a user