diff --git a/python/packages/core/agent_framework/_workflows/_agent.py b/python/packages/core/agent_framework/_workflows/_agent.py index 60c5ec3774..2fd3f35213 100644 --- a/python/packages/core/agent_framework/_workflows/_agent.py +++ b/python/packages/core/agent_framework/_workflows/_agent.py @@ -119,15 +119,11 @@ class WorkflowAgent(BaseAgent): if not any(is_type_compatible(list[Message], input_type) for input_type in start_executor.input_types): raise ValueError("Workflow's start executor cannot handle list[Message]") - resolved_context_providers = list(context_providers) if context_providers is not None else [] - if not resolved_context_providers: - resolved_context_providers.append(InMemoryHistoryProvider()) - super().__init__( id=id, name=name, description=description, - context_providers=resolved_context_providers, + context_providers=context_providers, **kwargs, ) self._workflow: Workflow = workflow @@ -261,6 +257,15 @@ class WorkflowAgent(BaseAgent): An AgentResponse representing the workflow execution results. """ input_messages = normalize_messages_input(messages) + + if ( + not any( + provider.load_messages for provider in self.context_providers if isinstance(provider, HistoryProvider) + ) + and session is not None + ): + self.context_providers.append(InMemoryHistoryProvider()) + provider_session = session if provider_session is None and self.context_providers: provider_session = AgentSession() @@ -332,6 +337,15 @@ class WorkflowAgent(BaseAgent): AgentResponseUpdate objects representing the workflow execution progress. """ input_messages = normalize_messages_input(messages) + + if ( + not any( + provider.load_messages for provider in self.context_providers if isinstance(provider, HistoryProvider) + ) + and session is not None + ): + self.context_providers.append(InMemoryHistoryProvider()) + provider_session = session if provider_session is None and self.context_providers: provider_session = AgentSession() diff --git a/python/packages/core/tests/workflow/test_workflow_agent.py b/python/packages/core/tests/workflow/test_workflow_agent.py index dd2100c1ae..9c7a655d23 100644 --- a/python/packages/core/tests/workflow/test_workflow_agent.py +++ b/python/packages/core/tests/workflow/test_workflow_agent.py @@ -14,6 +14,7 @@ from agent_framework import ( AgentSession, Content, Executor, + HistoryProvider, InMemoryHistoryProvider, Message, ResponseStream, @@ -678,6 +679,110 @@ class TestWorkflowAgent: assert agent.context_providers == [explicit_provider] + async def test_no_history_provider_injected_when_session_is_none(self) -> None: + """Test that InMemoryHistoryProvider is NOT injected when session is None.""" + capturing_executor = ConversationHistoryCapturingExecutor(id="no_session_test") + workflow = WorkflowBuilder(start_executor=capturing_executor).build() + agent = WorkflowAgent(workflow=workflow, name="No Session Agent") + + await agent.run("hello") + + assert not any(isinstance(p, InMemoryHistoryProvider) for p in agent.context_providers) + + async def test_no_history_provider_injected_when_session_is_none_streaming(self) -> None: + """Test that InMemoryHistoryProvider is NOT injected when session is None (streaming).""" + capturing_executor = ConversationHistoryCapturingExecutor(id="no_session_stream_test") + workflow = WorkflowBuilder(start_executor=capturing_executor).build() + agent = WorkflowAgent(workflow=workflow, name="No Session Stream Agent") + + async for _ in agent.run("hello", stream=True): + pass + + assert not any(isinstance(p, InMemoryHistoryProvider) for p in agent.context_providers) + + async def test_no_injection_when_history_provider_with_load_messages_exists(self) -> None: + """Test that no InMemoryHistoryProvider is injected when an existing HistoryProvider has load_messages=True.""" + capturing_executor = ConversationHistoryCapturingExecutor(id="existing_provider_test") + workflow = WorkflowBuilder(start_executor=capturing_executor).build() + existing_provider = InMemoryHistoryProvider("custom", load_messages=True) + agent = WorkflowAgent( + workflow=workflow, + name="Existing Provider Agent", + context_providers=[existing_provider], + ) + session = AgentSession() + + await agent.run("hello", session=session) + + # Should still have only the original provider + history_providers = [p for p in agent.context_providers if isinstance(p, HistoryProvider)] + assert len(history_providers) == 1 + assert history_providers[0] is existing_provider + + async def test_injection_when_history_provider_with_load_messages_false(self) -> None: + """Test that InMemoryHistoryProvider IS injected when existing HistoryProvider has load_messages=False.""" + capturing_executor = ConversationHistoryCapturingExecutor(id="no_load_provider_test") + workflow = WorkflowBuilder(start_executor=capturing_executor).build() + audit_provider = InMemoryHistoryProvider("audit", load_messages=False) + agent = WorkflowAgent( + workflow=workflow, + name="Audit Provider Agent", + context_providers=[audit_provider], + ) + session = AgentSession() + + await agent.run("hello", session=session) + + # Should have injected an additional InMemoryHistoryProvider with load_messages=True + history_providers = [p for p in agent.context_providers if isinstance(p, HistoryProvider)] + assert len(history_providers) == 2 + loading_providers = [p for p in history_providers if p.load_messages] + assert len(loading_providers) == 1 + assert isinstance(loading_providers[0], InMemoryHistoryProvider) + + async def test_no_duplicate_injection_on_multiple_runs(self) -> None: + """Test that calling run() multiple times does not keep adding InMemoryHistoryProvider.""" + capturing_executor = ConversationHistoryCapturingExecutor(id="no_dup_test") + workflow = WorkflowBuilder(start_executor=capturing_executor).build() + agent = WorkflowAgent(workflow=workflow, name="No Dup Agent") + session = AgentSession() + + await agent.run("first", session=session) + await agent.run("second", session=session) + await agent.run("third", session=session) + + history_providers = [p for p in agent.context_providers if isinstance(p, InMemoryHistoryProvider)] + assert len(history_providers) == 1 + + async def test_no_duplicate_injection_on_multiple_runs_streaming(self) -> None: + """Test that calling run(stream=True) multiple times does not keep adding InMemoryHistoryProvider.""" + capturing_executor = ConversationHistoryCapturingExecutor(id="no_dup_stream_test") + workflow = WorkflowBuilder(start_executor=capturing_executor).build() + agent = WorkflowAgent(workflow=workflow, name="No Dup Stream Agent") + session = AgentSession() + + async for _ in agent.run("first", stream=True, session=session): + pass + async for _ in agent.run("second", stream=True, session=session): + pass + async for _ in agent.run("third", stream=True, session=session): + pass + + history_providers = [p for p in agent.context_providers if isinstance(p, InMemoryHistoryProvider)] + assert len(history_providers) == 1 + + async def test_injection_with_session_in_streaming_mode(self) -> None: + """Test that InMemoryHistoryProvider is injected when session is provided in streaming mode.""" + capturing_executor = ConversationHistoryCapturingExecutor(id="stream_inject_test") + workflow = WorkflowBuilder(start_executor=capturing_executor).build() + agent = WorkflowAgent(workflow=workflow, name="Stream Inject Agent") + session = AgentSession() + + async for _ in agent.run("hello", stream=True, session=session): + pass + + assert any(isinstance(p, InMemoryHistoryProvider) for p in agent.context_providers) + async def test_checkpoint_storage_passed_to_workflow(self) -> None: """Test that checkpoint_storage parameter is passed through to the workflow.""" from agent_framework import InMemoryCheckpointStorage