diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index c108f7739d..1c43264398 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -14,6 +14,7 @@ import logging import re import uuid from collections.abc import Callable, Mapping +from copy import deepcopy from dataclasses import dataclass from datetime import datetime, timezone from typing import TYPE_CHECKING, Any, TypeVar, cast @@ -58,6 +59,11 @@ EntityHandler = Callable[[df.DurableEntityContext], None] HandlerT = TypeVar("HandlerT", bound=Callable[..., Any]) +def _create_state_snapshot(state: dict[str, Any]) -> dict[str, Any]: + """Create a deep copy of the deserialized state for later diffing.""" + return deepcopy(state) + + @dataclass class AgentMetadata: """Metadata for a registered agent. @@ -306,7 +312,7 @@ class AgentFunctionApp(DFAppBase): deserialized_state: dict[str, Any] = { str(k): deserialize_value(v) for k, v in shared_state_snapshot.items() } - original_snapshot: dict[str, Any] = dict(deserialized_state) + original_snapshot = _create_state_snapshot(deserialized_state) shared_state.import_state(deserialized_state) if is_hitl_response: @@ -339,9 +345,10 @@ class AgentFunctionApp(DFAppBase): deletes: set[str] = original_keys - current_keys # Updates = keys in current that are new or have different values - updates = { - k: v for k, v in current_state.items() if k not in original_snapshot or original_snapshot[k] != v - } + updates: dict[str, Any] = {} + for key in current_keys: + if key not in original_keys or current_state[key] != original_snapshot.get(key): + updates[key] = current_state[key] # Drain messages and events from runner context sent_messages = await runner_context.drain_messages() diff --git a/python/packages/azurefunctions/tests/test_app.py b/python/packages/azurefunctions/tests/test_app.py index f4b86ba2d7..03084d5ada 100644 --- a/python/packages/azurefunctions/tests/test_app.py +++ b/python/packages/azurefunctions/tests/test_app.py @@ -26,6 +26,7 @@ from agent_framework_durabletask import ( from agent_framework_azurefunctions import AgentFunctionApp from agent_framework_azurefunctions._entities import create_agent_entity +from agent_framework_azurefunctions._workflow import SOURCE_ORCHESTRATOR FuncT = TypeVar("FuncT", bound=Callable[..., Any]) @@ -1441,5 +1442,286 @@ class TestAgentFunctionAppWorkflow: assert "instance-456" in url +def _compute_state_updates(original_snapshot: dict[str, Any], current_state: dict[str, Any]) -> dict[str, Any]: + """Compute state updates by comparing current state against the original snapshot. + + This mirrors the inlined logic in ``_app.py``'s ``executor_activity.run()``. + """ + original_keys = set(original_snapshot.keys()) + current_keys = set(current_state.keys()) + updates: dict[str, Any] = {} + for key in current_keys: + if key not in original_keys or current_state[key] != original_snapshot.get(key): + updates[key] = current_state[key] + return updates + + +class TestStateSnapshotDiff: + """Test suite for state snapshot diffing in activity execution. + + The activity executor snapshots state before execution and diffs against the + post-execution state to determine which keys were updated. These tests exercise + the production snapshot helper and the state-update diffing logic to ensure that + in-place mutations to nested objects (dicts, lists) are correctly detected as changes. + """ + + def test_nested_dict_mutation_detected_in_diff(self) -> None: + """Test that mutating values inside a nested dict appears in the diff.""" + from agent_framework._workflows._state import State + + from agent_framework_azurefunctions._app import _create_state_snapshot + + deserialized_state: dict[str, Any] = { + "Local.config": {"code": "", "enabled": False}, + "simple_key": "simple_value", + } + + original_snapshot = _create_state_snapshot(deserialized_state) + + shared_state = State() + shared_state.import_state(deserialized_state) + + config = shared_state.get("Local.config") + config["code"] = "SOMECODEXXX" + config["enabled"] = True + + shared_state.commit() + current_state = shared_state.export_state() + + updates = _compute_state_updates(original_snapshot, current_state) + + assert "Local.config" in updates + assert updates["Local.config"]["code"] == "SOMECODEXXX" + assert updates["Local.config"]["enabled"] is True + + def test_new_key_in_nested_dict_detected_in_diff(self) -> None: + """Test that adding a key to a nested dict appears in the diff.""" + from agent_framework._workflows._state import State + + from agent_framework_azurefunctions._app import _create_state_snapshot + + deserialized_state: dict[str, Any] = { + "Local.data": {"existing": "value"}, + } + + original_snapshot = _create_state_snapshot(deserialized_state) + + shared_state = State() + shared_state.import_state(deserialized_state) + + data = shared_state.get("Local.data") + data["code"] = "NEW_CODE" + + shared_state.commit() + current_state = shared_state.export_state() + + updates = _compute_state_updates(original_snapshot, current_state) + + assert "Local.data" in updates + assert updates["Local.data"]["code"] == "NEW_CODE" + + def test_nested_list_mutation_detected_in_diff(self) -> None: + """Test that appending to a nested list appears in the diff.""" + from agent_framework._workflows._state import State + + from agent_framework_azurefunctions._app import _create_state_snapshot + + deserialized_state: dict[str, Any] = { + "Local.items": [1, 2, 3], + } + + original_snapshot = _create_state_snapshot(deserialized_state) + + shared_state = State() + shared_state.import_state(deserialized_state) + + items = shared_state.get("Local.items") + items.append(4) + + shared_state.commit() + current_state = shared_state.export_state() + + updates = _compute_state_updates(original_snapshot, current_state) + + assert "Local.items" in updates + assert updates["Local.items"] == [1, 2, 3, 4] + + def test_new_top_level_key_detected_in_diff(self) -> None: + """Test that setting a new top-level key appears in the diff.""" + from agent_framework._workflows._state import State + + from agent_framework_azurefunctions._app import _create_state_snapshot + + deserialized_state: dict[str, Any] = { + "existing": "value", + } + + original_snapshot = _create_state_snapshot(deserialized_state) + + shared_state = State() + shared_state.import_state(deserialized_state) + + shared_state.set("Local.code", "SOMECODEXXX") + + shared_state.commit() + current_state = shared_state.export_state() + + updates = _compute_state_updates(original_snapshot, current_state) + + assert "Local.code" in updates + assert updates["Local.code"] == "SOMECODEXXX" + + def test_unchanged_nested_state_produces_empty_diff(self) -> None: + """Test that unmodified nested state produces no updates.""" + from agent_framework._workflows._state import State + + from agent_framework_azurefunctions._app import _create_state_snapshot + + deserialized_state: dict[str, Any] = { + "Local.config": {"code": "existing", "enabled": True}, + "simple_key": "simple_value", + } + + original_snapshot = _create_state_snapshot(deserialized_state) + + shared_state = State() + shared_state.import_state(deserialized_state) + + # No mutations performed + shared_state.commit() + current_state = shared_state.export_state() + + updates = _compute_state_updates(original_snapshot, current_state) + + assert updates == {} + + def test_shallow_copy_would_miss_nested_mutations(self) -> None: + """Regression test: a shallow copy (dict()) shares nested refs, hiding mutations. + + This reproduces the original bug from #4500 where ``dict(deserialized_state)`` + was used instead of ``copy.deepcopy()``. With a shallow copy the snapshot and + the live state share nested objects, so in-place mutations appear in both and + the diff produces an empty update set. + """ + from agent_framework._workflows._state import State + + deserialized_state: dict[str, Any] = { + "Local.config": {"code": "", "enabled": False}, + } + + # Shallow copy (the OLD, buggy behaviour) + shallow_snapshot = dict(deserialized_state) + + shared_state = State() + shared_state.import_state(deserialized_state) + + config = shared_state.get("Local.config") + config["code"] = "SOMECODEXXX" + config["enabled"] = True + + shared_state.commit() + current_state = shared_state.export_state() + + # With a shallow copy the mutation leaks into the snapshot → empty diff + updates_shallow = _compute_state_updates(shallow_snapshot, current_state) + assert updates_shallow == {}, "shallow copy should miss nested mutations (demonstrating the bug)" + + def test_create_state_snapshot_isolates_nested_objects(self) -> None: + """Verify _create_state_snapshot produces a deep copy that is mutation-proof. + + This ensures the production snapshot helper is not equivalent to ``dict()`` + and will correctly isolate nested objects so that later mutations are detected. + """ + from agent_framework_azurefunctions._app import _create_state_snapshot + + original: dict[str, Any] = { + "nested_dict": {"a": 1}, + "nested_list": [1, 2, 3], + } + + snapshot = _create_state_snapshot(original) + + # Mutate the originals in place + original["nested_dict"]["a"] = 999 + original["nested_list"].append(4) + + # Snapshot must be unaffected + assert snapshot["nested_dict"]["a"] == 1 + assert snapshot["nested_list"] == [1, 2, 3] + + def test_executor_activity_detects_nested_state_mutations(self) -> None: + """Integration test: the full activity wrapper detects nested mutations. + + This exercises the actual executor_activity function registered by + _setup_executor_activity to verify the production code path uses + _create_state_snapshot (deep copy) rather than dict() (shallow copy). + If the implementation regressed to using a shallow copy such as + ``dict(deserialized_state)``, this test would fail because in-place + mutations would leak into the snapshot and produce an empty diff. + """ + mock_executor = Mock() + mock_executor.id = "test-exec" + + async def mutate_nested_state( + message: Any, + source_executor_ids: Any, + state: Any, + runner_context: Any, + ) -> None: + config = state.get("Local.config") + config["code"] = "MUTATED" + config["enabled"] = True + state.commit() + + mock_executor.execute = AsyncMock(side_effect=mutate_nested_state) + + mock_workflow = Mock() + mock_workflow.executors = {"test-exec": mock_executor} + + # Capture the activity function by making decorators pass-through + captured_activity: dict[str, Any] = {} + + def passthrough_function_name(name: str) -> Callable[[FuncT], FuncT]: + def decorator(fn: FuncT) -> FuncT: + captured_activity["fn"] = fn + return fn + + return decorator + + def passthrough_activity_trigger(input_name: str) -> Callable[[FuncT], FuncT]: + def decorator(fn: FuncT) -> FuncT: + return fn + + return decorator + + with ( + patch.object(AgentFunctionApp, "function_name", side_effect=passthrough_function_name), + patch.object(AgentFunctionApp, "activity_trigger", side_effect=passthrough_activity_trigger), + patch.object(AgentFunctionApp, "_setup_workflow_orchestration"), + ): + AgentFunctionApp(workflow=mock_workflow) + + assert "fn" in captured_activity, "activity function was not captured" + + # Call the activity with nested state that the executor will mutate + input_data = json.dumps({ + "message": "test", + "shared_state_snapshot": { + "Local.config": {"code": "", "enabled": False}, + }, + "source_executor_ids": [SOURCE_ORCHESTRATOR], + }) + + result = json.loads(captured_activity["fn"](input_data)) + + # The deep copy snapshot must detect the in-place nested mutations + assert "Local.config" in result["shared_state_updates"], ( + "nested mutation not detected — snapshot may be using shallow copy" + ) + updated_config = result["shared_state_updates"]["Local.config"] + assert updated_config["code"] == "MUTATED" + assert updated_config["enabled"] is True + + if __name__ == "__main__": pytest.main([__file__, "-v", "--tb=short"])