Python: Move InMemory history provider injection to the first invocation (#5236)

* Move InMemory history provider injection to the first invocation

* Add tests
This commit is contained in:
Tao Chen
2026-04-14 00:13:42 -07:00
committed by GitHub
Unverified
parent f183f888a3
commit 7bb0feca59
2 changed files with 124 additions and 5 deletions
@@ -119,15 +119,11 @@ class WorkflowAgent(BaseAgent):
if not any(is_type_compatible(list[Message], input_type) for input_type in start_executor.input_types):
raise ValueError("Workflow's start executor cannot handle list[Message]")
resolved_context_providers = list(context_providers) if context_providers is not None else []
if not resolved_context_providers:
resolved_context_providers.append(InMemoryHistoryProvider())
super().__init__(
id=id,
name=name,
description=description,
context_providers=resolved_context_providers,
context_providers=context_providers,
**kwargs,
)
self._workflow: Workflow = workflow
@@ -261,6 +257,15 @@ class WorkflowAgent(BaseAgent):
An AgentResponse representing the workflow execution results.
"""
input_messages = normalize_messages_input(messages)
if (
not any(
provider.load_messages for provider in self.context_providers if isinstance(provider, HistoryProvider)
)
and session is not None
):
self.context_providers.append(InMemoryHistoryProvider())
provider_session = session
if provider_session is None and self.context_providers:
provider_session = AgentSession()
@@ -332,6 +337,15 @@ class WorkflowAgent(BaseAgent):
AgentResponseUpdate objects representing the workflow execution progress.
"""
input_messages = normalize_messages_input(messages)
if (
not any(
provider.load_messages for provider in self.context_providers if isinstance(provider, HistoryProvider)
)
and session is not None
):
self.context_providers.append(InMemoryHistoryProvider())
provider_session = session
if provider_session is None and self.context_providers:
provider_session = AgentSession()
@@ -14,6 +14,7 @@ from agent_framework import (
AgentSession,
Content,
Executor,
HistoryProvider,
InMemoryHistoryProvider,
Message,
ResponseStream,
@@ -678,6 +679,110 @@ class TestWorkflowAgent:
assert agent.context_providers == [explicit_provider]
async def test_no_history_provider_injected_when_session_is_none(self) -> None:
"""Test that InMemoryHistoryProvider is NOT injected when session is None."""
capturing_executor = ConversationHistoryCapturingExecutor(id="no_session_test")
workflow = WorkflowBuilder(start_executor=capturing_executor).build()
agent = WorkflowAgent(workflow=workflow, name="No Session Agent")
await agent.run("hello")
assert not any(isinstance(p, InMemoryHistoryProvider) for p in agent.context_providers)
async def test_no_history_provider_injected_when_session_is_none_streaming(self) -> None:
"""Test that InMemoryHistoryProvider is NOT injected when session is None (streaming)."""
capturing_executor = ConversationHistoryCapturingExecutor(id="no_session_stream_test")
workflow = WorkflowBuilder(start_executor=capturing_executor).build()
agent = WorkflowAgent(workflow=workflow, name="No Session Stream Agent")
async for _ in agent.run("hello", stream=True):
pass
assert not any(isinstance(p, InMemoryHistoryProvider) for p in agent.context_providers)
async def test_no_injection_when_history_provider_with_load_messages_exists(self) -> None:
"""Test that no InMemoryHistoryProvider is injected when an existing HistoryProvider has load_messages=True."""
capturing_executor = ConversationHistoryCapturingExecutor(id="existing_provider_test")
workflow = WorkflowBuilder(start_executor=capturing_executor).build()
existing_provider = InMemoryHistoryProvider("custom", load_messages=True)
agent = WorkflowAgent(
workflow=workflow,
name="Existing Provider Agent",
context_providers=[existing_provider],
)
session = AgentSession()
await agent.run("hello", session=session)
# Should still have only the original provider
history_providers = [p for p in agent.context_providers if isinstance(p, HistoryProvider)]
assert len(history_providers) == 1
assert history_providers[0] is existing_provider
async def test_injection_when_history_provider_with_load_messages_false(self) -> None:
"""Test that InMemoryHistoryProvider IS injected when existing HistoryProvider has load_messages=False."""
capturing_executor = ConversationHistoryCapturingExecutor(id="no_load_provider_test")
workflow = WorkflowBuilder(start_executor=capturing_executor).build()
audit_provider = InMemoryHistoryProvider("audit", load_messages=False)
agent = WorkflowAgent(
workflow=workflow,
name="Audit Provider Agent",
context_providers=[audit_provider],
)
session = AgentSession()
await agent.run("hello", session=session)
# Should have injected an additional InMemoryHistoryProvider with load_messages=True
history_providers = [p for p in agent.context_providers if isinstance(p, HistoryProvider)]
assert len(history_providers) == 2
loading_providers = [p for p in history_providers if p.load_messages]
assert len(loading_providers) == 1
assert isinstance(loading_providers[0], InMemoryHistoryProvider)
async def test_no_duplicate_injection_on_multiple_runs(self) -> None:
"""Test that calling run() multiple times does not keep adding InMemoryHistoryProvider."""
capturing_executor = ConversationHistoryCapturingExecutor(id="no_dup_test")
workflow = WorkflowBuilder(start_executor=capturing_executor).build()
agent = WorkflowAgent(workflow=workflow, name="No Dup Agent")
session = AgentSession()
await agent.run("first", session=session)
await agent.run("second", session=session)
await agent.run("third", session=session)
history_providers = [p for p in agent.context_providers if isinstance(p, InMemoryHistoryProvider)]
assert len(history_providers) == 1
async def test_no_duplicate_injection_on_multiple_runs_streaming(self) -> None:
"""Test that calling run(stream=True) multiple times does not keep adding InMemoryHistoryProvider."""
capturing_executor = ConversationHistoryCapturingExecutor(id="no_dup_stream_test")
workflow = WorkflowBuilder(start_executor=capturing_executor).build()
agent = WorkflowAgent(workflow=workflow, name="No Dup Stream Agent")
session = AgentSession()
async for _ in agent.run("first", stream=True, session=session):
pass
async for _ in agent.run("second", stream=True, session=session):
pass
async for _ in agent.run("third", stream=True, session=session):
pass
history_providers = [p for p in agent.context_providers if isinstance(p, InMemoryHistoryProvider)]
assert len(history_providers) == 1
async def test_injection_with_session_in_streaming_mode(self) -> None:
"""Test that InMemoryHistoryProvider is injected when session is provided in streaming mode."""
capturing_executor = ConversationHistoryCapturingExecutor(id="stream_inject_test")
workflow = WorkflowBuilder(start_executor=capturing_executor).build()
agent = WorkflowAgent(workflow=workflow, name="Stream Inject Agent")
session = AgentSession()
async for _ in agent.run("hello", stream=True, session=session):
pass
assert any(isinstance(p, InMemoryHistoryProvider) for p in agent.context_providers)
async def test_checkpoint_storage_passed_to_workflow(self) -> None:
"""Test that checkpoint_storage parameter is passed through to the workflow."""
from agent_framework import InMemoryCheckpointStorage