Python: [BREAKING] PR2 — Wire context provider pipeline, remove old types, update all consumers (#3850)

* PR2: Wire context provider pipeline and update all internal consumers

- Replace AgentThread with AgentSession across all packages
- Replace ContextProvider with BaseContextProvider across all packages
- Replace context_provider param with context_providers (Sequence)
- Replace thread= with session= in run() signatures
- Replace get_new_thread() with create_session()
- Add get_session(service_session_id) to agent interface
- DurableAgentThread -> DurableAgentSession
- Remove _notify_thread_of_new_messages from WorkflowAgent
- Wire before_run/after_run context provider pipeline in RawAgent
- Auto-inject InMemoryHistoryProvider when no providers configured

* fix: update all tests for context provider pipeline, fix lazy-loaders, remove old test files

* refactor: update all sample files for context provider pipeline (AgentThread→AgentSession, ContextProvider→BaseContextProvider)

* fix: update remaining ag-ui references (client docstring, getting_started sample)

* fix: make get_session service_session_id keyword-only to avoid confusion with session_id

* refactor: rename _RunContext.thread_messages to session_messages

* refactor: remove _threads.py, _memory.py, and old provider files; migrate devui to use plain message lists

* rename: remove _new_ prefix from test files

* refactor: rewrite SlidingWindowChatMessageStore as SlidingWindowHistoryProvider(InMemoryHistoryProvider)

* fix: read full history from session state directly instead of reaching into provider internals

* fix: update stale .pyi stubs, sample imports, and README references for new provider types

* fix: remove stale message_store, _notify_thread_of_new_messages, and session_id.key references in samples

* refactor: merge context_providers and sessions sample folders into sessions, remove aggregate_context_provider

* refactor: UserInfoMemory stores state in session.state instead of instance attributes

* feat: add Pydantic BaseModel support to session state serialization

Pydantic models stored in session.state are now automatically serialized
via model_dump() and restored via model_validate() during to_dict()/from_dict()
round-trips. Models are auto-registered on first serialization; use
register_state_type() for cold-start deserialization.

Also export register_state_type as a public API.

* fix mem0

* Update sample README links and descriptions for session terminology

- Replace 'thread' with 'session' in sample descriptions across all READMEs
- Update file links for renamed samples (mem0_sessions, redis_sessions, etc.)
- Fix Threads section → Sessions section in main samples/README.md
- Update tools, middleware, workflows, durabletask, azure_functions READMEs
- Update architecture diagrams in concepts/tools/README.md
- Update migration guides (autogen, semantic-kernel)

* Fix broken Redis README link to renamed sample

* Fix Mem0 OSS client search: pass scoping params as direct kwargs

AsyncMemory (OSS) expects user_id/agent_id/run_id as direct kwargs,
while AsyncMemoryClient (Platform) expects them in a filters dict.
Adds tests for both client types.

Port of fix from #3844 to new Mem0ContextProvider.

* Fix rebase issues: restore missing _conversation_state.py and checkpoint decode logic

- Add back _conversation_state.py (encode/decode_chat_messages) lost in rebase
- Fix on_checkpoint_restore to decode cache/conversation with decode_chat_messages
- Fix on_checkpoint_restore to use decode_checkpoint_value for pending requests
- Add tests/workflow/__init__.py for relative import support
- Fix test_agent_executor checkpoint selection (checkpoints[1] not superstep)

* Add STORES_BY_DEFAULT ClassVar to skip redundant InMemoryHistoryProvider injection

Chat clients that store history server-side by default (OpenAI Responses API,
Azure AI Agent) now declare STORES_BY_DEFAULT = True. The agent checks this
during auto-injection and skips InMemoryHistoryProvider unless the user
explicitly sets store=False.

* Fix broken markdown links in azure_ai and redis READMEs

* Fix getting-started samples to use session API instead of removed thread/ContextProvider API

* updates to workflow as agent

* fix group chat import

* Rename Thread→Session throughout, fix service_session_id propagation, remove stale AGUIThread

- Fix: Propagate conversation_id from ChatResponse back to session.service_session_id
  in both streaming and non-streaming paths in _agents.py
- Rename AgentThreadException → AgentSessionException
- Remove stale AGUIThread from ag_ui lazy-loader
- Rename use_service_thread → use_service_session in ag-ui package
- Rename test functions from *_thread_* to *_session_*
- Rename sample files from *_thread* to *_session*
- Update docstrings and comments: thread → session
- Update _mcp.py kwargs filter: add 'session' alongside 'thread'
- Fix ContinuationToken docstring example: thread=thread → session=session
- Fix _clients.py docstring: 'Agent threads' → 'Agent sessions'

* Fix broken markdown links after thread→session file renames

* fix azure ai test
This commit is contained in:
Eduard van Valkenburg
2026-02-12 22:00:32 +01:00
committed by GitHub
Unverified
parent 0c67dbbce5
commit 1e350ea22f
312 changed files with 6669 additions and 11423 deletions
@@ -1,63 +1,57 @@
# Copyright (c) Microsoft. All rights reserved.
import json
from collections.abc import Sequence
from typing import Any
import tiktoken
from agent_framework import ChatMessageStore, Message
from agent_framework import InMemoryHistoryProvider, Message
from loguru import logger
class SlidingWindowChatMessageStore(ChatMessageStore):
"""A token-aware sliding window implementation of ChatMessageStore.
class SlidingWindowHistoryProvider(InMemoryHistoryProvider):
"""A token-aware sliding window implementation of InMemoryHistoryProvider.
Maintains two message lists: complete history and truncated window.
Automatically removes oldest messages when token limit is exceeded.
Also removes leading tool messages to ensure valid conversation flow.
Stores all messages in session state but returns a truncated window from
``get_messages`` that fits within ``max_tokens``. Automatically removes
oldest messages and leading tool messages to ensure valid conversation flow.
"""
def __init__(
self,
messages: Sequence[Message] | None = None,
source_id: str = "memory",
*,
max_tokens: int = 3800,
system_message: str | None = None,
tool_definitions: Any | None = None,
):
super().__init__(messages=messages)
self.truncated_messages = self.messages.copy()
super().__init__(source_id)
self.max_tokens = max_tokens
self.system_message = system_message # Included in token count
self.tool_definitions = tool_definitions
# An estimation based on a commonly used vocab table
self.encoding = tiktoken.get_encoding("o200k_base")
async def add_messages(self, messages: Sequence[Message]) -> None:
await super().add_messages(messages)
async def get_messages(
self, session_id: str | None, *, state: dict[str, Any] | None = None, **kwargs: Any
) -> list[Message]:
"""Retrieve messages from session state, truncated to fit within max_tokens."""
all_messages = await super().get_messages(session_id, state=state, **kwargs)
return self._truncate(list(all_messages))
self.truncated_messages = self.messages.copy()
self.truncate_messages()
async def list_messages(self) -> list[Message]:
"""Get the current list of messages, which may be truncated."""
return self.truncated_messages
async def list_all_messages(self) -> list[Message]:
"""Get all messages from the store including the truncated ones."""
return self.messages
def truncate_messages(self) -> None:
while len(self.truncated_messages) > 0 and self.get_token_count() > self.max_tokens:
def _truncate(self, messages: list[Message]) -> list[Message]:
"""Truncate messages to fit within max_tokens and remove leading tool messages."""
while len(messages) > 0 and self._get_token_count(messages) > self.max_tokens:
logger.warning("Messages exceed max tokens. Truncating oldest message.")
self.truncated_messages.pop(0)
messages.pop(0)
# Remove leading tool messages
while len(self.truncated_messages) > 0:
if self.truncated_messages[0].role != "tool":
while len(messages) > 0:
if messages[0].role != "tool":
break
logger.warning("Removing leading tool message because tool result cannot be the first message.")
self.truncated_messages.pop(0)
messages.pop(0)
return messages
def get_token_count(self) -> int:
def _get_token_count(self, messages: list[Message]) -> int:
"""Estimate token count for a list of messages using tiktoken.
Returns:
@@ -70,7 +64,7 @@ class SlidingWindowChatMessageStore(ChatMessageStore):
total_tokens += len(self.encoding.encode(self.system_message))
total_tokens += 4 # Extra tokens for system message formatting
for msg in self.truncated_messages:
for msg in messages:
# Add 4 tokens per message for role, formatting, etc.
total_tokens += 4
@@ -87,7 +81,7 @@ class SlidingWindowChatMessageStore(ChatMessageStore):
"name": content.name,
"arguments": content.arguments,
}
total_tokens += self.estimate_any_object_token_count(func_call_data)
total_tokens += self._estimate_any_object_token_count(func_call_data)
elif content.type == "function_result":
total_tokens += 4
# Serialize function result and count tokens
@@ -95,19 +89,16 @@ class SlidingWindowChatMessageStore(ChatMessageStore):
"call_id": content.call_id,
"result": content.result,
}
total_tokens += self.estimate_any_object_token_count(func_result_data)
total_tokens += self._estimate_any_object_token_count(func_result_data)
else:
# For other content types, serialize the whole content
total_tokens += self.estimate_any_object_token_count(content)
total_tokens += self._estimate_any_object_token_count(content)
else:
# Content without type, treat as text
total_tokens += self.estimate_any_object_token_count(content)
total_tokens += self._estimate_any_object_token_count(content)
elif hasattr(msg, "text") and msg.text:
# Simple text message
total_tokens += self.estimate_any_object_token_count(msg.text)
else:
# Skip it
pass
total_tokens += self._estimate_any_object_token_count(msg.text)
if total_tokens > self.max_tokens / 2:
logger.opt(colors=True).warning(
@@ -122,7 +113,7 @@ class SlidingWindowChatMessageStore(ChatMessageStore):
return total_tokens
def estimate_any_object_token_count(self, obj: Any) -> int:
def _estimate_any_object_token_count(self, obj: Any) -> int:
try:
serialized = json.dumps(obj)
except Exception:
@@ -3,7 +3,7 @@
from __future__ import annotations
import uuid
from typing import cast
from typing import Any
from agent_framework import (
Agent,
@@ -32,7 +32,7 @@ from tau2.user.user_simulator import ( # type: ignore[import-untyped]
from tau2.utils.utils import get_now # type: ignore[import-untyped]
from ._message_utils import flip_messages, log_messages
from ._sliding_window import SlidingWindowChatMessageStore
from ._sliding_window import SlidingWindowHistoryProvider
from ._tau2_utils import convert_agent_framework_messages_to_tau2_messages, convert_tau2_tool_to_function_tool
__all__ = ["ASSISTANT_AGENT_ID", "ORCHESTRATOR_ID", "USER_SIMULATOR_ID", "TaskRunner"]
@@ -201,11 +201,13 @@ class TaskRunner:
instructions=assistant_system_prompt,
tools=tools,
temperature=self.assistant_sampling_temperature,
chat_message_store_factory=lambda: SlidingWindowChatMessageStore(
system_message=assistant_system_prompt,
tool_definitions=[tool.openai_schema for tool in tools],
max_tokens=self.assistant_window_size,
),
context_providers=[
SlidingWindowHistoryProvider(
system_message=assistant_system_prompt,
tool_definitions=[tool.openai_schema for tool in tools],
max_tokens=self.assistant_window_size,
)
],
)
def user_simulator(self, user_simuator_chat_client: SupportsChatGetResponse, task: Task) -> Agent:
@@ -354,11 +356,11 @@ class TaskRunner:
# STEP 5: Ensemble the conversation history needed for evaluation.
# It's coming from three parts:
# 1. The initial greeting
# 2. The assistant's message store (not just the truncated window)
# 2. The assistant's session state (full history, not just the truncated window)
# 3. The final user message (if any)
assistant_executor = cast(AgentExecutor, self._assistant_executor)
message_store = cast(SlidingWindowChatMessageStore, assistant_executor._agent_thread.message_store)
full_conversation = [first_message] + await message_store.list_all_messages()
session_state: dict[str, Any] = self._assistant_executor._session.state # type: ignore
all_messages: list[Message] = list(session_state.get("memory", {}).get("messages", [])) # type: ignore
full_conversation = [first_message, *all_messages]
if self._final_user_message is not None:
full_conversation.extend(self._final_user_message)
@@ -1,145 +1,120 @@
# Copyright (c) Microsoft. All rights reserved.
"""Tests for sliding window message list."""
"""Tests for sliding window history provider."""
from unittest.mock import patch
from agent_framework._types import Content, Message
from agent_framework_lab_tau2._sliding_window import SlidingWindowChatMessageStore
from agent_framework_lab_tau2._sliding_window import SlidingWindowHistoryProvider
def test_initialization_empty():
"""Test initializing with no messages."""
sliding_window = SlidingWindowChatMessageStore(max_tokens=1000)
assert sliding_window.max_tokens == 1000
assert sliding_window.system_message is None
assert sliding_window.tool_definitions is None
assert len(sliding_window.messages) == 0
assert len(sliding_window.truncated_messages) == 0
def _make_state(provider: SlidingWindowHistoryProvider, messages: list[Message] | None = None) -> dict:
"""Helper to create a session state dict with messages pre-loaded."""
state: dict = {}
if messages:
state[provider.source_id] = {"messages": list(messages)}
return state
def test_initialization_with_parameters():
"""Test initializing with system message and tool definitions."""
system_msg = "You are a helpful assistant"
tool_defs = [{"name": "test_tool", "description": "A test tool"}]
sliding_window = SlidingWindowChatMessageStore(
max_tokens=2000, system_message=system_msg, tool_definitions=tool_defs
def test_initialization():
"""Test initializing with parameters."""
provider = SlidingWindowHistoryProvider(
max_tokens=2000,
system_message="You are a helpful assistant",
tool_definitions=[{"name": "test_tool"}],
)
assert sliding_window.max_tokens == 2000
assert sliding_window.system_message == system_msg
assert sliding_window.tool_definitions == tool_defs
assert provider.max_tokens == 2000
assert provider.system_message == "You are a helpful assistant"
assert provider.tool_definitions == [{"name": "test_tool"}]
assert provider.source_id == "memory"
def test_initialization_with_messages():
"""Test initializing with existing messages."""
messages = [
Message(role="user", contents=[Content.from_text(text="Hello")]),
Message(role="assistant", contents=[Content.from_text(text="Hi there!")]),
]
sliding_window = SlidingWindowChatMessageStore(messages=messages, max_tokens=1000)
assert len(sliding_window.messages) == 2
assert len(sliding_window.truncated_messages) == 2
async def test_get_messages_empty():
"""Test getting messages from empty state."""
provider = SlidingWindowHistoryProvider(max_tokens=1000)
messages = await provider.get_messages(None, state={})
assert messages == []
async def test_add_messages_simple():
"""Test adding messages without truncation."""
sliding_window = SlidingWindowChatMessageStore(max_tokens=10000) # Large limit
new_messages = [
async def test_get_messages_simple():
"""Test getting messages without truncation."""
provider = SlidingWindowHistoryProvider(max_tokens=10000)
msgs = [
Message(role="user", contents=[Content.from_text(text="What's the weather?")]),
Message(role="assistant", contents=[Content.from_text(text="I can help with that.")]),
]
state = _make_state(provider, msgs)
await sliding_window.add_messages(new_messages)
messages = await sliding_window.list_messages()
assert len(messages) == 2
assert messages[0].text == "What's the weather?"
assert messages[1].text == "I can help with that."
result = await provider.get_messages(None, state=state)
assert len(result) == 2
assert result[0].text == "What's the weather?"
assert result[1].text == "I can help with that."
async def test_list_all_messages_vs_list_messages():
"""Test difference between list_all_messages and list_messages."""
sliding_window = SlidingWindowChatMessageStore(max_tokens=50) # Small limit to force truncation
async def test_save_and_get_messages():
"""Test saving then getting messages with truncation."""
provider = SlidingWindowHistoryProvider(max_tokens=50)
state: dict = {}
# Add many messages to trigger truncation
messages = [
# Save many messages
msgs = [
Message(role="user", contents=[Content.from_text(text=f"Message {i} with some content")]) for i in range(10)
]
await provider.save_messages(None, msgs, state=state)
await sliding_window.add_messages(messages)
# get_messages returns truncated
truncated = await provider.get_messages(None, state=state)
# Full history is in session state
all_msgs = state[provider.source_id]["messages"]
truncated_messages = await sliding_window.list_messages()
all_messages = await sliding_window.list_all_messages()
# All messages should contain everything
assert len(all_messages) == 10
# Truncated messages should be fewer due to token limit
assert len(truncated_messages) < len(all_messages)
assert len(all_msgs) == 10
assert len(truncated) < len(all_msgs)
def test_get_token_count_basic():
"""Test basic token counting."""
sliding_window = SlidingWindowChatMessageStore(max_tokens=1000)
sliding_window.truncated_messages = [Message(role="user", contents=[Content.from_text(text="Hello")])]
provider = SlidingWindowHistoryProvider(max_tokens=1000)
messages = [Message(role="user", contents=[Content.from_text(text="Hello")])]
token_count = sliding_window.get_token_count()
# Should be more than 0 (exact count depends on encoding)
token_count = provider._get_token_count(messages)
assert token_count > 0
def test_get_token_count_with_system_message():
"""Test token counting includes system message."""
system_msg = "You are a helpful assistant"
sliding_window = SlidingWindowChatMessageStore(max_tokens=1000, system_message=system_msg)
provider = SlidingWindowHistoryProvider(max_tokens=1000, system_message="You are a helpful assistant")
# Without messages
token_count_empty = sliding_window.get_token_count()
count_empty = provider._get_token_count([])
count_with_msg = provider._get_token_count([Message(role="user", contents=[Content.from_text(text="Hello")])])
# Add a message
sliding_window.truncated_messages = [Message(role="user", contents=[Content.from_text(text="Hello")])]
token_count_with_message = sliding_window.get_token_count()
# With message should be more tokens
assert token_count_with_message > token_count_empty
assert token_count_empty > 0 # System message contributes tokens
assert count_with_msg > count_empty
assert count_empty > 0 # System message contributes tokens
def test_get_token_count_function_call():
"""Test token counting with function calls."""
function_call = Content.from_function_call(call_id="call_123", name="test_function", arguments={"param": "value"})
provider = SlidingWindowHistoryProvider(max_tokens=1000)
sliding_window = SlidingWindowChatMessageStore(max_tokens=1000)
sliding_window.truncated_messages = [Message(role="assistant", contents=[function_call])]
token_count = sliding_window.get_token_count()
token_count = provider._get_token_count([Message(role="assistant", contents=[function_call])])
assert token_count > 0
def test_get_token_count_function_result():
"""Test token counting with function results."""
function_result = Content.from_function_result(call_id="call_123", result={"success": True, "data": "result"})
provider = SlidingWindowHistoryProvider(max_tokens=1000)
sliding_window = SlidingWindowChatMessageStore(max_tokens=1000)
sliding_window.truncated_messages = [Message(role="tool", contents=[function_result])]
token_count = sliding_window.get_token_count()
token_count = provider._get_token_count([Message(role="tool", contents=[function_result])])
assert token_count > 0
@patch("agent_framework_lab_tau2._sliding_window.logger")
def test_truncate_messages_removes_old_messages(mock_logger):
def test_truncate_removes_old_messages(mock_logger):
"""Test that truncation removes old messages when token limit exceeded."""
sliding_window = SlidingWindowChatMessageStore(max_tokens=20) # Very small limit
provider = SlidingWindowHistoryProvider(max_tokens=20)
# Create messages that will exceed the limit
messages = [
Message(
role="user",
@@ -154,80 +129,45 @@ def test_truncate_messages_removes_old_messages(mock_logger):
Message(role="user", contents=[Content.from_text(text="Short msg")]),
]
sliding_window.truncated_messages = messages.copy()
sliding_window.truncate_messages()
# Should have fewer messages after truncation
assert len(sliding_window.truncated_messages) < len(messages)
# Should have logged warnings
result = provider._truncate(list(messages))
assert len(result) < len(messages)
assert mock_logger.warning.called
@patch("agent_framework_lab_tau2._sliding_window.logger")
def test_truncate_messages_removes_leading_tool_messages(mock_logger):
def test_truncate_removes_leading_tool_messages(mock_logger):
"""Test that truncation removes leading tool messages."""
sliding_window = SlidingWindowChatMessageStore(max_tokens=10000) # Large limit
provider = SlidingWindowHistoryProvider(max_tokens=10000)
# Create messages starting with tool message
tool_message = Message(role="tool", contents=[Content.from_function_result(call_id="call_123", result="result")])
user_message = Message(role="user", contents=[Content.from_text(text="Hello")])
sliding_window.truncated_messages = [tool_message, user_message]
sliding_window.truncate_messages()
# Tool message should be removed from the beginning
assert len(sliding_window.truncated_messages) == 1
assert sliding_window.truncated_messages[0].role == "user"
# Should have logged warning about removing tool message
result = provider._truncate([tool_message, user_message])
assert len(result) == 1
assert result[0].role == "user"
mock_logger.warning.assert_called()
def test_estimate_any_object_token_count_dict():
"""Test token counting for dictionary objects."""
sliding_window = SlidingWindowChatMessageStore(max_tokens=1000)
def test_estimate_any_object_token_count():
"""Test token counting for various object types."""
provider = SlidingWindowHistoryProvider(max_tokens=1000)
test_dict = {"key": "value", "number": 42}
token_count = sliding_window.estimate_any_object_token_count(test_dict)
assert provider._estimate_any_object_token_count({"key": "value"}) > 0
assert provider._estimate_any_object_token_count("test string") > 0
assert token_count > 0
def test_estimate_any_object_token_count_string():
"""Test token counting for string objects."""
sliding_window = SlidingWindowChatMessageStore(max_tokens=1000)
test_string = "This is a test string"
token_count = sliding_window.estimate_any_object_token_count(test_string)
assert token_count > 0
def test_estimate_any_object_token_count_non_serializable():
"""Test token counting for non-JSON-serializable objects."""
sliding_window = SlidingWindowChatMessageStore(max_tokens=1000)
# Create an object that can't be JSON serialized
class CustomObject:
# Non-serializable falls back to str()
class Custom:
def __str__(self):
return "CustomObject instance"
return "Custom instance"
custom_obj = CustomObject()
token_count = sliding_window.estimate_any_object_token_count(custom_obj)
# Should fall back to string representation
assert token_count > 0
assert provider._estimate_any_object_token_count(Custom()) > 0
async def test_real_world_scenario():
"""Test a realistic conversation scenario."""
sliding_window = SlidingWindowChatMessageStore(
max_tokens=30,
system_message="You are a helpful assistant", # Moderate limit
)
provider = SlidingWindowHistoryProvider(max_tokens=30, system_message="You are a helpful assistant")
state: dict = {}
# Simulate a conversation
conversation = [
Message(role="user", contents=[Content.from_text(text="Hello, how are you?")]),
Message(
@@ -253,18 +193,13 @@ async def test_real_world_scenario():
),
]
await sliding_window.add_messages(conversation)
await provider.save_messages(None, conversation, state=state)
current_messages = await sliding_window.list_messages()
all_messages = await sliding_window.list_all_messages()
truncated = await provider.get_messages(None, state=state)
all_msgs = state[provider.source_id]["messages"]
# All messages should be preserved
assert len(all_messages) == 6
assert len(all_msgs) == 6
assert len(truncated) <= 6
# Current messages might be truncated
assert len(current_messages) <= 6
# Token count should be within or close to limit
token_count = sliding_window.get_token_count()
# Allow some margin since truncation happens when exceeded
assert token_count <= sliding_window.max_tokens * 1.1
token_count = provider._get_token_count(truncated)
assert token_count <= provider.max_tokens * 1.1