mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Move InMemory history provider injection to the first invocation (#5236)
* Move InMemory history provider injection to the first invocation * Add tests
This commit is contained in:
committed by
GitHub
Unverified
parent
f183f888a3
commit
7bb0feca59
@@ -119,15 +119,11 @@ class WorkflowAgent(BaseAgent):
|
||||
if not any(is_type_compatible(list[Message], input_type) for input_type in start_executor.input_types):
|
||||
raise ValueError("Workflow's start executor cannot handle list[Message]")
|
||||
|
||||
resolved_context_providers = list(context_providers) if context_providers is not None else []
|
||||
if not resolved_context_providers:
|
||||
resolved_context_providers.append(InMemoryHistoryProvider())
|
||||
|
||||
super().__init__(
|
||||
id=id,
|
||||
name=name,
|
||||
description=description,
|
||||
context_providers=resolved_context_providers,
|
||||
context_providers=context_providers,
|
||||
**kwargs,
|
||||
)
|
||||
self._workflow: Workflow = workflow
|
||||
@@ -261,6 +257,15 @@ class WorkflowAgent(BaseAgent):
|
||||
An AgentResponse representing the workflow execution results.
|
||||
"""
|
||||
input_messages = normalize_messages_input(messages)
|
||||
|
||||
if (
|
||||
not any(
|
||||
provider.load_messages for provider in self.context_providers if isinstance(provider, HistoryProvider)
|
||||
)
|
||||
and session is not None
|
||||
):
|
||||
self.context_providers.append(InMemoryHistoryProvider())
|
||||
|
||||
provider_session = session
|
||||
if provider_session is None and self.context_providers:
|
||||
provider_session = AgentSession()
|
||||
@@ -332,6 +337,15 @@ class WorkflowAgent(BaseAgent):
|
||||
AgentResponseUpdate objects representing the workflow execution progress.
|
||||
"""
|
||||
input_messages = normalize_messages_input(messages)
|
||||
|
||||
if (
|
||||
not any(
|
||||
provider.load_messages for provider in self.context_providers if isinstance(provider, HistoryProvider)
|
||||
)
|
||||
and session is not None
|
||||
):
|
||||
self.context_providers.append(InMemoryHistoryProvider())
|
||||
|
||||
provider_session = session
|
||||
if provider_session is None and self.context_providers:
|
||||
provider_session = AgentSession()
|
||||
|
||||
@@ -14,6 +14,7 @@ from agent_framework import (
|
||||
AgentSession,
|
||||
Content,
|
||||
Executor,
|
||||
HistoryProvider,
|
||||
InMemoryHistoryProvider,
|
||||
Message,
|
||||
ResponseStream,
|
||||
@@ -678,6 +679,110 @@ class TestWorkflowAgent:
|
||||
|
||||
assert agent.context_providers == [explicit_provider]
|
||||
|
||||
async def test_no_history_provider_injected_when_session_is_none(self) -> None:
|
||||
"""Test that InMemoryHistoryProvider is NOT injected when session is None."""
|
||||
capturing_executor = ConversationHistoryCapturingExecutor(id="no_session_test")
|
||||
workflow = WorkflowBuilder(start_executor=capturing_executor).build()
|
||||
agent = WorkflowAgent(workflow=workflow, name="No Session Agent")
|
||||
|
||||
await agent.run("hello")
|
||||
|
||||
assert not any(isinstance(p, InMemoryHistoryProvider) for p in agent.context_providers)
|
||||
|
||||
async def test_no_history_provider_injected_when_session_is_none_streaming(self) -> None:
|
||||
"""Test that InMemoryHistoryProvider is NOT injected when session is None (streaming)."""
|
||||
capturing_executor = ConversationHistoryCapturingExecutor(id="no_session_stream_test")
|
||||
workflow = WorkflowBuilder(start_executor=capturing_executor).build()
|
||||
agent = WorkflowAgent(workflow=workflow, name="No Session Stream Agent")
|
||||
|
||||
async for _ in agent.run("hello", stream=True):
|
||||
pass
|
||||
|
||||
assert not any(isinstance(p, InMemoryHistoryProvider) for p in agent.context_providers)
|
||||
|
||||
async def test_no_injection_when_history_provider_with_load_messages_exists(self) -> None:
|
||||
"""Test that no InMemoryHistoryProvider is injected when an existing HistoryProvider has load_messages=True."""
|
||||
capturing_executor = ConversationHistoryCapturingExecutor(id="existing_provider_test")
|
||||
workflow = WorkflowBuilder(start_executor=capturing_executor).build()
|
||||
existing_provider = InMemoryHistoryProvider("custom", load_messages=True)
|
||||
agent = WorkflowAgent(
|
||||
workflow=workflow,
|
||||
name="Existing Provider Agent",
|
||||
context_providers=[existing_provider],
|
||||
)
|
||||
session = AgentSession()
|
||||
|
||||
await agent.run("hello", session=session)
|
||||
|
||||
# Should still have only the original provider
|
||||
history_providers = [p for p in agent.context_providers if isinstance(p, HistoryProvider)]
|
||||
assert len(history_providers) == 1
|
||||
assert history_providers[0] is existing_provider
|
||||
|
||||
async def test_injection_when_history_provider_with_load_messages_false(self) -> None:
|
||||
"""Test that InMemoryHistoryProvider IS injected when existing HistoryProvider has load_messages=False."""
|
||||
capturing_executor = ConversationHistoryCapturingExecutor(id="no_load_provider_test")
|
||||
workflow = WorkflowBuilder(start_executor=capturing_executor).build()
|
||||
audit_provider = InMemoryHistoryProvider("audit", load_messages=False)
|
||||
agent = WorkflowAgent(
|
||||
workflow=workflow,
|
||||
name="Audit Provider Agent",
|
||||
context_providers=[audit_provider],
|
||||
)
|
||||
session = AgentSession()
|
||||
|
||||
await agent.run("hello", session=session)
|
||||
|
||||
# Should have injected an additional InMemoryHistoryProvider with load_messages=True
|
||||
history_providers = [p for p in agent.context_providers if isinstance(p, HistoryProvider)]
|
||||
assert len(history_providers) == 2
|
||||
loading_providers = [p for p in history_providers if p.load_messages]
|
||||
assert len(loading_providers) == 1
|
||||
assert isinstance(loading_providers[0], InMemoryHistoryProvider)
|
||||
|
||||
async def test_no_duplicate_injection_on_multiple_runs(self) -> None:
|
||||
"""Test that calling run() multiple times does not keep adding InMemoryHistoryProvider."""
|
||||
capturing_executor = ConversationHistoryCapturingExecutor(id="no_dup_test")
|
||||
workflow = WorkflowBuilder(start_executor=capturing_executor).build()
|
||||
agent = WorkflowAgent(workflow=workflow, name="No Dup Agent")
|
||||
session = AgentSession()
|
||||
|
||||
await agent.run("first", session=session)
|
||||
await agent.run("second", session=session)
|
||||
await agent.run("third", session=session)
|
||||
|
||||
history_providers = [p for p in agent.context_providers if isinstance(p, InMemoryHistoryProvider)]
|
||||
assert len(history_providers) == 1
|
||||
|
||||
async def test_no_duplicate_injection_on_multiple_runs_streaming(self) -> None:
|
||||
"""Test that calling run(stream=True) multiple times does not keep adding InMemoryHistoryProvider."""
|
||||
capturing_executor = ConversationHistoryCapturingExecutor(id="no_dup_stream_test")
|
||||
workflow = WorkflowBuilder(start_executor=capturing_executor).build()
|
||||
agent = WorkflowAgent(workflow=workflow, name="No Dup Stream Agent")
|
||||
session = AgentSession()
|
||||
|
||||
async for _ in agent.run("first", stream=True, session=session):
|
||||
pass
|
||||
async for _ in agent.run("second", stream=True, session=session):
|
||||
pass
|
||||
async for _ in agent.run("third", stream=True, session=session):
|
||||
pass
|
||||
|
||||
history_providers = [p for p in agent.context_providers if isinstance(p, InMemoryHistoryProvider)]
|
||||
assert len(history_providers) == 1
|
||||
|
||||
async def test_injection_with_session_in_streaming_mode(self) -> None:
|
||||
"""Test that InMemoryHistoryProvider is injected when session is provided in streaming mode."""
|
||||
capturing_executor = ConversationHistoryCapturingExecutor(id="stream_inject_test")
|
||||
workflow = WorkflowBuilder(start_executor=capturing_executor).build()
|
||||
agent = WorkflowAgent(workflow=workflow, name="Stream Inject Agent")
|
||||
session = AgentSession()
|
||||
|
||||
async for _ in agent.run("hello", stream=True, session=session):
|
||||
pass
|
||||
|
||||
assert any(isinstance(p, InMemoryHistoryProvider) for p in agent.context_providers)
|
||||
|
||||
async def test_checkpoint_storage_passed_to_workflow(self) -> None:
|
||||
"""Test that checkpoint_storage parameter is passed through to the workflow."""
|
||||
from agent_framework import InMemoryCheckpointStorage
|
||||
|
||||
Reference in New Issue
Block a user