mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Fix GitHubCopilotAgent to invoke context provider before_run/after_run hooks (#5013)
* Fix GitHubCopilotAgent not calling context provider hooks (#3984) GitHubCopilotAgent accepted context_providers in its constructor but never called before_run()/after_run() on them in _run_impl() or _stream_updates(), silently ignoring all context providers. Add _run_before_providers() helper to create SessionContext and invoke before_run on each provider. Both _run_impl() and _stream_updates() now run the full provider lifecycle: before_run before sending the prompt (with provider instructions prepended) and after_run after receiving the response. This follows the same pattern used by A2AAgent. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Python: Fix GitHubCopilotAgent to invoke context provider before_run/after_run hooks Fixes #3984 * fix(#3984): address review feedback for context provider integration - Build prompt from session_context.get_messages(include_input=True) so provider-injected context_messages are included in both non-streaming and streaming paths (review comments #1, #2) - Preserve timeout in opts (use get instead of pop) so providers can observe it via context.options (review comment #3) - Eliminate streaming double-buffer: move after_run invocation to a ResponseStream result_hook (matching Agent class pattern) instead of maintaining a separate updates list in the generator (review comment #4) - Improve _run_before_providers docstring Add tests for: - Context messages included in prompt (non-streaming + streaming) - Error path: after_run NOT called when send_and_wait/streaming raises - Multiple providers: forward before_run, reverse after_run ordering - BaseHistoryProvider with load_messages=False is skipped - Streaming after_run response contains aggregated updates - Streaming with no updates still sets empty response - Timeout preserved in session context options for providers Note: _run_before_providers remains on GitHubCopilotAgent for now. A follow-up PR should extract it to BaseAgent so subclasses can reuse it without duplicating the provider iteration logic. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Address review feedback for #3984: Python: [Bug]: GitHubCopilotAgent Memory Example * refactor(#3984): promote _run_before_providers to BaseAgent Move _run_before_providers from GitHubCopilotAgent into BaseAgent, mirroring the existing _run_after_providers helper. Agent's _prepare_session_and_messages now delegates to the shared base method, eliminating the near-duplicate provider iteration logic that could drift as the provider contract evolves. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Address review feedback for #3984: Python: [Bug]: GitHubCopilotAgent Memory Example * revert: keep _run_before_providers in GitHubCopilotAgent only Undo the promotion of _run_before_providers to BaseAgent. The method stays in GitHubCopilotAgent where it is needed, and _agents.py retains its original inline provider iteration in RawAgent. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix: replace deprecated BaseContextProvider/BaseHistoryProvider with ContextProvider/HistoryProvider Update imports and usages in GitHubCopilotAgent and its tests to use the new non-deprecated class names from the core package. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix: address review feedback - reorder providers before session, wrap streaming after_run in try/except, assert after_run on skipped HistoryProvider - Move _run_before_providers before _get_or_create_session so provider contributions can affect session configuration - Wrap _run_after_providers in try/except in streaming _after_run_hook to prevent provider errors from replacing successful responses - Add after_run assertion to test_history_provider_skip_when_load_messages_false Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <copilot@github.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
62595b233f
commit
339e76d51f
@@ -17,8 +17,10 @@ from agent_framework import (
|
||||
BaseAgent,
|
||||
Content,
|
||||
ContextProvider,
|
||||
HistoryProvider,
|
||||
Message,
|
||||
ResponseStream,
|
||||
SessionContext,
|
||||
normalize_messages,
|
||||
)
|
||||
from agent_framework._settings import load_settings
|
||||
@@ -352,13 +354,25 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
AgentException: If the request fails.
|
||||
"""
|
||||
if stream:
|
||||
ctx_holder: dict[str, Any] = {}
|
||||
|
||||
async def _after_run_hook(response: AgentResponse) -> None:
|
||||
session_context = ctx_holder.get("session_context")
|
||||
sess = ctx_holder.get("session")
|
||||
if session_context is not None and sess is not None:
|
||||
session_context._response = response
|
||||
try:
|
||||
await self._run_after_providers(session=sess, context=session_context)
|
||||
except Exception:
|
||||
logger.exception("Error running after_run providers in streaming result hook")
|
||||
|
||||
def _finalize(updates: Sequence[AgentResponseUpdate]) -> AgentResponse:
|
||||
return AgentResponse.from_updates(updates)
|
||||
|
||||
return ResponseStream(
|
||||
self._stream_updates(messages=messages, session=session, options=options),
|
||||
self._stream_updates(messages=messages, session=session, options=options, _ctx_holder=ctx_holder),
|
||||
finalizer=_finalize,
|
||||
result_hooks=[_after_run_hook],
|
||||
)
|
||||
return self._run_impl(messages=messages, session=session, options=options)
|
||||
|
||||
@@ -377,11 +391,22 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
session = self.create_session()
|
||||
|
||||
opts: dict[str, Any] = dict(options) if options else {}
|
||||
timeout = opts.pop("timeout", None) or self._settings.get("timeout") or DEFAULT_TIMEOUT_SECONDS
|
||||
timeout = opts.get("timeout") or self._settings.get("timeout") or DEFAULT_TIMEOUT_SECONDS
|
||||
|
||||
copilot_session = await self._get_or_create_session(session, streaming=False, runtime_options=opts)
|
||||
input_messages = normalize_messages(messages)
|
||||
prompt = "\n".join([message.text for message in input_messages])
|
||||
|
||||
session_context = await self._run_before_providers(session=session, input_messages=input_messages, options=opts)
|
||||
|
||||
# NOTE: session is created after providers run so that future provider-contributed
|
||||
# tools/config could be folded into runtime_options before session creation.
|
||||
copilot_session = await self._get_or_create_session(session, streaming=False, runtime_options=opts)
|
||||
|
||||
# Build the prompt from the full set of messages in the session context,
|
||||
# so that any context/history provider-injected messages are included.
|
||||
context_messages = session_context.get_messages(include_input=True)
|
||||
prompt = "\n".join([message.text for message in context_messages])
|
||||
if session_context.instructions:
|
||||
prompt = "\n".join(session_context.instructions) + "\n" + prompt
|
||||
message_options = cast(MessageOptions, {"prompt": prompt})
|
||||
|
||||
try:
|
||||
@@ -408,7 +433,10 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
)
|
||||
response_id = message_id
|
||||
|
||||
return AgentResponse(messages=response_messages, response_id=response_id)
|
||||
response = AgentResponse(messages=response_messages, response_id=response_id)
|
||||
session_context._response = response # type: ignore[assignment]
|
||||
await self._run_after_providers(session=session, context=session_context)
|
||||
return response
|
||||
|
||||
async def _stream_updates(
|
||||
self,
|
||||
@@ -416,6 +444,7 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
*,
|
||||
session: AgentSession | None = None,
|
||||
options: OptionsT | None = None,
|
||||
_ctx_holder: dict[str, Any] | None = None,
|
||||
) -> AsyncIterable[AgentResponseUpdate]:
|
||||
"""Internal method to stream updates from GitHub Copilot.
|
||||
|
||||
@@ -425,6 +454,9 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
Keyword Args:
|
||||
session: The conversation session associated with the message(s).
|
||||
options: Runtime options (model, timeout, etc.).
|
||||
_ctx_holder: Internal dict populated with session_context and session
|
||||
so that the caller (via a ResponseStream result_hook) can run
|
||||
after_run providers without duplicating the updates buffer.
|
||||
|
||||
Yields:
|
||||
AgentResponseUpdate items.
|
||||
@@ -440,9 +472,23 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
|
||||
opts: dict[str, Any] = dict(options) if options else {}
|
||||
|
||||
copilot_session = await self._get_or_create_session(session, streaming=True, runtime_options=opts)
|
||||
input_messages = normalize_messages(messages)
|
||||
prompt = "\n".join([message.text for message in input_messages])
|
||||
|
||||
session_context = await self._run_before_providers(session=session, input_messages=input_messages, options=opts)
|
||||
|
||||
# NOTE: session is created after providers run so that future provider-contributed
|
||||
# tools/config could be folded into runtime_options before session creation.
|
||||
copilot_session = await self._get_or_create_session(session, streaming=True, runtime_options=opts)
|
||||
|
||||
if _ctx_holder is not None:
|
||||
_ctx_holder["session_context"] = session_context
|
||||
_ctx_holder["session"] = session
|
||||
|
||||
# Build the prompt from the full session context so provider-injected messages are included.
|
||||
context_messages = session_context.get_messages(include_input=True)
|
||||
prompt = "\n".join([message.text for message in context_messages])
|
||||
if session_context.instructions:
|
||||
prompt = "\n".join(session_context.instructions) + "\n" + prompt
|
||||
message_options = cast(MessageOptions, {"prompt": prompt})
|
||||
|
||||
queue: asyncio.Queue[AgentResponseUpdate | Exception | None] = asyncio.Queue()
|
||||
@@ -513,6 +559,46 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
finally:
|
||||
unsubscribe()
|
||||
|
||||
async def _run_before_providers(
|
||||
self,
|
||||
*,
|
||||
session: AgentSession,
|
||||
input_messages: list[Message],
|
||||
options: dict[str, Any],
|
||||
) -> SessionContext:
|
||||
"""Run before_run on all context providers and return the session context.
|
||||
|
||||
Creates a SessionContext and invokes ``before_run`` on each provider in
|
||||
forward order. ``HistoryProvider`` instances with
|
||||
``load_messages=False`` are skipped.
|
||||
|
||||
Keyword Args:
|
||||
session: The conversation session.
|
||||
input_messages: The normalized input messages.
|
||||
options: Runtime options dict.
|
||||
|
||||
Returns:
|
||||
The SessionContext with provider context populated.
|
||||
"""
|
||||
session_context = SessionContext(
|
||||
session_id=session.session_id,
|
||||
service_session_id=session.service_session_id,
|
||||
input_messages=input_messages,
|
||||
options=options,
|
||||
)
|
||||
|
||||
for provider in self.context_providers:
|
||||
if isinstance(provider, HistoryProvider) and not provider.load_messages:
|
||||
continue
|
||||
await provider.before_run(
|
||||
agent=self, # type: ignore[arg-type]
|
||||
session=session,
|
||||
context=session_context,
|
||||
state=session.state.setdefault(provider.source_id, {}),
|
||||
)
|
||||
|
||||
return session_context
|
||||
|
||||
@staticmethod
|
||||
def _prepare_system_message(
|
||||
instructions: str | None,
|
||||
|
||||
@@ -17,6 +17,8 @@ from agent_framework import (
|
||||
AgentResponseUpdate,
|
||||
AgentSession,
|
||||
Content,
|
||||
ContextProvider,
|
||||
HistoryProvider,
|
||||
Message,
|
||||
)
|
||||
from agent_framework.exceptions import AgentException
|
||||
@@ -1367,3 +1369,532 @@ class TestGitHubCopilotAgentPermissions:
|
||||
call_args = mock_client.create_session.call_args
|
||||
config = call_args[0][0]
|
||||
assert "on_permission_request" not in config
|
||||
|
||||
|
||||
class SpyContextProvider(ContextProvider):
|
||||
"""A context provider that records whether its hooks are called."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(source_id="spy-provider")
|
||||
self.before_run_called = False
|
||||
self.after_run_called = False
|
||||
self.before_run_context: Any = None
|
||||
self.after_run_context: Any = None
|
||||
|
||||
async def before_run(
|
||||
self,
|
||||
*,
|
||||
agent: Any,
|
||||
session: AgentSession,
|
||||
context: Any,
|
||||
state: dict[str, Any],
|
||||
) -> None:
|
||||
self.before_run_called = True
|
||||
self.before_run_context = context
|
||||
context.instructions.append("Injected by spy provider")
|
||||
|
||||
async def after_run(
|
||||
self,
|
||||
*,
|
||||
agent: Any,
|
||||
session: AgentSession,
|
||||
context: Any,
|
||||
state: dict[str, Any],
|
||||
) -> None:
|
||||
self.after_run_called = True
|
||||
self.after_run_context = context
|
||||
|
||||
|
||||
class TestGitHubCopilotAgentContextProviders:
|
||||
"""Test cases for context provider integration."""
|
||||
|
||||
async def test_before_run_called_on_run(
|
||||
self,
|
||||
mock_client: MagicMock,
|
||||
mock_session: MagicMock,
|
||||
assistant_message_event: SessionEvent,
|
||||
) -> None:
|
||||
"""Test that before_run is called on context providers during run()."""
|
||||
mock_session.send_and_wait.return_value = assistant_message_event
|
||||
spy = SpyContextProvider()
|
||||
|
||||
agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy])
|
||||
session = agent.create_session()
|
||||
await agent.run("Hello", session=session)
|
||||
|
||||
assert spy.before_run_called
|
||||
|
||||
async def test_after_run_called_on_run(
|
||||
self,
|
||||
mock_client: MagicMock,
|
||||
mock_session: MagicMock,
|
||||
assistant_message_event: SessionEvent,
|
||||
) -> None:
|
||||
"""Test that after_run is called on context providers after run()."""
|
||||
mock_session.send_and_wait.return_value = assistant_message_event
|
||||
spy = SpyContextProvider()
|
||||
|
||||
agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy])
|
||||
session = agent.create_session()
|
||||
await agent.run("Hello", session=session)
|
||||
|
||||
assert spy.after_run_called
|
||||
|
||||
async def test_provider_instructions_included_in_prompt(
|
||||
self,
|
||||
mock_client: MagicMock,
|
||||
mock_session: MagicMock,
|
||||
assistant_message_event: SessionEvent,
|
||||
) -> None:
|
||||
"""Test that instructions added by context providers are included in the prompt."""
|
||||
mock_session.send_and_wait.return_value = assistant_message_event
|
||||
spy = SpyContextProvider()
|
||||
|
||||
agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy])
|
||||
session = agent.create_session()
|
||||
await agent.run("Hello", session=session)
|
||||
|
||||
sent_prompt = mock_session.send_and_wait.call_args[0][0]["prompt"]
|
||||
assert "Injected by spy provider" in sent_prompt
|
||||
|
||||
async def test_after_run_receives_response(
|
||||
self,
|
||||
mock_client: MagicMock,
|
||||
mock_session: MagicMock,
|
||||
assistant_message_event: SessionEvent,
|
||||
) -> None:
|
||||
"""Test that after_run context contains the agent response."""
|
||||
mock_session.send_and_wait.return_value = assistant_message_event
|
||||
spy = SpyContextProvider()
|
||||
|
||||
agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy])
|
||||
session = agent.create_session()
|
||||
await agent.run("Hello", session=session)
|
||||
|
||||
assert spy.after_run_context is not None
|
||||
assert spy.after_run_context.response is not None
|
||||
|
||||
async def test_before_run_called_on_streaming(
|
||||
self,
|
||||
mock_client: MagicMock,
|
||||
mock_session: MagicMock,
|
||||
assistant_delta_event: SessionEvent,
|
||||
session_idle_event: SessionEvent,
|
||||
) -> None:
|
||||
"""Test that before_run is called on context providers during streaming."""
|
||||
events = [assistant_delta_event, session_idle_event]
|
||||
|
||||
def mock_on(handler: Any) -> Any:
|
||||
for event in events:
|
||||
handler(event)
|
||||
return lambda: None
|
||||
|
||||
mock_session.on = mock_on
|
||||
spy = SpyContextProvider()
|
||||
|
||||
agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy])
|
||||
session = agent.create_session()
|
||||
async for _ in agent.run("Hello", stream=True, session=session):
|
||||
pass
|
||||
|
||||
assert spy.before_run_called
|
||||
|
||||
async def test_after_run_called_on_streaming(
|
||||
self,
|
||||
mock_client: MagicMock,
|
||||
mock_session: MagicMock,
|
||||
assistant_delta_event: SessionEvent,
|
||||
session_idle_event: SessionEvent,
|
||||
) -> None:
|
||||
"""Test that after_run is called on context providers after streaming."""
|
||||
events = [assistant_delta_event, session_idle_event]
|
||||
|
||||
def mock_on(handler: Any) -> Any:
|
||||
for event in events:
|
||||
handler(event)
|
||||
return lambda: None
|
||||
|
||||
mock_session.on = mock_on
|
||||
spy = SpyContextProvider()
|
||||
|
||||
agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy])
|
||||
session = agent.create_session()
|
||||
async for _ in agent.run("Hello", stream=True, session=session):
|
||||
pass
|
||||
|
||||
assert spy.after_run_called
|
||||
|
||||
async def test_provider_instructions_included_in_streaming_prompt(
|
||||
self,
|
||||
mock_client: MagicMock,
|
||||
mock_session: MagicMock,
|
||||
assistant_delta_event: SessionEvent,
|
||||
session_idle_event: SessionEvent,
|
||||
) -> None:
|
||||
"""Test that instructions from context providers are included in the streaming prompt."""
|
||||
events = [assistant_delta_event, session_idle_event]
|
||||
|
||||
def mock_on(handler: Any) -> Any:
|
||||
for event in events:
|
||||
handler(event)
|
||||
return lambda: None
|
||||
|
||||
mock_session.on = mock_on
|
||||
spy = SpyContextProvider()
|
||||
|
||||
agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy])
|
||||
session = agent.create_session()
|
||||
async for _ in agent.run("Hello", stream=True, session=session):
|
||||
pass
|
||||
|
||||
sent_prompt = mock_session.send.call_args[0][0]["prompt"]
|
||||
assert "Injected by spy provider" in sent_prompt
|
||||
|
||||
async def test_context_preserved_across_runs(
|
||||
self,
|
||||
mock_client: MagicMock,
|
||||
mock_session: MagicMock,
|
||||
assistant_message_event: SessionEvent,
|
||||
) -> None:
|
||||
"""Test that provider state is preserved across multiple runs with the same session."""
|
||||
mock_session.send_and_wait.return_value = assistant_message_event
|
||||
spy = SpyContextProvider()
|
||||
|
||||
agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy])
|
||||
session = agent.create_session()
|
||||
|
||||
await agent.run("Hello", session=session)
|
||||
assert spy.before_run_called
|
||||
|
||||
spy.before_run_called = False
|
||||
await agent.run("Hello again", session=session)
|
||||
assert spy.before_run_called
|
||||
|
||||
async def test_context_messages_included_in_prompt(
|
||||
self,
|
||||
mock_client: MagicMock,
|
||||
mock_session: MagicMock,
|
||||
assistant_message_event: SessionEvent,
|
||||
) -> None:
|
||||
"""Test that context messages added by providers via extend_messages are included in the prompt."""
|
||||
mock_session.send_and_wait.return_value = assistant_message_event
|
||||
|
||||
class MessageInjectingProvider(ContextProvider):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(source_id="msg-injector")
|
||||
|
||||
async def before_run(
|
||||
self,
|
||||
*,
|
||||
agent: Any,
|
||||
session: AgentSession,
|
||||
context: Any,
|
||||
state: dict[str, Any],
|
||||
) -> None:
|
||||
context.extend_messages(self, [Message(role="user", contents=[Content.from_text("History message")])])
|
||||
|
||||
async def after_run(
|
||||
self,
|
||||
*,
|
||||
agent: Any,
|
||||
session: AgentSession,
|
||||
context: Any,
|
||||
state: dict[str, Any],
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
provider = MessageInjectingProvider()
|
||||
agent = GitHubCopilotAgent(client=mock_client, context_providers=[provider])
|
||||
session = agent.create_session()
|
||||
await agent.run("Hello", session=session)
|
||||
|
||||
sent_prompt = mock_session.send_and_wait.call_args[0][0]["prompt"]
|
||||
assert "History message" in sent_prompt
|
||||
assert "Hello" in sent_prompt
|
||||
|
||||
async def test_context_messages_included_in_streaming_prompt(
|
||||
self,
|
||||
mock_client: MagicMock,
|
||||
mock_session: MagicMock,
|
||||
assistant_delta_event: SessionEvent,
|
||||
session_idle_event: SessionEvent,
|
||||
) -> None:
|
||||
"""Test that context messages added by providers are included in the streaming prompt."""
|
||||
events = [assistant_delta_event, session_idle_event]
|
||||
|
||||
def mock_on(handler: Any) -> Any:
|
||||
for event in events:
|
||||
handler(event)
|
||||
return lambda: None
|
||||
|
||||
mock_session.on = mock_on
|
||||
|
||||
class MessageInjectingProvider(ContextProvider):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(source_id="msg-injector")
|
||||
|
||||
async def before_run(
|
||||
self,
|
||||
*,
|
||||
agent: Any,
|
||||
session: AgentSession,
|
||||
context: Any,
|
||||
state: dict[str, Any],
|
||||
) -> None:
|
||||
context.extend_messages(self, [Message(role="user", contents=[Content.from_text("History message")])])
|
||||
|
||||
async def after_run(
|
||||
self,
|
||||
*,
|
||||
agent: Any,
|
||||
session: AgentSession,
|
||||
context: Any,
|
||||
state: dict[str, Any],
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
provider = MessageInjectingProvider()
|
||||
agent = GitHubCopilotAgent(client=mock_client, context_providers=[provider])
|
||||
session = agent.create_session()
|
||||
async for _ in agent.run("Hello", stream=True, session=session):
|
||||
pass
|
||||
|
||||
sent_prompt = mock_session.send.call_args[0][0]["prompt"]
|
||||
assert "History message" in sent_prompt
|
||||
assert "Hello" in sent_prompt
|
||||
|
||||
async def test_after_run_not_called_on_error(
|
||||
self,
|
||||
mock_client: MagicMock,
|
||||
mock_session: MagicMock,
|
||||
) -> None:
|
||||
"""Test that after_run is NOT called when send_and_wait raises."""
|
||||
mock_session.send_and_wait.side_effect = Exception("Request failed")
|
||||
spy = SpyContextProvider()
|
||||
|
||||
agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy])
|
||||
session = agent.create_session()
|
||||
with pytest.raises(AgentException):
|
||||
await agent.run("Hello", session=session)
|
||||
|
||||
assert spy.before_run_called
|
||||
assert not spy.after_run_called
|
||||
|
||||
async def test_after_run_not_called_on_streaming_error(
|
||||
self,
|
||||
mock_client: MagicMock,
|
||||
mock_session: MagicMock,
|
||||
session_error_event: SessionEvent,
|
||||
) -> None:
|
||||
"""Test that after_run is NOT called when streaming encounters an error."""
|
||||
events = [session_error_event]
|
||||
|
||||
def mock_on(handler: Any) -> Any:
|
||||
for event in events:
|
||||
handler(event)
|
||||
return lambda: None
|
||||
|
||||
mock_session.on = mock_on
|
||||
spy = SpyContextProvider()
|
||||
|
||||
agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy])
|
||||
session = agent.create_session()
|
||||
with pytest.raises(AgentException):
|
||||
async for _ in agent.run("Hello", stream=True, session=session):
|
||||
pass
|
||||
|
||||
assert spy.before_run_called
|
||||
assert not spy.after_run_called
|
||||
|
||||
async def test_multiple_providers_ordering(
|
||||
self,
|
||||
mock_client: MagicMock,
|
||||
mock_session: MagicMock,
|
||||
assistant_message_event: SessionEvent,
|
||||
) -> None:
|
||||
"""Test that before_run is called in forward order and after_run in reverse order."""
|
||||
mock_session.send_and_wait.return_value = assistant_message_event
|
||||
call_order: list[str] = []
|
||||
|
||||
class OrderedProvider(ContextProvider):
|
||||
def __init__(self, name: str) -> None:
|
||||
super().__init__(source_id=name)
|
||||
self.name = name
|
||||
|
||||
async def before_run(
|
||||
self,
|
||||
*,
|
||||
agent: Any,
|
||||
session: AgentSession,
|
||||
context: Any,
|
||||
state: dict[str, Any],
|
||||
) -> None:
|
||||
call_order.append(f"before:{self.name}")
|
||||
|
||||
async def after_run(
|
||||
self,
|
||||
*,
|
||||
agent: Any,
|
||||
session: AgentSession,
|
||||
context: Any,
|
||||
state: dict[str, Any],
|
||||
) -> None:
|
||||
call_order.append(f"after:{self.name}")
|
||||
|
||||
providers = [OrderedProvider("A"), OrderedProvider("B"), OrderedProvider("C")]
|
||||
agent = GitHubCopilotAgent(client=mock_client, context_providers=providers)
|
||||
session = agent.create_session()
|
||||
await agent.run("Hello", session=session)
|
||||
|
||||
assert call_order == ["before:A", "before:B", "before:C", "after:C", "after:B", "after:A"]
|
||||
|
||||
async def test_history_provider_skip_when_load_messages_false(
|
||||
self,
|
||||
mock_client: MagicMock,
|
||||
mock_session: MagicMock,
|
||||
assistant_message_event: SessionEvent,
|
||||
) -> None:
|
||||
"""Test that HistoryProvider with load_messages=False is skipped in before_run."""
|
||||
mock_session.send_and_wait.return_value = assistant_message_event
|
||||
|
||||
class StubHistoryProvider(HistoryProvider):
|
||||
def __init__(self, *, load_messages: bool = True) -> None:
|
||||
super().__init__(source_id="stub-history", load_messages=load_messages)
|
||||
self.before_run_called = False
|
||||
|
||||
async def before_run(
|
||||
self,
|
||||
*,
|
||||
agent: Any,
|
||||
session: AgentSession,
|
||||
context: Any,
|
||||
state: dict[str, Any],
|
||||
) -> None:
|
||||
self.before_run_called = True
|
||||
|
||||
async def after_run(
|
||||
self,
|
||||
*,
|
||||
agent: Any,
|
||||
session: AgentSession,
|
||||
context: Any,
|
||||
state: dict[str, Any],
|
||||
) -> None:
|
||||
self.after_run_called = True
|
||||
|
||||
async def get_messages(self, *, session_id: str, **kwargs: Any) -> list[Message]:
|
||||
return []
|
||||
|
||||
async def save_messages(self, *, session_id: str, messages: list[Message], **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
skipped_provider = StubHistoryProvider(load_messages=False)
|
||||
active_provider = StubHistoryProvider(load_messages=True)
|
||||
# Use unique source_ids
|
||||
skipped_provider._source_id = "skipped-history"
|
||||
active_provider._source_id = "active-history"
|
||||
|
||||
agent = GitHubCopilotAgent(client=mock_client, context_providers=[skipped_provider, active_provider])
|
||||
session = agent.create_session()
|
||||
await agent.run("Hello", session=session)
|
||||
|
||||
assert not skipped_provider.before_run_called
|
||||
assert active_provider.before_run_called
|
||||
# after_run should still be called even when load_messages=False
|
||||
assert skipped_provider.after_run_called
|
||||
assert active_provider.after_run_called
|
||||
|
||||
async def test_streaming_after_run_response_has_updates(
|
||||
self,
|
||||
mock_client: MagicMock,
|
||||
mock_session: MagicMock,
|
||||
assistant_delta_event: SessionEvent,
|
||||
session_idle_event: SessionEvent,
|
||||
) -> None:
|
||||
"""Test that streaming after_run context.response contains the aggregated updates."""
|
||||
events = [assistant_delta_event, session_idle_event]
|
||||
|
||||
def mock_on(handler: Any) -> Any:
|
||||
for event in events:
|
||||
handler(event)
|
||||
return lambda: None
|
||||
|
||||
mock_session.on = mock_on
|
||||
spy = SpyContextProvider()
|
||||
|
||||
agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy])
|
||||
session = agent.create_session()
|
||||
async for _ in agent.run("Hello", stream=True, session=session):
|
||||
pass
|
||||
|
||||
assert spy.after_run_context is not None
|
||||
assert spy.after_run_context.response is not None
|
||||
assert len(spy.after_run_context.response.messages) > 0
|
||||
assert spy.after_run_context.response.messages[0].text == "Hello"
|
||||
|
||||
async def test_streaming_after_run_sets_empty_response_on_no_updates(
|
||||
self,
|
||||
mock_client: MagicMock,
|
||||
mock_session: MagicMock,
|
||||
session_idle_event: SessionEvent,
|
||||
) -> None:
|
||||
"""Test that streaming after_run sets an empty response when no updates are yielded."""
|
||||
events = [session_idle_event]
|
||||
|
||||
def mock_on(handler: Any) -> Any:
|
||||
for event in events:
|
||||
handler(event)
|
||||
return lambda: None
|
||||
|
||||
mock_session.on = mock_on
|
||||
spy = SpyContextProvider()
|
||||
|
||||
agent = GitHubCopilotAgent(client=mock_client, context_providers=[spy])
|
||||
session = agent.create_session()
|
||||
async for _ in agent.run("Hello", stream=True, session=session):
|
||||
pass
|
||||
|
||||
assert spy.after_run_called
|
||||
assert spy.after_run_context.response is not None
|
||||
assert len(spy.after_run_context.response.messages) == 0
|
||||
|
||||
async def test_timeout_preserved_in_session_context_options(
|
||||
self,
|
||||
mock_client: MagicMock,
|
||||
mock_session: MagicMock,
|
||||
assistant_message_event: SessionEvent,
|
||||
) -> None:
|
||||
"""Test that timeout is preserved in session context options for providers."""
|
||||
mock_session.send_and_wait.return_value = assistant_message_event
|
||||
observed_options: dict[str, Any] = {}
|
||||
|
||||
class OptionsObserverProvider(ContextProvider):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(source_id="options-observer")
|
||||
|
||||
async def before_run(
|
||||
self,
|
||||
*,
|
||||
agent: Any,
|
||||
session: AgentSession,
|
||||
context: Any,
|
||||
state: dict[str, Any],
|
||||
) -> None:
|
||||
observed_options.update(context.options)
|
||||
|
||||
async def after_run(
|
||||
self,
|
||||
*,
|
||||
agent: Any,
|
||||
session: AgentSession,
|
||||
context: Any,
|
||||
state: dict[str, Any],
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
provider = OptionsObserverProvider()
|
||||
agent = GitHubCopilotAgent(client=mock_client, context_providers=[provider])
|
||||
session = agent.create_session()
|
||||
await agent.run("Hello", session=session, options={"timeout": 120})
|
||||
|
||||
assert observed_options.get("timeout") == 120
|
||||
|
||||
Reference in New Issue
Block a user