From 4b0aeb76a5e4d613f340bf7006016c466eb20184 Mon Sep 17 00:00:00 2001 From: Peter Ibekwe Date: Thu, 11 Jun 2026 10:35:34 -0700 Subject: [PATCH] Address PR comments. --- .../_workflows/_declarative_base.py | 142 +++++++++++------- .../test_declarative_state_path_safety.py | 89 +++++++++-- .../declarative/tests/test_graph_coverage.py | 10 +- 3 files changed, 168 insertions(+), 73 deletions(-) diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py b/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py index ae2ed94f44..8451ab65fa 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py @@ -269,6 +269,9 @@ class DeclarativeWorkflowState: - Conversation: Conversation history """ + # Sentinel marking "no prior value" for temporary-key bookkeeping. + _MISSING: Any = object() + def __init__(self, state: State, env_config: DeclarativeEnvConfig | None = None): """Initialize with a State instance. @@ -492,6 +495,15 @@ class DeclarativeWorkflowState: else: raise ValueError(f"Cannot append to non-list at path '{path}'") + def _clear_local_path(self, name: str) -> None: + """Remove ``name`` from the ``Local`` namespace, if present.""" + state_data = self.get_state_data() + local = cast(dict[str, Any], state_data.get("Local")) + if local is None or name not in local: + return + local.pop(name, None) + self.set_state_data(state_data) + def eval(self, expression: str) -> Any: """Evaluate a PowerFx expression with the current state. @@ -532,53 +544,64 @@ class DeclarativeWorkflowState: return result # Pre-process nested custom functions (e.g., Upper(MessageText(...))) - # Replace them with their evaluated results before sending to PowerFx - formula = self._preprocess_custom_functions(formula) + # and run PowerFx. The finally below restores any temporary state + # written during preprocessing, regardless of where execution exits. + temp_writes: list[tuple[str, Any]] = [] - if Engine is None: - raise RuntimeError( - f"PowerFx is not available (dotnet runtime not installed). " - f"Expression '={formula[:80]}' cannot be evaluated. " - f"Install dotnet and the powerfx package for full PowerFx support." - ) - - symbols = self._to_powerfx_symbols() - # Use setlocale(category) query form so we can restore the exact prior value. - # getlocale() returns a normalized tuple and is not always a lossless - # round-trip for setlocale across platforms/locales. - original_numeric_locale = locale.setlocale(locale.LC_NUMERIC) try: - for locale_candidate in _POWERFX_NUMERIC_LOCALE_CANDIDATES: - try: - locale.setlocale(locale.LC_NUMERIC, locale_candidate) - break - except locale.Error: - continue + formula = self._preprocess_custom_functions(formula, temp_writes) - engine = Engine() - try: - from System.Globalization import ( # pyright: ignore[reportMissingImports] - CultureInfo, # pyright: ignore[reportUnknownVariableType] + if Engine is None: + raise RuntimeError( + f"PowerFx is not available (dotnet runtime not installed). " + f"Expression '={formula[:80]}' cannot be evaluated. " + f"Install dotnet and the powerfx package for full PowerFx support." ) - except ImportError: - return engine.eval(formula, symbols=symbols, locale=_POWERFX_EVAL_LOCALE) - original_culture = cast(Any, CultureInfo.CurrentCulture) # pyright: ignore[reportUnknownMemberType] + symbols = self._to_powerfx_symbols() + # Use setlocale(category) query form so we can restore the exact prior value. + # getlocale() returns a normalized tuple and is not always a lossless + # round-trip for setlocale across platforms/locales. + original_numeric_locale = locale.setlocale(locale.LC_NUMERIC) try: - CultureInfo.CurrentCulture = CultureInfo(_POWERFX_EVAL_LOCALE) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] - return engine.eval(formula, symbols=symbols, locale=_POWERFX_EVAL_LOCALE) + for locale_candidate in _POWERFX_NUMERIC_LOCALE_CANDIDATES: + try: + locale.setlocale(locale.LC_NUMERIC, locale_candidate) + break + except locale.Error: + continue + + engine = Engine() + try: + from System.Globalization import ( # pyright: ignore[reportMissingImports] + CultureInfo, # pyright: ignore[reportUnknownVariableType] + ) + except ImportError: + return engine.eval(formula, symbols=symbols, locale=_POWERFX_EVAL_LOCALE) + + original_culture = cast(Any, CultureInfo.CurrentCulture) # pyright: ignore[reportUnknownMemberType] + try: + CultureInfo.CurrentCulture = CultureInfo(_POWERFX_EVAL_LOCALE) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + return engine.eval(formula, symbols=symbols, locale=_POWERFX_EVAL_LOCALE) + finally: + CultureInfo.CurrentCulture = original_culture # pyright: ignore[reportUnknownMemberType] + except ValueError as e: + error_msg = str(e) + # Handle undefined variable errors gracefully by returning None + # This matches the behavior of the legacy fallback parser + if "isn't recognized" in error_msg or "Name isn't valid" in error_msg: + logger.debug(f"PowerFx: undefined variable in expression '{formula}', returning None") + return None + raise finally: - CultureInfo.CurrentCulture = original_culture # pyright: ignore[reportUnknownMemberType] - except ValueError as e: - error_msg = str(e) - # Handle undefined variable errors gracefully by returning None - # This matches the behavior of the legacy fallback parser - if "isn't recognized" in error_msg or "Name isn't valid" in error_msg: - logger.debug(f"PowerFx: undefined variable in expression '{formula}', returning None") - return None - raise + locale.setlocale(locale.LC_NUMERIC, original_numeric_locale) finally: - locale.setlocale(locale.LC_NUMERIC, original_numeric_locale) + # Restore each temporary key to its prior value (or remove it). + for path, previous in reversed(temp_writes): + if previous is self._MISSING: + self._clear_local_path(path.removeprefix("Local.")) + else: + self.set(path, previous) def _eval_custom_function(self, formula: str) -> Any | None: """Handle custom functions not supported by the Python PowerFx library. @@ -637,7 +660,7 @@ class DeclarativeWorkflowState: return None - def _preprocess_custom_functions(self, formula: str) -> str: + def _preprocess_custom_functions(self, formula: str, temp_writes: list[tuple[str, Any]]) -> str: """Pre-process custom functions nested inside other PowerFx functions. Custom functions like MessageText() are not supported by the PowerFx engine. @@ -652,9 +675,14 @@ class DeclarativeWorkflowState: Args: formula: The PowerFx formula to pre-process + temp_writes: Caller-owned list. Each write to a temporary key + appends a ``(path, previous_value)`` entry where + ``previous_value`` is the value at ``path`` before the write + or :attr:`_MISSING` if none. The caller must restore every + entry, including when this method raises mid-write. Returns: - The formula with custom function calls replaced by their evaluated results + The rewritten formula. """ import re @@ -663,7 +691,6 @@ class DeclarativeWorkflowState: # We use 500 to leave room for the rest of the expression around the replaced value. MAX_INLINE_LENGTH = 500 - # Counter for generating unique temp variable names temp_var_counter = 0 # Custom functions that need pre-processing: (regex pattern, handler) @@ -719,11 +746,14 @@ class DeclarativeWorkflowState: # Replace in formula if isinstance(replacement, str): if len(replacement) > MAX_INLINE_LENGTH: - # Store long strings in a temp variable to avoid PowerFx expression limit - temp_var_name = f"TempMessageText{temp_var_counter}" + # Store long results in an underscore-prefixed temp key; + # record the prior value so eval() can restore it. + temp_var_name = f"_TempMessageText{temp_var_counter}" temp_var_counter += 1 - self.set(f"Local.{temp_var_name}", replacement) - replacement_str = f"Local.{temp_var_name}" + temp_var_path = f"Local.{temp_var_name}" + temp_writes.append((temp_var_path, self.get(temp_var_path, default=self._MISSING))) + self.set(temp_var_path, replacement) + replacement_str = temp_var_path logger.debug( f"Stored long MessageText result ({len(replacement)} chars) " f"in temp variable {temp_var_name}" @@ -875,14 +905,13 @@ class DeclarativeWorkflowState: return value def interpolate_string(self, text: str) -> str: - """Interpolate {Variable.Path} references in a string. + """Interpolate ``{Variable.Path}`` references in a string. - Matched path segments must be valid declarative identifiers - (``[A-Za-z][A-Za-z0-9_]*``); other braced tokens are left as-is. - - This handles template-style variable substitution like: - - "Created ticket #{Local.TicketParameters.TicketId}" - - "Routing to {Local.RoutingParameters.TeamName}" + Captures brace-delimited tokens whose root segment is an identifier + (``[A-Za-z][A-Za-z0-9_]*``) followed by zero or more ``.`` separated + dict-key segments. Resolution is delegated to :meth:`get`; unresolved + tokens are replaced with the empty string. Tokens that do not look + like state paths (e.g. ``{foo-bar}``, ``{Ctrl+C}``) are left literal. Args: text: Text that may contain {Variable.Path} references @@ -897,10 +926,11 @@ class DeclarativeWorkflowState: value = self.get(var_path) return str(value) if value is not None else "" - # Match {Variable.Path} patterns where each segment is a declarative identifier. - pattern = r"\{([A-Za-z][A-Za-z0-9_]*(?:\.[A-Za-z][A-Za-z0-9_]*)*)\}" + # Root segment must be an identifier; follow-on segments accept any + # non-empty dict-key (e.g. ``_id``, ``1``, UUIDs). ``get()`` enforces + # per-segment safety on attribute traversal. + pattern = r"\{([A-Za-z][A-Za-z0-9_]*(?:\.[^{}\s.]+)*)\}" - # Replace all matches result = text for match in re.finditer(pattern, text): replacement = replace_var(match) diff --git a/python/packages/declarative/tests/test_declarative_state_path_safety.py b/python/packages/declarative/tests/test_declarative_state_path_safety.py index 783607a87d..2446fc3cf4 100644 --- a/python/packages/declarative/tests/test_declarative_state_path_safety.py +++ b/python/packages/declarative/tests/test_declarative_state_path_safety.py @@ -210,7 +210,7 @@ class TestSetRejectsInvalidPaths: """Rejected set() must not create an unreachable entry in the state.""" state.set("Local.user_input", "pre") with pytest.raises(ValueError): - state.set("Local.", "leak") + state.set("Local.", "value") local = state.get_state_data().get("Local", {}) assert "" not in local assert local == {"user_input": "pre"} @@ -221,14 +221,14 @@ class TestSetRejectsInvalidPaths: """Rejected append() must not create an unreachable entry in the state.""" state.set("Local.items", ["a"]) with pytest.raises(ValueError): - state.append("Local.", "leak") + state.append("Local.", "value") local = state.get_state_data().get("Local", {}) assert "" not in local assert local == {"items": ["a"]} # --------------------------------------------------------------------------- -# interpolate_string(): invalid placeholders left intact, valid ones resolved +# interpolate_string(): permissive matcher; get() enforces safety # --------------------------------------------------------------------------- @@ -241,11 +241,17 @@ class TestInterpolateString: out = state.interpolate_string("X={Local.obj.__class__.__init__.__globals__.os.environ}") assert sentinel not in out - assert "{Local.obj.__class__" in out # placeholder left as literal text + assert out == "X=" - def test_ignores_leading_underscore_segment(self, state: DeclarativeWorkflowState) -> None: - out = state.interpolate_string("v={Local._private}") - assert out == "v={Local._private}" + def test_unknown_path_reduces_to_empty(self, state: DeclarativeWorkflowState) -> None: + assert state.interpolate_string("v={Local._private}") == "v=" + + @pytest.mark.parametrize( + "literal", + ["{foo-bar}", "{Ctrl+C}", "{not:a:path}", "{Local.}", "{}"], + ) + def test_non_state_braced_tokens_left_literal(self, state: DeclarativeWorkflowState, literal: str) -> None: + assert state.interpolate_string(f"v={literal}") == f"v={literal}" def test_allows_underscore_inside_identifier(self, state: DeclarativeWorkflowState) -> None: state.set("Local.user_input", "hello") @@ -255,12 +261,29 @@ class TestInterpolateString: state.set("Local.params", {"team": "alpha"}) assert state.interpolate_string("team={Local.params.team}") == "team=alpha" - def test_end_to_end_send_activity_literal_placeholder( + @pytest.mark.parametrize( + ("key", "value"), + [ + ("_id", "abc123"), + ("1", "one"), + ("2025", "year-bucket"), + ], + ) + def test_resolves_dict_keyed_segments(self, state: DeclarativeWorkflowState, key: str, value: str) -> None: + state.set("Local.bag", {key: value}) + assert state.interpolate_string(f"v={{Local.bag.{key}}}") == f"v={value}" + + def test_resolves_uuid_conversation_key(self, state: DeclarativeWorkflowState) -> None: + conv_id = "eb815014-06f1-4db6-b7c1-304ea135424f" + state.set(f"System.conversations.{conv_id}.messages", ["hello"]) + out = state.interpolate_string(f"m={{System.conversations.{conv_id}.messages}}") + assert out == "m=['hello']" + + def test_end_to_end_send_activity_payload_neutralized( self, state: DeclarativeWorkflowState, monkeypatch, ) -> None: - """Mirror the SendActivity flow: eval_if_expression then interpolate_string.""" sentinel = "agent-framework-e2e-sentinel" monkeypatch.setenv("AF_E2E_SENTINEL", sentinel) state.set("Local.toolResult", _PlainObj()) @@ -269,8 +292,8 @@ class TestInterpolateString: evaluated = state.eval_if_expression(payload) rendered = state.interpolate_string(evaluated) if isinstance(evaluated, str) else str(evaluated) - assert rendered == payload assert sentinel not in rendered + assert rendered == "" # --------------------------------------------------------------------------- @@ -286,7 +309,7 @@ class TestPowerFxStillWorks: assert state.eval("=Local.x * Local.y") == 42 def test_internal_temp_message_text_still_works(self, state: DeclarativeWorkflowState) -> None: - """Long MessageText() results stored in TempMessageText{n} still round-trip.""" + """Long MessageText() results round-trip and the temp key is removed after eval.""" long_text = "A" * 600 state.set( "Local.Messages", @@ -296,4 +319,46 @@ class TestPowerFxStillWorks: result = state.eval("=Upper(MessageText(Local.Messages))") assert result == "A" * 600 - assert state.get("Local.TempMessageText0") == long_text + local = state.get_state_data().get("Local", {}) + remaining = sorted(k for k in local if k.startswith("_TempMessageText")) + assert not remaining, f"Temporary keys remain in Local: {remaining}" + + def test_message_text_eval_preserves_user_temp_value(self, state: DeclarativeWorkflowState) -> None: + """User state at the temp key path survives a long MessageText eval.""" + long_text = "A" * 600 + state.set("Local._TempMessageText0", "user-important-value") + state.set( + "Local.Messages", + [{"text": long_text, "contents": [{"type": "text", "text": long_text}]}], + ) + + result = state.eval("=Upper(MessageText(Local.Messages))") + assert result == "A" * 600 + assert state.get("Local._TempMessageText0") == "user-important-value" + + def test_message_text_eval_cleans_up_on_powerfx_failure( + self, + state: DeclarativeWorkflowState, + monkeypatch, + ) -> None: + """Temp key is removed even when PowerFx evaluation raises.""" + from agent_framework_declarative._workflows import _declarative_base as base + + class _FailingEngine: + def eval(self, *args: Any, **kwargs: Any) -> Any: + raise RuntimeError("boom") + + monkeypatch.setattr(base, "Engine", _FailingEngine) + + long_text = "A" * 600 + state.set( + "Local.Messages", + [{"text": long_text, "contents": [{"type": "text", "text": long_text}]}], + ) + + with pytest.raises(RuntimeError, match="boom"): + state.eval("=Upper(MessageText(Local.Messages))") + + local = state.get_state_data().get("Local", {}) + remaining = sorted(k for k in local if k.startswith("_TempMessageText")) + assert not remaining, f"Temporary keys remain in Local after PowerFx failure: {remaining}" diff --git a/python/packages/declarative/tests/test_graph_coverage.py b/python/packages/declarative/tests/test_graph_coverage.py index 47742f2f69..bc20a27f09 100644 --- a/python/packages/declarative/tests/test_graph_coverage.py +++ b/python/packages/declarative/tests/test_graph_coverage.py @@ -2761,11 +2761,11 @@ class TestLongMessageTextHandling: assert result == "HELLO WORLD" # No temp variable should be created for short strings - temp_var = state.get("Local.TempMessageText0") + temp_var = state.get("Local._TempMessageText0") assert temp_var is None async def test_long_message_text_stored_in_temp_variable(self, mock_state): - """Test that long MessageText results are stored in temp variables.""" + """Long MessageText results round-trip and the temp key is removed after eval.""" state = DeclarativeWorkflowState(mock_state) state.initialize() @@ -2777,9 +2777,9 @@ class TestLongMessageTextHandling: result = state.eval("=Upper(MessageText(Local.Messages))") assert result == "A" * 600 # Upper on 'A' is still 'A' - # A temp variable should have been created - temp_var = state.get("Local.TempMessageText0") - assert temp_var == long_text + local = state.get_state_data().get("Local", {}) + remaining = sorted(k for k in local if k.startswith("_TempMessageText")) + assert not remaining, f"Temporary keys remain in Local: {remaining}" async def test_find_with_long_message_text(self, mock_state): """Test Find function works with long MessageText stored in temp variable."""