mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Add propagate_session to as_tool() for session sharing in agent-as-tool scenarios (#4439)
* Python: Add propagate_session parameter to as_tool() for session sharing Add opt-in session propagation in agent-as-tool scenarios. When propagate_session=True, the parent agent's AgentSession is forwarded to the sub-agent's run() call, allowing both agents to share session state (history, metadata, session_id). - Add propagate_session parameter to BaseAgent.as_tool() (default False) - Include session in additional_function_arguments so it flows to tools - Add 3 tests for propagation on/off and shared state verification - Add sample showing session propagation with observability middleware Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Clarify propagate_session docstring per review feedback Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
23644ac6a7
commit
d02051dbb6
@@ -454,6 +454,7 @@ class BaseAgent(SerializationMixin):
|
||||
stream_callback: Callable[[AgentResponseUpdate], None]
|
||||
| Callable[[AgentResponseUpdate], Awaitable[None]]
|
||||
| None = None,
|
||||
propagate_session: bool = False,
|
||||
) -> FunctionTool:
|
||||
"""Create a FunctionTool that wraps this agent.
|
||||
|
||||
@@ -464,6 +465,12 @@ class BaseAgent(SerializationMixin):
|
||||
arg_description: The description for the function argument.
|
||||
If None, defaults to "Task for {tool_name}".
|
||||
stream_callback: Optional callback for streaming responses. If provided, uses run(..., stream=True).
|
||||
propagate_session: If True, the parent agent's ``AgentSession`` is
|
||||
forwarded to this sub-agent's ``run()`` call, so both agents
|
||||
operate within the same logical session (sharing the same
|
||||
``session_id`` and provider-managed state, such as any stored
|
||||
conversation history or metadata). Defaults to False, meaning
|
||||
the sub-agent runs with a new, independent session.
|
||||
|
||||
Returns:
|
||||
A FunctionTool that can be used as a tool by other agents.
|
||||
@@ -480,9 +487,12 @@ class BaseAgent(SerializationMixin):
|
||||
# Create an agent
|
||||
agent = Agent(client=client, name="research-agent", description="Performs research tasks")
|
||||
|
||||
# Convert the agent to a tool
|
||||
# Convert the agent to a tool (independent session)
|
||||
research_tool = agent.as_tool()
|
||||
|
||||
# Convert the agent to a tool (shared session with parent)
|
||||
research_tool = agent.as_tool(propagate_session=True)
|
||||
|
||||
# Use the tool with another agent
|
||||
coordinator = Agent(client=client, name="coordinator", tools=research_tool)
|
||||
"""
|
||||
@@ -509,16 +519,21 @@ class BaseAgent(SerializationMixin):
|
||||
# Extract the input from kwargs using the specified arg_name
|
||||
input_text = kwargs.get(arg_name, "")
|
||||
|
||||
# Forward runtime context kwargs, excluding arg_name and conversation_id.
|
||||
forwarded_kwargs = {k: v for k, v in kwargs.items() if k not in (arg_name, "conversation_id", "options")}
|
||||
# Extract parent session when propagate_session is enabled
|
||||
parent_session = kwargs.get("session") if propagate_session else None
|
||||
|
||||
# Forward runtime context kwargs, excluding framework-internal keys.
|
||||
forwarded_kwargs = {
|
||||
k: v for k, v in kwargs.items() if k not in (arg_name, "conversation_id", "options", "session")
|
||||
}
|
||||
|
||||
if stream_callback is None:
|
||||
# Use non-streaming mode
|
||||
return (await self.run(input_text, stream=False, **forwarded_kwargs)).text
|
||||
return (await self.run(input_text, stream=False, session=parent_session, **forwarded_kwargs)).text
|
||||
|
||||
# Use streaming mode - accumulate updates and create final response
|
||||
response_updates: list[AgentResponseUpdate] = []
|
||||
async for update in self.run(input_text, stream=True, **forwarded_kwargs):
|
||||
async for update in self.run(input_text, stream=True, session=parent_session, **forwarded_kwargs):
|
||||
response_updates.append(update)
|
||||
if is_async_callback:
|
||||
await stream_callback(update) # type: ignore[misc]
|
||||
@@ -1061,6 +1076,9 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
# in function middleware context and tool invocation.
|
||||
existing_additional_args = opts.pop("additional_function_arguments", None) or {}
|
||||
additional_function_arguments = {**kwargs, **existing_additional_args}
|
||||
# Include session so as_tool() wrappers with propagate_session=True can access it.
|
||||
if active_session is not None:
|
||||
additional_function_arguments["session"] = active_session
|
||||
|
||||
# Build options dict from run() options merged with provided options
|
||||
run_opts: dict[str, Any] = {
|
||||
|
||||
@@ -707,6 +707,81 @@ async def test_chat_agent_as_tool_name_sanitization(client: SupportsChatGetRespo
|
||||
assert tool.name == expected_tool_name, f"Expected {expected_tool_name}, got {tool.name} for input {agent_name}"
|
||||
|
||||
|
||||
async def test_chat_agent_as_tool_propagate_session_true(client: SupportsChatGetResponse) -> None:
|
||||
"""Test that propagate_session=True forwards the parent's session to the sub-agent."""
|
||||
agent = Agent(client=client, name="SubAgent", description="Sub agent")
|
||||
tool = agent.as_tool(propagate_session=True)
|
||||
|
||||
parent_session = AgentSession(session_id="parent-session-123")
|
||||
parent_session.state["shared_key"] = "shared_value"
|
||||
|
||||
# Spy on the agent's run method to capture the session argument
|
||||
original_run = agent.run
|
||||
captured_session = None
|
||||
|
||||
def capturing_run(*args: Any, **kwargs: Any) -> Any:
|
||||
nonlocal captured_session
|
||||
captured_session = kwargs.get("session")
|
||||
return original_run(*args, **kwargs)
|
||||
|
||||
agent.run = capturing_run # type: ignore[assignment, method-assign]
|
||||
|
||||
await tool.invoke(arguments=tool.input_model(task="Hello"), session=parent_session)
|
||||
|
||||
assert captured_session is parent_session
|
||||
assert captured_session.session_id == "parent-session-123"
|
||||
assert captured_session.state["shared_key"] == "shared_value"
|
||||
|
||||
|
||||
async def test_chat_agent_as_tool_propagate_session_false_by_default(client: SupportsChatGetResponse) -> None:
|
||||
"""Test that propagate_session defaults to False and does not forward the session."""
|
||||
agent = Agent(client=client, name="SubAgent", description="Sub agent")
|
||||
tool = agent.as_tool() # default: propagate_session=False
|
||||
|
||||
parent_session = AgentSession(session_id="parent-session-456")
|
||||
|
||||
original_run = agent.run
|
||||
captured_session = None
|
||||
|
||||
def capturing_run(*args: Any, **kwargs: Any) -> Any:
|
||||
nonlocal captured_session
|
||||
captured_session = kwargs.get("session")
|
||||
return original_run(*args, **kwargs)
|
||||
|
||||
agent.run = capturing_run # type: ignore[assignment, method-assign]
|
||||
|
||||
await tool.invoke(arguments=tool.input_model(task="Hello"), session=parent_session)
|
||||
|
||||
assert captured_session is None
|
||||
|
||||
|
||||
async def test_chat_agent_as_tool_propagate_session_shares_state(client: SupportsChatGetResponse) -> None:
|
||||
"""Test that shared session allows the sub-agent to read and write parent's state."""
|
||||
agent = Agent(client=client, name="SubAgent", description="Sub agent")
|
||||
tool = agent.as_tool(propagate_session=True)
|
||||
|
||||
parent_session = AgentSession(session_id="shared-session")
|
||||
parent_session.state["counter"] = 0
|
||||
|
||||
# The sub-agent receives the same session object, so mutations are shared
|
||||
original_run = agent.run
|
||||
captured_session = None
|
||||
|
||||
def capturing_run(*args: Any, **kwargs: Any) -> Any:
|
||||
nonlocal captured_session
|
||||
captured_session = kwargs.get("session")
|
||||
if captured_session:
|
||||
captured_session.state["counter"] += 1
|
||||
return original_run(*args, **kwargs)
|
||||
|
||||
agent.run = capturing_run # type: ignore[assignment, method-assign]
|
||||
|
||||
await tool.invoke(arguments=tool.input_model(task="Hello"), session=parent_session)
|
||||
|
||||
# The parent's state should reflect the sub-agent's mutation
|
||||
assert parent_session.state["counter"] == 1
|
||||
|
||||
|
||||
async def test_chat_agent_as_mcp_server_basic(client: SupportsChatGetResponse) -> None:
|
||||
"""Test basic as_mcp_server functionality."""
|
||||
agent = Agent(client=client, name="TestAgent", description="Test agent for MCP")
|
||||
|
||||
@@ -0,0 +1,93 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from agent_framework import AgentContext, AgentSession
|
||||
from agent_framework.openai import OpenAIResponsesClient
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
"""
|
||||
Agent-as-Tool: Session Propagation Example
|
||||
|
||||
Demonstrates how to share an AgentSession between a coordinator agent and a
|
||||
sub-agent invoked as a tool using ``propagate_session=True``.
|
||||
|
||||
When session propagation is enabled, both agents share the same session object,
|
||||
including session_id and the mutable state dict. This allows correlated
|
||||
conversation tracking and shared state across the agent hierarchy.
|
||||
|
||||
The middleware functions below are purely for observability — they are NOT
|
||||
required for session propagation to work.
|
||||
"""
|
||||
|
||||
|
||||
async def log_session(
|
||||
context: AgentContext,
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Agent middleware that logs the session received by each agent.
|
||||
|
||||
NOT required for session propagation — only used to observe the flow.
|
||||
If propagation is working, both agents will show the same session_id.
|
||||
"""
|
||||
session: AgentSession | None = context.session
|
||||
agent_name = context.agent.name or "unknown"
|
||||
session_id = session.session_id if session else None
|
||||
state = dict(session.state) if session else {}
|
||||
print(f" [{agent_name}] session_id={session_id}, state={state}")
|
||||
await call_next()
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
print("=== Agent-as-Tool: Session Propagation ===\n")
|
||||
|
||||
client = OpenAIResponsesClient()
|
||||
|
||||
# --- Sub-agent: a research specialist ---
|
||||
# The sub-agent has the same log_session middleware to prove it receives the session.
|
||||
research_agent = client.as_agent(
|
||||
name="ResearchAgent",
|
||||
instructions="You are a research assistant. Provide concise answers.",
|
||||
middleware=[log_session],
|
||||
)
|
||||
|
||||
# propagate_session=True: the coordinator's session will be forwarded
|
||||
research_tool = research_agent.as_tool(
|
||||
name="research",
|
||||
description="Research a topic and return findings",
|
||||
arg_name="query",
|
||||
arg_description="The research query",
|
||||
propagate_session=True,
|
||||
)
|
||||
|
||||
# --- Coordinator agent ---
|
||||
coordinator = client.as_agent(
|
||||
name="CoordinatorAgent",
|
||||
instructions="You coordinate research. Use the 'research' tool to look up information.",
|
||||
tools=[research_tool],
|
||||
middleware=[log_session],
|
||||
)
|
||||
|
||||
# Create a shared session and put some state in it
|
||||
session = coordinator.create_session()
|
||||
session.state["request_source"] = "demo"
|
||||
print(f"Session ID: {session.session_id}")
|
||||
print(f"Session state before run: {session.state}\n")
|
||||
|
||||
query = "What are the latest developments in quantum computing?"
|
||||
print(f"User: {query}\n")
|
||||
|
||||
result = await coordinator.run(query, session=session)
|
||||
|
||||
print(f"\nCoordinator: {result}\n")
|
||||
print(f"Session state after run: {session.state}")
|
||||
print(
|
||||
"\nIf both agents show the same session_id above, session propagation is working."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user