diff --git a/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotAgent.cs b/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotAgent.cs index 053a08f725..a85c126fd6 100644 --- a/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotAgent.cs +++ b/dotnet/src/Microsoft.Agents.AI.GitHub.Copilot/GitHubCopilotAgent.cs @@ -145,11 +145,12 @@ public sealed class GitHubCopilotAgent : AIAgent, IAsyncDisposable // Ensure the client is started await this.EnsureClientStartedAsync(cancellationToken).ConfigureAwait(false); - // Create or resume a session with streaming enabled + // Create or resume a session with streaming enabled by default SessionConfig sessionConfig = this._sessionConfig != null ? CopySessionConfig(this._sessionConfig) : new SessionConfig { Streaming = true }; + bool isStreaming = sessionConfig.Streaming ?? true; CopilotSession copilotSession; if (typedSession.SessionId is not null) { @@ -178,7 +179,7 @@ public sealed class GitHubCopilotAgent : AIAgent, IAsyncDisposable break; case AssistantMessageEvent assistantMessage: - channel.Writer.TryWrite(this.ConvertToAgentResponseUpdate(assistantMessage)); + channel.Writer.TryWrite(this.ConvertToAgentResponseUpdate(assistantMessage, isStreaming)); break; case AssistantUsageEvent usageEvent: @@ -271,19 +272,20 @@ public sealed class GitHubCopilotAgent : AIAgent, IAsyncDisposable } /// - /// Copies all supported properties from a source into a new instance - /// with set to true. + /// Copies all supported properties from a source into a new instance, + /// preserving from the source (defaulting to true if unset). /// internal static SessionConfig CopySessionConfig(SessionConfig source) { SessionConfig copy = source.Clone(); - copy.Streaming = true; + copy.Streaming = source.Streaming ?? true; return copy; } /// /// Copies all supported properties from a source into a new - /// with set to true. + /// , preserving + /// from the source (defaulting to true if unset). /// internal static ResumeSessionConfig CopyResumeSessionConfig(SessionConfig? source) { @@ -306,7 +308,7 @@ public sealed class GitHubCopilotAgent : AIAgent, IAsyncDisposable SkillDirectories = source?.SkillDirectories, DisabledSkills = source?.DisabledSkills, InfiniteSessions = source?.InfiniteSessions, - Streaming = true + Streaming = source?.Streaming ?? true }; } @@ -325,12 +327,18 @@ public sealed class GitHubCopilotAgent : AIAgent, IAsyncDisposable }; } - internal AgentResponseUpdate ConvertToAgentResponseUpdate(AssistantMessageEvent assistantMessage) + /// + /// Converts an to an . + /// When streaming is enabled, text was already delivered via delta events, so only raw metadata is emitted. + /// When streaming is disabled, the full message text is emitted as . + /// + internal AgentResponseUpdate ConvertToAgentResponseUpdate(AssistantMessageEvent assistantMessage, bool isStreaming) { - AIContent content = new() - { - RawRepresentation = assistantMessage - }; + // When streaming, text was already delivered via AssistantMessageDeltaEvent. + // When not streaming, this is the only opportunity to emit the response text. + AIContent content = isStreaming + ? new AIContent { RawRepresentation = assistantMessage } + : new TextContent(assistantMessage.Data?.Content ?? string.Empty) { RawRepresentation = assistantMessage }; return new AgentResponseUpdate(ChatRole.Assistant, [content]) { diff --git a/dotnet/src/Microsoft.Agents.AI.Workflows/Checkpointing/FileSystemJsonCheckpointStore.cs b/dotnet/src/Microsoft.Agents.AI.Workflows/Checkpointing/FileSystemJsonCheckpointStore.cs index 9a2ecd8c23..0c3d57976d 100644 --- a/dotnet/src/Microsoft.Agents.AI.Workflows/Checkpointing/FileSystemJsonCheckpointStore.cs +++ b/dotnet/src/Microsoft.Agents.AI.Workflows/Checkpointing/FileSystemJsonCheckpointStore.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.IO; +using System.Linq; using System.Text; using System.Text.Json; using System.Text.Json.Serialization.Metadata; @@ -11,7 +12,11 @@ using System.Threading.Tasks; namespace Microsoft.Agents.AI.Workflows.Checkpointing; -internal record CheckpointFileIndexEntry(CheckpointInfo CheckpointInfo, string FileName); +internal record CheckpointFileIndexEntry( + CheckpointInfo CheckpointInfo, + string FileName, + string? ParentCheckpointId = null, + bool HasParentMetadata = false); /// /// Provides a file system-based implementation of a JSON checkpoint store that persists checkpoint data and index @@ -30,6 +35,8 @@ public sealed class FileSystemJsonCheckpointStore : JsonCheckpointStore, IDispos internal DirectoryInfo Directory { get; } internal HashSet CheckpointIndex { get; } + private Dictionary CheckpointParents { get; } = []; + private HashSet CheckpointsWithKnownParent { get; } = []; private static JsonTypeInfo EntryTypeInfo => WorkflowsJsonUtilities.JsonContext.Default.CheckpointFileIndexEntry; @@ -74,6 +81,11 @@ public sealed class FileSystemJsonCheckpointStore : JsonCheckpointStore, IDispos // We never actually use the file names from the index entries since they can be derived from the CheckpointInfo, but it is useful to // have the UrlEncoded file names in the index file for human readability this.CheckpointIndex.Add(entry.CheckpointInfo); + this.CheckpointParents[entry.CheckpointInfo] = entry.ParentCheckpointId; + if (entry.HasParentMetadata) + { + this.CheckpointsWithKnownParent.Add(entry.CheckpointInfo); + } } } } @@ -137,7 +149,11 @@ public sealed class FileSystemJsonCheckpointStore : JsonCheckpointStore, IDispos using Utf8JsonWriter jsonWriter = new(checkpointStream, new JsonWriterOptions() { Indented = false }); value.WriteTo(jsonWriter); - CheckpointFileIndexEntry entry = new(key, fileName); + string? parentCheckpointId = parent?.CheckpointId; + this.CheckpointParents[key] = parentCheckpointId; + this.CheckpointsWithKnownParent.Add(key); + + CheckpointFileIndexEntry entry = new(key, fileName, parentCheckpointId, HasParentMetadata: true); JsonSerializer.Serialize(this._indexFile!, entry, EntryTypeInfo); byte[] bytes = Encoding.UTF8.GetBytes(Environment.NewLine); await this._indexFile!.WriteAsync(bytes, 0, bytes.Length, CancellationToken.None).ConfigureAwait(false); @@ -148,6 +164,8 @@ public sealed class FileSystemJsonCheckpointStore : JsonCheckpointStore, IDispos catch (Exception ex) { this.CheckpointIndex.Remove(key); + this.CheckpointParents.Remove(key); + this.CheckpointsWithKnownParent.Remove(key); try { @@ -184,6 +202,12 @@ public sealed class FileSystemJsonCheckpointStore : JsonCheckpointStore, IDispos { this.CheckDisposed(); - return new(this.CheckpointIndex); + return new(this.CheckpointIndex + .Where(checkpoint => checkpoint.SessionId == sessionId && + (withParent is null || + !this.CheckpointsWithKnownParent.Contains(checkpoint) || + (this.CheckpointParents.TryGetValue(checkpoint, out string? parentCheckpointId) && + parentCheckpointId == withParent.CheckpointId))) + .ToArray()); } } diff --git a/dotnet/tests/Microsoft.Agents.AI.GitHub.Copilot.UnitTests/GitHubCopilotAgentTests.cs b/dotnet/tests/Microsoft.Agents.AI.GitHub.Copilot.UnitTests/GitHubCopilotAgentTests.cs index 944f2f30ab..8faa842eb0 100644 --- a/dotnet/tests/Microsoft.Agents.AI.GitHub.Copilot.UnitTests/GitHubCopilotAgentTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.GitHub.Copilot.UnitTests/GitHubCopilotAgentTests.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Threading.Tasks; using GitHub.Copilot; using GitHub.Copilot.Rpc; @@ -222,7 +223,73 @@ public sealed class GitHubCopilotAgentTests } [Fact] - public void ConvertToAgentResponseUpdate_AssistantMessageEvent_DoesNotEmitTextContent() + public void CopySessionConfig_WithStreamingDisabled_PreservesStreamingValue() + { + // Arrange + var source = new SessionConfig + { + Streaming = false, + Model = "gpt-4o", + }; + + // Act + SessionConfig result = GitHubCopilotAgent.CopySessionConfig(source); + + // Assert + Assert.False(result.Streaming); + } + + [Fact] + public void CopySessionConfig_WithStreamingNull_DefaultsToTrue() + { + // Arrange + var source = new SessionConfig + { + Model = "gpt-4o", + }; + + // Act + SessionConfig result = GitHubCopilotAgent.CopySessionConfig(source); + + // Assert + Assert.True(result.Streaming); + } + + [Fact] + public void CopyResumeSessionConfig_WithStreamingDisabled_PreservesStreamingValue() + { + // Arrange + var source = new SessionConfig + { + Streaming = false, + Model = "gpt-4o", + }; + + // Act + ResumeSessionConfig result = GitHubCopilotAgent.CopyResumeSessionConfig(source); + + // Assert + Assert.False(result.Streaming); + } + + [Fact] + public void CopyResumeSessionConfig_WithStreamingNull_DefaultsToTrue() + { + // Arrange + var source = new SessionConfig + { + Model = "gpt-4o", + }; + + // Act + ResumeSessionConfig result = GitHubCopilotAgent.CopyResumeSessionConfig(source); + + // Assert + Assert.True(result.Streaming); + } + + [Fact] + public void ConvertToAgentResponseUpdate_AssistantMessageEventWhenStreaming_DoesNotEmitTextContent() { var assistantMessage = new AssistantMessageEvent { @@ -235,11 +302,84 @@ public sealed class GitHubCopilotAgentTests CopilotClient copilotClient = new(new CopilotClientOptions()); const string TestId = "agent-id"; var agent = new GitHubCopilotAgent(copilotClient, ownsClient: false, id: TestId, tools: null); - AgentResponseUpdate result = agent.ConvertToAgentResponseUpdate(assistantMessage); + AgentResponseUpdate result = agent.ConvertToAgentResponseUpdate(assistantMessage, isStreaming: true); - // result.Text need to be empty because the content was already delivered via delta events, and we want to avoid emitting duplicate content in the response update. - // The content should be delivered through TextContent in the Contents collection instead. + // result.Text should be empty because content was already delivered via delta events. Assert.Empty(result.Text); Assert.DoesNotContain(result.Contents, c => c is TextContent); } + + [Fact] + public void ConvertToAgentResponseUpdate_AssistantMessageEventWhenNotStreaming_EmitsTextContent() + { + // Arrange + const string ExpectedContent = "Full response text from non-streaming session"; + var assistantMessage = new AssistantMessageEvent + { + Data = new AssistantMessageData + { + MessageId = "msg-789", + Content = ExpectedContent + } + }; + CopilotClient copilotClient = new(new CopilotClientOptions()); + const string TestId = "agent-id"; + var agent = new GitHubCopilotAgent(copilotClient, ownsClient: false, id: TestId, tools: null); + + // Act + AgentResponseUpdate result = agent.ConvertToAgentResponseUpdate(assistantMessage, isStreaming: false); + + // Assert - text must be emitted since no delta events precede it in non-streaming mode. + Assert.Equal(ExpectedContent, result.Text); + Assert.Contains(result.Contents, c => c is TextContent); + TextContent textContent = (TextContent)result.Contents.Single(c => c is TextContent); + Assert.Equal(ExpectedContent, textContent.Text); + Assert.Same(assistantMessage, textContent.RawRepresentation); + } + + [Fact] + public void ConvertToAgentResponseUpdate_AssistantMessageEventWhenNotStreaming_HandlesEmptyContent() + { + // Arrange + var assistantMessage = new AssistantMessageEvent + { + Data = new AssistantMessageData + { + MessageId = "msg-000", + Content = string.Empty + } + }; + CopilotClient copilotClient = new(new CopilotClientOptions()); + const string TestId = "agent-id"; + var agent = new GitHubCopilotAgent(copilotClient, ownsClient: false, id: TestId, tools: null); + + // Act + AgentResponseUpdate result = agent.ConvertToAgentResponseUpdate(assistantMessage, isStreaming: false); + + // Assert - should emit empty TextContent rather than throwing. + Assert.Empty(result.Text); + Assert.Contains(result.Contents, c => c is TextContent); + } + + [Fact] + public void ConvertToAgentResponseUpdate_AssistantMessageEventWhenNotStreaming_HandlesNullData() + { + // Arrange + var assistantMessage = new AssistantMessageEvent + { + Data = null! + }; + CopilotClient copilotClient = new(new CopilotClientOptions()); + const string TestId = "agent-id"; + var agent = new GitHubCopilotAgent(copilotClient, ownsClient: false, id: TestId, tools: null); + + // Act + AgentResponseUpdate result = agent.ConvertToAgentResponseUpdate(assistantMessage, isStreaming: false); + + // Assert - null Data should produce empty TextContent via null-propagation fallback. + Assert.Empty(result.Text); + Assert.Contains(result.Contents, c => c is TextContent); + Assert.Null(result.MessageId); + Assert.Null(result.ResponseId); + } } diff --git a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/FileSystemJsonCheckpointStoreTests.cs b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/FileSystemJsonCheckpointStoreTests.cs index f0058e7390..90758be481 100644 --- a/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/FileSystemJsonCheckpointStoreTests.cs +++ b/dotnet/tests/Microsoft.Agents.AI.Workflows.UnitTests/FileSystemJsonCheckpointStoreTests.cs @@ -2,6 +2,7 @@ using System; using System.IO; +using System.Linq; using System.Text.Json; using System.Threading.Tasks; using FluentAssertions; @@ -197,4 +198,131 @@ public sealed class FileSystemJsonCheckpointStoreTests retrieved.GetProperty("name").GetString().Should().Be("test"); retrieved.GetProperty("value").GetInt32().Should().Be(42); } + + [Fact] + public async Task RetrieveIndexAsync_ShouldOnlyReturnCheckpointsForRequestedSessionAsync() + { + // Arrange + using TempDirectory tempDirectory = new(); + string firstSessionId = Guid.NewGuid().ToString("N"); + string secondSessionId = Guid.NewGuid().ToString("N"); + CheckpointInfo firstCheckpoint; + CheckpointInfo secondCheckpoint; + + using (FileSystemJsonCheckpointStore store = new(tempDirectory)) + { + firstCheckpoint = await store.CreateCheckpointAsync(firstSessionId, TestData); + secondCheckpoint = await store.CreateCheckpointAsync(secondSessionId, TestData); + + // Act + CheckpointInfo[] firstSessionIndex = (await store.RetrieveIndexAsync(firstSessionId)).ToArray(); + + // Assert + firstSessionIndex.Should().ContainSingle().Which.Should().Be(firstCheckpoint); + firstSessionIndex.Should().NotContain(secondCheckpoint); + } + + using (FileSystemJsonCheckpointStore reopenedStore = new(tempDirectory)) + { + CheckpointInfo[] secondSessionIndex = (await reopenedStore.RetrieveIndexAsync(secondSessionId)).ToArray(); + + secondSessionIndex.Should().ContainSingle().Which.Should().Be(secondCheckpoint); + secondSessionIndex.Should().NotContain(firstCheckpoint); + } + } + + [Fact] + public async Task RetrieveIndexAsync_ShouldFilterByParentCheckpointAsync() + { + // Arrange + using TempDirectory tempDirectory = new(); + string sessionId = Guid.NewGuid().ToString("N"); + CheckpointInfo parentCheckpoint; + CheckpointInfo childCheckpoint; + CheckpointInfo unrelatedCheckpoint; + + using (FileSystemJsonCheckpointStore store = new(tempDirectory)) + { + parentCheckpoint = await store.CreateCheckpointAsync(sessionId, TestData); + childCheckpoint = await store.CreateCheckpointAsync(sessionId, TestData, parentCheckpoint); + unrelatedCheckpoint = await store.CreateCheckpointAsync(sessionId, TestData); + + // Act + CheckpointInfo[] childIndex = (await store.RetrieveIndexAsync(sessionId, parentCheckpoint)).ToArray(); + + // Assert + childIndex.Should().ContainSingle().Which.Should().Be(childCheckpoint); + childIndex.Should().NotContain(parentCheckpoint); + childIndex.Should().NotContain(unrelatedCheckpoint); + } + + using (FileSystemJsonCheckpointStore reopenedStore = new(tempDirectory)) + { + CheckpointInfo[] childIndex = (await reopenedStore.RetrieveIndexAsync(sessionId, parentCheckpoint)).ToArray(); + + childIndex.Should().ContainSingle().Which.Should().Be(childCheckpoint); + childIndex.Should().NotContain(parentCheckpoint); + childIndex.Should().NotContain(unrelatedCheckpoint); + } + } + + [Fact] + public async Task RetrieveIndexAsync_ShouldKeepLegacyEntriesDiscoverableWithParentFilterAsync() + { + // Arrange + using TempDirectory tempDirectory = new(); + string sessionId = Guid.NewGuid().ToString("N"); + CheckpointInfo parentCheckpoint; + CheckpointInfo childCheckpoint; + string childFileName; + + using (FileSystemJsonCheckpointStore store = new(tempDirectory)) + { + parentCheckpoint = await store.CreateCheckpointAsync(sessionId, TestData); + childCheckpoint = await store.CreateCheckpointAsync(sessionId, TestData, parentCheckpoint); + childFileName = store.GetFileNameForCheckpoint(sessionId, childCheckpoint); + } + + string indexPath = Path.Combine(tempDirectory.FullName, "index.jsonl"); + string legacyEntry = JsonSerializer.Serialize(new CheckpointFileIndexEntry(childCheckpoint, childFileName)); + File.WriteAllText(indexPath, legacyEntry + Environment.NewLine); + + // Act + using FileSystemJsonCheckpointStore reopenedStore = new(tempDirectory); + CheckpointInfo[] childIndex = (await reopenedStore.RetrieveIndexAsync(sessionId, parentCheckpoint)).ToArray(); + + // Assert + childIndex.Should().ContainSingle().Which.Should().Be(childCheckpoint); + } + + [Fact] + public async Task RetrieveIndexAsync_ShouldKeepLegacyChildDiscoverableWithUnrelatedParentFilterAsync() + { + // Arrange + using TempDirectory tempDirectory = new(); + string sessionId = Guid.NewGuid().ToString("N"); + CheckpointInfo parentCheckpoint; + CheckpointInfo childCheckpoint; + CheckpointInfo unrelatedCheckpoint; + string childFileName; + + using (FileSystemJsonCheckpointStore store = new(tempDirectory)) + { + parentCheckpoint = await store.CreateCheckpointAsync(sessionId, TestData); + childCheckpoint = await store.CreateCheckpointAsync(sessionId, TestData, parentCheckpoint); + unrelatedCheckpoint = await store.CreateCheckpointAsync(sessionId, TestData); + childFileName = store.GetFileNameForCheckpoint(sessionId, childCheckpoint); + } + + string indexPath = Path.Combine(tempDirectory.FullName, "index.jsonl"); + string legacyEntry = JsonSerializer.Serialize(new CheckpointFileIndexEntry(childCheckpoint, childFileName)); + File.WriteAllText(indexPath, legacyEntry + Environment.NewLine); + + // Act + using FileSystemJsonCheckpointStore reopenedStore = new(tempDirectory); + CheckpointInfo[] childIndex = (await reopenedStore.RetrieveIndexAsync(sessionId, unrelatedCheckpoint)).ToArray(); + + // Assert + childIndex.Should().ContainSingle().Which.Should().Be(childCheckpoint); + } } diff --git a/python/packages/core/agent_framework/__init__.py b/python/packages/core/agent_framework/__init__.py index 478adb2dc7..03a32f1a9c 100644 --- a/python/packages/core/agent_framework/__init__.py +++ b/python/packages/core/agent_framework/__init__.py @@ -27,6 +27,7 @@ from ._clients import ( SupportsGetEmbeddings, SupportsImageGenerationTool, SupportsMCPTool, + SupportsShellTool, SupportsWebSearchTool, ) from ._compaction import ( @@ -506,6 +507,7 @@ __all__ = [ "SupportsGetEmbeddings", "SupportsImageGenerationTool", "SupportsMCPTool", + "SupportsShellTool", "SupportsWebSearchTool", "SwitchCaseEdgeGroup", "SwitchCaseEdgeGroupCase", diff --git a/python/packages/core/agent_framework/_clients.py b/python/packages/core/agent_framework/_clients.py index 746427bffd..f0bd051980 100644 --- a/python/packages/core/agent_framework/_clients.py +++ b/python/packages/core/agent_framework/_clients.py @@ -819,6 +819,36 @@ class SupportsFileSearchTool(Protocol): ... +@runtime_checkable +class SupportsShellTool(Protocol): + """Protocol for clients that support shell tools. + + This protocol enables runtime checking to determine if a client + supports executing shell commands. + + Examples: + .. code-block:: python + + from agent_framework import SupportsShellTool + + if isinstance(client, SupportsShellTool): + tool = client.get_shell_tool(func=shell.as_function()) + agent = ChatAgent(client, tools=[tool]) + """ + + @staticmethod + def get_shell_tool(**kwargs: Any) -> Any: + """Create a shell tool configuration. + + Keyword Args: + **kwargs: Provider-specific configuration options. + + Returns: + A tool configuration ready to pass to ChatAgent. + """ + ... + + # endregion diff --git a/python/packages/core/agent_framework/_harness/_agent.py b/python/packages/core/agent_framework/_harness/_agent.py index 0ae0c73032..1a6178b54d 100644 --- a/python/packages/core/agent_framework/_harness/_agent.py +++ b/python/packages/core/agent_framework/_harness/_agent.py @@ -15,7 +15,7 @@ from collections.abc import Callable, Sequence from typing import TYPE_CHECKING, Any from .._agents import Agent, SupportsAgentRun -from .._clients import SupportsWebSearchTool +from .._clients import SupportsShellTool, SupportsWebSearchTool from .._compaction import CompactionProvider, ContextWindowCompactionStrategy, ToolResultCompactionStrategy from .._feature_stage import ExperimentalFeature, experimental from .._sessions import ContextProvider, HistoryProvider, InMemoryHistoryProvider @@ -28,6 +28,8 @@ from ._todo import TodoProvider if TYPE_CHECKING: from collections.abc import Mapping + from agent_framework_tools.shell import ShellEnvironmentProviderOptions, ShellExecutor + from .._clients import SupportsChatGetResponse from .._compaction import CompactionStrategy, TokenizerProtocol from .._middleware import MiddlewareTypes @@ -128,6 +130,7 @@ def _assemble_context_providers( skills_paths: Sequence[str] | None, background_agents: Sequence[SupportsAgentRun] | None, background_agents_instructions: str | None, + shell_context_provider: ContextProvider | None, extra_context_providers: Sequence[ContextProvider] | None, ) -> list[ContextProvider]: """Assemble the ordered list of context providers.""" @@ -159,6 +162,10 @@ def _assemble_context_providers( if background_agents: providers.append(BackgroundAgentsProvider(background_agents, instructions=background_agents_instructions)) + # Shell environment provider is opt-in: only added when a shell tool was wired. + if shell_context_provider is not None: + providers.append(shell_context_provider) + # Append any user-supplied additional providers. if extra_context_providers: providers.extend(extra_context_providers) @@ -166,6 +173,50 @@ def _assemble_context_providers( return providers +def _assemble_shell( + client: SupportsChatGetResponse[Any], + shell_executor: ShellExecutor | None, + shell_environment_provider_options: ShellEnvironmentProviderOptions | None, +) -> tuple[ToolTypes | None, ContextProvider | None]: + """Build the shell tool and environment provider when a shell executor is supplied. + + Returns a ``(tool, provider)`` tuple. Both are ``None`` when no shell executor is + provided, or when the client does not support shell tools (a warning is logged in the + latter case, since the environment provider is not useful without an execution path). + + Raises: + TypeError: If ``shell_executor`` does not expose a callable ``as_function()`` method. + """ + if shell_executor is None: + return None, None + + # ShellExecutor is a protocol without ``as_function()``, so the + # contract is validated at runtime: a shell tool such as LocalShellTool/DockerShellTool exposes it. + as_function = getattr(shell_executor, "as_function", None) + if not callable(as_function): + raise TypeError( + f"shell_executor must expose a callable 'as_function()' method " + f"(e.g. a LocalShellTool or DockerShellTool from agent-framework-tools), " + f"but got {type(shell_executor).__name__}." + ) + + if not isinstance(client, SupportsShellTool): + logger.warning( + "Shell tool not available: client %r does not implement SupportsShellTool. " + "Skipping the shell tool and environment provider.", + type(client).__name__, + ) + return None, None + + # Imported lazily: the shell types live in the separate agent-framework-tools package, + # which depends on core, so core cannot import them at module load time. + from agent_framework_tools.shell import ShellEnvironmentProvider + + shell_tool = client.get_shell_tool(func=as_function()) + shell_provider = ShellEnvironmentProvider(shell_executor, shell_environment_provider_options) + return shell_tool, shell_provider + + HARNESS_AGENT_PROVIDER_NAME = "microsoft.agent_framework.harness" @@ -196,6 +247,8 @@ def create_harness_agent( skills_paths: Sequence[str] | None = None, background_agents: Sequence[SupportsAgentRun] | None = None, background_agents_instructions: str | None = None, + shell_executor: ShellExecutor | None = None, + shell_environment_provider_options: ShellEnvironmentProviderOptions | None = None, disable_web_search: bool = False, otel_provider_name: str | None = None, context_providers: Sequence[ContextProvider] | None = None, @@ -298,6 +351,15 @@ def create_harness_agent( background_agents_instructions: Optional instruction override for the ``BackgroundAgentsProvider``. May include ``{background_agents}`` placeholder which will be replaced with the agent listing. + shell_executor: Optional shell tool that enables shell command execution. When + provided, the shell tool and a ``ShellEnvironmentProvider`` are automatically + added (provided the client supports shell tools; otherwise a warning is logged + and both are skipped). The object must expose ``as_function()`` and satisfy the + ``ShellExecutor`` protocol -- e.g. a ``LocalShellTool`` or ``DockerShellTool`` from + the ``agent-framework-tools`` package. The caller owns the executor's lifecycle. + shell_environment_provider_options: Optional ``ShellEnvironmentProviderOptions`` + (from ``agent-framework-tools``) used to customize the ``ShellEnvironmentProvider`` + environment probing and instructions. Only used when ``shell_executor`` is provided. disable_web_search: When True, skip automatic web search tool inclusion. When False (default), the web search tool is automatically added if the client implements SupportsWebSearchTool. A warning is logged if the client @@ -340,6 +402,13 @@ def create_harness_agent( tokenizer=tokenizer, ) + # Build the shell tool and environment provider (opt-in via shell_executor). + shell_tool, shell_provider = _assemble_shell( + client, + shell_executor, + shell_environment_provider_options, + ) + # Build context providers. assembled_providers = _assemble_context_providers( history_provider=resolved_history, @@ -354,6 +423,7 @@ def create_harness_agent( skills_paths=skills_paths, background_agents=background_agents, background_agents_instructions=background_agents_instructions, + shell_context_provider=shell_provider, extra_context_providers=context_providers, ) @@ -371,6 +441,8 @@ def create_harness_agent( "Set disable_web_search=True to suppress this warning.", type(client).__name__, ) + if shell_tool is not None: + assembled_tools.append(shell_tool) if tools is not None: if isinstance(tools, Sequence): assembled_tools.extend(tools) # pyright: ignore[reportUnknownArgumentType] diff --git a/python/packages/core/tests/core/test_harness_agent.py b/python/packages/core/tests/core/test_harness_agent.py index 7da1bdbf36..0a280e87ed 100644 --- a/python/packages/core/tests/core/test_harness_agent.py +++ b/python/packages/core/tests/core/test_harness_agent.py @@ -543,3 +543,127 @@ def test_create_harness_agent_empty_background_agents_list() -> None: ) providers = agent.context_providers or [] assert not any(isinstance(p, BackgroundAgentsProvider) for p in providers) + + +# --- Shell Tool Tests --- + + +class _FakeShellTool: + """Fake shell executor/tool exposing as_function().""" + + def as_function(self) -> str: + return "shell_fn" + + +class _FakeShellClient(_FakeChatClient): + """Fake client that supports the shell tool.""" + + def __init__(self) -> None: + self.shell_func: Any = None + + def get_shell_tool(self, *, func: Any = None, **kwargs: Any) -> str: + self.shell_func = func + return "shell_tool_instance" + + +def test_create_harness_agent_adds_shell_tool_and_provider() -> None: + """Shell tool and ShellEnvironmentProvider should be added when a shell executor is supplied.""" + from agent_framework_tools.shell import ShellEnvironmentProvider + + client = _FakeShellClient() + agent = create_harness_agent( + client=client, # type: ignore[arg-type] + max_context_window_tokens=128_000, + max_output_tokens=16_384, + disable_web_search=True, + shell_executor=_FakeShellTool(), + ) + tools = agent.default_options.get("tools", []) + assert "shell_tool_instance" in tools + assert client.shell_func == "shell_fn" + providers = agent.context_providers or [] + assert any(isinstance(p, ShellEnvironmentProvider) for p in providers) + + +def test_create_harness_agent_shell_passes_custom_options() -> None: + """Custom ShellEnvironmentProviderOptions should be forwarded to the provider.""" + from agent_framework_tools.shell import ShellEnvironmentProvider, ShellEnvironmentProviderOptions + + options = ShellEnvironmentProviderOptions(probe_tools=("git",)) + agent = create_harness_agent( + client=_FakeShellClient(), # type: ignore[arg-type] + max_context_window_tokens=128_000, + max_output_tokens=16_384, + disable_web_search=True, + shell_executor=_FakeShellTool(), + shell_environment_provider_options=options, + ) + providers = agent.context_providers or [] + provider = next(p for p in providers if isinstance(p, ShellEnvironmentProvider)) + assert provider._options is options + + +def test_create_harness_agent_shell_skipped_when_unsupported(caplog: pytest.LogCaptureFixture) -> None: + """When the client lacks get_shell_tool, both the tool and provider are skipped with a warning.""" + import logging + + from agent_framework_tools.shell import ShellEnvironmentProvider + + with caplog.at_level(logging.WARNING, logger="agent_framework._harness._agent"): + agent = create_harness_agent( + client=_FakeChatClient(), # type: ignore[arg-type] + max_context_window_tokens=128_000, + max_output_tokens=16_384, + disable_web_search=True, + shell_executor=_FakeShellTool(), + ) + assert any("SupportsShellTool" in msg for msg in caplog.messages) + providers = agent.context_providers or [] + assert not any(isinstance(p, ShellEnvironmentProvider) for p in providers) + assert "tools" not in agent.default_options or not agent.default_options.get("tools") + + +def test_create_harness_agent_no_shell_by_default() -> None: + """No shell tool or provider should be added when shell_executor is not provided.""" + from agent_framework_tools.shell import ShellEnvironmentProvider + + agent = create_harness_agent( + client=_FakeShellClient(), # type: ignore[arg-type] + max_context_window_tokens=128_000, + max_output_tokens=16_384, + disable_web_search=True, + ) + providers = agent.context_providers or [] + assert not any(isinstance(p, ShellEnvironmentProvider) for p in providers) + + +def test_create_harness_agent_shell_executor_without_as_function_raises() -> None: + """A shell_executor lacking a callable as_function() should raise a clear TypeError.""" + + class _BadExecutor: + pass + + with pytest.raises(TypeError, match="as_function"): + create_harness_agent( + client=_FakeShellClient(), # type: ignore[arg-type] + max_context_window_tokens=128_000, + max_output_tokens=16_384, + disable_web_search=True, + shell_executor=_BadExecutor(), + ) + + +def test_create_harness_agent_shell_executor_validated_before_client_check() -> None: + """The as_function() contract is validated upfront, even when the client lacks shell support.""" + + class _BadExecutor: + pass + + with pytest.raises(TypeError, match="as_function"): + create_harness_agent( + client=_FakeChatClient(), # type: ignore[arg-type] + max_context_window_tokens=128_000, + max_output_tokens=16_384, + disable_web_search=True, + shell_executor=_BadExecutor(), + ) diff --git a/python/samples/02-agents/harness/README.md b/python/samples/02-agents/harness/README.md index 15424e1422..e162c1ff45 100644 --- a/python/samples/02-agents/harness/README.md +++ b/python/samples/02-agents/harness/README.md @@ -17,6 +17,7 @@ from a chat client. | AgentModeProvider | Plan/execute mode tracking | | MemoryContextProvider | File-based durable memory (when `memory_store` provided) | | SkillsProvider | File-based skill discovery and progressive loading | +| Shell tool | Shell command execution + environment probing (when `shell_executor` provided) | | OpenTelemetry | Built-in observability | Each feature can be disabled or customized via keyword arguments. @@ -91,3 +92,25 @@ agent = create_harness_agent( The `AgentModeProvider` enables a two-phase workflow: 1. **Plan mode** — Interactive: the agent asks questions, creates todos, gets approval 2. **Execute mode** — Autonomous: the agent works through todos independently + +### Shell Tool + +Pass a shell executor (e.g. `LocalShellTool` from `agent-framework-tools`) to enable shell +command execution plus automatic environment probing via a `ShellEnvironmentProvider`. The +tool is only wired when the chat client supports shell tools; otherwise a warning is logged +and the shell tool/provider are skipped. The caller owns the executor's lifecycle. + +```python +from agent_framework_tools.shell import LocalShellTool, ShellEnvironmentProviderOptions + +async with LocalShellTool(acknowledge_unsafe=True) as shell: + agent = create_harness_agent( + client=client, + max_context_window_tokens=128_000, + max_output_tokens=16_384, + shell_executor=shell, + # Optional: customize environment probing. + shell_environment_provider_options=ShellEnvironmentProviderOptions(probe_tools=("git", "python")), + ) +``` +