Python: [BREAKING] Scope provider state by source_id and standardize source IDs (#3995)

* Initial plan

* Add FoundryMemoryProvider and tests

Co-authored-by: eavanvalkenburg <13749212+eavanvalkenburg@users.noreply.github.com>

* Add sample and documentation for FoundryMemoryProvider

Co-authored-by: eavanvalkenburg <13749212+eavanvalkenburg@users.noreply.github.com>

* Address code review feedback for FoundryMemoryProvider

Co-authored-by: eavanvalkenburg <13749212+eavanvalkenburg@users.noreply.github.com>

* Address PR review comments: Add DEFAULT_SOURCE_ID, use logging.getLogger, move state to session.state

Co-authored-by: eavanvalkenburg <13749212+eavanvalkenburg@users.noreply.github.com>

* Fix Foundry memory ItemParam usage and exports

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Refactor provider hook state and standardize source IDs

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Support endpoint-based Foundry memory init

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Fix core README workflows link

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* updated implementation and sample

* Split out Foundry memory provider changes

Remove FoundryMemoryProvider implementation/tests/sample plus export and docs mentions from this branch so only non-Foundry changes remain.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Trigger CI rerun for PR #3995

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: eavanvalkenburg <13749212+eavanvalkenburg@users.noreply.github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Eduard van Valkenburg
2026-02-17 20:12:28 +01:00
committed by GitHub
Unverified
parent a5f948c215
commit cc98d5b6f7
28 changed files with 359 additions and 148 deletions
@@ -13,6 +13,7 @@ from agent_framework import (
ChatResponseUpdate,
Content,
FunctionInvocationLayer,
InMemoryHistoryProvider,
Message,
ResponseStream,
)
@@ -195,7 +196,7 @@ async def main() -> None:
print(f"Agent: {result.messages[0].text}\n")
# Check conversation history
memory_state = session.state.get("memory", {})
memory_state = session.state.get(InMemoryHistoryProvider.DEFAULT_SOURCE_ID, {})
session_messages = memory_state.get("messages", [])
if session_messages:
print(f"Session contains {len(session_messages)} messages")
@@ -17,7 +17,9 @@ class UserInfo(BaseModel):
class UserInfoMemory(BaseContextProvider):
def __init__(self, source_id: str = "user-info-memory", *, client: SupportsChatGetResponse, **kwargs: Any):
DEFAULT_SOURCE_ID = "user_info_memory"
def __init__(self, source_id: str = DEFAULT_SOURCE_ID, *, client: SupportsChatGetResponse, **kwargs: Any):
"""Create the memory.
If you pass in kwargs, they will be attempted to be used to create a UserInfo object.
@@ -39,9 +41,7 @@ class UserInfoMemory(BaseContextProvider):
# Check if we need to extract user info from user messages
user_messages = [msg for msg in request_messages if hasattr(msg, "role") and msg.role == "user"] # type: ignore
if (
state[self.source_id]["user_info"].name is None or state[self.source_id]["user_info"].age is None
) and user_messages:
if (state["user_info"].name is None or state["user_info"].age is None) and user_messages:
with suppress(Exception):
# Use the chat client to extract structured information
result = await self._chat_client.get_response(
@@ -54,10 +54,10 @@ class UserInfoMemory(BaseContextProvider):
# Update user info with extracted data
with suppress(Exception):
extracted = result.value
if state[self.source_id]["user_info"].name is None and extracted.name:
state[self.source_id]["user_info"].name = extracted.name
if state[self.source_id]["user_info"].age is None and extracted.age:
state[self.source_id]["user_info"].age = extracted.age
if state["user_info"].name is None and extracted.name:
state["user_info"].name = extracted.name
if state["user_info"].age is None and extracted.age:
state["user_info"].age = extracted.age
async def before_run(
self,
@@ -68,20 +68,19 @@ class UserInfoMemory(BaseContextProvider):
state: dict[str, Any],
) -> None:
"""Provide user information context before each agent call."""
if state.setdefault(self.source_id, None) is None:
state[self.source_id] = {"user_info": UserInfo()}
state.setdefault("user_info", UserInfo())
context.extend_instructions(
self.source_id,
"Ask the user for their name and politely decline to answer any questions until they provide it."
if state[self.source_id]["user_info"].name is None
else f"The user's name is {state[self.source_id]['user_info'].name}.",
if state["user_info"].name is None
else f"The user's name is {state['user_info'].name}.",
)
context.extend_instructions(
self.source_id,
"Ask the user for their age and politely decline to answer any questions until they provide it."
if state[self.source_id]["user_info"].age is None
else f"The user's age is {state[self.source_id]['user_info'].age}.",
if state["user_info"].age is None
else f"The user's age is {state['user_info'].age}.",
)
@@ -92,7 +91,7 @@ async def main():
credential=AzureCliCredential(),
)
context_name = "user-info-memory"
context_name = UserInfoMemory.DEFAULT_SOURCE_ID
# Create the memory provider
memory_provider = UserInfoMemory(context_name, client=client)
@@ -6,6 +6,7 @@ from typing import Annotated
from agent_framework import (
AgentContext,
InMemoryHistoryProvider,
tool,
)
from agent_framework.azure import AzureOpenAIChatClient
@@ -50,7 +51,7 @@ async def thread_tracking_middleware(
"""MiddlewareTypes that tracks and logs session behavior across runs."""
session_message_count = 0
if context.session:
memory_state = context.session.state.get("memory", {})
memory_state = context.session.state.get(InMemoryHistoryProvider.DEFAULT_SOURCE_ID, {})
session_message_count = len(memory_state.get("messages", []))
print(f"[MiddlewareTypes pre-execution] Current input messages: {len(context.messages)}")
@@ -62,7 +63,7 @@ async def thread_tracking_middleware(
# Check session state after agent execution
updated_session_message_count = 0
if context.session:
memory_state = context.session.state.get("memory", {})
memory_state = context.session.state.get(InMemoryHistoryProvider.DEFAULT_SOURCE_ID, {})
updated_session_message_count = len(memory_state.get("messages", []))
print(f"[MiddlewareTypes post-execution] Updated session messages: {updated_session_message_count}")
@@ -4,7 +4,7 @@ import asyncio
from random import randint
from typing import Annotated
from agent_framework import Agent, AgentSession, tool
from agent_framework import Agent, AgentSession, InMemoryHistoryProvider, tool
from agent_framework.azure import AzureOpenAIChatClient
from azure.identity import AzureCliCredential
from pydantic import Field
@@ -112,7 +112,7 @@ async def example_with_existing_session_messages() -> None:
print(f"Agent: {result1.text}")
# The session now contains the conversation history in state
memory_state = session.state.get("memory", {})
memory_state = session.state.get(InMemoryHistoryProvider.DEFAULT_SOURCE_ID, {})
messages = memory_state.get("messages", [])
if messages:
print(f"Session contains {len(messages)} messages")
@@ -10,6 +10,7 @@ from agent_framework import (
AgentSession,
BaseAgent,
Content,
InMemoryHistoryProvider,
Message,
Role,
normalize_messages,
@@ -93,7 +94,9 @@ class EchoAgent(BaseAgent):
if not normalized_messages:
response_message = Message(
role=Role.ASSISTANT,
contents=[Content.from_text(text="Hello! I'm a custom echo agent. Send me a message and I'll echo it back.")],
contents=[
Content.from_text(text="Hello! I'm a custom echo agent. Send me a message and I'll echo it back.")
],
)
else:
# For simplicity, echo the last user message
@@ -199,7 +202,7 @@ async def main() -> None:
print(f"Agent: {result2.messages[0].text}")
# Check conversation history
memory_state = session.state.get("memory", {})
memory_state = session.state.get(InMemoryHistoryProvider.DEFAULT_SOURCE_ID, {})
messages = memory_state.get("messages", [])
if messages:
print(f"\nSession contains {len(messages)} messages in history")
@@ -4,7 +4,7 @@ import asyncio
from random import randint
from typing import Annotated
from agent_framework import Agent, AgentSession, tool
from agent_framework import Agent, AgentSession, InMemoryHistoryProvider, tool
from agent_framework.openai import OpenAIChatClient
from pydantic import Field
@@ -105,7 +105,7 @@ async def example_with_existing_session_messages() -> None:
print(f"Agent: {result1.text}")
# The session now contains the conversation history in state
memory_state = session.state.get("memory", {})
memory_state = session.state.get(InMemoryHistoryProvider.DEFAULT_SOURCE_ID, {})
messages = memory_state.get("messages", [])
if messages:
print(f"Session contains {len(messages)} messages")
@@ -59,17 +59,17 @@ async def main() -> None:
credential=AzureCliCredential(),
)
# set the same context provider, with the same source_id, for both agents to share the thread
# set the same context provider (same default source_id) for both agents to share the thread
writer = client.as_agent(
instructions=("You are a concise copywriter. Provide a single, punchy marketing sentence based on the prompt."),
name="writer",
context_providers=[InMemoryHistoryProvider("memory")],
context_providers=[InMemoryHistoryProvider()],
)
reviewer = client.as_agent(
instructions=("You are a thoughtful reviewer. Give brief feedback on the previous assistant message."),
name="reviewer",
context_providers=[InMemoryHistoryProvider("memory")],
context_providers=[InMemoryHistoryProvider()],
)
# Create the shared session
@@ -96,7 +96,7 @@ async def main() -> None:
# The shared session now contains the conversation between the writer and reviewer. Print it out.
print("=== Shared Session Conversation ===")
memory_state = shared_session.state.get("memory", {})
memory_state = shared_session.state.get(InMemoryHistoryProvider.DEFAULT_SOURCE_ID, {})
for message in memory_state.get("messages", []):
print(f"{message.author_name or message.role}: {message.text}")
@@ -3,7 +3,7 @@
import asyncio
import os
from agent_framework import AgentSession
from agent_framework import AgentSession, InMemoryHistoryProvider
from agent_framework.azure import AzureOpenAIResponsesClient
from agent_framework.orchestrations import SequentialBuilder
from azure.identity import AzureCliCredential
@@ -109,7 +109,7 @@ async def main() -> None:
print("\n" + "=" * 60)
print("Full Session History")
print("=" * 60)
memory_state = session.state.get("memory", {})
memory_state = session.state.get(InMemoryHistoryProvider.DEFAULT_SOURCE_ID, {})
history = memory_state.get("messages", [])
for i, msg in enumerate(history, start=1):
role = msg.role if hasattr(msg.role, "value") else str(msg.role)
@@ -29,6 +29,7 @@ import os
from agent_framework import (
InMemoryCheckpointStorage,
InMemoryHistoryProvider,
)
from agent_framework.azure import AzureOpenAIResponsesClient
from agent_framework.orchestrations import SequentialBuilder
@@ -122,7 +123,7 @@ async def checkpointing_with_thread() -> None:
checkpoints = await checkpoint_storage.list_checkpoints(workflow_name=workflow.name)
print(f"\nTotal checkpoints across both turns: {len(checkpoints)}")
memory_state = session.state.get("memory", {})
memory_state = session.state.get(InMemoryHistoryProvider.DEFAULT_SOURCE_ID, {})
history = memory_state.get("messages", [])
print(f"Messages in session history: {len(history)}")
@@ -23,12 +23,14 @@ class UserInfo(BaseModel):
class UserInfoMemory(BaseContextProvider):
"""Context provider that extracts and remembers user info (name, age).
State is stored in ``session.state["user-info-memory"]`` so it survives
State is stored in ``session.state["user_info_memory"]`` so it survives
serialization via ``session.to_dict()`` / ``AgentSession.from_dict()``.
"""
DEFAULT_SOURCE_ID = "user_info_memory"
def __init__(self, client: SupportsChatGetResponse):
super().__init__("user-info-memory")
super().__init__(self.DEFAULT_SOURCE_ID)
self._chat_client = client
async def before_run(
@@ -40,8 +42,7 @@ class UserInfoMemory(BaseContextProvider):
state: dict[str, Any],
) -> None:
"""Provide user information context before each agent call."""
my_state = state.setdefault(self.source_id, {})
user_info = my_state.setdefault("user_info", UserInfo())
user_info = state.setdefault("user_info", UserInfo())
instructions: list[str] = []
@@ -70,8 +71,7 @@ class UserInfoMemory(BaseContextProvider):
state: dict[str, Any],
) -> None:
"""Extract user information from messages after each agent call."""
my_state = state.setdefault(self.source_id, {})
user_info = my_state.setdefault("user_info", UserInfo())
user_info = state.setdefault("user_info", UserInfo())
if user_info.name is not None and user_info.age is not None:
return # Already have everything
@@ -92,7 +92,7 @@ class UserInfoMemory(BaseContextProvider):
user_info.name = extracted.name
if extracted and user_info.age is None and extracted.age:
user_info.age = extracted.age
state.setdefault(self.source_id, {})["user_info"] = user_info
state["user_info"] = user_info
except Exception:
pass # Failed to extract, continue without updating
@@ -113,7 +113,7 @@ async def main():
print(await agent.run("I am 20 years old", session=session))
# Inspect extracted user info from session state
user_info = session.state.get("user-info-memory", {}).get("user_info", UserInfo())
user_info = session.state.get(UserInfoMemory.DEFAULT_SOURCE_ID, {}).get("user_info", UserInfo())
print()
print(f"MEMORY - User Name: {user_info.name}")
print(f"MEMORY - User Age: {user_info.age}")