Python: Fix GitHubCopilotAgent to invoke context provider before_run/after_run hooks (#5013)

* Fix GitHubCopilotAgent not calling context provider hooks (#3984)

GitHubCopilotAgent accepted context_providers in its constructor but
never called before_run()/after_run() on them in _run_impl() or
_stream_updates(), silently ignoring all context providers.

Add _run_before_providers() helper to create SessionContext and invoke
before_run on each provider. Both _run_impl() and _stream_updates() now
run the full provider lifecycle: before_run before sending the prompt
(with provider instructions prepended) and after_run after receiving the
response. This follows the same pattern used by A2AAgent.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Python: Fix GitHubCopilotAgent to invoke context provider before_run/after_run hooks

Fixes #3984

* fix(#3984): address review feedback for context provider integration

- Build prompt from session_context.get_messages(include_input=True) so
  provider-injected context_messages are included in both non-streaming
  and streaming paths (review comments #1, #2)
- Preserve timeout in opts (use get instead of pop) so providers can
  observe it via context.options (review comment #3)
- Eliminate streaming double-buffer: move after_run invocation to a
  ResponseStream result_hook (matching Agent class pattern) instead of
  maintaining a separate updates list in the generator (review comment #4)
- Improve _run_before_providers docstring

Add tests for:
- Context messages included in prompt (non-streaming + streaming)
- Error path: after_run NOT called when send_and_wait/streaming raises
- Multiple providers: forward before_run, reverse after_run ordering
- BaseHistoryProvider with load_messages=False is skipped
- Streaming after_run response contains aggregated updates
- Streaming with no updates still sets empty response
- Timeout preserved in session context options for providers

Note: _run_before_providers remains on GitHubCopilotAgent for now. A
follow-up PR should extract it to BaseAgent so subclasses can reuse it
without duplicating the provider iteration logic.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Address review feedback for #3984: Python: [Bug]: GitHubCopilotAgent Memory Example

* refactor(#3984): promote _run_before_providers to BaseAgent

Move _run_before_providers from GitHubCopilotAgent into BaseAgent,
mirroring the existing _run_after_providers helper. Agent's
_prepare_session_and_messages now delegates to the shared base method,
eliminating the near-duplicate provider iteration logic that could
drift as the provider contract evolves.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Address review feedback for #3984: Python: [Bug]: GitHubCopilotAgent Memory Example

* revert: keep _run_before_providers in GitHubCopilotAgent only

Undo the promotion of _run_before_providers to BaseAgent. The method
stays in GitHubCopilotAgent where it is needed, and _agents.py
retains its original inline provider iteration in RawAgent.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* fix: replace deprecated BaseContextProvider/BaseHistoryProvider with ContextProvider/HistoryProvider

Update imports and usages in GitHubCopilotAgent and its tests to use
the new non-deprecated class names from the core package.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* fix: address review feedback - reorder providers before session, wrap streaming after_run in try/except, assert after_run on skipped HistoryProvider

- Move _run_before_providers before _get_or_create_session so provider
  contributions can affect session configuration
- Wrap _run_after_providers in try/except in streaming _after_run_hook
  to prevent provider errors from replacing successful responses
- Add after_run assertion to test_history_provider_skip_when_load_messages_false

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot <copilot@github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Giles Odigwe
2026-04-02 02:43:02 -07:00
committed by GitHub
Unverified
parent 62595b233f
commit 339e76d51f
2 changed files with 624 additions and 7 deletions
@@ -17,8 +17,10 @@ from agent_framework import (
BaseAgent,
Content,
ContextProvider,
HistoryProvider,
Message,
ResponseStream,
SessionContext,
normalize_messages,
)
from agent_framework._settings import load_settings
@@ -352,13 +354,25 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
AgentException: If the request fails.
"""
if stream:
ctx_holder: dict[str, Any] = {}
async def _after_run_hook(response: AgentResponse) -> None:
session_context = ctx_holder.get("session_context")
sess = ctx_holder.get("session")
if session_context is not None and sess is not None:
session_context._response = response
try:
await self._run_after_providers(session=sess, context=session_context)
except Exception:
logger.exception("Error running after_run providers in streaming result hook")
def _finalize(updates: Sequence[AgentResponseUpdate]) -> AgentResponse:
return AgentResponse.from_updates(updates)
return ResponseStream(
self._stream_updates(messages=messages, session=session, options=options),
self._stream_updates(messages=messages, session=session, options=options, _ctx_holder=ctx_holder),
finalizer=_finalize,
result_hooks=[_after_run_hook],
)
return self._run_impl(messages=messages, session=session, options=options)
@@ -377,11 +391,22 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
session = self.create_session()
opts: dict[str, Any] = dict(options) if options else {}
timeout = opts.pop("timeout", None) or self._settings.get("timeout") or DEFAULT_TIMEOUT_SECONDS
timeout = opts.get("timeout") or self._settings.get("timeout") or DEFAULT_TIMEOUT_SECONDS
copilot_session = await self._get_or_create_session(session, streaming=False, runtime_options=opts)
input_messages = normalize_messages(messages)
prompt = "\n".join([message.text for message in input_messages])
session_context = await self._run_before_providers(session=session, input_messages=input_messages, options=opts)
# NOTE: session is created after providers run so that future provider-contributed
# tools/config could be folded into runtime_options before session creation.
copilot_session = await self._get_or_create_session(session, streaming=False, runtime_options=opts)
# Build the prompt from the full set of messages in the session context,
# so that any context/history provider-injected messages are included.
context_messages = session_context.get_messages(include_input=True)
prompt = "\n".join([message.text for message in context_messages])
if session_context.instructions:
prompt = "\n".join(session_context.instructions) + "\n" + prompt
message_options = cast(MessageOptions, {"prompt": prompt})
try:
@@ -408,7 +433,10 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
)
response_id = message_id
return AgentResponse(messages=response_messages, response_id=response_id)
response = AgentResponse(messages=response_messages, response_id=response_id)
session_context._response = response # type: ignore[assignment]
await self._run_after_providers(session=session, context=session_context)
return response
async def _stream_updates(
self,
@@ -416,6 +444,7 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
*,
session: AgentSession | None = None,
options: OptionsT | None = None,
_ctx_holder: dict[str, Any] | None = None,
) -> AsyncIterable[AgentResponseUpdate]:
"""Internal method to stream updates from GitHub Copilot.
@@ -425,6 +454,9 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
Keyword Args:
session: The conversation session associated with the message(s).
options: Runtime options (model, timeout, etc.).
_ctx_holder: Internal dict populated with session_context and session
so that the caller (via a ResponseStream result_hook) can run
after_run providers without duplicating the updates buffer.
Yields:
AgentResponseUpdate items.
@@ -440,9 +472,23 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
opts: dict[str, Any] = dict(options) if options else {}
copilot_session = await self._get_or_create_session(session, streaming=True, runtime_options=opts)
input_messages = normalize_messages(messages)
prompt = "\n".join([message.text for message in input_messages])
session_context = await self._run_before_providers(session=session, input_messages=input_messages, options=opts)
# NOTE: session is created after providers run so that future provider-contributed
# tools/config could be folded into runtime_options before session creation.
copilot_session = await self._get_or_create_session(session, streaming=True, runtime_options=opts)
if _ctx_holder is not None:
_ctx_holder["session_context"] = session_context
_ctx_holder["session"] = session
# Build the prompt from the full session context so provider-injected messages are included.
context_messages = session_context.get_messages(include_input=True)
prompt = "\n".join([message.text for message in context_messages])
if session_context.instructions:
prompt = "\n".join(session_context.instructions) + "\n" + prompt
message_options = cast(MessageOptions, {"prompt": prompt})
queue: asyncio.Queue[AgentResponseUpdate | Exception | None] = asyncio.Queue()
@@ -513,6 +559,46 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
finally:
unsubscribe()
async def _run_before_providers(
self,
*,
session: AgentSession,
input_messages: list[Message],
options: dict[str, Any],
) -> SessionContext:
"""Run before_run on all context providers and return the session context.
Creates a SessionContext and invokes ``before_run`` on each provider in
forward order. ``HistoryProvider`` instances with
``load_messages=False`` are skipped.
Keyword Args:
session: The conversation session.
input_messages: The normalized input messages.
options: Runtime options dict.
Returns:
The SessionContext with provider context populated.
"""
session_context = SessionContext(
session_id=session.session_id,
service_session_id=session.service_session_id,
input_messages=input_messages,
options=options,
)
for provider in self.context_providers:
if isinstance(provider, HistoryProvider) and not provider.load_messages:
continue
await provider.before_run(
agent=self, # type: ignore[arg-type]
session=session,
context=session_context,
state=session.state.setdefault(provider.source_id, {}),
)
return session_context
@staticmethod
def _prepare_system_message(
instructions: str | None,
@@ -17,6 +17,8 @@ from agent_framework import (
AgentResponseUpdate,
AgentSession,
Content,
ContextProvider,
HistoryProvider,
Message,
)
from agent_framework.exceptions import AgentException
@@ -1367,3 +1369,532 @@ class TestGitHubCopilotAgentPermissions:
call_args = mock_client.create_session.call_args
config = call_args[0][0]
assert "on_permission_request" not in config
class SpyContextProvider(ContextProvider):
"""A context provider that records whether its hooks are called."""
def __init__(self) -> None:
super().__init__(source_id="spy-provider")
self.before_run_called = False
self.after_run_called = False
self.before_run_context: Any = None
self.after_run_context: Any = None
async def before_run(
self,
*,
agent: Any,
session: AgentSession,
context: Any,
state: dict[str, Any],
) -> None:
self.before_run_called = True
self.before_run_context = context
context.instructions.append("Injected by spy provider")
async def after_run(
self,
*,
agent: Any,
session: AgentSession,
context: Any,
state: dict[str, Any],
) -> None:
self.after_run_called = True
self.after_run_context = context
class TestGitHubCopilotAgentContextProviders:
"""Test cases for context provider integration."""
async def test_before_run_called_on_run(
self,
mock_client: MagicMock,
mock_session: MagicMock,
assistant_message_event: SessionEvent,
) -> None:
"""Test that before_run is called on context providers during run()."""
mock_session.send_and_wait.return_value = assistant_message_event
spy = SpyContextProvider()
agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy])
session = agent.create_session()
await agent.run("Hello", session=session)
assert spy.before_run_called
async def test_after_run_called_on_run(
self,
mock_client: MagicMock,
mock_session: MagicMock,
assistant_message_event: SessionEvent,
) -> None:
"""Test that after_run is called on context providers after run()."""
mock_session.send_and_wait.return_value = assistant_message_event
spy = SpyContextProvider()
agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy])
session = agent.create_session()
await agent.run("Hello", session=session)
assert spy.after_run_called
async def test_provider_instructions_included_in_prompt(
self,
mock_client: MagicMock,
mock_session: MagicMock,
assistant_message_event: SessionEvent,
) -> None:
"""Test that instructions added by context providers are included in the prompt."""
mock_session.send_and_wait.return_value = assistant_message_event
spy = SpyContextProvider()
agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy])
session = agent.create_session()
await agent.run("Hello", session=session)
sent_prompt = mock_session.send_and_wait.call_args[0][0]["prompt"]
assert "Injected by spy provider" in sent_prompt
async def test_after_run_receives_response(
self,
mock_client: MagicMock,
mock_session: MagicMock,
assistant_message_event: SessionEvent,
) -> None:
"""Test that after_run context contains the agent response."""
mock_session.send_and_wait.return_value = assistant_message_event
spy = SpyContextProvider()
agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy])
session = agent.create_session()
await agent.run("Hello", session=session)
assert spy.after_run_context is not None
assert spy.after_run_context.response is not None
async def test_before_run_called_on_streaming(
self,
mock_client: MagicMock,
mock_session: MagicMock,
assistant_delta_event: SessionEvent,
session_idle_event: SessionEvent,
) -> None:
"""Test that before_run is called on context providers during streaming."""
events = [assistant_delta_event, session_idle_event]
def mock_on(handler: Any) -> Any:
for event in events:
handler(event)
return lambda: None
mock_session.on = mock_on
spy = SpyContextProvider()
agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy])
session = agent.create_session()
async for _ in agent.run("Hello", stream=True, session=session):
pass
assert spy.before_run_called
async def test_after_run_called_on_streaming(
self,
mock_client: MagicMock,
mock_session: MagicMock,
assistant_delta_event: SessionEvent,
session_idle_event: SessionEvent,
) -> None:
"""Test that after_run is called on context providers after streaming."""
events = [assistant_delta_event, session_idle_event]
def mock_on(handler: Any) -> Any:
for event in events:
handler(event)
return lambda: None
mock_session.on = mock_on
spy = SpyContextProvider()
agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy])
session = agent.create_session()
async for _ in agent.run("Hello", stream=True, session=session):
pass
assert spy.after_run_called
async def test_provider_instructions_included_in_streaming_prompt(
self,
mock_client: MagicMock,
mock_session: MagicMock,
assistant_delta_event: SessionEvent,
session_idle_event: SessionEvent,
) -> None:
"""Test that instructions from context providers are included in the streaming prompt."""
events = [assistant_delta_event, session_idle_event]
def mock_on(handler: Any) -> Any:
for event in events:
handler(event)
return lambda: None
mock_session.on = mock_on
spy = SpyContextProvider()
agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy])
session = agent.create_session()
async for _ in agent.run("Hello", stream=True, session=session):
pass
sent_prompt = mock_session.send.call_args[0][0]["prompt"]
assert "Injected by spy provider" in sent_prompt
async def test_context_preserved_across_runs(
self,
mock_client: MagicMock,
mock_session: MagicMock,
assistant_message_event: SessionEvent,
) -> None:
"""Test that provider state is preserved across multiple runs with the same session."""
mock_session.send_and_wait.return_value = assistant_message_event
spy = SpyContextProvider()
agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy])
session = agent.create_session()
await agent.run("Hello", session=session)
assert spy.before_run_called
spy.before_run_called = False
await agent.run("Hello again", session=session)
assert spy.before_run_called
async def test_context_messages_included_in_prompt(
self,
mock_client: MagicMock,
mock_session: MagicMock,
assistant_message_event: SessionEvent,
) -> None:
"""Test that context messages added by providers via extend_messages are included in the prompt."""
mock_session.send_and_wait.return_value = assistant_message_event
class MessageInjectingProvider(ContextProvider):
def __init__(self) -> None:
super().__init__(source_id="msg-injector")
async def before_run(
self,
*,
agent: Any,
session: AgentSession,
context: Any,
state: dict[str, Any],
) -> None:
context.extend_messages(self, [Message(role="user", contents=[Content.from_text("History message")])])
async def after_run(
self,
*,
agent: Any,
session: AgentSession,
context: Any,
state: dict[str, Any],
) -> None:
pass
provider = MessageInjectingProvider()
agent = GitHubCopilotAgent(client=mock_client, context_providers=[provider])
session = agent.create_session()
await agent.run("Hello", session=session)
sent_prompt = mock_session.send_and_wait.call_args[0][0]["prompt"]
assert "History message" in sent_prompt
assert "Hello" in sent_prompt
async def test_context_messages_included_in_streaming_prompt(
self,
mock_client: MagicMock,
mock_session: MagicMock,
assistant_delta_event: SessionEvent,
session_idle_event: SessionEvent,
) -> None:
"""Test that context messages added by providers are included in the streaming prompt."""
events = [assistant_delta_event, session_idle_event]
def mock_on(handler: Any) -> Any:
for event in events:
handler(event)
return lambda: None
mock_session.on = mock_on
class MessageInjectingProvider(ContextProvider):
def __init__(self) -> None:
super().__init__(source_id="msg-injector")
async def before_run(
self,
*,
agent: Any,
session: AgentSession,
context: Any,
state: dict[str, Any],
) -> None:
context.extend_messages(self, [Message(role="user", contents=[Content.from_text("History message")])])
async def after_run(
self,
*,
agent: Any,
session: AgentSession,
context: Any,
state: dict[str, Any],
) -> None:
pass
provider = MessageInjectingProvider()
agent = GitHubCopilotAgent(client=mock_client, context_providers=[provider])
session = agent.create_session()
async for _ in agent.run("Hello", stream=True, session=session):
pass
sent_prompt = mock_session.send.call_args[0][0]["prompt"]
assert "History message" in sent_prompt
assert "Hello" in sent_prompt
async def test_after_run_not_called_on_error(
self,
mock_client: MagicMock,
mock_session: MagicMock,
) -> None:
"""Test that after_run is NOT called when send_and_wait raises."""
mock_session.send_and_wait.side_effect = Exception("Request failed")
spy = SpyContextProvider()
agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy])
session = agent.create_session()
with pytest.raises(AgentException):
await agent.run("Hello", session=session)
assert spy.before_run_called
assert not spy.after_run_called
async def test_after_run_not_called_on_streaming_error(
self,
mock_client: MagicMock,
mock_session: MagicMock,
session_error_event: SessionEvent,
) -> None:
"""Test that after_run is NOT called when streaming encounters an error."""
events = [session_error_event]
def mock_on(handler: Any) -> Any:
for event in events:
handler(event)
return lambda: None
mock_session.on = mock_on
spy = SpyContextProvider()
agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy])
session = agent.create_session()
with pytest.raises(AgentException):
async for _ in agent.run("Hello", stream=True, session=session):
pass
assert spy.before_run_called
assert not spy.after_run_called
async def test_multiple_providers_ordering(
self,
mock_client: MagicMock,
mock_session: MagicMock,
assistant_message_event: SessionEvent,
) -> None:
"""Test that before_run is called in forward order and after_run in reverse order."""
mock_session.send_and_wait.return_value = assistant_message_event
call_order: list[str] = []
class OrderedProvider(ContextProvider):
def __init__(self, name: str) -> None:
super().__init__(source_id=name)
self.name = name
async def before_run(
self,
*,
agent: Any,
session: AgentSession,
context: Any,
state: dict[str, Any],
) -> None:
call_order.append(f"before:{self.name}")
async def after_run(
self,
*,
agent: Any,
session: AgentSession,
context: Any,
state: dict[str, Any],
) -> None:
call_order.append(f"after:{self.name}")
providers = [OrderedProvider("A"), OrderedProvider("B"), OrderedProvider("C")]
agent = GitHubCopilotAgent(client=mock_client, context_providers=providers)
session = agent.create_session()
await agent.run("Hello", session=session)
assert call_order == ["before:A", "before:B", "before:C", "after:C", "after:B", "after:A"]
async def test_history_provider_skip_when_load_messages_false(
self,
mock_client: MagicMock,
mock_session: MagicMock,
assistant_message_event: SessionEvent,
) -> None:
"""Test that HistoryProvider with load_messages=False is skipped in before_run."""
mock_session.send_and_wait.return_value = assistant_message_event
class StubHistoryProvider(HistoryProvider):
def __init__(self, *, load_messages: bool = True) -> None:
super().__init__(source_id="stub-history", load_messages=load_messages)
self.before_run_called = False
async def before_run(
self,
*,
agent: Any,
session: AgentSession,
context: Any,
state: dict[str, Any],
) -> None:
self.before_run_called = True
async def after_run(
self,
*,
agent: Any,
session: AgentSession,
context: Any,
state: dict[str, Any],
) -> None:
self.after_run_called = True
async def get_messages(self, *, session_id: str, **kwargs: Any) -> list[Message]:
return []
async def save_messages(self, *, session_id: str, messages: list[Message], **kwargs: Any) -> None:
pass
skipped_provider = StubHistoryProvider(load_messages=False)
active_provider = StubHistoryProvider(load_messages=True)
# Use unique source_ids
skipped_provider._source_id = "skipped-history"
active_provider._source_id = "active-history"
agent = GitHubCopilotAgent(client=mock_client, context_providers=[skipped_provider, active_provider])
session = agent.create_session()
await agent.run("Hello", session=session)
assert not skipped_provider.before_run_called
assert active_provider.before_run_called
# after_run should still be called even when load_messages=False
assert skipped_provider.after_run_called
assert active_provider.after_run_called
async def test_streaming_after_run_response_has_updates(
self,
mock_client: MagicMock,
mock_session: MagicMock,
assistant_delta_event: SessionEvent,
session_idle_event: SessionEvent,
) -> None:
"""Test that streaming after_run context.response contains the aggregated updates."""
events = [assistant_delta_event, session_idle_event]
def mock_on(handler: Any) -> Any:
for event in events:
handler(event)
return lambda: None
mock_session.on = mock_on
spy = SpyContextProvider()
agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy])
session = agent.create_session()
async for _ in agent.run("Hello", stream=True, session=session):
pass
assert spy.after_run_context is not None
assert spy.after_run_context.response is not None
assert len(spy.after_run_context.response.messages) > 0
assert spy.after_run_context.response.messages[0].text == "Hello"
async def test_streaming_after_run_sets_empty_response_on_no_updates(
self,
mock_client: MagicMock,
mock_session: MagicMock,
session_idle_event: SessionEvent,
) -> None:
"""Test that streaming after_run sets an empty response when no updates are yielded."""
events = [session_idle_event]
def mock_on(handler: Any) -> Any:
for event in events:
handler(event)
return lambda: None
mock_session.on = mock_on
spy = SpyContextProvider()
agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy])
session = agent.create_session()
async for _ in agent.run("Hello", stream=True, session=session):
pass
assert spy.after_run_called
assert spy.after_run_context.response is not None
assert len(spy.after_run_context.response.messages) == 0
async def test_timeout_preserved_in_session_context_options(
self,
mock_client: MagicMock,
mock_session: MagicMock,
assistant_message_event: SessionEvent,
) -> None:
"""Test that timeout is preserved in session context options for providers."""
mock_session.send_and_wait.return_value = assistant_message_event
observed_options: dict[str, Any] = {}
class OptionsObserverProvider(ContextProvider):
def __init__(self) -> None:
super().__init__(source_id="options-observer")
async def before_run(
self,
*,
agent: Any,
session: AgentSession,
context: Any,
state: dict[str, Any],
) -> None:
observed_options.update(context.options)
async def after_run(
self,
*,
agent: Any,
session: AgentSession,
context: Any,
state: dict[str, Any],
) -> None:
pass
provider = OptionsObserverProvider()
agent = GitHubCopilotAgent(client=mock_client, context_providers=[provider])
session = agent.create_session()
await agent.run("Hello", session=session, options={"timeout": 120})
assert observed_options.get("timeout") == 120