From e7937947d91ffc129d8e885644c8a5f365be075a Mon Sep 17 00:00:00 2001 From: Peter Ibekwe <109177538+peibekwe@users.noreply.github.com> Date: Thu, 11 Jun 2026 15:34:15 -0700 Subject: [PATCH] Python: Bug fix for declarative workflows (#6468) * Fix declarative object parsing bug * Remove unnecessary comment * Address PR comments * Address PR comments. * Fix CI failures. --- .../_workflows/_declarative_base.py | 181 ++++++--- .../test_declarative_state_path_safety.py | 364 ++++++++++++++++++ .../declarative/tests/test_graph_coverage.py | 8 +- 3 files changed, 489 insertions(+), 64 deletions(-) create mode 100644 python/packages/declarative/tests/test_declarative_state_path_safety.py diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py b/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py index e6fc0a820d..6a035a448a 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py @@ -63,6 +63,9 @@ logger = logging.getLogger(__name__) _ENV_REFERENCE_RE = re.compile(r"\bEnv\.([A-Za-z_][A-Za-z0-9_]*)") +# Allowed identifier shape for object-attribute steps in declarative state paths +_SAFE_PATH_SEGMENT_RE = re.compile(r"^[A-Za-z][A-Za-z0-9_]*$") + @dataclass(frozen=True) class DeclarativeEnvConfig: @@ -266,6 +269,9 @@ class DeclarativeWorkflowState: - Conversation: Conversation history """ + # Sentinel marking "no prior value" for temporary-key bookkeeping. + _MISSING: Any = object() + def __init__(self, state: State, env_config: DeclarativeEnvConfig | None = None): """Initialize with a State instance. @@ -331,16 +337,21 @@ class DeclarativeWorkflowState: def get(self, path: str, default: Any = None) -> Any: """Get a value from the state using a dot-notated path. + Dict-keyed segments may use arbitrary string keys (e.g. UUIDs in + ``System.conversations..messages``). Segments that would resolve + via object-attribute access must be valid declarative identifiers + (``[A-Za-z][A-Za-z0-9_]*``); other shapes return ``default``. + Args: path: Dot-notated path like 'Local.results' or 'Workflow.Inputs.query' default: Default value if path doesn't exist Returns: - The value at the path, or default if not found + The value at the path, or default if not found or unreachable. """ state_data = self.get_state_data() parts = path.split(".") - if not parts: + if not parts or any(not p for p in parts): return default namespace = parts[0] @@ -377,10 +388,19 @@ class DeclarativeWorkflowState: obj = obj.get(part, default) # type: ignore[union-attr] if obj is default: return default - elif hasattr(obj, part): # type: ignore[arg-type] - obj = getattr(obj, part) # type: ignore[arg-type] else: - return default + # Attribute access is only allowed for safe declarative identifiers. + if not _SAFE_PATH_SEGMENT_RE.match(part): + logger.warning( + "DeclarativeWorkflowState.get: rejecting attribute segment %r in path %r", + part, + path, + ) + return default + if hasattr(obj, part): # type: ignore[arg-type] + obj = getattr(obj, part) # type: ignore[arg-type] + else: + return default return obj # type: ignore[return-value] @@ -392,12 +412,14 @@ class DeclarativeWorkflowState: value: The value to set Raises: - ValueError: If attempting to set Workflow.Inputs (which is read-only) + ValueError: If ``path`` is empty or contains empty segments + (e.g. ``"Local."``, ``"Local..foo"``), or if attempting to set + ``Workflow.Inputs`` (which is read-only). """ state_data = self.get_state_data() parts = path.split(".") - if not parts: - return + if not parts or any(not p for p in parts): + raise ValueError(f"Invalid path {path!r}: empty segments are not allowed") namespace = parts[0] remaining = parts[1:] @@ -453,7 +475,16 @@ class DeclarativeWorkflowState: Args: path: Dot-notated path to a list value: The value to append + + Raises: + ValueError: If ``path`` is empty or contains empty segments + (e.g. ``"Local."``, ``"Local..foo"``), or if the existing + value at ``path`` is not a list. """ + parts = path.split(".") + if not parts or any(not p for p in parts): + raise ValueError(f"Invalid path {path!r}: empty segments are not allowed") + existing = self.get(path) if existing is None: self.set(path, [value]) @@ -464,6 +495,15 @@ class DeclarativeWorkflowState: else: raise ValueError(f"Cannot append to non-list at path '{path}'") + def _clear_local_path(self, name: str) -> None: + """Remove ``name`` from the ``Local`` namespace, if present.""" + state_data = self.get_state_data() + local = state_data.get("Local") + if local is None or name not in local: + return + local.pop(name, None) + self.set_state_data(state_data) + def eval(self, expression: str) -> Any: """Evaluate a PowerFx expression with the current state. @@ -504,53 +544,64 @@ class DeclarativeWorkflowState: return result # Pre-process nested custom functions (e.g., Upper(MessageText(...))) - # Replace them with their evaluated results before sending to PowerFx - formula = self._preprocess_custom_functions(formula) + # and run PowerFx. The finally below restores any temporary state + # written during preprocessing, regardless of where execution exits. + temp_writes: list[tuple[str, Any]] = [] - if Engine is None: - raise RuntimeError( - f"PowerFx is not available (dotnet runtime not installed). " - f"Expression '={formula[:80]}' cannot be evaluated. " - f"Install dotnet and the powerfx package for full PowerFx support." - ) - - symbols = self._to_powerfx_symbols() - # Use setlocale(category) query form so we can restore the exact prior value. - # getlocale() returns a normalized tuple and is not always a lossless - # round-trip for setlocale across platforms/locales. - original_numeric_locale = locale.setlocale(locale.LC_NUMERIC) try: - for locale_candidate in _POWERFX_NUMERIC_LOCALE_CANDIDATES: - try: - locale.setlocale(locale.LC_NUMERIC, locale_candidate) - break - except locale.Error: - continue + formula = self._preprocess_custom_functions(formula, temp_writes) - engine = Engine() - try: - from System.Globalization import ( # pyright: ignore[reportMissingImports] - CultureInfo, # pyright: ignore[reportUnknownVariableType] + if Engine is None: + raise RuntimeError( + f"PowerFx is not available (dotnet runtime not installed). " + f"Expression '={formula[:80]}' cannot be evaluated. " + f"Install dotnet and the powerfx package for full PowerFx support." ) - except ImportError: - return engine.eval(formula, symbols=symbols, locale=_POWERFX_EVAL_LOCALE) - original_culture = cast(Any, CultureInfo.CurrentCulture) # pyright: ignore[reportUnknownMemberType] + symbols = self._to_powerfx_symbols() + # Use setlocale(category) query form so we can restore the exact prior value. + # getlocale() returns a normalized tuple and is not always a lossless + # round-trip for setlocale across platforms/locales. + original_numeric_locale = locale.setlocale(locale.LC_NUMERIC) try: - CultureInfo.CurrentCulture = CultureInfo(_POWERFX_EVAL_LOCALE) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] - return engine.eval(formula, symbols=symbols, locale=_POWERFX_EVAL_LOCALE) + for locale_candidate in _POWERFX_NUMERIC_LOCALE_CANDIDATES: + try: + locale.setlocale(locale.LC_NUMERIC, locale_candidate) + break + except locale.Error: + continue + + engine = Engine() + try: + from System.Globalization import ( # pyright: ignore[reportMissingImports] + CultureInfo, # pyright: ignore[reportUnknownVariableType] + ) + except ImportError: + return engine.eval(formula, symbols=symbols, locale=_POWERFX_EVAL_LOCALE) + + original_culture = cast(Any, CultureInfo.CurrentCulture) # pyright: ignore[reportUnknownMemberType] + try: + CultureInfo.CurrentCulture = CultureInfo(_POWERFX_EVAL_LOCALE) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + return engine.eval(formula, symbols=symbols, locale=_POWERFX_EVAL_LOCALE) + finally: + CultureInfo.CurrentCulture = original_culture # pyright: ignore[reportUnknownMemberType] + except ValueError as e: + error_msg = str(e) + # Handle undefined variable errors gracefully by returning None + # This matches the behavior of the legacy fallback parser + if "isn't recognized" in error_msg or "Name isn't valid" in error_msg: + logger.debug(f"PowerFx: undefined variable in expression '{formula}', returning None") + return None + raise finally: - CultureInfo.CurrentCulture = original_culture # pyright: ignore[reportUnknownMemberType] - except ValueError as e: - error_msg = str(e) - # Handle undefined variable errors gracefully by returning None - # This matches the behavior of the legacy fallback parser - if "isn't recognized" in error_msg or "Name isn't valid" in error_msg: - logger.debug(f"PowerFx: undefined variable in expression '{formula}', returning None") - return None - raise + locale.setlocale(locale.LC_NUMERIC, original_numeric_locale) finally: - locale.setlocale(locale.LC_NUMERIC, original_numeric_locale) + # Restore each temporary key to its prior value (or remove it). + for path, previous in reversed(temp_writes): + if previous is self._MISSING: + self._clear_local_path(path.removeprefix("Local.")) + else: + self.set(path, previous) def _eval_custom_function(self, formula: str) -> Any | None: """Handle custom functions not supported by the Python PowerFx library. @@ -609,7 +660,7 @@ class DeclarativeWorkflowState: return None - def _preprocess_custom_functions(self, formula: str) -> str: + def _preprocess_custom_functions(self, formula: str, temp_writes: list[tuple[str, Any]]) -> str: """Pre-process custom functions nested inside other PowerFx functions. Custom functions like MessageText() are not supported by the PowerFx engine. @@ -624,9 +675,14 @@ class DeclarativeWorkflowState: Args: formula: The PowerFx formula to pre-process + temp_writes: Caller-owned list. Each write to a temporary key + appends a ``(path, previous_value)`` entry where + ``previous_value`` is the value at ``path`` before the write + or :attr:`_MISSING` if none. The caller must restore every + entry, including when this method raises mid-write. Returns: - The formula with custom function calls replaced by their evaluated results + The rewritten formula. """ import re @@ -635,7 +691,6 @@ class DeclarativeWorkflowState: # We use 500 to leave room for the rest of the expression around the replaced value. MAX_INLINE_LENGTH = 500 - # Counter for generating unique temp variable names temp_var_counter = 0 # Custom functions that need pre-processing: (regex pattern, handler) @@ -691,11 +746,14 @@ class DeclarativeWorkflowState: # Replace in formula if isinstance(replacement, str): if len(replacement) > MAX_INLINE_LENGTH: - # Store long strings in a temp variable to avoid PowerFx expression limit + # Store long results in an underscore-prefixed temp key; + # record the prior value so eval() can restore it. temp_var_name = f"_TempMessageText{temp_var_counter}" temp_var_counter += 1 - self.set(f"Local.{temp_var_name}", replacement) - replacement_str = f"Local.{temp_var_name}" + temp_var_path = f"Local.{temp_var_name}" + temp_writes.append((temp_var_path, self.get(temp_var_path, default=self._MISSING))) + self.set(temp_var_path, replacement) + replacement_str = temp_var_path logger.debug( f"Stored long MessageText result ({len(replacement)} chars) " f"in temp variable {temp_var_name}" @@ -847,11 +905,13 @@ class DeclarativeWorkflowState: return value def interpolate_string(self, text: str) -> str: - """Interpolate {Variable.Path} references in a string. + """Interpolate ``{Variable.Path}`` references in a string. - This handles template-style variable substitution like: - - "Created ticket #{Local.TicketParameters.TicketId}" - - "Routing to {Local.RoutingParameters.TeamName}" + Captures brace-delimited tokens whose root segment is an identifier + (``[A-Za-z][A-Za-z0-9_]*``) followed by zero or more ``.`` separated + dict-key segments. Resolution is delegated to :meth:`get`; unresolved + tokens are replaced with the empty string. Tokens that do not look + like state paths (e.g. ``{foo-bar}``, ``{Ctrl+C}``) are left literal. Args: text: Text that may contain {Variable.Path} references @@ -866,10 +926,11 @@ class DeclarativeWorkflowState: value = self.get(var_path) return str(value) if value is not None else "" - # Match {Variable.Path} patterns - pattern = r"\{([A-Za-z][A-Za-z0-9_.]*)\}" + # Root segment must be an identifier; follow-on segments accept any + # non-empty dict-key (e.g. ``_id``, ``1``, UUIDs). ``get()`` enforces + # per-segment safety on attribute traversal. + pattern = r"\{([A-Za-z][A-Za-z0-9_]*(?:\.[^{}\s.]+)*)\}" - # Replace all matches result = text for match in re.finditer(pattern, text): replacement = replace_var(match) diff --git a/python/packages/declarative/tests/test_declarative_state_path_safety.py b/python/packages/declarative/tests/test_declarative_state_path_safety.py new file mode 100644 index 0000000000..2446fc3cf4 --- /dev/null +++ b/python/packages/declarative/tests/test_declarative_state_path_safety.py @@ -0,0 +1,364 @@ +# Copyright (c) Microsoft. All rights reserved. +# pyright: reportUnknownParameterType=false, reportUnknownArgumentType=false +# pyright: reportMissingParameterType=false, reportUnknownMemberType=false +# pyright: reportPrivateUsage=false, reportUnknownVariableType=false +# pyright: reportGeneralTypeIssues=false + +"""Path-segment validation tests for DeclarativeWorkflowState. + +Path segments handed to ``get``/``set``/``append`` and ``{Variable.Path}`` +placeholders in ``interpolate_string`` are subject to three distinct rules +that this module pins: + +- **Empty segments** (e.g. ``""``, ``"Local."``, ``"Local..foo"``) are rejected + by all of ``get``/``set``/``append`` and ``interpolate_string``. ``get`` and + ``interpolate_string`` return their default / leave the placeholder literal; + ``set`` and ``append`` raise ``ValueError``. +- **Object-attribute segments** — segments that ``get`` would resolve via + ``getattr`` because the parent is a non-dict object — must match the safe + identifier shape ``[A-Za-z][A-Za-z0-9_]*``. Other shapes are rejected with a + warning log and the default is returned. +- **Dict-keyed segments** — segments that resolve via dict lookup because the + parent is a ``dict`` — may use arbitrary non-empty string keys (e.g. UUIDs + or hyphenated identifiers like ``System.conversations..messages``). +""" + +import logging +from dataclasses import dataclass +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from agent_framework_declarative._workflows import DeclarativeWorkflowState + +try: + import powerfx # noqa: F401 + + _powerfx_available = True +except (ImportError, RuntimeError): + _powerfx_available = False + +_requires_powerfx = pytest.mark.skipif(not _powerfx_available, reason="PowerFx engine not available") + + +@pytest.fixture +def mock_state() -> MagicMock: + """In-memory mock for the underlying State.""" + ms = MagicMock() + ms._data = {} + + def get(key: str, default: Any = None) -> Any: + return ms._data.get(key, default) + + def set_(key: str, value: Any) -> None: + ms._data[key] = value + + def has(key: str) -> bool: + return key in ms._data + + def delete(key: str) -> None: + ms._data.pop(key, None) + + ms.get = MagicMock(side_effect=get) + ms.set = MagicMock(side_effect=set_) + ms.has = MagicMock(side_effect=has) + ms.delete = MagicMock(side_effect=delete) + return ms + + +@pytest.fixture +def state(mock_state: MagicMock) -> DeclarativeWorkflowState: + s = DeclarativeWorkflowState(mock_state) + s.initialize() + return s + + +@dataclass +class _PlainObj: + """Non-dict object so ``get`` falls through to attribute access.""" + + text: str = "hi" + + +# --------------------------------------------------------------------------- +# get(): invalid paths return default +# --------------------------------------------------------------------------- + + +class TestGetRejectsInvalidPaths: + def test_rejects_dunder_segment_via_attribute_access(self, state: DeclarativeWorkflowState) -> None: + state.set("Local.obj", _PlainObj()) + assert state.get("Local.obj.__class__") is None + assert state.get("Local.obj.__class__", default="DEF") == "DEF" + + def test_rejects_full_env_exfil_chain(self, state: DeclarativeWorkflowState, monkeypatch) -> None: + sentinel = "agent-framework-path-safety-sentinel" + monkeypatch.setenv("AF_PATH_SAFETY_SENTINEL", sentinel) + state.set("Local.obj", _PlainObj()) + + result = state.get("Local.obj.__class__.__init__.__globals__.os.environ") + + assert result is None + assert sentinel not in str(result) + + def test_rejects_leading_underscore_via_attribute_access(self, state: DeclarativeWorkflowState) -> None: + state.set("Local.obj", _PlainObj()) + assert state.get("Local.obj._private") is None + + def test_rejects_invalid_chars_via_attribute_access(self, state: DeclarativeWorkflowState) -> None: + state.set("Local.obj", _PlainObj()) + assert state.get("Local.obj.text bar") is None + assert state.get("Local.obj.text-bar") is None + + def test_rejects_empty_path_and_empty_segments(self, state: DeclarativeWorkflowState) -> None: + assert state.get("") is None + assert state.get(".") is None + assert state.get("Local.") is None + assert state.get(".Local") is None + + def test_warning_logged_on_rejected_attribute_segment( + self, + state: DeclarativeWorkflowState, + caplog: pytest.LogCaptureFixture, + ) -> None: + state.set("Local.obj", _PlainObj()) + with caplog.at_level(logging.WARNING, logger="agent_framework_declarative._workflows._declarative_base"): + state.get("Local.obj.__class__") + assert any("rejecting attribute segment" in r.message for r in caplog.records) + + def test_dict_keyed_dunder_is_not_attribute_access(self, state: DeclarativeWorkflowState) -> None: + """A literal dunder dict key is harmless because dict lookup never reaches getattr.""" + state.set("Local.bag", {"__class__": "harmless-string"}) + assert state.get("Local.bag.__class__") == "harmless-string" + + +# --------------------------------------------------------------------------- +# get(): legitimate paths continue to work +# --------------------------------------------------------------------------- + + +class TestGetAllowsValidPaths: + def test_underscore_inside_identifier(self, state: DeclarativeWorkflowState) -> None: + state.set("Local.user_input", "ok") + assert state.get("Local.user_input") == "ok" + + def test_mixed_case_identifiers(self, state: DeclarativeWorkflowState) -> None: + state.set("Local.UserInput", "u1") + state.set("Local.userInput", "u2") + assert state.get("Local.UserInput") == "u1" + assert state.get("Local.userInput") == "u2" + + def test_object_attribute_traversal_still_works(self, state: DeclarativeWorkflowState) -> None: + state.set("Local.msg", _PlainObj(text="hello")) + assert state.get("Local.msg.text") == "hello" + + def test_nested_dict_traversal_still_works(self, state: DeclarativeWorkflowState) -> None: + state.set("Local.params", {"team": {"name": "alpha"}}) + assert state.get("Local.params.team.name") == "alpha" + + def test_uuid_and_hyphenated_dict_keys_are_allowed(self, state: DeclarativeWorkflowState) -> None: + """Conversation-id style paths use arbitrary dict keys (UUIDs / hyphens).""" + conv_id = "eb815014-06f1-4db6-b7c1-304ea135424f" + state.set(f"System.conversations.{conv_id}.messages", ["m1", "m2"]) + assert state.get(f"System.conversations.{conv_id}.messages") == ["m1", "m2"] + + +# --------------------------------------------------------------------------- +# set() / append(): dict-keyed operations accept arbitrary string keys +# --------------------------------------------------------------------------- + + +class TestSetAndAppend: + def test_set_allows_underscore_inside_identifier(self, state: DeclarativeWorkflowState) -> None: + state.set("Local.user_input", "ok") + assert state.get("Local.user_input") == "ok" + + def test_set_allows_uuid_and_hyphenated_dict_keys(self, state: DeclarativeWorkflowState) -> None: + conv_id = "conv-test-1" + state.set(f"System.conversations.{conv_id}.messages", []) + assert state.get(f"System.conversations.{conv_id}.messages") == [] + + def test_append_allows_uuid_and_hyphenated_dict_keys(self, state: DeclarativeWorkflowState) -> None: + conv_id = "conv-42" + state.append(f"System.conversations.{conv_id}.messages", {"role": "user", "text": "hi"}) + msgs = state.get(f"System.conversations.{conv_id}.messages") + assert msgs == [{"role": "user", "text": "hi"}] + + def test_workflow_inputs_still_read_only(self, state: DeclarativeWorkflowState) -> None: + with pytest.raises(ValueError, match="read-only"): + state.set("Workflow.Inputs.x", 1) + + +# --------------------------------------------------------------------------- +# set() / append(): malformed paths (empty segments) raise ValueError +# --------------------------------------------------------------------------- + + +class TestSetRejectsInvalidPaths: + @pytest.mark.parametrize("bad_path", ["", "Local.", "Local..foo", ".Local"]) + def test_set_rejects_empty_segment(self, state: DeclarativeWorkflowState, bad_path: str) -> None: + with pytest.raises(ValueError, match="empty segments are not allowed"): + state.set(bad_path, "x") + + @pytest.mark.parametrize("bad_path", ["", "Local.", "Local..foo", ".Local"]) + def test_append_rejects_empty_segment(self, state: DeclarativeWorkflowState, bad_path: str) -> None: + with pytest.raises(ValueError, match="empty segments are not allowed"): + state.append(bad_path, "x") + + def test_set_rejection_makes_no_partial_write(self, state: DeclarativeWorkflowState) -> None: + """Rejected set() must not create an unreachable entry in the state.""" + state.set("Local.user_input", "pre") + with pytest.raises(ValueError): + state.set("Local.", "value") + local = state.get_state_data().get("Local", {}) + assert "" not in local + assert local == {"user_input": "pre"} + assert state.get("Local.") is None + assert state.get("Local.user_input") == "pre" + + def test_append_rejection_makes_no_partial_write(self, state: DeclarativeWorkflowState) -> None: + """Rejected append() must not create an unreachable entry in the state.""" + state.set("Local.items", ["a"]) + with pytest.raises(ValueError): + state.append("Local.", "value") + local = state.get_state_data().get("Local", {}) + assert "" not in local + assert local == {"items": ["a"]} + + +# --------------------------------------------------------------------------- +# interpolate_string(): permissive matcher; get() enforces safety +# --------------------------------------------------------------------------- + + +class TestInterpolateString: + def test_ignores_dunder_payload(self, state: DeclarativeWorkflowState, monkeypatch) -> None: + sentinel = "agent-framework-interp-sentinel" + monkeypatch.setenv("AF_INTERP_SENTINEL", sentinel) + state.set("Local.obj", _PlainObj()) + + out = state.interpolate_string("X={Local.obj.__class__.__init__.__globals__.os.environ}") + + assert sentinel not in out + assert out == "X=" + + def test_unknown_path_reduces_to_empty(self, state: DeclarativeWorkflowState) -> None: + assert state.interpolate_string("v={Local._private}") == "v=" + + @pytest.mark.parametrize( + "literal", + ["{foo-bar}", "{Ctrl+C}", "{not:a:path}", "{Local.}", "{}"], + ) + def test_non_state_braced_tokens_left_literal(self, state: DeclarativeWorkflowState, literal: str) -> None: + assert state.interpolate_string(f"v={literal}") == f"v={literal}" + + def test_allows_underscore_inside_identifier(self, state: DeclarativeWorkflowState) -> None: + state.set("Local.user_input", "hello") + assert state.interpolate_string("v={Local.user_input}") == "v=hello" + + def test_resolves_nested_dict_path(self, state: DeclarativeWorkflowState) -> None: + state.set("Local.params", {"team": "alpha"}) + assert state.interpolate_string("team={Local.params.team}") == "team=alpha" + + @pytest.mark.parametrize( + ("key", "value"), + [ + ("_id", "abc123"), + ("1", "one"), + ("2025", "year-bucket"), + ], + ) + def test_resolves_dict_keyed_segments(self, state: DeclarativeWorkflowState, key: str, value: str) -> None: + state.set("Local.bag", {key: value}) + assert state.interpolate_string(f"v={{Local.bag.{key}}}") == f"v={value}" + + def test_resolves_uuid_conversation_key(self, state: DeclarativeWorkflowState) -> None: + conv_id = "eb815014-06f1-4db6-b7c1-304ea135424f" + state.set(f"System.conversations.{conv_id}.messages", ["hello"]) + out = state.interpolate_string(f"m={{System.conversations.{conv_id}.messages}}") + assert out == "m=['hello']" + + def test_end_to_end_send_activity_payload_neutralized( + self, + state: DeclarativeWorkflowState, + monkeypatch, + ) -> None: + sentinel = "agent-framework-e2e-sentinel" + monkeypatch.setenv("AF_E2E_SENTINEL", sentinel) + state.set("Local.toolResult", _PlainObj()) + + payload = "{Local.toolResult.__class__.__init__.__globals__.os.environ}" + evaluated = state.eval_if_expression(payload) + rendered = state.interpolate_string(evaluated) if isinstance(evaluated, str) else str(evaluated) + + assert sentinel not in rendered + assert rendered == "" + + +# --------------------------------------------------------------------------- +# Regressions: PowerFx and internal temp-variable handling still work +# --------------------------------------------------------------------------- + + +@_requires_powerfx +class TestPowerFxStillWorks: + def test_simple_powerfx_expression_evaluates(self, state: DeclarativeWorkflowState) -> None: + state.set("Local.x", 6) + state.set("Local.y", 7) + assert state.eval("=Local.x * Local.y") == 42 + + def test_internal_temp_message_text_still_works(self, state: DeclarativeWorkflowState) -> None: + """Long MessageText() results round-trip and the temp key is removed after eval.""" + long_text = "A" * 600 + state.set( + "Local.Messages", + [{"text": long_text, "contents": [{"type": "text", "text": long_text}]}], + ) + + result = state.eval("=Upper(MessageText(Local.Messages))") + assert result == "A" * 600 + + local = state.get_state_data().get("Local", {}) + remaining = sorted(k for k in local if k.startswith("_TempMessageText")) + assert not remaining, f"Temporary keys remain in Local: {remaining}" + + def test_message_text_eval_preserves_user_temp_value(self, state: DeclarativeWorkflowState) -> None: + """User state at the temp key path survives a long MessageText eval.""" + long_text = "A" * 600 + state.set("Local._TempMessageText0", "user-important-value") + state.set( + "Local.Messages", + [{"text": long_text, "contents": [{"type": "text", "text": long_text}]}], + ) + + result = state.eval("=Upper(MessageText(Local.Messages))") + assert result == "A" * 600 + assert state.get("Local._TempMessageText0") == "user-important-value" + + def test_message_text_eval_cleans_up_on_powerfx_failure( + self, + state: DeclarativeWorkflowState, + monkeypatch, + ) -> None: + """Temp key is removed even when PowerFx evaluation raises.""" + from agent_framework_declarative._workflows import _declarative_base as base + + class _FailingEngine: + def eval(self, *args: Any, **kwargs: Any) -> Any: + raise RuntimeError("boom") + + monkeypatch.setattr(base, "Engine", _FailingEngine) + + long_text = "A" * 600 + state.set( + "Local.Messages", + [{"text": long_text, "contents": [{"type": "text", "text": long_text}]}], + ) + + with pytest.raises(RuntimeError, match="boom"): + state.eval("=Upper(MessageText(Local.Messages))") + + local = state.get_state_data().get("Local", {}) + remaining = sorted(k for k in local if k.startswith("_TempMessageText")) + assert not remaining, f"Temporary keys remain in Local after PowerFx failure: {remaining}" diff --git a/python/packages/declarative/tests/test_graph_coverage.py b/python/packages/declarative/tests/test_graph_coverage.py index f114c8f0ae..bc20a27f09 100644 --- a/python/packages/declarative/tests/test_graph_coverage.py +++ b/python/packages/declarative/tests/test_graph_coverage.py @@ -2765,7 +2765,7 @@ class TestLongMessageTextHandling: assert temp_var is None async def test_long_message_text_stored_in_temp_variable(self, mock_state): - """Test that long MessageText results are stored in temp variables.""" + """Long MessageText results round-trip and the temp key is removed after eval.""" state = DeclarativeWorkflowState(mock_state) state.initialize() @@ -2777,9 +2777,9 @@ class TestLongMessageTextHandling: result = state.eval("=Upper(MessageText(Local.Messages))") assert result == "A" * 600 # Upper on 'A' is still 'A' - # A temp variable should have been created - temp_var = state.get("Local._TempMessageText0") - assert temp_var == long_text + local = state.get_state_data().get("Local", {}) + remaining = sorted(k for k in local if k.startswith("_TempMessageText")) + assert not remaining, f"Temporary keys remain in Local: {remaining}" async def test_find_with_long_message_text(self, mock_state): """Test Find function works with long MessageText stored in temp variable."""