Python: Add propagate_session to as_tool() for session sharing in agent-as-tool scenarios (#4439)

* Python: Add propagate_session parameter to as_tool() for session sharing

Add opt-in session propagation in agent-as-tool scenarios. When
propagate_session=True, the parent agent's AgentSession is forwarded
to the sub-agent's run() call, allowing both agents to share session
state (history, metadata, session_id).

- Add propagate_session parameter to BaseAgent.as_tool() (default False)
- Include session in additional_function_arguments so it flows to tools
- Add 3 tests for propagation on/off and shared state verification
- Add sample showing session propagation with observability middleware

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Clarify propagate_session docstring per review feedback

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Giles Odigwe
2026-03-04 15:14:17 -08:00
committed by GitHub
Unverified
parent 23644ac6a7
commit d02051dbb6
3 changed files with 191 additions and 5 deletions
@@ -454,6 +454,7 @@ class BaseAgent(SerializationMixin):
stream_callback: Callable[[AgentResponseUpdate], None]
| Callable[[AgentResponseUpdate], Awaitable[None]]
| None = None,
propagate_session: bool = False,
) -> FunctionTool:
"""Create a FunctionTool that wraps this agent.
@@ -464,6 +465,12 @@ class BaseAgent(SerializationMixin):
arg_description: The description for the function argument.
If None, defaults to "Task for {tool_name}".
stream_callback: Optional callback for streaming responses. If provided, uses run(..., stream=True).
propagate_session: If True, the parent agent's ``AgentSession`` is
forwarded to this sub-agent's ``run()`` call, so both agents
operate within the same logical session (sharing the same
``session_id`` and provider-managed state, such as any stored
conversation history or metadata). Defaults to False, meaning
the sub-agent runs with a new, independent session.
Returns:
A FunctionTool that can be used as a tool by other agents.
@@ -480,9 +487,12 @@ class BaseAgent(SerializationMixin):
# Create an agent
agent = Agent(client=client, name="research-agent", description="Performs research tasks")
# Convert the agent to a tool
# Convert the agent to a tool (independent session)
research_tool = agent.as_tool()
# Convert the agent to a tool (shared session with parent)
research_tool = agent.as_tool(propagate_session=True)
# Use the tool with another agent
coordinator = Agent(client=client, name="coordinator", tools=research_tool)
"""
@@ -509,16 +519,21 @@ class BaseAgent(SerializationMixin):
# Extract the input from kwargs using the specified arg_name
input_text = kwargs.get(arg_name, "")
# Forward runtime context kwargs, excluding arg_name and conversation_id.
forwarded_kwargs = {k: v for k, v in kwargs.items() if k not in (arg_name, "conversation_id", "options")}
# Extract parent session when propagate_session is enabled
parent_session = kwargs.get("session") if propagate_session else None
# Forward runtime context kwargs, excluding framework-internal keys.
forwarded_kwargs = {
k: v for k, v in kwargs.items() if k not in (arg_name, "conversation_id", "options", "session")
}
if stream_callback is None:
# Use non-streaming mode
return (await self.run(input_text, stream=False, **forwarded_kwargs)).text
return (await self.run(input_text, stream=False, session=parent_session, **forwarded_kwargs)).text
# Use streaming mode - accumulate updates and create final response
response_updates: list[AgentResponseUpdate] = []
async for update in self.run(input_text, stream=True, **forwarded_kwargs):
async for update in self.run(input_text, stream=True, session=parent_session, **forwarded_kwargs):
response_updates.append(update)
if is_async_callback:
await stream_callback(update) # type: ignore[misc]
@@ -1061,6 +1076,9 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
# in function middleware context and tool invocation.
existing_additional_args = opts.pop("additional_function_arguments", None) or {}
additional_function_arguments = {**kwargs, **existing_additional_args}
# Include session so as_tool() wrappers with propagate_session=True can access it.
if active_session is not None:
additional_function_arguments["session"] = active_session
# Build options dict from run() options merged with provided options
run_opts: dict[str, Any] = {
@@ -707,6 +707,81 @@ async def test_chat_agent_as_tool_name_sanitization(client: SupportsChatGetRespo
assert tool.name == expected_tool_name, f"Expected {expected_tool_name}, got {tool.name} for input {agent_name}"
async def test_chat_agent_as_tool_propagate_session_true(client: SupportsChatGetResponse) -> None:
"""Test that propagate_session=True forwards the parent's session to the sub-agent."""
agent = Agent(client=client, name="SubAgent", description="Sub agent")
tool = agent.as_tool(propagate_session=True)
parent_session = AgentSession(session_id="parent-session-123")
parent_session.state["shared_key"] = "shared_value"
# Spy on the agent's run method to capture the session argument
original_run = agent.run
captured_session = None
def capturing_run(*args: Any, **kwargs: Any) -> Any:
nonlocal captured_session
captured_session = kwargs.get("session")
return original_run(*args, **kwargs)
agent.run = capturing_run # type: ignore[assignment, method-assign]
await tool.invoke(arguments=tool.input_model(task="Hello"), session=parent_session)
assert captured_session is parent_session
assert captured_session.session_id == "parent-session-123"
assert captured_session.state["shared_key"] == "shared_value"
async def test_chat_agent_as_tool_propagate_session_false_by_default(client: SupportsChatGetResponse) -> None:
"""Test that propagate_session defaults to False and does not forward the session."""
agent = Agent(client=client, name="SubAgent", description="Sub agent")
tool = agent.as_tool() # default: propagate_session=False
parent_session = AgentSession(session_id="parent-session-456")
original_run = agent.run
captured_session = None
def capturing_run(*args: Any, **kwargs: Any) -> Any:
nonlocal captured_session
captured_session = kwargs.get("session")
return original_run(*args, **kwargs)
agent.run = capturing_run # type: ignore[assignment, method-assign]
await tool.invoke(arguments=tool.input_model(task="Hello"), session=parent_session)
assert captured_session is None
async def test_chat_agent_as_tool_propagate_session_shares_state(client: SupportsChatGetResponse) -> None:
"""Test that shared session allows the sub-agent to read and write parent's state."""
agent = Agent(client=client, name="SubAgent", description="Sub agent")
tool = agent.as_tool(propagate_session=True)
parent_session = AgentSession(session_id="shared-session")
parent_session.state["counter"] = 0
# The sub-agent receives the same session object, so mutations are shared
original_run = agent.run
captured_session = None
def capturing_run(*args: Any, **kwargs: Any) -> Any:
nonlocal captured_session
captured_session = kwargs.get("session")
if captured_session:
captured_session.state["counter"] += 1
return original_run(*args, **kwargs)
agent.run = capturing_run # type: ignore[assignment, method-assign]
await tool.invoke(arguments=tool.input_model(task="Hello"), session=parent_session)
# The parent's state should reflect the sub-agent's mutation
assert parent_session.state["counter"] == 1
async def test_chat_agent_as_mcp_server_basic(client: SupportsChatGetResponse) -> None:
"""Test basic as_mcp_server functionality."""
agent = Agent(client=client, name="TestAgent", description="Test agent for MCP")
@@ -0,0 +1,93 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
from collections.abc import Awaitable, Callable
from agent_framework import AgentContext, AgentSession
from agent_framework.openai import OpenAIResponsesClient
from dotenv import load_dotenv
load_dotenv()
"""
Agent-as-Tool: Session Propagation Example
Demonstrates how to share an AgentSession between a coordinator agent and a
sub-agent invoked as a tool using ``propagate_session=True``.
When session propagation is enabled, both agents share the same session object,
including session_id and the mutable state dict. This allows correlated
conversation tracking and shared state across the agent hierarchy.
The middleware functions below are purely for observability — they are NOT
required for session propagation to work.
"""
async def log_session(
context: AgentContext,
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Agent middleware that logs the session received by each agent.
NOT required for session propagation — only used to observe the flow.
If propagation is working, both agents will show the same session_id.
"""
session: AgentSession | None = context.session
agent_name = context.agent.name or "unknown"
session_id = session.session_id if session else None
state = dict(session.state) if session else {}
print(f" [{agent_name}] session_id={session_id}, state={state}")
await call_next()
async def main() -> None:
print("=== Agent-as-Tool: Session Propagation ===\n")
client = OpenAIResponsesClient()
# --- Sub-agent: a research specialist ---
# The sub-agent has the same log_session middleware to prove it receives the session.
research_agent = client.as_agent(
name="ResearchAgent",
instructions="You are a research assistant. Provide concise answers.",
middleware=[log_session],
)
# propagate_session=True: the coordinator's session will be forwarded
research_tool = research_agent.as_tool(
name="research",
description="Research a topic and return findings",
arg_name="query",
arg_description="The research query",
propagate_session=True,
)
# --- Coordinator agent ---
coordinator = client.as_agent(
name="CoordinatorAgent",
instructions="You coordinate research. Use the 'research' tool to look up information.",
tools=[research_tool],
middleware=[log_session],
)
# Create a shared session and put some state in it
session = coordinator.create_session()
session.state["request_source"] = "demo"
print(f"Session ID: {session.session_id}")
print(f"Session state before run: {session.state}\n")
query = "What are the latest developments in quantum computing?"
print(f"User: {query}\n")
result = await coordinator.run(query, session=session)
print(f"\nCoordinator: {result}\n")
print(f"Session state after run: {session.state}")
print(
"\nIf both agents show the same session_id above, session propagation is working."
)
if __name__ == "__main__":
asyncio.run(main())