mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Fix GitHubCopilotAgent to include tools added by ContextProvider.before_run in session creation (#5780)
* Fix GitHubCopilotAgent ignoring tools from context providers (#5736) _create_session and _resume_session only forwarded self._tools (constructor tools) to CopilotClient.create_session, dropping any tools contributed by context providers via session_context.extend_tools() during before_run. Merge provider-contributed tools into runtime_options in both _run_impl and _stream_updates before session creation, mirroring how RawAgent handles the merge at lines 1435-1440 in _agents.py. Update _create_session and _resume_session to combine self._tools with the merged runtime tools. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Python: Fix GitHubCopilotAgent to include tools added by ContextProvider.before_run in session creation Fixes #5736 * Fix provider tool merge to avoid mutating caller's list - Replace in-place .extend() with fresh list creation in both _run_impl and _stream_updates paths to prevent mutating the caller-provided options['tools'] list (shallow copy issue) - Also handles immutable Sequence types (e.g. tuple) correctly - Add test for provider tools forwarded via _resume_session path Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Address review feedback for #5736: review comment fixes --------- Co-authored-by: Copilot <copilot@github.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
d81a8753d7
commit
0d09d40f0f
@@ -157,9 +157,7 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
self.client = factory.create(agent_card, interceptors=interceptors) # type: ignore
|
||||
except Exception as transport_error:
|
||||
# Transport negotiation failed - fall back to minimal agent card with JSONRPC
|
||||
fallback_url = (
|
||||
agent_card.supported_interfaces[0].url if agent_card.supported_interfaces else url
|
||||
)
|
||||
fallback_url = agent_card.supported_interfaces[0].url if agent_card.supported_interfaces else url
|
||||
if not fallback_url:
|
||||
raise ValueError(
|
||||
"A2A transport negotiation failed and no fallback URL is available. "
|
||||
|
||||
@@ -520,8 +520,11 @@ class RawGitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
|
||||
session_context = await self._run_before_providers(session=session, input_messages=input_messages, options=opts)
|
||||
|
||||
# NOTE: session is created after providers run so that future provider-contributed
|
||||
# tools/config could be folded into runtime_options before session creation.
|
||||
# Merge provider-contributed tools into runtime_options before session creation.
|
||||
if session_context.tools:
|
||||
existing = list(opts.get("tools") or [])
|
||||
opts["tools"] = existing + list(session_context.tools)
|
||||
|
||||
copilot_session = await self._get_or_create_session(session, streaming=False, runtime_options=opts)
|
||||
|
||||
# Build the prompt from the full set of messages in the session context,
|
||||
@@ -605,8 +608,11 @@ class RawGitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
|
||||
session_context = await self._run_before_providers(session=session, input_messages=input_messages, options=opts)
|
||||
|
||||
# NOTE: session is created after providers run so that future provider-contributed
|
||||
# tools/config could be folded into runtime_options before session creation.
|
||||
# Merge provider-contributed tools into runtime_options before session creation.
|
||||
if session_context.tools:
|
||||
existing = list(opts.get("tools") or [])
|
||||
opts["tools"] = existing + list(session_context.tools)
|
||||
|
||||
copilot_session = await self._get_or_create_session(session, streaming=True, runtime_options=opts)
|
||||
|
||||
if _ctx_holder is not None:
|
||||
@@ -891,7 +897,8 @@ class RawGitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
mcp_servers = opts.get("mcp_servers") or self._mcp_servers or None
|
||||
provider = opts.get("provider") or self._provider or None
|
||||
instruction_directories = opts.get("instruction_directories", self._instruction_directories)
|
||||
tools = self._prepare_tools(self._tools) if self._tools else None
|
||||
all_tools = list(self._tools or []) + list(opts.get("tools") or [])
|
||||
tools = self._prepare_tools(all_tools) if all_tools else None
|
||||
|
||||
return await self._client.create_session(
|
||||
on_permission_request=permission_handler,
|
||||
@@ -929,7 +936,8 @@ class RawGitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
mcp_servers = opts.get("mcp_servers") or self._mcp_servers or None
|
||||
provider = opts.get("provider") or self._provider or None
|
||||
instruction_directories = opts.get("instruction_directories", self._instruction_directories)
|
||||
tools = self._prepare_tools(self._tools) if self._tools else None
|
||||
all_tools = list(self._tools or []) + list(opts.get("tools") or [])
|
||||
tools = self._prepare_tools(all_tools) if all_tools else None
|
||||
|
||||
return await self._client.resume_session(
|
||||
session_id,
|
||||
|
||||
@@ -2477,3 +2477,231 @@ class TestGitHubCopilotAgentContextProviders:
|
||||
with pytest.raises(ValueError, match="on_function_approval"):
|
||||
async for _ in agent.run("hello", stream=True, options={"on_function_approval": lambda _c: True}):
|
||||
pass
|
||||
|
||||
async def test_provider_tools_forwarded_to_session(
|
||||
self,
|
||||
mock_client: MagicMock,
|
||||
mock_session: MagicMock,
|
||||
assistant_message_event: SessionEvent,
|
||||
) -> None:
|
||||
"""Test that tools added by context providers are forwarded to session creation."""
|
||||
mock_session.send_and_wait.return_value = assistant_message_event
|
||||
|
||||
class ToolInjectingProvider(ContextProvider):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(source_id="tool-injector")
|
||||
|
||||
async def before_run(
|
||||
self,
|
||||
*,
|
||||
agent: Any,
|
||||
session: AgentSession,
|
||||
context: Any,
|
||||
state: dict[str, Any],
|
||||
) -> None:
|
||||
from agent_framework._tools import normalize_tools
|
||||
|
||||
def load_skill(skill_name: str) -> str:
|
||||
"""Load a skill by name."""
|
||||
return f"Loaded: {skill_name}"
|
||||
|
||||
context.extend_tools(self.source_id, normalize_tools([load_skill]))
|
||||
|
||||
provider = ToolInjectingProvider()
|
||||
agent = GitHubCopilotAgent(client=mock_client, context_providers=[provider])
|
||||
session = agent.create_session()
|
||||
await agent.run("Hello", session=session)
|
||||
|
||||
call_kwargs = mock_client.create_session.call_args.kwargs
|
||||
assert call_kwargs.get("tools") is not None
|
||||
tool_names = [t.name for t in call_kwargs["tools"]]
|
||||
assert "load_skill" in tool_names
|
||||
|
||||
async def test_provider_tools_merged_with_constructor_tools(
|
||||
self,
|
||||
mock_client: MagicMock,
|
||||
mock_session: MagicMock,
|
||||
assistant_message_event: SessionEvent,
|
||||
) -> None:
|
||||
"""Test that provider tools are merged with constructor tools, not replacing them."""
|
||||
mock_session.send_and_wait.return_value = assistant_message_event
|
||||
|
||||
def my_tool(x: str) -> str:
|
||||
"""A constructor tool."""
|
||||
return x
|
||||
|
||||
class ToolInjectingProvider(ContextProvider):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(source_id="tool-injector")
|
||||
|
||||
async def before_run(
|
||||
self,
|
||||
*,
|
||||
agent: Any,
|
||||
session: AgentSession,
|
||||
context: Any,
|
||||
state: dict[str, Any],
|
||||
) -> None:
|
||||
from agent_framework._tools import normalize_tools
|
||||
|
||||
def load_skill(skill_name: str) -> str:
|
||||
"""Load a skill by name."""
|
||||
return f"Loaded: {skill_name}"
|
||||
|
||||
context.extend_tools(self.source_id, normalize_tools([load_skill]))
|
||||
|
||||
provider = ToolInjectingProvider()
|
||||
agent = GitHubCopilotAgent(
|
||||
client=mock_client,
|
||||
tools=[my_tool],
|
||||
context_providers=[provider],
|
||||
)
|
||||
session = agent.create_session()
|
||||
await agent.run("Hello", session=session)
|
||||
|
||||
call_kwargs = mock_client.create_session.call_args.kwargs
|
||||
assert call_kwargs.get("tools") is not None
|
||||
tool_names = [t.name for t in call_kwargs["tools"]]
|
||||
assert "my_tool" in tool_names
|
||||
assert "load_skill" in tool_names
|
||||
|
||||
async def test_provider_tools_forwarded_in_streaming(
|
||||
self,
|
||||
mock_client: MagicMock,
|
||||
mock_session: MagicMock,
|
||||
assistant_delta_event: SessionEvent,
|
||||
session_idle_event: SessionEvent,
|
||||
) -> None:
|
||||
"""Test that provider tools are forwarded in the streaming path."""
|
||||
events = [assistant_delta_event, session_idle_event]
|
||||
|
||||
def mock_on(handler: Any) -> Any:
|
||||
for event in events:
|
||||
handler(event)
|
||||
return lambda: None
|
||||
|
||||
mock_session.on = mock_on
|
||||
|
||||
class ToolInjectingProvider(ContextProvider):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(source_id="tool-injector")
|
||||
|
||||
async def before_run(
|
||||
self,
|
||||
*,
|
||||
agent: Any,
|
||||
session: AgentSession,
|
||||
context: Any,
|
||||
state: dict[str, Any],
|
||||
) -> None:
|
||||
from agent_framework._tools import normalize_tools
|
||||
|
||||
def load_skill(skill_name: str) -> str:
|
||||
"""Load a skill by name."""
|
||||
return f"Loaded: {skill_name}"
|
||||
|
||||
context.extend_tools(self.source_id, normalize_tools([load_skill]))
|
||||
|
||||
provider = ToolInjectingProvider()
|
||||
agent = GitHubCopilotAgent(client=mock_client, context_providers=[provider])
|
||||
session = agent.create_session()
|
||||
async for _ in agent.run("Hello", stream=True, session=session):
|
||||
pass
|
||||
|
||||
call_kwargs = mock_client.create_session.call_args.kwargs
|
||||
assert call_kwargs.get("tools") is not None
|
||||
tool_names = [t.name for t in call_kwargs["tools"]]
|
||||
assert "load_skill" in tool_names
|
||||
|
||||
async def test_provider_tools_forwarded_to_resume_session(
|
||||
self,
|
||||
mock_client: MagicMock,
|
||||
mock_session: MagicMock,
|
||||
assistant_message_event: SessionEvent,
|
||||
) -> None:
|
||||
"""Test that provider tools are forwarded when resuming an existing session."""
|
||||
mock_session.send_and_wait.return_value = assistant_message_event
|
||||
|
||||
class ToolInjectingProvider(ContextProvider):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(source_id="tool-injector")
|
||||
|
||||
async def before_run(
|
||||
self,
|
||||
*,
|
||||
agent: Any,
|
||||
session: AgentSession,
|
||||
context: Any,
|
||||
state: dict[str, Any],
|
||||
) -> None:
|
||||
from agent_framework._tools import normalize_tools
|
||||
|
||||
def load_skill(skill_name: str) -> str:
|
||||
"""Load a skill by name."""
|
||||
return f"Loaded: {skill_name}"
|
||||
|
||||
context.extend_tools(self.source_id, normalize_tools([load_skill]))
|
||||
|
||||
provider = ToolInjectingProvider()
|
||||
agent = GitHubCopilotAgent(client=mock_client, context_providers=[provider])
|
||||
session = agent.create_session()
|
||||
session.service_session_id = "existing-id"
|
||||
await agent.run("Hello", session=session)
|
||||
|
||||
mock_client.create_session.assert_not_called()
|
||||
mock_client.resume_session.assert_called_once()
|
||||
call_kwargs = mock_client.resume_session.call_args.kwargs
|
||||
assert call_kwargs.get("tools") is not None
|
||||
tool_names = [t.name for t in call_kwargs["tools"]]
|
||||
assert "load_skill" in tool_names
|
||||
|
||||
async def test_provider_tools_forwarded_to_resume_session_streaming(
|
||||
self,
|
||||
mock_client: MagicMock,
|
||||
mock_session: MagicMock,
|
||||
assistant_delta_event: SessionEvent,
|
||||
session_idle_event: SessionEvent,
|
||||
) -> None:
|
||||
"""Test that provider tools are forwarded when resuming an existing session in streaming mode."""
|
||||
events = [assistant_delta_event, session_idle_event]
|
||||
|
||||
def mock_on(handler: Any) -> Any:
|
||||
for event in events:
|
||||
handler(event)
|
||||
return lambda: None
|
||||
|
||||
mock_session.on = mock_on
|
||||
|
||||
class ToolInjectingProvider(ContextProvider):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(source_id="tool-injector")
|
||||
|
||||
async def before_run(
|
||||
self,
|
||||
*,
|
||||
agent: Any,
|
||||
session: AgentSession,
|
||||
context: Any,
|
||||
state: dict[str, Any],
|
||||
) -> None:
|
||||
from agent_framework._tools import normalize_tools
|
||||
|
||||
def load_skill(skill_name: str) -> str:
|
||||
"""Load a skill by name."""
|
||||
return f"Loaded: {skill_name}"
|
||||
|
||||
context.extend_tools(self.source_id, normalize_tools([load_skill]))
|
||||
|
||||
provider = ToolInjectingProvider()
|
||||
agent = GitHubCopilotAgent(client=mock_client, context_providers=[provider])
|
||||
session = agent.create_session()
|
||||
session.service_session_id = "existing-id"
|
||||
async for _ in agent.run("Hello", stream=True, session=session):
|
||||
pass
|
||||
|
||||
mock_client.create_session.assert_not_called()
|
||||
mock_client.resume_session.assert_called_once()
|
||||
call_kwargs = mock_client.resume_session.call_args.kwargs
|
||||
assert call_kwargs.get("tools") is not None
|
||||
tool_names = [t.name for t in call_kwargs["tools"]]
|
||||
assert "load_skill" in tool_names
|
||||
|
||||
Reference in New Issue
Block a user