mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
.NET: Improve resolving AITool from DI (#3175)
* remove localagenttoolregistry * also give the factory method API
This commit is contained in:
committed by
GitHub
Unverified
parent
3e13909e59
commit
c7cb5be231
@@ -1,8 +1,7 @@
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using Microsoft.Agents.AI.Hosting.Local;
|
||||
using System.Linq;
|
||||
using Microsoft.Extensions.AI;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Shared.Diagnostics;
|
||||
@@ -29,7 +28,7 @@ public static class AgentHostingServiceCollectionExtensions
|
||||
return services.AddAIAgent(name, (sp, key) =>
|
||||
{
|
||||
var chatClient = sp.GetRequiredService<IChatClient>();
|
||||
var tools = GetRegisteredToolsForAgent(sp, name);
|
||||
var tools = sp.GetKeyedServices<AITool>(name).ToList();
|
||||
return new ChatClientAgent(chatClient, instructions, key, tools: tools);
|
||||
});
|
||||
}
|
||||
@@ -49,7 +48,7 @@ public static class AgentHostingServiceCollectionExtensions
|
||||
Throw.IfNullOrEmpty(name);
|
||||
return services.AddAIAgent(name, (sp, key) =>
|
||||
{
|
||||
var tools = GetRegisteredToolsForAgent(sp, name);
|
||||
var tools = sp.GetKeyedServices<AITool>(name).ToList();
|
||||
return new ChatClientAgent(chatClient, instructions, key, tools: tools);
|
||||
});
|
||||
}
|
||||
@@ -70,7 +69,7 @@ public static class AgentHostingServiceCollectionExtensions
|
||||
return services.AddAIAgent(name, (sp, key) =>
|
||||
{
|
||||
var chatClient = chatClientServiceKey is null ? sp.GetRequiredService<IChatClient>() : sp.GetRequiredKeyedService<IChatClient>(chatClientServiceKey);
|
||||
var tools = GetRegisteredToolsForAgent(sp, name);
|
||||
var tools = sp.GetKeyedServices<AITool>(name).ToList();
|
||||
return new ChatClientAgent(chatClient, instructions, key, tools: tools);
|
||||
});
|
||||
}
|
||||
@@ -92,7 +91,7 @@ public static class AgentHostingServiceCollectionExtensions
|
||||
return services.AddAIAgent(name, (sp, key) =>
|
||||
{
|
||||
var chatClient = chatClientServiceKey is null ? sp.GetRequiredService<IChatClient>() : sp.GetRequiredKeyedService<IChatClient>(chatClientServiceKey);
|
||||
var tools = GetRegisteredToolsForAgent(sp, name);
|
||||
var tools = sp.GetKeyedServices<AITool>(name).ToList();
|
||||
return new ChatClientAgent(chatClient, instructions: instructions, name: key, description: description, tools: tools);
|
||||
});
|
||||
}
|
||||
@@ -127,10 +126,4 @@ public static class AgentHostingServiceCollectionExtensions
|
||||
|
||||
return new HostedAgentBuilder(name, services);
|
||||
}
|
||||
|
||||
private static IList<AITool> GetRegisteredToolsForAgent(IServiceProvider serviceProvider, string agentName)
|
||||
{
|
||||
var registry = serviceProvider.GetService<LocalAgentToolRegistry>();
|
||||
return registry?.GetTools(agentName) ?? [];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
using System;
|
||||
using System.Linq;
|
||||
using Microsoft.Agents.AI.Hosting.Local;
|
||||
using Microsoft.Extensions.AI;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Shared.Diagnostics;
|
||||
@@ -70,18 +68,7 @@ public static class HostedAgentBuilderExtensions
|
||||
Throw.IfNull(builder);
|
||||
Throw.IfNull(tool);
|
||||
|
||||
var agentName = builder.Name;
|
||||
var services = builder.ServiceCollection;
|
||||
|
||||
// Get or create the agent tool registry
|
||||
var descriptor = services.FirstOrDefault(sd => !sd.IsKeyedService && sd.ServiceType.Equals(typeof(LocalAgentToolRegistry)));
|
||||
if (descriptor?.ImplementationInstance is not LocalAgentToolRegistry toolRegistry)
|
||||
{
|
||||
toolRegistry = new();
|
||||
services.Add(ServiceDescriptor.Singleton(toolRegistry));
|
||||
}
|
||||
|
||||
toolRegistry.AddTool(agentName, tool);
|
||||
builder.ServiceCollection.AddKeyedSingleton(builder.Name, tool);
|
||||
|
||||
return builder;
|
||||
}
|
||||
@@ -105,4 +92,19 @@ public static class HostedAgentBuilderExtensions
|
||||
|
||||
return builder;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Adds AI tool to an agent being configured with the service collection.
|
||||
/// </summary>
|
||||
/// <param name="builder">The hosted agent builder.</param>
|
||||
/// <param name="factory">A factory function that creates a AI tool using the provided service provider.</param>
|
||||
public static IHostedAgentBuilder WithAITool(this IHostedAgentBuilder builder, Func<IServiceProvider, AITool> factory)
|
||||
{
|
||||
Throw.IfNull(builder);
|
||||
Throw.IfNull(factory);
|
||||
|
||||
builder.ServiceCollection.AddKeyedSingleton(builder.Name, (sp, name) => factory(sp));
|
||||
|
||||
return builder;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
using System.Collections.Generic;
|
||||
using Microsoft.Extensions.AI;
|
||||
|
||||
namespace Microsoft.Agents.AI.Hosting.Local;
|
||||
|
||||
internal sealed class LocalAgentToolRegistry
|
||||
{
|
||||
private readonly Dictionary<string, List<AITool>> _toolsByAgentName = [];
|
||||
|
||||
public void AddTool(string agentName, AITool tool)
|
||||
{
|
||||
if (!this._toolsByAgentName.TryGetValue(agentName, out var tools))
|
||||
{
|
||||
tools = [];
|
||||
this._toolsByAgentName[agentName] = tools;
|
||||
}
|
||||
|
||||
tools.Add(tool);
|
||||
}
|
||||
|
||||
public IList<AITool> GetTools(string agentName)
|
||||
{
|
||||
return this._toolsByAgentName.TryGetValue(agentName, out var tools) ? tools : [];
|
||||
}
|
||||
}
|
||||
+149
-14
@@ -2,6 +2,7 @@
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.Extensions.AI;
|
||||
@@ -17,49 +18,40 @@ public sealed class HostedAgentBuilderToolsExtensionsTests
|
||||
[Fact]
|
||||
public void WithAITool_ThrowsWhenBuilderIsNull()
|
||||
{
|
||||
// Arrange
|
||||
var tool = new DummyAITool();
|
||||
|
||||
// Act & Assert
|
||||
Assert.Throws<ArgumentNullException>(() => HostedAgentBuilderExtensions.WithAITool(null!, tool));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void WithAITool_ThrowsWhenToolIsNull()
|
||||
{
|
||||
// Arrange
|
||||
var services = new ServiceCollection();
|
||||
var builder = services.AddAIAgent("test-agent", "Test instructions");
|
||||
|
||||
// Act & Assert
|
||||
Assert.Throws<ArgumentNullException>(() => builder.WithAITool(null!));
|
||||
Assert.Throws<ArgumentNullException>(() => builder.WithAITool(tool: null!));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void WithAITools_ThrowsWhenBuilderIsNull()
|
||||
{
|
||||
// Arrange
|
||||
var tools = new[] { new DummyAITool() };
|
||||
|
||||
// Act & Assert
|
||||
Assert.Throws<ArgumentNullException>(() => HostedAgentBuilderExtensions.WithAITools(null!, tools));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void WithAITools_ThrowsWhenToolsArrayIsNull()
|
||||
{
|
||||
// Arrange
|
||||
var services = new ServiceCollection();
|
||||
var builder = services.AddAIAgent("test-agent", "Test instructions");
|
||||
|
||||
// Act & Assert
|
||||
Assert.Throws<ArgumentNullException>(() => builder.WithAITools(null!));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void RegisteredTools_ResolvesAllToolsForAgent()
|
||||
{
|
||||
// Arrange
|
||||
var services = new ServiceCollection();
|
||||
services.AddSingleton<IChatClient>(new MockChatClient());
|
||||
|
||||
@@ -73,9 +65,13 @@ public sealed class HostedAgentBuilderToolsExtensionsTests
|
||||
|
||||
var serviceProvider = services.BuildServiceProvider();
|
||||
|
||||
var agent1Tools = ResolveAgentTools(serviceProvider, "test-agent");
|
||||
var agent1Tools = ResolveToolsFromAgent(serviceProvider, "test-agent");
|
||||
Assert.Contains(tool1, agent1Tools);
|
||||
Assert.Contains(tool2, agent1Tools);
|
||||
|
||||
var agent1ToolsDI = ResolveToolsFromDI(serviceProvider, "test-agent");
|
||||
Assert.Contains(tool1, agent1ToolsDI);
|
||||
Assert.Contains(tool2, agent1ToolsDI);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
@@ -100,21 +96,160 @@ public sealed class HostedAgentBuilderToolsExtensionsTests
|
||||
|
||||
var serviceProvider = services.BuildServiceProvider();
|
||||
|
||||
var agent1Tools = ResolveAgentTools(serviceProvider, "agent1");
|
||||
var agent2Tools = ResolveAgentTools(serviceProvider, "agent2");
|
||||
var agent1Tools = ResolveToolsFromAgent(serviceProvider, "agent1");
|
||||
var agent2Tools = ResolveToolsFromAgent(serviceProvider, "agent2");
|
||||
|
||||
var agent1ToolsDI = ResolveToolsFromDI(serviceProvider, "agent1");
|
||||
var agent2ToolsDI = ResolveToolsFromDI(serviceProvider, "agent2");
|
||||
|
||||
Assert.Contains(tool1, agent1Tools);
|
||||
Assert.Contains(tool2, agent1Tools);
|
||||
Assert.Contains(tool1, agent1ToolsDI);
|
||||
Assert.Contains(tool2, agent1ToolsDI);
|
||||
|
||||
Assert.Contains(tool3, agent2Tools);
|
||||
Assert.Contains(tool3, agent2ToolsDI);
|
||||
}
|
||||
|
||||
private static IList<AITool> ResolveAgentTools(IServiceProvider serviceProvider, string name)
|
||||
private static IList<AITool> ResolveToolsFromAgent(IServiceProvider serviceProvider, string name)
|
||||
{
|
||||
var agent = serviceProvider.GetRequiredKeyedService<AIAgent>(name) as ChatClientAgent;
|
||||
Assert.NotNull(agent?.ChatOptions?.Tools);
|
||||
return agent.ChatOptions.Tools;
|
||||
}
|
||||
|
||||
private static List<AITool> ResolveToolsFromDI(IServiceProvider serviceProvider, string name)
|
||||
{
|
||||
var tools = serviceProvider.GetKeyedServices<AITool>(name);
|
||||
Assert.NotNull(tools);
|
||||
return tools.ToList();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void WithAIToolFactory_ThrowsWhenBuilderIsNull()
|
||||
{
|
||||
Assert.Throws<ArgumentNullException>(() => HostedAgentBuilderExtensions.WithAITool(null!, CreateTool));
|
||||
|
||||
static AITool CreateTool(IServiceProvider _) => new DummyAITool();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void WithAIToolFactory_ThrowsWhenFactoryIsNull()
|
||||
{
|
||||
var services = new ServiceCollection();
|
||||
var builder = services.AddAIAgent("test-agent", "Test instructions");
|
||||
|
||||
Assert.Throws<ArgumentNullException>(() => builder.WithAITool(factory: null!));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void WithAIToolFactory_RegistersToolFromFactory()
|
||||
{
|
||||
var services = new ServiceCollection();
|
||||
services.AddSingleton<IChatClient>(new MockChatClient());
|
||||
|
||||
DummyAITool? createdTool = null;
|
||||
var builder = services.AddAIAgent("test-agent", "Test instructions");
|
||||
builder.WithAITool(sp =>
|
||||
{
|
||||
createdTool = new DummyAITool();
|
||||
return createdTool;
|
||||
});
|
||||
|
||||
var serviceProvider = services.BuildServiceProvider();
|
||||
var tools = ResolveToolsFromDI(serviceProvider, "test-agent");
|
||||
|
||||
Assert.Single(tools);
|
||||
Assert.Same(createdTool, tools[0]);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void WithAIToolFactory_CanAccessServicesFromFactory()
|
||||
{
|
||||
var services = new ServiceCollection();
|
||||
var mockChatClient = new MockChatClient();
|
||||
services.AddSingleton<IChatClient>(mockChatClient);
|
||||
|
||||
IChatClient? resolvedChatClient = null;
|
||||
var builder = services.AddAIAgent("test-agent", "Test instructions");
|
||||
builder.WithAITool(sp =>
|
||||
{
|
||||
resolvedChatClient = sp.GetService<IChatClient>();
|
||||
return new DummyAITool();
|
||||
});
|
||||
|
||||
var serviceProvider = services.BuildServiceProvider();
|
||||
_ = ResolveToolsFromDI(serviceProvider, "test-agent");
|
||||
|
||||
Assert.Same(mockChatClient, resolvedChatClient);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void WithAIToolFactory_ToolsAreIsolatedPerAgent()
|
||||
{
|
||||
var services = new ServiceCollection();
|
||||
services.AddSingleton<IChatClient>(new MockChatClient());
|
||||
|
||||
var tool1 = new DummyAITool();
|
||||
var tool2 = new DummyAITool();
|
||||
|
||||
var builder1 = services.AddAIAgent("agent1", "Agent 1 instructions");
|
||||
var builder2 = services.AddAIAgent("agent2", "Agent 2 instructions");
|
||||
|
||||
builder1.WithAITool(_ => tool1);
|
||||
builder2.WithAITool(_ => tool2);
|
||||
|
||||
var serviceProvider = services.BuildServiceProvider();
|
||||
var agent1Tools = ResolveToolsFromDI(serviceProvider, "agent1");
|
||||
var agent2Tools = ResolveToolsFromDI(serviceProvider, "agent2");
|
||||
|
||||
Assert.Single(agent1Tools);
|
||||
Assert.Contains(tool1, agent1Tools);
|
||||
Assert.DoesNotContain(tool2, agent1Tools);
|
||||
|
||||
Assert.Single(agent2Tools);
|
||||
Assert.Contains(tool2, agent2Tools);
|
||||
Assert.DoesNotContain(tool1, agent2Tools);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void WithAIToolFactory_CanCombineWithDirectToolRegistration()
|
||||
{
|
||||
var services = new ServiceCollection();
|
||||
services.AddSingleton<IChatClient>(new MockChatClient());
|
||||
|
||||
var directTool = new DummyAITool();
|
||||
var factoryTool = new DummyAITool();
|
||||
|
||||
var builder = services.AddAIAgent("test-agent", "Test instructions");
|
||||
builder
|
||||
.WithAITool(directTool)
|
||||
.WithAITool(_ => factoryTool);
|
||||
|
||||
var serviceProvider = services.BuildServiceProvider();
|
||||
var tools = ResolveToolsFromDI(serviceProvider, "test-agent");
|
||||
|
||||
Assert.Equal(2, tools.Count);
|
||||
Assert.Contains(directTool, tools);
|
||||
Assert.Contains(factoryTool, tools);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void WithAIToolFactory_ToolsAvailableOnAgent()
|
||||
{
|
||||
var services = new ServiceCollection();
|
||||
services.AddSingleton<IChatClient>(new MockChatClient());
|
||||
|
||||
var factoryTool = new DummyAITool();
|
||||
var builder = services.AddAIAgent("test-agent", "Test instructions");
|
||||
builder.WithAITool(_ => factoryTool);
|
||||
|
||||
var serviceProvider = services.BuildServiceProvider();
|
||||
var agentTools = ResolveToolsFromAgent(serviceProvider, "test-agent");
|
||||
|
||||
Assert.Contains(factoryTool, agentTools);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Dummy AITool implementation for testing.
|
||||
/// </summary>
|
||||
|
||||
Reference in New Issue
Block a user