From 8b0405de1bb0f6852ae543a0c46c05c2b7921ccd Mon Sep 17 00:00:00 2001 From: Giles Odigwe <79032838+giles17@users.noreply.github.com> Date: Thu, 11 Jun 2026 11:18:05 -0700 Subject: [PATCH 1/4] .NET: Fix CopySessionConfig() and CopyResumeSessionConfig() to preserve SessionConfig.Streaming value (#6463) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix CopySessionConfig and CopyResumeSessionConfig ignoring Streaming value (#4732) CopySessionConfig() and CopyResumeSessionConfig() hardcoded Streaming = true, ignoring the caller's explicitly set SessionConfig.Streaming value. This made it impossible to disable streaming when using AsAIAgent() with the GitHub Copilot SDK. Changed both methods to use source.Streaming ?? true (and source?.Streaming ?? true for the nullable overload), preserving the caller's value when set while maintaining backward compatibility by defaulting to true when unset. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Fix non-streaming response path for SessionConfig.Streaming=false (#4732) The config-copy fix (preserving Streaming=false via null-coalescing) was already in place, but ConvertToAgentResponseUpdate(AssistantMessageEvent) always emitted raw AIContent without text—assuming delta events had already delivered it. When streaming is disabled there are no delta events, so the assistant's final text was silently dropped. Changes: - Add isStreaming parameter to ConvertToAgentResponseUpdate for AssistantMessageEvent so it emits TextContent in non-streaming mode. - Capture the resolved streaming flag in RunCoreStreamingAsync and pass it through the event subscription closure. - Add/update unit tests for both streaming and non-streaming paths. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Add test for null Data path in ConvertToAgentResponseUpdate (#4732) Add a regression test covering the null-propagation path where AssistantMessageEvent.Data is null. The production code already handles this via ?. operators, but no test previously verified the behavior. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../GitHubCopilotAgent.cs | 32 ++-- .../GitHubCopilotAgentTests.cs | 148 +++++++++++++++++- 2 files changed, 164 insertions(+), 16 deletions(-) diff --git a/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotAgent.cs b/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotAgent.cs index 053a08f725..a85c126fd6 100644 --- a/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotAgent.cs @@ -145,11 +145,12 @@ public sealed class GitHubCopilotAgent : AIAgent, IAsyncDisposable // Ensure the client is started await this.EnsureClientStartedAsync(cancellationToken).ConfigureAwait(false); - // Create or resume a session with streaming enabled + // Create or resume a session with streaming enabled by default SessionConfig sessionConfig = this._sessionConfig != null ? CopySessionConfig(this._sessionConfig) : new SessionConfig { Streaming = true }; + bool isStreaming = sessionConfig.Streaming ?? true; CopilotSession copilotSession; if (typedSession.SessionId is not null) { @@ -178,7 +179,7 @@ public sealed class GitHubCopilotAgent : AIAgent, IAsyncDisposable break; case AssistantMessageEvent assistantMessage: - channel.Writer.TryWrite(this.ConvertToAgentResponseUpdate(assistantMessage)); + channel.Writer.TryWrite(this.ConvertToAgentResponseUpdate(assistantMessage, isStreaming)); break; case AssistantUsageEvent usageEvent: @@ -271,19 +272,20 @@ public sealed class GitHubCopilotAgent : AIAgent, IAsyncDisposable } /// - /// Copies all supported properties from a source into a new instance - /// with set to true. + /// Copies all supported properties from a source into a new instance, + /// preserving from the source (defaulting to true if unset). /// internal static SessionConfig CopySessionConfig(SessionConfig source) { SessionConfig copy = source.Clone(); - copy.Streaming = true; + copy.Streaming = source.Streaming ?? true; return copy; } /// /// Copies all supported properties from a source into a new - /// with set to true. + /// , preserving + /// from the source (defaulting to true if unset). /// internal static ResumeSessionConfig CopyResumeSessionConfig(SessionConfig? source) { @@ -306,7 +308,7 @@ public sealed class GitHubCopilotAgent : AIAgent, IAsyncDisposable SkillDirectories = source?.SkillDirectories, DisabledSkills = source?.DisabledSkills, InfiniteSessions = source?.InfiniteSessions, - Streaming = true + Streaming = source?.Streaming ?? true }; } @@ -325,12 +327,18 @@ public sealed class GitHubCopilotAgent : AIAgent, IAsyncDisposable }; } - internal AgentResponseUpdate ConvertToAgentResponseUpdate(AssistantMessageEvent assistantMessage) + /// + /// Converts an to an . + /// When streaming is enabled, text was already delivered via delta events, so only raw metadata is emitted. + /// When streaming is disabled, the full message text is emitted as . + /// + internal AgentResponseUpdate ConvertToAgentResponseUpdate(AssistantMessageEvent assistantMessage, bool isStreaming) { - AIContent content = new() - { - RawRepresentation = assistantMessage - }; + // When streaming, text was already delivered via AssistantMessageDeltaEvent. + // When not streaming, this is the only opportunity to emit the response text. + AIContent content = isStreaming + ? new AIContent { RawRepresentation = assistantMessage } + : new TextContent(assistantMessage.Data?.Content ?? string.Empty) { RawRepresentation = assistantMessage }; return new AgentResponseUpdate(ChatRole.Assistant, [content]) { diff --git a/dotnet/tests/Microsoft.Agents.AI.GitHub.Copilot.UnitTests/GitHubCopilotAgentTests.cs b/dotnet/tests/Microsoft.Agents.AI.GitHub.Copilot.UnitTests/GitHubCopilotAgentTests.cs index 944f2f30ab..8faa842eb0 100644 --- a/dotnet/tests/Microsoft.Agents.AI.GitHub.Copilot.UnitTests/GitHubCopilotAgentTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.GitHub.Copilot.UnitTests/GitHubCopilotAgentTests.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Threading.Tasks; using GitHub.Copilot; using GitHub.Copilot.Rpc; @@ -222,7 +223,73 @@ public sealed class GitHubCopilotAgentTests } [Fact] - public void ConvertToAgentResponseUpdate_AssistantMessageEvent_DoesNotEmitTextContent() + public void CopySessionConfig_WithStreamingDisabled_PreservesStreamingValue() + { + // Arrange + var source = new SessionConfig + { + Streaming = false, + Model = "gpt-4o", + }; + + // Act + SessionConfig result = GitHubCopilotAgent.CopySessionConfig(source); + + // Assert + Assert.False(result.Streaming); + } + + [Fact] + public void CopySessionConfig_WithStreamingNull_DefaultsToTrue() + { + // Arrange + var source = new SessionConfig + { + Model = "gpt-4o", + }; + + // Act + SessionConfig result = GitHubCopilotAgent.CopySessionConfig(source); + + // Assert + Assert.True(result.Streaming); + } + + [Fact] + public void CopyResumeSessionConfig_WithStreamingDisabled_PreservesStreamingValue() + { + // Arrange + var source = new SessionConfig + { + Streaming = false, + Model = "gpt-4o", + }; + + // Act + ResumeSessionConfig result = GitHubCopilotAgent.CopyResumeSessionConfig(source); + + // Assert + Assert.False(result.Streaming); + } + + [Fact] + public void CopyResumeSessionConfig_WithStreamingNull_DefaultsToTrue() + { + // Arrange + var source = new SessionConfig + { + Model = "gpt-4o", + }; + + // Act + ResumeSessionConfig result = GitHubCopilotAgent.CopyResumeSessionConfig(source); + + // Assert + Assert.True(result.Streaming); + } + + [Fact] + public void ConvertToAgentResponseUpdate_AssistantMessageEventWhenStreaming_DoesNotEmitTextContent() { var assistantMessage = new AssistantMessageEvent { @@ -235,11 +302,84 @@ public sealed class GitHubCopilotAgentTests CopilotClient copilotClient = new(new CopilotClientOptions()); const string TestId = "agent-id"; var agent = new GitHubCopilotAgent(copilotClient, ownsClient: false, id: TestId, tools: null); - AgentResponseUpdate result = agent.ConvertToAgentResponseUpdate(assistantMessage); + AgentResponseUpdate result = agent.ConvertToAgentResponseUpdate(assistantMessage, isStreaming: true); - // result.Text need to be empty because the content was already delivered via delta events, and we want to avoid emitting duplicate content in the response update. - // The content should be delivered through TextContent in the Contents collection instead. + // result.Text should be empty because content was already delivered via delta events. Assert.Empty(result.Text); Assert.DoesNotContain(result.Contents, c => c is TextContent); } + + [Fact] + public void ConvertToAgentResponseUpdate_AssistantMessageEventWhenNotStreaming_EmitsTextContent() + { + // Arrange + const string ExpectedContent = "Full response text from non-streaming session"; + var assistantMessage = new AssistantMessageEvent + { + Data = new AssistantMessageData + { + MessageId = "msg-789", + Content = ExpectedContent + } + }; + CopilotClient copilotClient = new(new CopilotClientOptions()); + const string TestId = "agent-id"; + var agent = new GitHubCopilotAgent(copilotClient, ownsClient: false, id: TestId, tools: null); + + // Act + AgentResponseUpdate result = agent.ConvertToAgentResponseUpdate(assistantMessage, isStreaming: false); + + // Assert - text must be emitted since no delta events precede it in non-streaming mode. + Assert.Equal(ExpectedContent, result.Text); + Assert.Contains(result.Contents, c => c is TextContent); + TextContent textContent = (TextContent)result.Contents.Single(c => c is TextContent); + Assert.Equal(ExpectedContent, textContent.Text); + Assert.Same(assistantMessage, textContent.RawRepresentation); + } + + [Fact] + public void ConvertToAgentResponseUpdate_AssistantMessageEventWhenNotStreaming_HandlesEmptyContent() + { + // Arrange + var assistantMessage = new AssistantMessageEvent + { + Data = new AssistantMessageData + { + MessageId = "msg-000", + Content = string.Empty + } + }; + CopilotClient copilotClient = new(new CopilotClientOptions()); + const string TestId = "agent-id"; + var agent = new GitHubCopilotAgent(copilotClient, ownsClient: false, id: TestId, tools: null); + + // Act + AgentResponseUpdate result = agent.ConvertToAgentResponseUpdate(assistantMessage, isStreaming: false); + + // Assert - should emit empty TextContent rather than throwing. + Assert.Empty(result.Text); + Assert.Contains(result.Contents, c => c is TextContent); + } + + [Fact] + public void ConvertToAgentResponseUpdate_AssistantMessageEventWhenNotStreaming_HandlesNullData() + { + // Arrange + var assistantMessage = new AssistantMessageEvent + { + Data = null! + }; + CopilotClient copilotClient = new(new CopilotClientOptions()); + const string TestId = "agent-id"; + var agent = new GitHubCopilotAgent(copilotClient, ownsClient: false, id: TestId, tools: null); + + // Act + AgentResponseUpdate result = agent.ConvertToAgentResponseUpdate(assistantMessage, isStreaming: false); + + // Assert - null Data should produce empty TextContent via null-propagation fallback. + Assert.Empty(result.Text); + Assert.Contains(result.Contents, c => c is TextContent); + Assert.Null(result.MessageId); + Assert.Null(result.ResponseId); + } } From 3d5421edc1415d85159aa936ff8c16437328a783 Mon Sep 17 00:00:00 2001 From: westey <164392973+westey-m@users.noreply.github.com> Date: Thu, 11 Jun 2026 21:51:59 +0100 Subject: [PATCH 2/4] Python: Integrate shell tool into harness agent (#6451) * Integrate shell tool into AgentHarness * Validate shell_executor exposes as_function() with a clear TypeError Addresses PR review feedback: a public factory should fail fast with an actionable error rather than a cryptic AttributeError when an incompatible shell_executor is supplied. Validation happens upfront, regardless of whether the client supports shell tools. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Type shell harness params via TYPE_CHECKING import Addresses PR review feedback: type shell_executor and shell_environment_provider_options instead of Any, using a TYPE_CHECKING import from agent_framework_tools.shell. The import never executes at runtime, so there is no circular dependency, and the lazy runtime import of ShellEnvironmentProvider is retained. Since ShellExecutor is a protocol without as_function(), the validated getattr result is invoked directly. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../packages/core/agent_framework/__init__.py | 2 + .../packages/core/agent_framework/_clients.py | 30 +++++ .../core/agent_framework/_harness/_agent.py | 74 ++++++++++- .../core/tests/core/test_harness_agent.py | 124 ++++++++++++++++++ python/samples/02-agents/harness/README.md | 23 ++++ 5 files changed, 252 insertions(+), 1 deletion(-) diff --git a/python/packages/core/agent_framework/__init__.py b/python/packages/core/agent_framework/__init__.py index 478adb2dc7..03a32f1a9c 100644 --- a/python/packages/core/agent_framework/__init__.py +++ b/python/packages/core/agent_framework/__init__.py @@ -27,6 +27,7 @@ from ._clients import ( SupportsGetEmbeddings, SupportsImageGenerationTool, SupportsMCPTool, + SupportsShellTool, SupportsWebSearchTool, ) from ._compaction import ( @@ -506,6 +507,7 @@ __all__ = [ "SupportsGetEmbeddings", "SupportsImageGenerationTool", "SupportsMCPTool", + "SupportsShellTool", "SupportsWebSearchTool", "SwitchCaseEdgeGroup", "SwitchCaseEdgeGroupCase", diff --git a/python/packages/core/agent_framework/_clients.py b/python/packages/core/agent_framework/_clients.py index 746427bffd..f0bd051980 100644 --- a/python/packages/core/agent_framework/_clients.py +++ b/python/packages/core/agent_framework/_clients.py @@ -819,6 +819,36 @@ class SupportsFileSearchTool(Protocol): ... +@runtime_checkable +class SupportsShellTool(Protocol): + """Protocol for clients that support shell tools. + + This protocol enables runtime checking to determine if a client + supports executing shell commands. + + Examples: + .. code-block:: python + + from agent_framework import SupportsShellTool + + if isinstance(client, SupportsShellTool): + tool = client.get_shell_tool(func=shell.as_function()) + agent = ChatAgent(client, tools=[tool]) + """ + + @staticmethod + def get_shell_tool(**kwargs: Any) -> Any: + """Create a shell tool configuration. + + Keyword Args: + **kwargs: Provider-specific configuration options. + + Returns: + A tool configuration ready to pass to ChatAgent. + """ + ... + + # endregion diff --git a/python/packages/core/agent_framework/_harness/_agent.py b/python/packages/core/agent_framework/_harness/_agent.py index 0ae0c73032..1a6178b54d 100644 --- a/python/packages/core/agent_framework/_harness/_agent.py +++ b/python/packages/core/agent_framework/_harness/_agent.py @@ -15,7 +15,7 @@ from collections.abc import Callable, Sequence from typing import TYPE_CHECKING, Any from .._agents import Agent, SupportsAgentRun -from .._clients import SupportsWebSearchTool +from .._clients import SupportsShellTool, SupportsWebSearchTool from .._compaction import CompactionProvider, ContextWindowCompactionStrategy, ToolResultCompactionStrategy from .._feature_stage import ExperimentalFeature, experimental from .._sessions import ContextProvider, HistoryProvider, InMemoryHistoryProvider @@ -28,6 +28,8 @@ from ._todo import TodoProvider if TYPE_CHECKING: from collections.abc import Mapping + from agent_framework_tools.shell import ShellEnvironmentProviderOptions, ShellExecutor + from .._clients import SupportsChatGetResponse from .._compaction import CompactionStrategy, TokenizerProtocol from .._middleware import MiddlewareTypes @@ -128,6 +130,7 @@ def _assemble_context_providers( skills_paths: Sequence[str] | None, background_agents: Sequence[SupportsAgentRun] | None, background_agents_instructions: str | None, + shell_context_provider: ContextProvider | None, extra_context_providers: Sequence[ContextProvider] | None, ) -> list[ContextProvider]: """Assemble the ordered list of context providers.""" @@ -159,6 +162,10 @@ def _assemble_context_providers( if background_agents: providers.append(BackgroundAgentsProvider(background_agents, instructions=background_agents_instructions)) + # Shell environment provider is opt-in: only added when a shell tool was wired. + if shell_context_provider is not None: + providers.append(shell_context_provider) + # Append any user-supplied additional providers. if extra_context_providers: providers.extend(extra_context_providers) @@ -166,6 +173,50 @@ def _assemble_context_providers( return providers +def _assemble_shell( + client: SupportsChatGetResponse[Any], + shell_executor: ShellExecutor | None, + shell_environment_provider_options: ShellEnvironmentProviderOptions | None, +) -> tuple[ToolTypes | None, ContextProvider | None]: + """Build the shell tool and environment provider when a shell executor is supplied. + + Returns a ``(tool, provider)`` tuple. Both are ``None`` when no shell executor is + provided, or when the client does not support shell tools (a warning is logged in the + latter case, since the environment provider is not useful without an execution path). + + Raises: + TypeError: If ``shell_executor`` does not expose a callable ``as_function()`` method. + """ + if shell_executor is None: + return None, None + + # ShellExecutor is a protocol without ``as_function()``, so the + # contract is validated at runtime: a shell tool such as LocalShellTool/DockerShellTool exposes it. + as_function = getattr(shell_executor, "as_function", None) + if not callable(as_function): + raise TypeError( + f"shell_executor must expose a callable 'as_function()' method " + f"(e.g. a LocalShellTool or DockerShellTool from agent-framework-tools), " + f"but got {type(shell_executor).__name__}." + ) + + if not isinstance(client, SupportsShellTool): + logger.warning( + "Shell tool not available: client %r does not implement SupportsShellTool. " + "Skipping the shell tool and environment provider.", + type(client).__name__, + ) + return None, None + + # Imported lazily: the shell types live in the separate agent-framework-tools package, + # which depends on core, so core cannot import them at module load time. + from agent_framework_tools.shell import ShellEnvironmentProvider + + shell_tool = client.get_shell_tool(func=as_function()) + shell_provider = ShellEnvironmentProvider(shell_executor, shell_environment_provider_options) + return shell_tool, shell_provider + + HARNESS_AGENT_PROVIDER_NAME = "microsoft.agent_framework.harness" @@ -196,6 +247,8 @@ def create_harness_agent( skills_paths: Sequence[str] | None = None, background_agents: Sequence[SupportsAgentRun] | None = None, background_agents_instructions: str | None = None, + shell_executor: ShellExecutor | None = None, + shell_environment_provider_options: ShellEnvironmentProviderOptions | None = None, disable_web_search: bool = False, otel_provider_name: str | None = None, context_providers: Sequence[ContextProvider] | None = None, @@ -298,6 +351,15 @@ def create_harness_agent( background_agents_instructions: Optional instruction override for the ``BackgroundAgentsProvider``. May include ``{background_agents}`` placeholder which will be replaced with the agent listing. + shell_executor: Optional shell tool that enables shell command execution. When + provided, the shell tool and a ``ShellEnvironmentProvider`` are automatically + added (provided the client supports shell tools; otherwise a warning is logged + and both are skipped). The object must expose ``as_function()`` and satisfy the + ``ShellExecutor`` protocol -- e.g. a ``LocalShellTool`` or ``DockerShellTool`` from + the ``agent-framework-tools`` package. The caller owns the executor's lifecycle. + shell_environment_provider_options: Optional ``ShellEnvironmentProviderOptions`` + (from ``agent-framework-tools``) used to customize the ``ShellEnvironmentProvider`` + environment probing and instructions. Only used when ``shell_executor`` is provided. disable_web_search: When True, skip automatic web search tool inclusion. When False (default), the web search tool is automatically added if the client implements SupportsWebSearchTool. A warning is logged if the client @@ -340,6 +402,13 @@ def create_harness_agent( tokenizer=tokenizer, ) + # Build the shell tool and environment provider (opt-in via shell_executor). + shell_tool, shell_provider = _assemble_shell( + client, + shell_executor, + shell_environment_provider_options, + ) + # Build context providers. assembled_providers = _assemble_context_providers( history_provider=resolved_history, @@ -354,6 +423,7 @@ def create_harness_agent( skills_paths=skills_paths, background_agents=background_agents, background_agents_instructions=background_agents_instructions, + shell_context_provider=shell_provider, extra_context_providers=context_providers, ) @@ -371,6 +441,8 @@ def create_harness_agent( "Set disable_web_search=True to suppress this warning.", type(client).__name__, ) + if shell_tool is not None: + assembled_tools.append(shell_tool) if tools is not None: if isinstance(tools, Sequence): assembled_tools.extend(tools) # pyright: ignore[reportUnknownArgumentType] diff --git a/python/packages/core/tests/core/test_harness_agent.py b/python/packages/core/tests/core/test_harness_agent.py index 7da1bdbf36..0a280e87ed 100644 --- a/python/packages/core/tests/core/test_harness_agent.py +++ b/python/packages/core/tests/core/test_harness_agent.py @@ -543,3 +543,127 @@ def test_create_harness_agent_empty_background_agents_list() -> None: ) providers = agent.context_providers or [] assert not any(isinstance(p, BackgroundAgentsProvider) for p in providers) + + +# --- Shell Tool Tests --- + + +class _FakeShellTool: + """Fake shell executor/tool exposing as_function().""" + + def as_function(self) -> str: + return "shell_fn" + + +class _FakeShellClient(_FakeChatClient): + """Fake client that supports the shell tool.""" + + def __init__(self) -> None: + self.shell_func: Any = None + + def get_shell_tool(self, *, func: Any = None, **kwargs: Any) -> str: + self.shell_func = func + return "shell_tool_instance" + + +def test_create_harness_agent_adds_shell_tool_and_provider() -> None: + """Shell tool and ShellEnvironmentProvider should be added when a shell executor is supplied.""" + from agent_framework_tools.shell import ShellEnvironmentProvider + + client = _FakeShellClient() + agent = create_harness_agent( + client=client, # type: ignore[arg-type] + max_context_window_tokens=128_000, + max_output_tokens=16_384, + disable_web_search=True, + shell_executor=_FakeShellTool(), + ) + tools = agent.default_options.get("tools", []) + assert "shell_tool_instance" in tools + assert client.shell_func == "shell_fn" + providers = agent.context_providers or [] + assert any(isinstance(p, ShellEnvironmentProvider) for p in providers) + + +def test_create_harness_agent_shell_passes_custom_options() -> None: + """Custom ShellEnvironmentProviderOptions should be forwarded to the provider.""" + from agent_framework_tools.shell import ShellEnvironmentProvider, ShellEnvironmentProviderOptions + + options = ShellEnvironmentProviderOptions(probe_tools=("git",)) + agent = create_harness_agent( + client=_FakeShellClient(), # type: ignore[arg-type] + max_context_window_tokens=128_000, + max_output_tokens=16_384, + disable_web_search=True, + shell_executor=_FakeShellTool(), + shell_environment_provider_options=options, + ) + providers = agent.context_providers or [] + provider = next(p for p in providers if isinstance(p, ShellEnvironmentProvider)) + assert provider._options is options + + +def test_create_harness_agent_shell_skipped_when_unsupported(caplog: pytest.LogCaptureFixture) -> None: + """When the client lacks get_shell_tool, both the tool and provider are skipped with a warning.""" + import logging + + from agent_framework_tools.shell import ShellEnvironmentProvider + + with caplog.at_level(logging.WARNING, logger="agent_framework._harness._agent"): + agent = create_harness_agent( + client=_FakeChatClient(), # type: ignore[arg-type] + max_context_window_tokens=128_000, + max_output_tokens=16_384, + disable_web_search=True, + shell_executor=_FakeShellTool(), + ) + assert any("SupportsShellTool" in msg for msg in caplog.messages) + providers = agent.context_providers or [] + assert not any(isinstance(p, ShellEnvironmentProvider) for p in providers) + assert "tools" not in agent.default_options or not agent.default_options.get("tools") + + +def test_create_harness_agent_no_shell_by_default() -> None: + """No shell tool or provider should be added when shell_executor is not provided.""" + from agent_framework_tools.shell import ShellEnvironmentProvider + + agent = create_harness_agent( + client=_FakeShellClient(), # type: ignore[arg-type] + max_context_window_tokens=128_000, + max_output_tokens=16_384, + disable_web_search=True, + ) + providers = agent.context_providers or [] + assert not any(isinstance(p, ShellEnvironmentProvider) for p in providers) + + +def test_create_harness_agent_shell_executor_without_as_function_raises() -> None: + """A shell_executor lacking a callable as_function() should raise a clear TypeError.""" + + class _BadExecutor: + pass + + with pytest.raises(TypeError, match="as_function"): + create_harness_agent( + client=_FakeShellClient(), # type: ignore[arg-type] + max_context_window_tokens=128_000, + max_output_tokens=16_384, + disable_web_search=True, + shell_executor=_BadExecutor(), + ) + + +def test_create_harness_agent_shell_executor_validated_before_client_check() -> None: + """The as_function() contract is validated upfront, even when the client lacks shell support.""" + + class _BadExecutor: + pass + + with pytest.raises(TypeError, match="as_function"): + create_harness_agent( + client=_FakeChatClient(), # type: ignore[arg-type] + max_context_window_tokens=128_000, + max_output_tokens=16_384, + disable_web_search=True, + shell_executor=_BadExecutor(), + ) diff --git a/python/samples/02-agents/harness/README.md b/python/samples/02-agents/harness/README.md index 15424e1422..e162c1ff45 100644 --- a/python/samples/02-agents/harness/README.md +++ b/python/samples/02-agents/harness/README.md @@ -17,6 +17,7 @@ from a chat client. | AgentModeProvider | Plan/execute mode tracking | | MemoryContextProvider | File-based durable memory (when `memory_store` provided) | | SkillsProvider | File-based skill discovery and progressive loading | +| Shell tool | Shell command execution + environment probing (when `shell_executor` provided) | | OpenTelemetry | Built-in observability | Each feature can be disabled or customized via keyword arguments. @@ -91,3 +92,25 @@ agent = create_harness_agent( The `AgentModeProvider` enables a two-phase workflow: 1. **Plan mode** — Interactive: the agent asks questions, creates todos, gets approval 2. **Execute mode** — Autonomous: the agent works through todos independently + +### Shell Tool + +Pass a shell executor (e.g. `LocalShellTool` from `agent-framework-tools`) to enable shell +command execution plus automatic environment probing via a `ShellEnvironmentProvider`. The +tool is only wired when the chat client supports shell tools; otherwise a warning is logged +and the shell tool/provider are skipped. The caller owns the executor's lifecycle. + +```python +from agent_framework_tools.shell import LocalShellTool, ShellEnvironmentProviderOptions + +async with LocalShellTool(acknowledge_unsafe=True) as shell: + agent = create_harness_agent( + client=client, + max_context_window_tokens=128_000, + max_output_tokens=16_384, + shell_executor=shell, + # Optional: customize environment probing. + shell_environment_provider_options=ShellEnvironmentProviderOptions(probe_tools=("git", "python")), + ) +``` + From e7937947d91ffc129d8e885644c8a5f365be075a Mon Sep 17 00:00:00 2001 From: Peter Ibekwe <109177538+peibekwe@users.noreply.github.com> Date: Thu, 11 Jun 2026 15:34:15 -0700 Subject: [PATCH 3/4] Python: Bug fix for declarative workflows (#6468) * Fix declarative object parsing bug * Remove unnecessary comment * Address PR comments * Address PR comments. * Fix CI failures. --- .../_workflows/_declarative_base.py | 181 ++++++--- .../test_declarative_state_path_safety.py | 364 ++++++++++++++++++ .../declarative/tests/test_graph_coverage.py | 8 +- 3 files changed, 489 insertions(+), 64 deletions(-) create mode 100644 python/packages/declarative/tests/test_declarative_state_path_safety.py diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py b/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py index e6fc0a820d..6a035a448a 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_declarative_base.py @@ -63,6 +63,9 @@ logger = logging.getLogger(__name__) _ENV_REFERENCE_RE = re.compile(r"\bEnv\.([A-Za-z_][A-Za-z0-9_]*)") +# Allowed identifier shape for object-attribute steps in declarative state paths +_SAFE_PATH_SEGMENT_RE = re.compile(r"^[A-Za-z][A-Za-z0-9_]*$") + @dataclass(frozen=True) class DeclarativeEnvConfig: @@ -266,6 +269,9 @@ class DeclarativeWorkflowState: - Conversation: Conversation history """ + # Sentinel marking "no prior value" for temporary-key bookkeeping. + _MISSING: Any = object() + def __init__(self, state: State, env_config: DeclarativeEnvConfig | None = None): """Initialize with a State instance. @@ -331,16 +337,21 @@ class DeclarativeWorkflowState: def get(self, path: str, default: Any = None) -> Any: """Get a value from the state using a dot-notated path. + Dict-keyed segments may use arbitrary string keys (e.g. UUIDs in + ``System.conversations..messages``). Segments that would resolve + via object-attribute access must be valid declarative identifiers + (``[A-Za-z][A-Za-z0-9_]*``); other shapes return ``default``. + Args: path: Dot-notated path like 'Local.results' or 'Workflow.Inputs.query' default: Default value if path doesn't exist Returns: - The value at the path, or default if not found + The value at the path, or default if not found or unreachable. """ state_data = self.get_state_data() parts = path.split(".") - if not parts: + if not parts or any(not p for p in parts): return default namespace = parts[0] @@ -377,10 +388,19 @@ class DeclarativeWorkflowState: obj = obj.get(part, default) # type: ignore[union-attr] if obj is default: return default - elif hasattr(obj, part): # type: ignore[arg-type] - obj = getattr(obj, part) # type: ignore[arg-type] else: - return default + # Attribute access is only allowed for safe declarative identifiers. + if not _SAFE_PATH_SEGMENT_RE.match(part): + logger.warning( + "DeclarativeWorkflowState.get: rejecting attribute segment %r in path %r", + part, + path, + ) + return default + if hasattr(obj, part): # type: ignore[arg-type] + obj = getattr(obj, part) # type: ignore[arg-type] + else: + return default return obj # type: ignore[return-value] @@ -392,12 +412,14 @@ class DeclarativeWorkflowState: value: The value to set Raises: - ValueError: If attempting to set Workflow.Inputs (which is read-only) + ValueError: If ``path`` is empty or contains empty segments + (e.g. ``"Local."``, ``"Local..foo"``), or if attempting to set + ``Workflow.Inputs`` (which is read-only). """ state_data = self.get_state_data() parts = path.split(".") - if not parts: - return + if not parts or any(not p for p in parts): + raise ValueError(f"Invalid path {path!r}: empty segments are not allowed") namespace = parts[0] remaining = parts[1:] @@ -453,7 +475,16 @@ class DeclarativeWorkflowState: Args: path: Dot-notated path to a list value: The value to append + + Raises: + ValueError: If ``path`` is empty or contains empty segments + (e.g. ``"Local."``, ``"Local..foo"``), or if the existing + value at ``path`` is not a list. """ + parts = path.split(".") + if not parts or any(not p for p in parts): + raise ValueError(f"Invalid path {path!r}: empty segments are not allowed") + existing = self.get(path) if existing is None: self.set(path, [value]) @@ -464,6 +495,15 @@ class DeclarativeWorkflowState: else: raise ValueError(f"Cannot append to non-list at path '{path}'") + def _clear_local_path(self, name: str) -> None: + """Remove ``name`` from the ``Local`` namespace, if present.""" + state_data = self.get_state_data() + local = state_data.get("Local") + if local is None or name not in local: + return + local.pop(name, None) + self.set_state_data(state_data) + def eval(self, expression: str) -> Any: """Evaluate a PowerFx expression with the current state. @@ -504,53 +544,64 @@ class DeclarativeWorkflowState: return result # Pre-process nested custom functions (e.g., Upper(MessageText(...))) - # Replace them with their evaluated results before sending to PowerFx - formula = self._preprocess_custom_functions(formula) + # and run PowerFx. The finally below restores any temporary state + # written during preprocessing, regardless of where execution exits. + temp_writes: list[tuple[str, Any]] = [] - if Engine is None: - raise RuntimeError( - f"PowerFx is not available (dotnet runtime not installed). " - f"Expression '={formula[:80]}' cannot be evaluated. " - f"Install dotnet and the powerfx package for full PowerFx support." - ) - - symbols = self._to_powerfx_symbols() - # Use setlocale(category) query form so we can restore the exact prior value. - # getlocale() returns a normalized tuple and is not always a lossless - # round-trip for setlocale across platforms/locales. - original_numeric_locale = locale.setlocale(locale.LC_NUMERIC) try: - for locale_candidate in _POWERFX_NUMERIC_LOCALE_CANDIDATES: - try: - locale.setlocale(locale.LC_NUMERIC, locale_candidate) - break - except locale.Error: - continue + formula = self._preprocess_custom_functions(formula, temp_writes) - engine = Engine() - try: - from System.Globalization import ( # pyright: ignore[reportMissingImports] - CultureInfo, # pyright: ignore[reportUnknownVariableType] + if Engine is None: + raise RuntimeError( + f"PowerFx is not available (dotnet runtime not installed). " + f"Expression '={formula[:80]}' cannot be evaluated. " + f"Install dotnet and the powerfx package for full PowerFx support." ) - except ImportError: - return engine.eval(formula, symbols=symbols, locale=_POWERFX_EVAL_LOCALE) - original_culture = cast(Any, CultureInfo.CurrentCulture) # pyright: ignore[reportUnknownMemberType] + symbols = self._to_powerfx_symbols() + # Use setlocale(category) query form so we can restore the exact prior value. + # getlocale() returns a normalized tuple and is not always a lossless + # round-trip for setlocale across platforms/locales. + original_numeric_locale = locale.setlocale(locale.LC_NUMERIC) try: - CultureInfo.CurrentCulture = CultureInfo(_POWERFX_EVAL_LOCALE) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] - return engine.eval(formula, symbols=symbols, locale=_POWERFX_EVAL_LOCALE) + for locale_candidate in _POWERFX_NUMERIC_LOCALE_CANDIDATES: + try: + locale.setlocale(locale.LC_NUMERIC, locale_candidate) + break + except locale.Error: + continue + + engine = Engine() + try: + from System.Globalization import ( # pyright: ignore[reportMissingImports] + CultureInfo, # pyright: ignore[reportUnknownVariableType] + ) + except ImportError: + return engine.eval(formula, symbols=symbols, locale=_POWERFX_EVAL_LOCALE) + + original_culture = cast(Any, CultureInfo.CurrentCulture) # pyright: ignore[reportUnknownMemberType] + try: + CultureInfo.CurrentCulture = CultureInfo(_POWERFX_EVAL_LOCALE) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] + return engine.eval(formula, symbols=symbols, locale=_POWERFX_EVAL_LOCALE) + finally: + CultureInfo.CurrentCulture = original_culture # pyright: ignore[reportUnknownMemberType] + except ValueError as e: + error_msg = str(e) + # Handle undefined variable errors gracefully by returning None + # This matches the behavior of the legacy fallback parser + if "isn't recognized" in error_msg or "Name isn't valid" in error_msg: + logger.debug(f"PowerFx: undefined variable in expression '{formula}', returning None") + return None + raise finally: - CultureInfo.CurrentCulture = original_culture # pyright: ignore[reportUnknownMemberType] - except ValueError as e: - error_msg = str(e) - # Handle undefined variable errors gracefully by returning None - # This matches the behavior of the legacy fallback parser - if "isn't recognized" in error_msg or "Name isn't valid" in error_msg: - logger.debug(f"PowerFx: undefined variable in expression '{formula}', returning None") - return None - raise + locale.setlocale(locale.LC_NUMERIC, original_numeric_locale) finally: - locale.setlocale(locale.LC_NUMERIC, original_numeric_locale) + # Restore each temporary key to its prior value (or remove it). + for path, previous in reversed(temp_writes): + if previous is self._MISSING: + self._clear_local_path(path.removeprefix("Local.")) + else: + self.set(path, previous) def _eval_custom_function(self, formula: str) -> Any | None: """Handle custom functions not supported by the Python PowerFx library. @@ -609,7 +660,7 @@ class DeclarativeWorkflowState: return None - def _preprocess_custom_functions(self, formula: str) -> str: + def _preprocess_custom_functions(self, formula: str, temp_writes: list[tuple[str, Any]]) -> str: """Pre-process custom functions nested inside other PowerFx functions. Custom functions like MessageText() are not supported by the PowerFx engine. @@ -624,9 +675,14 @@ class DeclarativeWorkflowState: Args: formula: The PowerFx formula to pre-process + temp_writes: Caller-owned list. Each write to a temporary key + appends a ``(path, previous_value)`` entry where + ``previous_value`` is the value at ``path`` before the write + or :attr:`_MISSING` if none. The caller must restore every + entry, including when this method raises mid-write. Returns: - The formula with custom function calls replaced by their evaluated results + The rewritten formula. """ import re @@ -635,7 +691,6 @@ class DeclarativeWorkflowState: # We use 500 to leave room for the rest of the expression around the replaced value. MAX_INLINE_LENGTH = 500 - # Counter for generating unique temp variable names temp_var_counter = 0 # Custom functions that need pre-processing: (regex pattern, handler) @@ -691,11 +746,14 @@ class DeclarativeWorkflowState: # Replace in formula if isinstance(replacement, str): if len(replacement) > MAX_INLINE_LENGTH: - # Store long strings in a temp variable to avoid PowerFx expression limit + # Store long results in an underscore-prefixed temp key; + # record the prior value so eval() can restore it. temp_var_name = f"_TempMessageText{temp_var_counter}" temp_var_counter += 1 - self.set(f"Local.{temp_var_name}", replacement) - replacement_str = f"Local.{temp_var_name}" + temp_var_path = f"Local.{temp_var_name}" + temp_writes.append((temp_var_path, self.get(temp_var_path, default=self._MISSING))) + self.set(temp_var_path, replacement) + replacement_str = temp_var_path logger.debug( f"Stored long MessageText result ({len(replacement)} chars) " f"in temp variable {temp_var_name}" @@ -847,11 +905,13 @@ class DeclarativeWorkflowState: return value def interpolate_string(self, text: str) -> str: - """Interpolate {Variable.Path} references in a string. + """Interpolate ``{Variable.Path}`` references in a string. - This handles template-style variable substitution like: - - "Created ticket #{Local.TicketParameters.TicketId}" - - "Routing to {Local.RoutingParameters.TeamName}" + Captures brace-delimited tokens whose root segment is an identifier + (``[A-Za-z][A-Za-z0-9_]*``) followed by zero or more ``.`` separated + dict-key segments. Resolution is delegated to :meth:`get`; unresolved + tokens are replaced with the empty string. Tokens that do not look + like state paths (e.g. ``{foo-bar}``, ``{Ctrl+C}``) are left literal. Args: text: Text that may contain {Variable.Path} references @@ -866,10 +926,11 @@ class DeclarativeWorkflowState: value = self.get(var_path) return str(value) if value is not None else "" - # Match {Variable.Path} patterns - pattern = r"\{([A-Za-z][A-Za-z0-9_.]*)\}" + # Root segment must be an identifier; follow-on segments accept any + # non-empty dict-key (e.g. ``_id``, ``1``, UUIDs). ``get()`` enforces + # per-segment safety on attribute traversal. + pattern = r"\{([A-Za-z][A-Za-z0-9_]*(?:\.[^{}\s.]+)*)\}" - # Replace all matches result = text for match in re.finditer(pattern, text): replacement = replace_var(match) diff --git a/python/packages/declarative/tests/test_declarative_state_path_safety.py b/python/packages/declarative/tests/test_declarative_state_path_safety.py new file mode 100644 index 0000000000..2446fc3cf4 --- /dev/null +++ b/python/packages/declarative/tests/test_declarative_state_path_safety.py @@ -0,0 +1,364 @@ +# Copyright (c) Microsoft. All rights reserved. +# pyright: reportUnknownParameterType=false, reportUnknownArgumentType=false +# pyright: reportMissingParameterType=false, reportUnknownMemberType=false +# pyright: reportPrivateUsage=false, reportUnknownVariableType=false +# pyright: reportGeneralTypeIssues=false + +"""Path-segment validation tests for DeclarativeWorkflowState. + +Path segments handed to ``get``/``set``/``append`` and ``{Variable.Path}`` +placeholders in ``interpolate_string`` are subject to three distinct rules +that this module pins: + +- **Empty segments** (e.g. ``""``, ``"Local."``, ``"Local..foo"``) are rejected + by all of ``get``/``set``/``append`` and ``interpolate_string``. ``get`` and + ``interpolate_string`` return their default / leave the placeholder literal; + ``set`` and ``append`` raise ``ValueError``. +- **Object-attribute segments** — segments that ``get`` would resolve via + ``getattr`` because the parent is a non-dict object — must match the safe + identifier shape ``[A-Za-z][A-Za-z0-9_]*``. Other shapes are rejected with a + warning log and the default is returned. +- **Dict-keyed segments** — segments that resolve via dict lookup because the + parent is a ``dict`` — may use arbitrary non-empty string keys (e.g. UUIDs + or hyphenated identifiers like ``System.conversations..messages``). +""" + +import logging +from dataclasses import dataclass +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from agent_framework_declarative._workflows import DeclarativeWorkflowState + +try: + import powerfx # noqa: F401 + + _powerfx_available = True +except (ImportError, RuntimeError): + _powerfx_available = False + +_requires_powerfx = pytest.mark.skipif(not _powerfx_available, reason="PowerFx engine not available") + + +@pytest.fixture +def mock_state() -> MagicMock: + """In-memory mock for the underlying State.""" + ms = MagicMock() + ms._data = {} + + def get(key: str, default: Any = None) -> Any: + return ms._data.get(key, default) + + def set_(key: str, value: Any) -> None: + ms._data[key] = value + + def has(key: str) -> bool: + return key in ms._data + + def delete(key: str) -> None: + ms._data.pop(key, None) + + ms.get = MagicMock(side_effect=get) + ms.set = MagicMock(side_effect=set_) + ms.has = MagicMock(side_effect=has) + ms.delete = MagicMock(side_effect=delete) + return ms + + +@pytest.fixture +def state(mock_state: MagicMock) -> DeclarativeWorkflowState: + s = DeclarativeWorkflowState(mock_state) + s.initialize() + return s + + +@dataclass +class _PlainObj: + """Non-dict object so ``get`` falls through to attribute access.""" + + text: str = "hi" + + +# --------------------------------------------------------------------------- +# get(): invalid paths return default +# --------------------------------------------------------------------------- + + +class TestGetRejectsInvalidPaths: + def test_rejects_dunder_segment_via_attribute_access(self, state: DeclarativeWorkflowState) -> None: + state.set("Local.obj", _PlainObj()) + assert state.get("Local.obj.__class__") is None + assert state.get("Local.obj.__class__", default="DEF") == "DEF" + + def test_rejects_full_env_exfil_chain(self, state: DeclarativeWorkflowState, monkeypatch) -> None: + sentinel = "agent-framework-path-safety-sentinel" + monkeypatch.setenv("AF_PATH_SAFETY_SENTINEL", sentinel) + state.set("Local.obj", _PlainObj()) + + result = state.get("Local.obj.__class__.__init__.__globals__.os.environ") + + assert result is None + assert sentinel not in str(result) + + def test_rejects_leading_underscore_via_attribute_access(self, state: DeclarativeWorkflowState) -> None: + state.set("Local.obj", _PlainObj()) + assert state.get("Local.obj._private") is None + + def test_rejects_invalid_chars_via_attribute_access(self, state: DeclarativeWorkflowState) -> None: + state.set("Local.obj", _PlainObj()) + assert state.get("Local.obj.text bar") is None + assert state.get("Local.obj.text-bar") is None + + def test_rejects_empty_path_and_empty_segments(self, state: DeclarativeWorkflowState) -> None: + assert state.get("") is None + assert state.get(".") is None + assert state.get("Local.") is None + assert state.get(".Local") is None + + def test_warning_logged_on_rejected_attribute_segment( + self, + state: DeclarativeWorkflowState, + caplog: pytest.LogCaptureFixture, + ) -> None: + state.set("Local.obj", _PlainObj()) + with caplog.at_level(logging.WARNING, logger="agent_framework_declarative._workflows._declarative_base"): + state.get("Local.obj.__class__") + assert any("rejecting attribute segment" in r.message for r in caplog.records) + + def test_dict_keyed_dunder_is_not_attribute_access(self, state: DeclarativeWorkflowState) -> None: + """A literal dunder dict key is harmless because dict lookup never reaches getattr.""" + state.set("Local.bag", {"__class__": "harmless-string"}) + assert state.get("Local.bag.__class__") == "harmless-string" + + +# --------------------------------------------------------------------------- +# get(): legitimate paths continue to work +# --------------------------------------------------------------------------- + + +class TestGetAllowsValidPaths: + def test_underscore_inside_identifier(self, state: DeclarativeWorkflowState) -> None: + state.set("Local.user_input", "ok") + assert state.get("Local.user_input") == "ok" + + def test_mixed_case_identifiers(self, state: DeclarativeWorkflowState) -> None: + state.set("Local.UserInput", "u1") + state.set("Local.userInput", "u2") + assert state.get("Local.UserInput") == "u1" + assert state.get("Local.userInput") == "u2" + + def test_object_attribute_traversal_still_works(self, state: DeclarativeWorkflowState) -> None: + state.set("Local.msg", _PlainObj(text="hello")) + assert state.get("Local.msg.text") == "hello" + + def test_nested_dict_traversal_still_works(self, state: DeclarativeWorkflowState) -> None: + state.set("Local.params", {"team": {"name": "alpha"}}) + assert state.get("Local.params.team.name") == "alpha" + + def test_uuid_and_hyphenated_dict_keys_are_allowed(self, state: DeclarativeWorkflowState) -> None: + """Conversation-id style paths use arbitrary dict keys (UUIDs / hyphens).""" + conv_id = "eb815014-06f1-4db6-b7c1-304ea135424f" + state.set(f"System.conversations.{conv_id}.messages", ["m1", "m2"]) + assert state.get(f"System.conversations.{conv_id}.messages") == ["m1", "m2"] + + +# --------------------------------------------------------------------------- +# set() / append(): dict-keyed operations accept arbitrary string keys +# --------------------------------------------------------------------------- + + +class TestSetAndAppend: + def test_set_allows_underscore_inside_identifier(self, state: DeclarativeWorkflowState) -> None: + state.set("Local.user_input", "ok") + assert state.get("Local.user_input") == "ok" + + def test_set_allows_uuid_and_hyphenated_dict_keys(self, state: DeclarativeWorkflowState) -> None: + conv_id = "conv-test-1" + state.set(f"System.conversations.{conv_id}.messages", []) + assert state.get(f"System.conversations.{conv_id}.messages") == [] + + def test_append_allows_uuid_and_hyphenated_dict_keys(self, state: DeclarativeWorkflowState) -> None: + conv_id = "conv-42" + state.append(f"System.conversations.{conv_id}.messages", {"role": "user", "text": "hi"}) + msgs = state.get(f"System.conversations.{conv_id}.messages") + assert msgs == [{"role": "user", "text": "hi"}] + + def test_workflow_inputs_still_read_only(self, state: DeclarativeWorkflowState) -> None: + with pytest.raises(ValueError, match="read-only"): + state.set("Workflow.Inputs.x", 1) + + +# --------------------------------------------------------------------------- +# set() / append(): malformed paths (empty segments) raise ValueError +# --------------------------------------------------------------------------- + + +class TestSetRejectsInvalidPaths: + @pytest.mark.parametrize("bad_path", ["", "Local.", "Local..foo", ".Local"]) + def test_set_rejects_empty_segment(self, state: DeclarativeWorkflowState, bad_path: str) -> None: + with pytest.raises(ValueError, match="empty segments are not allowed"): + state.set(bad_path, "x") + + @pytest.mark.parametrize("bad_path", ["", "Local.", "Local..foo", ".Local"]) + def test_append_rejects_empty_segment(self, state: DeclarativeWorkflowState, bad_path: str) -> None: + with pytest.raises(ValueError, match="empty segments are not allowed"): + state.append(bad_path, "x") + + def test_set_rejection_makes_no_partial_write(self, state: DeclarativeWorkflowState) -> None: + """Rejected set() must not create an unreachable entry in the state.""" + state.set("Local.user_input", "pre") + with pytest.raises(ValueError): + state.set("Local.", "value") + local = state.get_state_data().get("Local", {}) + assert "" not in local + assert local == {"user_input": "pre"} + assert state.get("Local.") is None + assert state.get("Local.user_input") == "pre" + + def test_append_rejection_makes_no_partial_write(self, state: DeclarativeWorkflowState) -> None: + """Rejected append() must not create an unreachable entry in the state.""" + state.set("Local.items", ["a"]) + with pytest.raises(ValueError): + state.append("Local.", "value") + local = state.get_state_data().get("Local", {}) + assert "" not in local + assert local == {"items": ["a"]} + + +# --------------------------------------------------------------------------- +# interpolate_string(): permissive matcher; get() enforces safety +# --------------------------------------------------------------------------- + + +class TestInterpolateString: + def test_ignores_dunder_payload(self, state: DeclarativeWorkflowState, monkeypatch) -> None: + sentinel = "agent-framework-interp-sentinel" + monkeypatch.setenv("AF_INTERP_SENTINEL", sentinel) + state.set("Local.obj", _PlainObj()) + + out = state.interpolate_string("X={Local.obj.__class__.__init__.__globals__.os.environ}") + + assert sentinel not in out + assert out == "X=" + + def test_unknown_path_reduces_to_empty(self, state: DeclarativeWorkflowState) -> None: + assert state.interpolate_string("v={Local._private}") == "v=" + + @pytest.mark.parametrize( + "literal", + ["{foo-bar}", "{Ctrl+C}", "{not:a:path}", "{Local.}", "{}"], + ) + def test_non_state_braced_tokens_left_literal(self, state: DeclarativeWorkflowState, literal: str) -> None: + assert state.interpolate_string(f"v={literal}") == f"v={literal}" + + def test_allows_underscore_inside_identifier(self, state: DeclarativeWorkflowState) -> None: + state.set("Local.user_input", "hello") + assert state.interpolate_string("v={Local.user_input}") == "v=hello" + + def test_resolves_nested_dict_path(self, state: DeclarativeWorkflowState) -> None: + state.set("Local.params", {"team": "alpha"}) + assert state.interpolate_string("team={Local.params.team}") == "team=alpha" + + @pytest.mark.parametrize( + ("key", "value"), + [ + ("_id", "abc123"), + ("1", "one"), + ("2025", "year-bucket"), + ], + ) + def test_resolves_dict_keyed_segments(self, state: DeclarativeWorkflowState, key: str, value: str) -> None: + state.set("Local.bag", {key: value}) + assert state.interpolate_string(f"v={{Local.bag.{key}}}") == f"v={value}" + + def test_resolves_uuid_conversation_key(self, state: DeclarativeWorkflowState) -> None: + conv_id = "eb815014-06f1-4db6-b7c1-304ea135424f" + state.set(f"System.conversations.{conv_id}.messages", ["hello"]) + out = state.interpolate_string(f"m={{System.conversations.{conv_id}.messages}}") + assert out == "m=['hello']" + + def test_end_to_end_send_activity_payload_neutralized( + self, + state: DeclarativeWorkflowState, + monkeypatch, + ) -> None: + sentinel = "agent-framework-e2e-sentinel" + monkeypatch.setenv("AF_E2E_SENTINEL", sentinel) + state.set("Local.toolResult", _PlainObj()) + + payload = "{Local.toolResult.__class__.__init__.__globals__.os.environ}" + evaluated = state.eval_if_expression(payload) + rendered = state.interpolate_string(evaluated) if isinstance(evaluated, str) else str(evaluated) + + assert sentinel not in rendered + assert rendered == "" + + +# --------------------------------------------------------------------------- +# Regressions: PowerFx and internal temp-variable handling still work +# --------------------------------------------------------------------------- + + +@_requires_powerfx +class TestPowerFxStillWorks: + def test_simple_powerfx_expression_evaluates(self, state: DeclarativeWorkflowState) -> None: + state.set("Local.x", 6) + state.set("Local.y", 7) + assert state.eval("=Local.x * Local.y") == 42 + + def test_internal_temp_message_text_still_works(self, state: DeclarativeWorkflowState) -> None: + """Long MessageText() results round-trip and the temp key is removed after eval.""" + long_text = "A" * 600 + state.set( + "Local.Messages", + [{"text": long_text, "contents": [{"type": "text", "text": long_text}]}], + ) + + result = state.eval("=Upper(MessageText(Local.Messages))") + assert result == "A" * 600 + + local = state.get_state_data().get("Local", {}) + remaining = sorted(k for k in local if k.startswith("_TempMessageText")) + assert not remaining, f"Temporary keys remain in Local: {remaining}" + + def test_message_text_eval_preserves_user_temp_value(self, state: DeclarativeWorkflowState) -> None: + """User state at the temp key path survives a long MessageText eval.""" + long_text = "A" * 600 + state.set("Local._TempMessageText0", "user-important-value") + state.set( + "Local.Messages", + [{"text": long_text, "contents": [{"type": "text", "text": long_text}]}], + ) + + result = state.eval("=Upper(MessageText(Local.Messages))") + assert result == "A" * 600 + assert state.get("Local._TempMessageText0") == "user-important-value" + + def test_message_text_eval_cleans_up_on_powerfx_failure( + self, + state: DeclarativeWorkflowState, + monkeypatch, + ) -> None: + """Temp key is removed even when PowerFx evaluation raises.""" + from agent_framework_declarative._workflows import _declarative_base as base + + class _FailingEngine: + def eval(self, *args: Any, **kwargs: Any) -> Any: + raise RuntimeError("boom") + + monkeypatch.setattr(base, "Engine", _FailingEngine) + + long_text = "A" * 600 + state.set( + "Local.Messages", + [{"text": long_text, "contents": [{"type": "text", "text": long_text}]}], + ) + + with pytest.raises(RuntimeError, match="boom"): + state.eval("=Upper(MessageText(Local.Messages))") + + local = state.get_state_data().get("Local", {}) + remaining = sorted(k for k in local if k.startswith("_TempMessageText")) + assert not remaining, f"Temporary keys remain in Local after PowerFx failure: {remaining}" diff --git a/python/packages/declarative/tests/test_graph_coverage.py b/python/packages/declarative/tests/test_graph_coverage.py index f114c8f0ae..bc20a27f09 100644 --- a/python/packages/declarative/tests/test_graph_coverage.py +++ b/python/packages/declarative/tests/test_graph_coverage.py @@ -2765,7 +2765,7 @@ class TestLongMessageTextHandling: assert temp_var is None async def test_long_message_text_stored_in_temp_variable(self, mock_state): - """Test that long MessageText results are stored in temp variables.""" + """Long MessageText results round-trip and the temp key is removed after eval.""" state = DeclarativeWorkflowState(mock_state) state.initialize() @@ -2777,9 +2777,9 @@ class TestLongMessageTextHandling: result = state.eval("=Upper(MessageText(Local.Messages))") assert result == "A" * 600 # Upper on 'A' is still 'A' - # A temp variable should have been created - temp_var = state.get("Local._TempMessageText0") - assert temp_var == long_text + local = state.get_state_data().get("Local", {}) + remaining = sorted(k for k in local if k.startswith("_TempMessageText")) + assert not remaining, f"Temporary keys remain in Local: {remaining}" async def test_find_with_long_message_text(self, mock_state): """Test Find function works with long MessageText stored in temp variable.""" From 4c1b9efa8c500a0478b483adc0a41d6839aa87b3 Mon Sep 17 00:00:00 2001 From: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Date: Fri, 12 Jun 2026 06:35:57 +0800 Subject: [PATCH 4/4] .NET: fix: filter filesystem checkpoint index by session (#6132) * fix: filter filesystem checkpoint index by session * fix: filter checkpoint index by parent * .NET: preserve legacy checkpoint index discovery --- .../FileSystemJsonCheckpointStore.cs | 30 +++- .../FileSystemJsonCheckpointStoreTests.cs | 128 ++++++++++++++++++ 2 files changed, 155 insertions(+), 3 deletions(-) diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Checkpointing/FileSystemJsonCheckpointStore.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Checkpointing/FileSystemJsonCheckpointStore.cs index 9a2ecd8c23..0c3d57976d 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Checkpointing/FileSystemJsonCheckpointStore.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Checkpointing/FileSystemJsonCheckpointStore.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.IO; +using System.Linq; using System.Text; using System.Text.Json; using System.Text.Json.Serialization.Metadata; @@ -11,7 +12,11 @@ using System.Threading.Tasks; namespace Microsoft.Agents.AI.Workflows.Checkpointing; -internal record CheckpointFileIndexEntry(CheckpointInfo CheckpointInfo, string FileName); +internal record CheckpointFileIndexEntry( + CheckpointInfo CheckpointInfo, + string FileName, + string? ParentCheckpointId = null, + bool HasParentMetadata = false); /// /// Provides a file system-based implementation of a JSON checkpoint store that persists checkpoint data and index @@ -30,6 +35,8 @@ public sealed class FileSystemJsonCheckpointStore : JsonCheckpointStore, IDispos internal DirectoryInfo Directory { get; } internal HashSet CheckpointIndex { get; } + private Dictionary CheckpointParents { get; } = []; + private HashSet CheckpointsWithKnownParent { get; } = []; private static JsonTypeInfo EntryTypeInfo => WorkflowsJsonUtilities.JsonContext.Default.CheckpointFileIndexEntry; @@ -74,6 +81,11 @@ public sealed class FileSystemJsonCheckpointStore : JsonCheckpointStore, IDispos // We never actually use the file names from the index entries since they can be derived from the CheckpointInfo, but it is useful to // have the UrlEncoded file names in the index file for human readability this.CheckpointIndex.Add(entry.CheckpointInfo); + this.CheckpointParents[entry.CheckpointInfo] = entry.ParentCheckpointId; + if (entry.HasParentMetadata) + { + this.CheckpointsWithKnownParent.Add(entry.CheckpointInfo); + } } } } @@ -137,7 +149,11 @@ public sealed class FileSystemJsonCheckpointStore : JsonCheckpointStore, IDispos using Utf8JsonWriter jsonWriter = new(checkpointStream, new JsonWriterOptions() { Indented = false }); value.WriteTo(jsonWriter); - CheckpointFileIndexEntry entry = new(key, fileName); + string? parentCheckpointId = parent?.CheckpointId; + this.CheckpointParents[key] = parentCheckpointId; + this.CheckpointsWithKnownParent.Add(key); + + CheckpointFileIndexEntry entry = new(key, fileName, parentCheckpointId, HasParentMetadata: true); JsonSerializer.Serialize(this._indexFile!, entry, EntryTypeInfo); byte[] bytes = Encoding.UTF8.GetBytes(Environment.NewLine); await this._indexFile!.WriteAsync(bytes, 0, bytes.Length, CancellationToken.None).ConfigureAwait(false); @@ -148,6 +164,8 @@ public sealed class FileSystemJsonCheckpointStore : JsonCheckpointStore, IDispos catch (Exception ex) { this.CheckpointIndex.Remove(key); + this.CheckpointParents.Remove(key); + this.CheckpointsWithKnownParent.Remove(key); try { @@ -184,6 +202,12 @@ public sealed class FileSystemJsonCheckpointStore : JsonCheckpointStore, IDispos { this.CheckDisposed(); - return new(this.CheckpointIndex); + return new(this.CheckpointIndex + .Where(checkpoint => checkpoint.SessionId == sessionId && + (withParent is null || + !this.CheckpointsWithKnownParent.Contains(checkpoint) || + (this.CheckpointParents.TryGetValue(checkpoint, out string? parentCheckpointId) && + parentCheckpointId == withParent.CheckpointId))) + .ToArray()); } } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/FileSystemJsonCheckpointStoreTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/FileSystemJsonCheckpointStoreTests.cs index f0058e7390..90758be481 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/FileSystemJsonCheckpointStoreTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/FileSystemJsonCheckpointStoreTests.cs @@ -2,6 +2,7 @@ using System; using System.IO; +using System.Linq; using System.Text.Json; using System.Threading.Tasks; using FluentAssertions; @@ -197,4 +198,131 @@ public sealed class FileSystemJsonCheckpointStoreTests retrieved.GetProperty("name").GetString().Should().Be("test"); retrieved.GetProperty("value").GetInt32().Should().Be(42); } + + [Fact] + public async Task RetrieveIndexAsync_ShouldOnlyReturnCheckpointsForRequestedSessionAsync() + { + // Arrange + using TempDirectory tempDirectory = new(); + string firstSessionId = Guid.NewGuid().ToString("N"); + string secondSessionId = Guid.NewGuid().ToString("N"); + CheckpointInfo firstCheckpoint; + CheckpointInfo secondCheckpoint; + + using (FileSystemJsonCheckpointStore store = new(tempDirectory)) + { + firstCheckpoint = await store.CreateCheckpointAsync(firstSessionId, TestData); + secondCheckpoint = await store.CreateCheckpointAsync(secondSessionId, TestData); + + // Act + CheckpointInfo[] firstSessionIndex = (await store.RetrieveIndexAsync(firstSessionId)).ToArray(); + + // Assert + firstSessionIndex.Should().ContainSingle().Which.Should().Be(firstCheckpoint); + firstSessionIndex.Should().NotContain(secondCheckpoint); + } + + using (FileSystemJsonCheckpointStore reopenedStore = new(tempDirectory)) + { + CheckpointInfo[] secondSessionIndex = (await reopenedStore.RetrieveIndexAsync(secondSessionId)).ToArray(); + + secondSessionIndex.Should().ContainSingle().Which.Should().Be(secondCheckpoint); + secondSessionIndex.Should().NotContain(firstCheckpoint); + } + } + + [Fact] + public async Task RetrieveIndexAsync_ShouldFilterByParentCheckpointAsync() + { + // Arrange + using TempDirectory tempDirectory = new(); + string sessionId = Guid.NewGuid().ToString("N"); + CheckpointInfo parentCheckpoint; + CheckpointInfo childCheckpoint; + CheckpointInfo unrelatedCheckpoint; + + using (FileSystemJsonCheckpointStore store = new(tempDirectory)) + { + parentCheckpoint = await store.CreateCheckpointAsync(sessionId, TestData); + childCheckpoint = await store.CreateCheckpointAsync(sessionId, TestData, parentCheckpoint); + unrelatedCheckpoint = await store.CreateCheckpointAsync(sessionId, TestData); + + // Act + CheckpointInfo[] childIndex = (await store.RetrieveIndexAsync(sessionId, parentCheckpoint)).ToArray(); + + // Assert + childIndex.Should().ContainSingle().Which.Should().Be(childCheckpoint); + childIndex.Should().NotContain(parentCheckpoint); + childIndex.Should().NotContain(unrelatedCheckpoint); + } + + using (FileSystemJsonCheckpointStore reopenedStore = new(tempDirectory)) + { + CheckpointInfo[] childIndex = (await reopenedStore.RetrieveIndexAsync(sessionId, parentCheckpoint)).ToArray(); + + childIndex.Should().ContainSingle().Which.Should().Be(childCheckpoint); + childIndex.Should().NotContain(parentCheckpoint); + childIndex.Should().NotContain(unrelatedCheckpoint); + } + } + + [Fact] + public async Task RetrieveIndexAsync_ShouldKeepLegacyEntriesDiscoverableWithParentFilterAsync() + { + // Arrange + using TempDirectory tempDirectory = new(); + string sessionId = Guid.NewGuid().ToString("N"); + CheckpointInfo parentCheckpoint; + CheckpointInfo childCheckpoint; + string childFileName; + + using (FileSystemJsonCheckpointStore store = new(tempDirectory)) + { + parentCheckpoint = await store.CreateCheckpointAsync(sessionId, TestData); + childCheckpoint = await store.CreateCheckpointAsync(sessionId, TestData, parentCheckpoint); + childFileName = store.GetFileNameForCheckpoint(sessionId, childCheckpoint); + } + + string indexPath = Path.Combine(tempDirectory.FullName, "index.jsonl"); + string legacyEntry = JsonSerializer.Serialize(new CheckpointFileIndexEntry(childCheckpoint, childFileName)); + File.WriteAllText(indexPath, legacyEntry + Environment.NewLine); + + // Act + using FileSystemJsonCheckpointStore reopenedStore = new(tempDirectory); + CheckpointInfo[] childIndex = (await reopenedStore.RetrieveIndexAsync(sessionId, parentCheckpoint)).ToArray(); + + // Assert + childIndex.Should().ContainSingle().Which.Should().Be(childCheckpoint); + } + + [Fact] + public async Task RetrieveIndexAsync_ShouldKeepLegacyChildDiscoverableWithUnrelatedParentFilterAsync() + { + // Arrange + using TempDirectory tempDirectory = new(); + string sessionId = Guid.NewGuid().ToString("N"); + CheckpointInfo parentCheckpoint; + CheckpointInfo childCheckpoint; + CheckpointInfo unrelatedCheckpoint; + string childFileName; + + using (FileSystemJsonCheckpointStore store = new(tempDirectory)) + { + parentCheckpoint = await store.CreateCheckpointAsync(sessionId, TestData); + childCheckpoint = await store.CreateCheckpointAsync(sessionId, TestData, parentCheckpoint); + unrelatedCheckpoint = await store.CreateCheckpointAsync(sessionId, TestData); + childFileName = store.GetFileNameForCheckpoint(sessionId, childCheckpoint); + } + + string indexPath = Path.Combine(tempDirectory.FullName, "index.jsonl"); + string legacyEntry = JsonSerializer.Serialize(new CheckpointFileIndexEntry(childCheckpoint, childFileName)); + File.WriteAllText(indexPath, legacyEntry + Environment.NewLine); + + // Act + using FileSystemJsonCheckpointStore reopenedStore = new(tempDirectory); + CheckpointInfo[] childIndex = (await reopenedStore.RetrieveIndexAsync(sessionId, unrelatedCheckpoint)).ToArray(); + + // Assert + childIndex.Should().ContainSingle().Which.Should().Be(childCheckpoint); + } }