.NET: Improve resolving AITool from DI (#3175)

* remove localagenttoolregistry

* also give the factory method API
This commit is contained in:
Korolev Dmitry
2026-01-12 17:13:44 +01:00
committed by GitHub
Unverified
parent 3e13909e59
commit c7cb5be231
4 changed files with 170 additions and 67 deletions
@@ -1,8 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Collections.Generic;
using Microsoft.Agents.AI.Hosting.Local;
using System.Linq;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Shared.Diagnostics;
@@ -29,7 +28,7 @@ public static class AgentHostingServiceCollectionExtensions
return services.AddAIAgent(name, (sp, key) =>
{
var chatClient = sp.GetRequiredService<IChatClient>();
var tools = GetRegisteredToolsForAgent(sp, name);
var tools = sp.GetKeyedServices<AITool>(name).ToList();
return new ChatClientAgent(chatClient, instructions, key, tools: tools);
});
}
@@ -49,7 +48,7 @@ public static class AgentHostingServiceCollectionExtensions
Throw.IfNullOrEmpty(name);
return services.AddAIAgent(name, (sp, key) =>
{
var tools = GetRegisteredToolsForAgent(sp, name);
var tools = sp.GetKeyedServices<AITool>(name).ToList();
return new ChatClientAgent(chatClient, instructions, key, tools: tools);
});
}
@@ -70,7 +69,7 @@ public static class AgentHostingServiceCollectionExtensions
return services.AddAIAgent(name, (sp, key) =>
{
var chatClient = chatClientServiceKey is null ? sp.GetRequiredService<IChatClient>() : sp.GetRequiredKeyedService<IChatClient>(chatClientServiceKey);
var tools = GetRegisteredToolsForAgent(sp, name);
var tools = sp.GetKeyedServices<AITool>(name).ToList();
return new ChatClientAgent(chatClient, instructions, key, tools: tools);
});
}
@@ -92,7 +91,7 @@ public static class AgentHostingServiceCollectionExtensions
return services.AddAIAgent(name, (sp, key) =>
{
var chatClient = chatClientServiceKey is null ? sp.GetRequiredService<IChatClient>() : sp.GetRequiredKeyedService<IChatClient>(chatClientServiceKey);
var tools = GetRegisteredToolsForAgent(sp, name);
var tools = sp.GetKeyedServices<AITool>(name).ToList();
return new ChatClientAgent(chatClient, instructions: instructions, name: key, description: description, tools: tools);
});
}
@@ -127,10 +126,4 @@ public static class AgentHostingServiceCollectionExtensions
return new HostedAgentBuilder(name, services);
}
private static IList<AITool> GetRegisteredToolsForAgent(IServiceProvider serviceProvider, string agentName)
{
var registry = serviceProvider.GetService<LocalAgentToolRegistry>();
return registry?.GetTools(agentName) ?? [];
}
}
@@ -1,8 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.
using System;
using System.Linq;
using Microsoft.Agents.AI.Hosting.Local;
using Microsoft.Extensions.AI;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Shared.Diagnostics;
@@ -70,18 +68,7 @@ public static class HostedAgentBuilderExtensions
Throw.IfNull(builder);
Throw.IfNull(tool);
var agentName = builder.Name;
var services = builder.ServiceCollection;
// Get or create the agent tool registry
var descriptor = services.FirstOrDefault(sd => !sd.IsKeyedService && sd.ServiceType.Equals(typeof(LocalAgentToolRegistry)));
if (descriptor?.ImplementationInstance is not LocalAgentToolRegistry toolRegistry)
{
toolRegistry = new();
services.Add(ServiceDescriptor.Singleton(toolRegistry));
}
toolRegistry.AddTool(agentName, tool);
builder.ServiceCollection.AddKeyedSingleton(builder.Name, tool);
return builder;
}
@@ -105,4 +92,19 @@ public static class HostedAgentBuilderExtensions
return builder;
}
/// <summary>
/// Adds AI tool to an agent being configured with the service collection.
/// </summary>
/// <param name="builder">The hosted agent builder.</param>
/// <param name="factory">A factory function that creates a AI tool using the provided service provider.</param>
public static IHostedAgentBuilder WithAITool(this IHostedAgentBuilder builder, Func<IServiceProvider, AITool> factory)
{
Throw.IfNull(builder);
Throw.IfNull(factory);
builder.ServiceCollection.AddKeyedSingleton(builder.Name, (sp, name) => factory(sp));
return builder;
}
}
@@ -1,27 +0,0 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Collections.Generic;
using Microsoft.Extensions.AI;
namespace Microsoft.Agents.AI.Hosting.Local;
internal sealed class LocalAgentToolRegistry
{
private readonly Dictionary<string, List<AITool>> _toolsByAgentName = [];
public void AddTool(string agentName, AITool tool)
{
if (!this._toolsByAgentName.TryGetValue(agentName, out var tools))
{
tools = [];
this._toolsByAgentName[agentName] = tools;
}
tools.Add(tool);
}
public IList<AITool> GetTools(string agentName)
{
return this._toolsByAgentName.TryGetValue(agentName, out var tools) ? tools : [];
}
}
@@ -2,6 +2,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.AI;
@@ -17,49 +18,40 @@ public sealed class HostedAgentBuilderToolsExtensionsTests
[Fact]
public void WithAITool_ThrowsWhenBuilderIsNull()
{
// Arrange
var tool = new DummyAITool();
// Act & Assert
Assert.Throws<ArgumentNullException>(() => HostedAgentBuilderExtensions.WithAITool(null!, tool));
}
[Fact]
public void WithAITool_ThrowsWhenToolIsNull()
{
// Arrange
var services = new ServiceCollection();
var builder = services.AddAIAgent("test-agent", "Test instructions");
// Act & Assert
Assert.Throws<ArgumentNullException>(() => builder.WithAITool(null!));
Assert.Throws<ArgumentNullException>(() => builder.WithAITool(tool: null!));
}
[Fact]
public void WithAITools_ThrowsWhenBuilderIsNull()
{
// Arrange
var tools = new[] { new DummyAITool() };
// Act & Assert
Assert.Throws<ArgumentNullException>(() => HostedAgentBuilderExtensions.WithAITools(null!, tools));
}
[Fact]
public void WithAITools_ThrowsWhenToolsArrayIsNull()
{
// Arrange
var services = new ServiceCollection();
var builder = services.AddAIAgent("test-agent", "Test instructions");
// Act & Assert
Assert.Throws<ArgumentNullException>(() => builder.WithAITools(null!));
}
[Fact]
public void RegisteredTools_ResolvesAllToolsForAgent()
{
// Arrange
var services = new ServiceCollection();
services.AddSingleton<IChatClient>(new MockChatClient());
@@ -73,9 +65,13 @@ public sealed class HostedAgentBuilderToolsExtensionsTests
var serviceProvider = services.BuildServiceProvider();
var agent1Tools = ResolveAgentTools(serviceProvider, "test-agent");
var agent1Tools = ResolveToolsFromAgent(serviceProvider, "test-agent");
Assert.Contains(tool1, agent1Tools);
Assert.Contains(tool2, agent1Tools);
var agent1ToolsDI = ResolveToolsFromDI(serviceProvider, "test-agent");
Assert.Contains(tool1, agent1ToolsDI);
Assert.Contains(tool2, agent1ToolsDI);
}
[Fact]
@@ -100,21 +96,160 @@ public sealed class HostedAgentBuilderToolsExtensionsTests
var serviceProvider = services.BuildServiceProvider();
var agent1Tools = ResolveAgentTools(serviceProvider, "agent1");
var agent2Tools = ResolveAgentTools(serviceProvider, "agent2");
var agent1Tools = ResolveToolsFromAgent(serviceProvider, "agent1");
var agent2Tools = ResolveToolsFromAgent(serviceProvider, "agent2");
var agent1ToolsDI = ResolveToolsFromDI(serviceProvider, "agent1");
var agent2ToolsDI = ResolveToolsFromDI(serviceProvider, "agent2");
Assert.Contains(tool1, agent1Tools);
Assert.Contains(tool2, agent1Tools);
Assert.Contains(tool1, agent1ToolsDI);
Assert.Contains(tool2, agent1ToolsDI);
Assert.Contains(tool3, agent2Tools);
Assert.Contains(tool3, agent2ToolsDI);
}
private static IList<AITool> ResolveAgentTools(IServiceProvider serviceProvider, string name)
private static IList<AITool> ResolveToolsFromAgent(IServiceProvider serviceProvider, string name)
{
var agent = serviceProvider.GetRequiredKeyedService<AIAgent>(name) as ChatClientAgent;
Assert.NotNull(agent?.ChatOptions?.Tools);
return agent.ChatOptions.Tools;
}
private static List<AITool> ResolveToolsFromDI(IServiceProvider serviceProvider, string name)
{
var tools = serviceProvider.GetKeyedServices<AITool>(name);
Assert.NotNull(tools);
return tools.ToList();
}
[Fact]
public void WithAIToolFactory_ThrowsWhenBuilderIsNull()
{
Assert.Throws<ArgumentNullException>(() => HostedAgentBuilderExtensions.WithAITool(null!, CreateTool));
static AITool CreateTool(IServiceProvider _) => new DummyAITool();
}
[Fact]
public void WithAIToolFactory_ThrowsWhenFactoryIsNull()
{
var services = new ServiceCollection();
var builder = services.AddAIAgent("test-agent", "Test instructions");
Assert.Throws<ArgumentNullException>(() => builder.WithAITool(factory: null!));
}
[Fact]
public void WithAIToolFactory_RegistersToolFromFactory()
{
var services = new ServiceCollection();
services.AddSingleton<IChatClient>(new MockChatClient());
DummyAITool? createdTool = null;
var builder = services.AddAIAgent("test-agent", "Test instructions");
builder.WithAITool(sp =>
{
createdTool = new DummyAITool();
return createdTool;
});
var serviceProvider = services.BuildServiceProvider();
var tools = ResolveToolsFromDI(serviceProvider, "test-agent");
Assert.Single(tools);
Assert.Same(createdTool, tools[0]);
}
[Fact]
public void WithAIToolFactory_CanAccessServicesFromFactory()
{
var services = new ServiceCollection();
var mockChatClient = new MockChatClient();
services.AddSingleton<IChatClient>(mockChatClient);
IChatClient? resolvedChatClient = null;
var builder = services.AddAIAgent("test-agent", "Test instructions");
builder.WithAITool(sp =>
{
resolvedChatClient = sp.GetService<IChatClient>();
return new DummyAITool();
});
var serviceProvider = services.BuildServiceProvider();
_ = ResolveToolsFromDI(serviceProvider, "test-agent");
Assert.Same(mockChatClient, resolvedChatClient);
}
[Fact]
public void WithAIToolFactory_ToolsAreIsolatedPerAgent()
{
var services = new ServiceCollection();
services.AddSingleton<IChatClient>(new MockChatClient());
var tool1 = new DummyAITool();
var tool2 = new DummyAITool();
var builder1 = services.AddAIAgent("agent1", "Agent 1 instructions");
var builder2 = services.AddAIAgent("agent2", "Agent 2 instructions");
builder1.WithAITool(_ => tool1);
builder2.WithAITool(_ => tool2);
var serviceProvider = services.BuildServiceProvider();
var agent1Tools = ResolveToolsFromDI(serviceProvider, "agent1");
var agent2Tools = ResolveToolsFromDI(serviceProvider, "agent2");
Assert.Single(agent1Tools);
Assert.Contains(tool1, agent1Tools);
Assert.DoesNotContain(tool2, agent1Tools);
Assert.Single(agent2Tools);
Assert.Contains(tool2, agent2Tools);
Assert.DoesNotContain(tool1, agent2Tools);
}
[Fact]
public void WithAIToolFactory_CanCombineWithDirectToolRegistration()
{
var services = new ServiceCollection();
services.AddSingleton<IChatClient>(new MockChatClient());
var directTool = new DummyAITool();
var factoryTool = new DummyAITool();
var builder = services.AddAIAgent("test-agent", "Test instructions");
builder
.WithAITool(directTool)
.WithAITool(_ => factoryTool);
var serviceProvider = services.BuildServiceProvider();
var tools = ResolveToolsFromDI(serviceProvider, "test-agent");
Assert.Equal(2, tools.Count);
Assert.Contains(directTool, tools);
Assert.Contains(factoryTool, tools);
}
[Fact]
public void WithAIToolFactory_ToolsAvailableOnAgent()
{
var services = new ServiceCollection();
services.AddSingleton<IChatClient>(new MockChatClient());
var factoryTool = new DummyAITool();
var builder = services.AddAIAgent("test-agent", "Test instructions");
builder.WithAITool(_ => factoryTool);
var serviceProvider = services.BuildServiceProvider();
var agentTools = ResolveToolsFromAgent(serviceProvider, "test-agent");
Assert.Contains(factoryTool, agentTools);
}
/// <summary>
/// Dummy AITool implementation for testing.
/// </summary>