Python: Fix GitHubCopilotAgent to include tools added by ContextProvider.before_run in session creation (#5780)

* Fix GitHubCopilotAgent ignoring tools from context providers (#5736)

_create_session and _resume_session only forwarded self._tools (constructor
tools) to CopilotClient.create_session, dropping any tools contributed by
context providers via session_context.extend_tools() during before_run.

Merge provider-contributed tools into runtime_options in both _run_impl and
_stream_updates before session creation, mirroring how RawAgent handles the
merge at lines 1435-1440 in _agents.py. Update _create_session and
_resume_session to combine self._tools with the merged runtime tools.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Python: Fix GitHubCopilotAgent to include tools added by ContextProvider.before_run in session creation

Fixes #5736

* Fix provider tool merge to avoid mutating caller's list

- Replace in-place .extend() with fresh list creation in both
  _run_impl and _stream_updates paths to prevent mutating the
  caller-provided options['tools'] list (shallow copy issue)
- Also handles immutable Sequence types (e.g. tuple) correctly
- Add test for provider tools forwarded via _resume_session path

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Address review feedback for #5736: review comment fixes

---------

Co-authored-by: Copilot <copilot@github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Giles Odigwe
2026-05-15 07:59:22 -07:00
committed by GitHub
Unverified
parent d81a8753d7
commit 0d09d40f0f
4 changed files with 244 additions and 10 deletions
@@ -157,9 +157,7 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
self.client = factory.create(agent_card, interceptors=interceptors) # type: ignore
except Exception as transport_error:
# Transport negotiation failed - fall back to minimal agent card with JSONRPC
fallback_url = (
agent_card.supported_interfaces[0].url if agent_card.supported_interfaces else url
)
fallback_url = agent_card.supported_interfaces[0].url if agent_card.supported_interfaces else url
if not fallback_url:
raise ValueError(
"A2A transport negotiation failed and no fallback URL is available. "
@@ -520,8 +520,11 @@ class RawGitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
session_context = await self._run_before_providers(session=session, input_messages=input_messages, options=opts)
# NOTE: session is created after providers run so that future provider-contributed
# tools/config could be folded into runtime_options before session creation.
# Merge provider-contributed tools into runtime_options before session creation.
if session_context.tools:
existing = list(opts.get("tools") or [])
opts["tools"] = existing + list(session_context.tools)
copilot_session = await self._get_or_create_session(session, streaming=False, runtime_options=opts)
# Build the prompt from the full set of messages in the session context,
@@ -605,8 +608,11 @@ class RawGitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
session_context = await self._run_before_providers(session=session, input_messages=input_messages, options=opts)
# NOTE: session is created after providers run so that future provider-contributed
# tools/config could be folded into runtime_options before session creation.
# Merge provider-contributed tools into runtime_options before session creation.
if session_context.tools:
existing = list(opts.get("tools") or [])
opts["tools"] = existing + list(session_context.tools)
copilot_session = await self._get_or_create_session(session, streaming=True, runtime_options=opts)
if _ctx_holder is not None:
@@ -891,7 +897,8 @@ class RawGitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
mcp_servers = opts.get("mcp_servers") or self._mcp_servers or None
provider = opts.get("provider") or self._provider or None
instruction_directories = opts.get("instruction_directories", self._instruction_directories)
tools = self._prepare_tools(self._tools) if self._tools else None
all_tools = list(self._tools or []) + list(opts.get("tools") or [])
tools = self._prepare_tools(all_tools) if all_tools else None
return await self._client.create_session(
on_permission_request=permission_handler,
@@ -929,7 +936,8 @@ class RawGitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
mcp_servers = opts.get("mcp_servers") or self._mcp_servers or None
provider = opts.get("provider") or self._provider or None
instruction_directories = opts.get("instruction_directories", self._instruction_directories)
tools = self._prepare_tools(self._tools) if self._tools else None
all_tools = list(self._tools or []) + list(opts.get("tools") or [])
tools = self._prepare_tools(all_tools) if all_tools else None
return await self._client.resume_session(
session_id,
@@ -2477,3 +2477,231 @@ class TestGitHubCopilotAgentContextProviders:
with pytest.raises(ValueError, match="on_function_approval"):
async for _ in agent.run("hello", stream=True, options={"on_function_approval": lambda _c: True}):
pass
async def test_provider_tools_forwarded_to_session(
self,
mock_client: MagicMock,
mock_session: MagicMock,
assistant_message_event: SessionEvent,
) -> None:
"""Test that tools added by context providers are forwarded to session creation."""
mock_session.send_and_wait.return_value = assistant_message_event
class ToolInjectingProvider(ContextProvider):
def __init__(self) -> None:
super().__init__(source_id="tool-injector")
async def before_run(
self,
*,
agent: Any,
session: AgentSession,
context: Any,
state: dict[str, Any],
) -> None:
from agent_framework._tools import normalize_tools
def load_skill(skill_name: str) -> str:
"""Load a skill by name."""
return f"Loaded: {skill_name}"
context.extend_tools(self.source_id, normalize_tools([load_skill]))
provider = ToolInjectingProvider()
agent = GitHubCopilotAgent(client=mock_client, context_providers=[provider])
session = agent.create_session()
await agent.run("Hello", session=session)
call_kwargs = mock_client.create_session.call_args.kwargs
assert call_kwargs.get("tools") is not None
tool_names = [t.name for t in call_kwargs["tools"]]
assert "load_skill" in tool_names
async def test_provider_tools_merged_with_constructor_tools(
self,
mock_client: MagicMock,
mock_session: MagicMock,
assistant_message_event: SessionEvent,
) -> None:
"""Test that provider tools are merged with constructor tools, not replacing them."""
mock_session.send_and_wait.return_value = assistant_message_event
def my_tool(x: str) -> str:
"""A constructor tool."""
return x
class ToolInjectingProvider(ContextProvider):
def __init__(self) -> None:
super().__init__(source_id="tool-injector")
async def before_run(
self,
*,
agent: Any,
session: AgentSession,
context: Any,
state: dict[str, Any],
) -> None:
from agent_framework._tools import normalize_tools
def load_skill(skill_name: str) -> str:
"""Load a skill by name."""
return f"Loaded: {skill_name}"
context.extend_tools(self.source_id, normalize_tools([load_skill]))
provider = ToolInjectingProvider()
agent = GitHubCopilotAgent(
client=mock_client,
tools=[my_tool],
context_providers=[provider],
)
session = agent.create_session()
await agent.run("Hello", session=session)
call_kwargs = mock_client.create_session.call_args.kwargs
assert call_kwargs.get("tools") is not None
tool_names = [t.name for t in call_kwargs["tools"]]
assert "my_tool" in tool_names
assert "load_skill" in tool_names
async def test_provider_tools_forwarded_in_streaming(
self,
mock_client: MagicMock,
mock_session: MagicMock,
assistant_delta_event: SessionEvent,
session_idle_event: SessionEvent,
) -> None:
"""Test that provider tools are forwarded in the streaming path."""
events = [assistant_delta_event, session_idle_event]
def mock_on(handler: Any) -> Any:
for event in events:
handler(event)
return lambda: None
mock_session.on = mock_on
class ToolInjectingProvider(ContextProvider):
def __init__(self) -> None:
super().__init__(source_id="tool-injector")
async def before_run(
self,
*,
agent: Any,
session: AgentSession,
context: Any,
state: dict[str, Any],
) -> None:
from agent_framework._tools import normalize_tools
def load_skill(skill_name: str) -> str:
"""Load a skill by name."""
return f"Loaded: {skill_name}"
context.extend_tools(self.source_id, normalize_tools([load_skill]))
provider = ToolInjectingProvider()
agent = GitHubCopilotAgent(client=mock_client, context_providers=[provider])
session = agent.create_session()
async for _ in agent.run("Hello", stream=True, session=session):
pass
call_kwargs = mock_client.create_session.call_args.kwargs
assert call_kwargs.get("tools") is not None
tool_names = [t.name for t in call_kwargs["tools"]]
assert "load_skill" in tool_names
async def test_provider_tools_forwarded_to_resume_session(
self,
mock_client: MagicMock,
mock_session: MagicMock,
assistant_message_event: SessionEvent,
) -> None:
"""Test that provider tools are forwarded when resuming an existing session."""
mock_session.send_and_wait.return_value = assistant_message_event
class ToolInjectingProvider(ContextProvider):
def __init__(self) -> None:
super().__init__(source_id="tool-injector")
async def before_run(
self,
*,
agent: Any,
session: AgentSession,
context: Any,
state: dict[str, Any],
) -> None:
from agent_framework._tools import normalize_tools
def load_skill(skill_name: str) -> str:
"""Load a skill by name."""
return f"Loaded: {skill_name}"
context.extend_tools(self.source_id, normalize_tools([load_skill]))
provider = ToolInjectingProvider()
agent = GitHubCopilotAgent(client=mock_client, context_providers=[provider])
session = agent.create_session()
session.service_session_id = "existing-id"
await agent.run("Hello", session=session)
mock_client.create_session.assert_not_called()
mock_client.resume_session.assert_called_once()
call_kwargs = mock_client.resume_session.call_args.kwargs
assert call_kwargs.get("tools") is not None
tool_names = [t.name for t in call_kwargs["tools"]]
assert "load_skill" in tool_names
async def test_provider_tools_forwarded_to_resume_session_streaming(
self,
mock_client: MagicMock,
mock_session: MagicMock,
assistant_delta_event: SessionEvent,
session_idle_event: SessionEvent,
) -> None:
"""Test that provider tools are forwarded when resuming an existing session in streaming mode."""
events = [assistant_delta_event, session_idle_event]
def mock_on(handler: Any) -> Any:
for event in events:
handler(event)
return lambda: None
mock_session.on = mock_on
class ToolInjectingProvider(ContextProvider):
def __init__(self) -> None:
super().__init__(source_id="tool-injector")
async def before_run(
self,
*,
agent: Any,
session: AgentSession,
context: Any,
state: dict[str, Any],
) -> None:
from agent_framework._tools import normalize_tools
def load_skill(skill_name: str) -> str:
"""Load a skill by name."""
return f"Loaded: {skill_name}"
context.extend_tools(self.source_id, normalize_tools([load_skill]))
provider = ToolInjectingProvider()
agent = GitHubCopilotAgent(client=mock_client, context_providers=[provider])
session = agent.create_session()
session.service_session_id = "existing-id"
async for _ in agent.run("Hello", stream=True, session=session):
pass
mock_client.create_session.assert_not_called()
mock_client.resume_session.assert_called_once()
call_kwargs = mock_client.resume_session.call_args.kwargs
assert call_kwargs.get("tools") is not None
tool_names = [t.name for t in call_kwargs["tools"]]
assert "load_skill" in tool_names
+1 -1
View File
@@ -602,7 +602,7 @@ dependencies = [
[package.metadata]
requires-dist = [
{ name = "agent-framework-core", editable = "packages/core" },
{ name = "github-copilot-sdk", marker = "python_full_version >= '3.11'", specifier = "<=1.0.0b2,>=1.0.0b2" },
{ name = "github-copilot-sdk", marker = "python_full_version >= '3.11'", specifier = ">=1.0.0b2,<=1.0.0b2" },
]
[[package]]