Python: [BREAKING] Scope provider state by source_id and standardize source IDs (#3995)

* Initial plan

* Add FoundryMemoryProvider and tests

Co-authored-by: eavanvalkenburg <13749212+eavanvalkenburg@users.noreply.github.com>

* Add sample and documentation for FoundryMemoryProvider

Co-authored-by: eavanvalkenburg <13749212+eavanvalkenburg@users.noreply.github.com>

* Address code review feedback for FoundryMemoryProvider

Co-authored-by: eavanvalkenburg <13749212+eavanvalkenburg@users.noreply.github.com>

* Address PR review comments: Add DEFAULT_SOURCE_ID, use logging.getLogger, move state to session.state

Co-authored-by: eavanvalkenburg <13749212+eavanvalkenburg@users.noreply.github.com>

* Fix Foundry memory ItemParam usage and exports

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Refactor provider hook state and standardize source IDs

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Support endpoint-based Foundry memory init

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Fix core README workflows link

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* updated implementation and sample

* Split out Foundry memory provider changes

Remove FoundryMemoryProvider implementation/tests/sample plus export and docs mentions from this branch so only non-Foundry changes remain.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Trigger CI rerun for PR #3995

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: eavanvalkenburg <13749212+eavanvalkenburg@users.noreply.github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Eduard van Valkenburg
2026-02-17 20:12:28 +01:00
committed by GitHub
Unverified
parent a5f948c215
commit cc98d5b6f7
28 changed files with 359 additions and 148 deletions
+1 -1
View File
@@ -220,7 +220,7 @@ if __name__ == "__main__":
- [Getting Started with Agents](../../samples/02-agents): Basic agent creation and tool usage
- [Chat Client Examples](../../samples/02-agents/chat_client): Direct chat client usage patterns
- [Azure AI Integration](https://github.com/microsoft/agent-framework/tree/main/python/packages/azure-ai): Azure AI integration
- [.NET Workflows Samples](https://github.com/microsoft/agent-framework/tree/main/dotnet/samples/GettingStarted/Workflows): Advanced multi-agent patterns (.NET)
- [.NET Workflows Samples](../../../dotnet/samples/GettingStarted/Workflows): Advanced multi-agent patterns (.NET)
## Agent Framework Documentation
+29 -13
View File
@@ -420,13 +420,18 @@ class BaseAgent(SerializationMixin):
session: The conversation session.
context: The invocation context with response populated.
"""
state = session.state if session else {}
provider_session = session
if provider_session is None and self.context_providers:
provider_session = AgentSession()
for provider in reversed(self.context_providers):
if provider_session is None:
raise RuntimeError("Provider session must be available when context providers are configured.")
await provider.after_run(
agent=self, # type: ignore[arg-type]
session=session, # type: ignore[arg-type]
session=provider_session,
context=context,
state=state,
state=provider_session.state.setdefault(provider.source_id, {}),
)
def as_tool(
@@ -988,10 +993,14 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
and not opts.get("store")
and not (getattr(self.client, "STORES_BY_DEFAULT", False) and opts.get("store") is not False)
):
self.context_providers.append(InMemoryHistoryProvider("memory"))
self.context_providers.append(InMemoryHistoryProvider())
active_session = session
if active_session is None and self.context_providers:
active_session = AgentSession()
session_context, chat_options = await self._prepare_session_and_messages(
session=session,
session=active_session,
input_messages=input_messages,
options=opts,
)
@@ -1018,7 +1027,9 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
# Build options dict from run() options merged with provided options
run_opts: dict[str, Any] = {
"model_id": opts.pop("model_id", None),
"conversation_id": session.service_session_id if session else opts.pop("conversation_id", None),
"conversation_id": active_session.service_session_id
if active_session
else opts.pop("conversation_id", None),
"allow_multiple_tool_calls": opts.pop("allow_multiple_tool_calls", None),
"additional_function_arguments": opts.pop("additional_function_arguments", None),
"frequency_penalty": opts.pop("frequency_penalty", None),
@@ -1046,12 +1057,12 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
# Ensure session is forwarded in kwargs for tool invocation
finalize_kwargs = dict(kwargs)
finalize_kwargs["session"] = session
finalize_kwargs["session"] = active_session
# Filter chat_options from kwargs to prevent duplicate keyword argument
filtered_kwargs = {k: v for k, v in finalize_kwargs.items() if k != "chat_options"}
return {
"session": session,
"session": active_session,
"session_context": session_context,
"input_messages": input_messages,
"session_messages": session_messages,
@@ -1129,23 +1140,28 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
else:
chat_options = {}
provider_session = session
if provider_session is None and self.context_providers:
provider_session = AgentSession()
session_context = SessionContext(
session_id=session.session_id if session else None,
service_session_id=session.service_session_id if session else None,
session_id=provider_session.session_id if provider_session else None,
service_session_id=provider_session.service_session_id if provider_session else None,
input_messages=input_messages or [],
options=options or {},
)
# Run before_run providers (forward order, skip BaseHistoryProvider with load_messages=False)
state = session.state if session else {}
for provider in self.context_providers:
if isinstance(provider, BaseHistoryProvider) and not provider.load_messages:
continue
if provider_session is None:
raise RuntimeError("Provider session must be available when context providers are configured.")
await provider.before_run(
agent=self, # type: ignore[arg-type]
session=session, # type: ignore[arg-type]
session=provider_session,
context=session_context,
state=state,
state=provider_session.state.setdefault(provider.source_id, {}),
)
# Merge provider-contributed tools into chat_options
@@ -16,7 +16,7 @@ import copy
import uuid
from abc import abstractmethod
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, ClassVar
from ._types import AgentResponse, Message
@@ -310,7 +310,8 @@ class BaseContextProvider:
agent: The agent running this invocation.
session: The current session.
context: The invocation context - add messages/instructions/tools here.
state: The session's mutable state dict.
state: The provider-scoped mutable state dict for this provider.
Full cross-provider state remains available at ``session.state``.
"""
async def after_run(
@@ -330,7 +331,8 @@ class BaseContextProvider:
agent: The agent that ran this invocation.
session: The current session.
context: The invocation context with response populated.
state: The session's mutable state dict.
state: The provider-scoped mutable state dict for this provider.
Full cross-provider state remains available at ``session.state``.
"""
@@ -520,25 +522,56 @@ class AgentSession:
class InMemoryHistoryProvider(BaseHistoryProvider):
"""Built-in history provider that stores messages in session.state.
Messages are stored in ``state[source_id]["messages"]`` as a list of
Messages are stored in ``state["messages"]`` as a list of
``Message`` objects. Serialization to/from dicts is handled by
``AgentSession.to_dict()``/``from_dict()`` using ``SerializationProtocol``.
This provider holds no instance state — all data lives in the session's
state dict, passed as a named ``state`` parameter to ``get_messages``/``save_messages``.
This is the default provider auto-added by the agent when no providers
are configured and ``conversation_id`` or ``store=True`` is set.
This is the default provider auto-added by the agent for local sessions
when no providers are configured and service-side storage is not requested.
"""
DEFAULT_SOURCE_ID: ClassVar[str] = "in_memory"
def __init__(
self,
source_id: str | None = None,
*,
load_messages: bool = True,
store_inputs: bool = True,
store_context_messages: bool = False,
store_context_from: set[str] | None = None,
store_outputs: bool = True,
) -> None:
"""Initialize the in-memory history provider.
Args:
source_id: Unique identifier for this provider instance.
Defaults to DEFAULT_SOURCE_ID when not provided.
load_messages: Whether to load messages before invocation.
store_inputs: Whether to store input messages.
store_context_messages: Whether to store context from other providers.
store_context_from: If set, only store context from these source_ids.
store_outputs: Whether to store response messages.
"""
super().__init__(
source_id=source_id or self.DEFAULT_SOURCE_ID,
load_messages=load_messages,
store_inputs=store_inputs,
store_context_messages=store_context_messages,
store_context_from=store_context_from,
store_outputs=store_outputs,
)
async def get_messages(
self, session_id: str | None, *, state: dict[str, Any] | None = None, **kwargs: Any
) -> list[Message]:
"""Retrieve messages from session state."""
if state is None:
return []
my_state = state.get(self.source_id, {})
return list(my_state.get("messages", []))
return list(state.get("messages", []))
async def save_messages(
self,
@@ -551,6 +584,5 @@ class InMemoryHistoryProvider(BaseHistoryProvider):
"""Persist messages to session state."""
if state is None:
return
my_state = state.setdefault(self.source_id, {})
existing = my_state.get("messages", [])
my_state["messages"] = [*existing, *messages]
existing = state.get("messages", [])
state["messages"] = [*existing, *messages]
@@ -121,7 +121,7 @@ class WorkflowAgent(BaseAgent):
resolved_context_providers = list(context_providers) if context_providers is not None else []
if not resolved_context_providers:
resolved_context_providers.append(InMemoryHistoryProvider("memory"))
resolved_context_providers.append(InMemoryHistoryProvider())
super().__init__(
id=id,
@@ -237,23 +237,27 @@ class WorkflowAgent(BaseAgent):
An AgentResponse representing the workflow execution results.
"""
input_messages = normalize_messages_input(messages)
provider_session = session
if provider_session is None and self.context_providers:
provider_session = AgentSession()
# run the context providers with the session
session_context = SessionContext(
session_id=session.session_id if session else None,
service_session_id=session.service_session_id if session else None,
session_id=provider_session.session_id if provider_session else None,
service_session_id=provider_session.service_session_id if provider_session else None,
input_messages=input_messages or [],
options={},
)
state = session.state if session else {}
for provider in self.context_providers:
if isinstance(provider, BaseHistoryProvider) and not provider.load_messages:
continue
if provider_session is None:
raise RuntimeError("Provider session must be available when context providers are configured.")
await provider.before_run(
agent=self, # type: ignore[arg-type]
session=session, # type: ignore[arg-type]
session=provider_session,
context=session_context,
state=state,
state=provider_session.state.setdefault(provider.source_id, {}),
)
# combine the messages
session_messages: list[Message] = session_context.get_messages(include_input=True)
@@ -266,7 +270,7 @@ class WorkflowAgent(BaseAgent):
output_events.append(event)
result = self._convert_workflow_events_to_agent_response(response_id, output_events)
await self._run_after_providers(session=session, context=session_context)
await self._run_after_providers(session=provider_session, context=session_context)
return result
async def _run_stream_impl(
@@ -293,23 +297,27 @@ class WorkflowAgent(BaseAgent):
AgentResponseUpdate objects representing the workflow execution progress.
"""
input_messages = normalize_messages_input(messages)
provider_session = session
if provider_session is None and self.context_providers:
provider_session = AgentSession()
# run the context providers with the session
session_context = SessionContext(
session_id=session.session_id if session else None,
service_session_id=session.service_session_id if session else None,
session_id=provider_session.session_id if provider_session else None,
service_session_id=provider_session.service_session_id if provider_session else None,
input_messages=input_messages or [],
options={},
)
state = session.state if session else {}
for provider in self.context_providers:
if isinstance(provider, BaseHistoryProvider) and not provider.load_messages:
continue
if provider_session is None:
raise RuntimeError("Provider session must be available when context providers are configured.")
await provider.before_run(
agent=self, # type: ignore[arg-type]
session=session, # type: ignore[arg-type]
session=provider_session,
context=session_context,
state=state,
state=provider_session.state.setdefault(provider.source_id, {}),
)
# combine the messages
@@ -320,7 +328,7 @@ class WorkflowAgent(BaseAgent):
updates = self._convert_workflow_event_to_agent_response_updates(response_id, event)
for update in updates:
yield update
await self._run_after_providers(session=session, context=session_context)
await self._run_after_providers(session=provider_session, context=session_context)
async def _run_core(
self,
@@ -107,10 +107,10 @@ async def test_chat_client_agent_create_session(client: SupportsChatGetResponse)
async def test_chat_client_agent_prepare_session_and_messages(client: SupportsChatGetResponse) -> None:
from agent_framework._sessions import InMemoryHistoryProvider
agent = Agent(client=client, context_providers=[InMemoryHistoryProvider("memory")])
agent = Agent(client=client, context_providers=[InMemoryHistoryProvider()])
message = Message(role="user", text="Hello")
session = AgentSession()
session.state["memory"] = {"messages": [message]}
session.state[InMemoryHistoryProvider.DEFAULT_SOURCE_ID] = {"messages": [message]}
session_context, _ = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage]
session=session,
@@ -267,6 +267,8 @@ async def test_chat_client_agent_update_session_id_streaming_does_not_use_respon
async def test_chat_client_agent_update_session_messages(client: SupportsChatGetResponse) -> None:
from agent_framework._sessions import InMemoryHistoryProvider
agent = Agent(client=client)
session = agent.create_session()
@@ -275,7 +277,7 @@ async def test_chat_client_agent_update_session_messages(client: SupportsChatGet
assert session.service_session_id is None
chat_messages: list[Message] = session.state.get("memory", {}).get("messages", [])
chat_messages: list[Message] = session.state.get(InMemoryHistoryProvider.DEFAULT_SOURCE_ID, {}).get("messages", [])
assert chat_messages is not None
assert len(chat_messages) == 2
@@ -27,6 +27,7 @@ from agent_framework import (
chat_middleware,
function_middleware,
)
from agent_framework._sessions import InMemoryHistoryProvider
from .conftest import MockBaseChatClient, MockChatClient
@@ -1416,8 +1417,10 @@ class TestChatAgentSessionBehavior:
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
# Capture state before next() call
thread_messages = []
if context.session and context.session.state.get("memory"):
thread_messages = context.session.state.get("memory", {}).get("messages", [])
if context.session and context.session.state.get(InMemoryHistoryProvider.DEFAULT_SOURCE_ID):
thread_messages = context.session.state.get(InMemoryHistoryProvider.DEFAULT_SOURCE_ID, {}).get(
"messages", []
)
before_state = {
"before_next": True,
@@ -1432,8 +1435,10 @@ class TestChatAgentSessionBehavior:
# Capture state after next() call
thread_messages_after = []
if context.session and context.session.state.get("memory"):
thread_messages_after = context.session.state.get("memory", {}).get("messages", [])
if context.session and context.session.state.get(InMemoryHistoryProvider.DEFAULT_SOURCE_ID):
thread_messages_after = context.session.state.get(
InMemoryHistoryProvider.DEFAULT_SOURCE_ID, {}
).get("messages", [])
after_state = {
"before_next": False,
@@ -359,30 +359,50 @@ class TestAgentSession:
class TestInMemoryHistoryProvider:
async def test_empty_state_returns_no_messages(self) -> None:
provider = InMemoryHistoryProvider("memory")
provider = InMemoryHistoryProvider()
session = AgentSession()
ctx = SessionContext(session_id="s1", input_messages=[])
await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
assert ctx.context_messages.get("memory", []) == []
await provider.before_run( # type: ignore[arg-type]
agent=None,
session=session,
context=ctx,
state=session.state.setdefault(provider.source_id, {}),
)
assert ctx.context_messages.get(provider.source_id, []) == []
async def test_stores_and_loads_messages(self) -> None:
from agent_framework import AgentResponse
provider = InMemoryHistoryProvider("memory")
provider = InMemoryHistoryProvider()
session = AgentSession()
# First run: send input, get response
input_msg = Message(role="user", contents=["hello"])
resp_msg = Message(role="assistant", contents=["hi there"])
ctx1 = SessionContext(session_id="s1", input_messages=[input_msg])
await provider.before_run(agent=None, session=session, context=ctx1, state=session.state) # type: ignore[arg-type]
await provider.before_run( # type: ignore[arg-type]
agent=None,
session=session,
context=ctx1,
state=session.state.setdefault(provider.source_id, {}),
)
ctx1._response = AgentResponse(messages=[resp_msg])
await provider.after_run(agent=None, session=session, context=ctx1, state=session.state) # type: ignore[arg-type]
await provider.after_run( # type: ignore[arg-type]
agent=None,
session=session,
context=ctx1,
state=session.state.setdefault(provider.source_id, {}),
)
# Second run: should load previous messages
ctx2 = SessionContext(session_id="s1", input_messages=[Message(role="user", contents=["again"])])
await provider.before_run(agent=None, session=session, context=ctx2, state=session.state) # type: ignore[arg-type]
loaded = ctx2.context_messages.get("memory", [])
await provider.before_run( # type: ignore[arg-type]
agent=None,
session=session,
context=ctx2,
state=session.state.setdefault(provider.source_id, {}),
)
loaded = ctx2.context_messages.get(provider.source_id, [])
assert len(loaded) == 2
assert loaded[0].text == "hello"
assert loaded[1].text == "hi there"
@@ -390,17 +410,27 @@ class TestInMemoryHistoryProvider:
async def test_state_is_serializable(self) -> None:
from agent_framework import AgentResponse
provider = InMemoryHistoryProvider("memory")
provider = InMemoryHistoryProvider()
session = AgentSession()
input_msg = Message(role="user", contents=["test"])
ctx = SessionContext(session_id="s1", input_messages=[input_msg])
await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
await provider.before_run( # type: ignore[arg-type]
agent=None,
session=session,
context=ctx,
state=session.state.setdefault(provider.source_id, {}),
)
ctx._response = AgentResponse(messages=[Message(role="assistant", contents=["reply"])])
await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
await provider.after_run( # type: ignore[arg-type]
agent=None,
session=session,
context=ctx,
state=session.state.setdefault(provider.source_id, {}),
)
# State contains Message objects (not dicts)
assert isinstance(session.state["memory"]["messages"][0], Message)
assert isinstance(session.state[provider.source_id]["messages"][0], Message)
# to_dict() serializes them via SerializationProtocol
session_dict = session.to_dict()
@@ -409,9 +439,9 @@ class TestInMemoryHistoryProvider:
# Round-trip through session serialization restores Message objects
restored = AgentSession.from_dict(json.loads(json_str))
assert isinstance(restored.state["memory"]["messages"][0], Message)
assert restored.state["memory"]["messages"][0].text == "test"
assert restored.state["memory"]["messages"][1].text == "reply"
assert isinstance(restored.state[provider.source_id]["messages"][0], Message)
assert restored.state[provider.source_id]["messages"][0].text == "test"
assert restored.state[provider.source_id]["messages"][1].text == "reply"
async def test_source_id_attribution(self) -> None:
provider = InMemoryHistoryProvider("custom-source")