mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: [BREAKING] PR2 — Wire context provider pipeline, remove old types, update all consumers (#3850)
* PR2: Wire context provider pipeline and update all internal consumers - Replace AgentThread with AgentSession across all packages - Replace ContextProvider with BaseContextProvider across all packages - Replace context_provider param with context_providers (Sequence) - Replace thread= with session= in run() signatures - Replace get_new_thread() with create_session() - Add get_session(service_session_id) to agent interface - DurableAgentThread -> DurableAgentSession - Remove _notify_thread_of_new_messages from WorkflowAgent - Wire before_run/after_run context provider pipeline in RawAgent - Auto-inject InMemoryHistoryProvider when no providers configured * fix: update all tests for context provider pipeline, fix lazy-loaders, remove old test files * refactor: update all sample files for context provider pipeline (AgentThread→AgentSession, ContextProvider→BaseContextProvider) * fix: update remaining ag-ui references (client docstring, getting_started sample) * fix: make get_session service_session_id keyword-only to avoid confusion with session_id * refactor: rename _RunContext.thread_messages to session_messages * refactor: remove _threads.py, _memory.py, and old provider files; migrate devui to use plain message lists * rename: remove _new_ prefix from test files * refactor: rewrite SlidingWindowChatMessageStore as SlidingWindowHistoryProvider(InMemoryHistoryProvider) * fix: read full history from session state directly instead of reaching into provider internals * fix: update stale .pyi stubs, sample imports, and README references for new provider types * fix: remove stale message_store, _notify_thread_of_new_messages, and session_id.key references in samples * refactor: merge context_providers and sessions sample folders into sessions, remove aggregate_context_provider * refactor: UserInfoMemory stores state in session.state instead of instance attributes * feat: add Pydantic BaseModel support to session state serialization Pydantic models stored in session.state are now automatically serialized via model_dump() and restored via model_validate() during to_dict()/from_dict() round-trips. Models are auto-registered on first serialization; use register_state_type() for cold-start deserialization. Also export register_state_type as a public API. * fix mem0 * Update sample README links and descriptions for session terminology - Replace 'thread' with 'session' in sample descriptions across all READMEs - Update file links for renamed samples (mem0_sessions, redis_sessions, etc.) - Fix Threads section → Sessions section in main samples/README.md - Update tools, middleware, workflows, durabletask, azure_functions READMEs - Update architecture diagrams in concepts/tools/README.md - Update migration guides (autogen, semantic-kernel) * Fix broken Redis README link to renamed sample * Fix Mem0 OSS client search: pass scoping params as direct kwargs AsyncMemory (OSS) expects user_id/agent_id/run_id as direct kwargs, while AsyncMemoryClient (Platform) expects them in a filters dict. Adds tests for both client types. Port of fix from #3844 to new Mem0ContextProvider. * Fix rebase issues: restore missing _conversation_state.py and checkpoint decode logic - Add back _conversation_state.py (encode/decode_chat_messages) lost in rebase - Fix on_checkpoint_restore to decode cache/conversation with decode_chat_messages - Fix on_checkpoint_restore to use decode_checkpoint_value for pending requests - Add tests/workflow/__init__.py for relative import support - Fix test_agent_executor checkpoint selection (checkpoints[1] not superstep) * Add STORES_BY_DEFAULT ClassVar to skip redundant InMemoryHistoryProvider injection Chat clients that store history server-side by default (OpenAI Responses API, Azure AI Agent) now declare STORES_BY_DEFAULT = True. The agent checks this during auto-injection and skips InMemoryHistoryProvider unless the user explicitly sets store=False. * Fix broken markdown links in azure_ai and redis READMEs * Fix getting-started samples to use session API instead of removed thread/ContextProvider API * updates to workflow as agent * fix group chat import * Rename Thread→Session throughout, fix service_session_id propagation, remove stale AGUIThread - Fix: Propagate conversation_id from ChatResponse back to session.service_session_id in both streaming and non-streaming paths in _agents.py - Rename AgentThreadException → AgentSessionException - Remove stale AGUIThread from ag_ui lazy-loader - Rename use_service_thread → use_service_session in ag-ui package - Rename test functions from *_thread_* to *_session_* - Rename sample files from *_thread* to *_session* - Update docstrings and comments: thread → session - Update _mcp.py kwargs filter: add 'session' alongside 'thread' - Fix ContinuationToken docstring example: thread=thread → session=session - Fix _clients.py docstring: 'Agent threads' → 'Agent sessions' * Fix broken markdown links after thread→session file renames * fix azure ai test
This commit is contained in:
committed by
GitHub
Unverified
parent
0c67dbbce5
commit
1e350ea22f
@@ -1,63 +1,57 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import tiktoken
|
||||
from agent_framework import ChatMessageStore, Message
|
||||
from agent_framework import InMemoryHistoryProvider, Message
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class SlidingWindowChatMessageStore(ChatMessageStore):
|
||||
"""A token-aware sliding window implementation of ChatMessageStore.
|
||||
class SlidingWindowHistoryProvider(InMemoryHistoryProvider):
|
||||
"""A token-aware sliding window implementation of InMemoryHistoryProvider.
|
||||
|
||||
Maintains two message lists: complete history and truncated window.
|
||||
Automatically removes oldest messages when token limit is exceeded.
|
||||
Also removes leading tool messages to ensure valid conversation flow.
|
||||
Stores all messages in session state but returns a truncated window from
|
||||
``get_messages`` that fits within ``max_tokens``. Automatically removes
|
||||
oldest messages and leading tool messages to ensure valid conversation flow.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
messages: Sequence[Message] | None = None,
|
||||
source_id: str = "memory",
|
||||
*,
|
||||
max_tokens: int = 3800,
|
||||
system_message: str | None = None,
|
||||
tool_definitions: Any | None = None,
|
||||
):
|
||||
super().__init__(messages=messages)
|
||||
self.truncated_messages = self.messages.copy()
|
||||
super().__init__(source_id)
|
||||
self.max_tokens = max_tokens
|
||||
self.system_message = system_message # Included in token count
|
||||
self.tool_definitions = tool_definitions
|
||||
# An estimation based on a commonly used vocab table
|
||||
self.encoding = tiktoken.get_encoding("o200k_base")
|
||||
|
||||
async def add_messages(self, messages: Sequence[Message]) -> None:
|
||||
await super().add_messages(messages)
|
||||
async def get_messages(
|
||||
self, session_id: str | None, *, state: dict[str, Any] | None = None, **kwargs: Any
|
||||
) -> list[Message]:
|
||||
"""Retrieve messages from session state, truncated to fit within max_tokens."""
|
||||
all_messages = await super().get_messages(session_id, state=state, **kwargs)
|
||||
return self._truncate(list(all_messages))
|
||||
|
||||
self.truncated_messages = self.messages.copy()
|
||||
self.truncate_messages()
|
||||
|
||||
async def list_messages(self) -> list[Message]:
|
||||
"""Get the current list of messages, which may be truncated."""
|
||||
return self.truncated_messages
|
||||
|
||||
async def list_all_messages(self) -> list[Message]:
|
||||
"""Get all messages from the store including the truncated ones."""
|
||||
return self.messages
|
||||
|
||||
def truncate_messages(self) -> None:
|
||||
while len(self.truncated_messages) > 0 and self.get_token_count() > self.max_tokens:
|
||||
def _truncate(self, messages: list[Message]) -> list[Message]:
|
||||
"""Truncate messages to fit within max_tokens and remove leading tool messages."""
|
||||
while len(messages) > 0 and self._get_token_count(messages) > self.max_tokens:
|
||||
logger.warning("Messages exceed max tokens. Truncating oldest message.")
|
||||
self.truncated_messages.pop(0)
|
||||
messages.pop(0)
|
||||
# Remove leading tool messages
|
||||
while len(self.truncated_messages) > 0:
|
||||
if self.truncated_messages[0].role != "tool":
|
||||
while len(messages) > 0:
|
||||
if messages[0].role != "tool":
|
||||
break
|
||||
logger.warning("Removing leading tool message because tool result cannot be the first message.")
|
||||
self.truncated_messages.pop(0)
|
||||
messages.pop(0)
|
||||
return messages
|
||||
|
||||
def get_token_count(self) -> int:
|
||||
def _get_token_count(self, messages: list[Message]) -> int:
|
||||
"""Estimate token count for a list of messages using tiktoken.
|
||||
|
||||
Returns:
|
||||
@@ -70,7 +64,7 @@ class SlidingWindowChatMessageStore(ChatMessageStore):
|
||||
total_tokens += len(self.encoding.encode(self.system_message))
|
||||
total_tokens += 4 # Extra tokens for system message formatting
|
||||
|
||||
for msg in self.truncated_messages:
|
||||
for msg in messages:
|
||||
# Add 4 tokens per message for role, formatting, etc.
|
||||
total_tokens += 4
|
||||
|
||||
@@ -87,7 +81,7 @@ class SlidingWindowChatMessageStore(ChatMessageStore):
|
||||
"name": content.name,
|
||||
"arguments": content.arguments,
|
||||
}
|
||||
total_tokens += self.estimate_any_object_token_count(func_call_data)
|
||||
total_tokens += self._estimate_any_object_token_count(func_call_data)
|
||||
elif content.type == "function_result":
|
||||
total_tokens += 4
|
||||
# Serialize function result and count tokens
|
||||
@@ -95,19 +89,16 @@ class SlidingWindowChatMessageStore(ChatMessageStore):
|
||||
"call_id": content.call_id,
|
||||
"result": content.result,
|
||||
}
|
||||
total_tokens += self.estimate_any_object_token_count(func_result_data)
|
||||
total_tokens += self._estimate_any_object_token_count(func_result_data)
|
||||
else:
|
||||
# For other content types, serialize the whole content
|
||||
total_tokens += self.estimate_any_object_token_count(content)
|
||||
total_tokens += self._estimate_any_object_token_count(content)
|
||||
else:
|
||||
# Content without type, treat as text
|
||||
total_tokens += self.estimate_any_object_token_count(content)
|
||||
total_tokens += self._estimate_any_object_token_count(content)
|
||||
elif hasattr(msg, "text") and msg.text:
|
||||
# Simple text message
|
||||
total_tokens += self.estimate_any_object_token_count(msg.text)
|
||||
else:
|
||||
# Skip it
|
||||
pass
|
||||
total_tokens += self._estimate_any_object_token_count(msg.text)
|
||||
|
||||
if total_tokens > self.max_tokens / 2:
|
||||
logger.opt(colors=True).warning(
|
||||
@@ -122,7 +113,7 @@ class SlidingWindowChatMessageStore(ChatMessageStore):
|
||||
|
||||
return total_tokens
|
||||
|
||||
def estimate_any_object_token_count(self, obj: Any) -> int:
|
||||
def _estimate_any_object_token_count(self, obj: Any) -> int:
|
||||
try:
|
||||
serialized = json.dumps(obj)
|
||||
except Exception:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import cast
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import (
|
||||
Agent,
|
||||
@@ -32,7 +32,7 @@ from tau2.user.user_simulator import ( # type: ignore[import-untyped]
|
||||
from tau2.utils.utils import get_now # type: ignore[import-untyped]
|
||||
|
||||
from ._message_utils import flip_messages, log_messages
|
||||
from ._sliding_window import SlidingWindowChatMessageStore
|
||||
from ._sliding_window import SlidingWindowHistoryProvider
|
||||
from ._tau2_utils import convert_agent_framework_messages_to_tau2_messages, convert_tau2_tool_to_function_tool
|
||||
|
||||
__all__ = ["ASSISTANT_AGENT_ID", "ORCHESTRATOR_ID", "USER_SIMULATOR_ID", "TaskRunner"]
|
||||
@@ -201,11 +201,13 @@ class TaskRunner:
|
||||
instructions=assistant_system_prompt,
|
||||
tools=tools,
|
||||
temperature=self.assistant_sampling_temperature,
|
||||
chat_message_store_factory=lambda: SlidingWindowChatMessageStore(
|
||||
system_message=assistant_system_prompt,
|
||||
tool_definitions=[tool.openai_schema for tool in tools],
|
||||
max_tokens=self.assistant_window_size,
|
||||
),
|
||||
context_providers=[
|
||||
SlidingWindowHistoryProvider(
|
||||
system_message=assistant_system_prompt,
|
||||
tool_definitions=[tool.openai_schema for tool in tools],
|
||||
max_tokens=self.assistant_window_size,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def user_simulator(self, user_simuator_chat_client: SupportsChatGetResponse, task: Task) -> Agent:
|
||||
@@ -354,11 +356,11 @@ class TaskRunner:
|
||||
# STEP 5: Ensemble the conversation history needed for evaluation.
|
||||
# It's coming from three parts:
|
||||
# 1. The initial greeting
|
||||
# 2. The assistant's message store (not just the truncated window)
|
||||
# 2. The assistant's session state (full history, not just the truncated window)
|
||||
# 3. The final user message (if any)
|
||||
assistant_executor = cast(AgentExecutor, self._assistant_executor)
|
||||
message_store = cast(SlidingWindowChatMessageStore, assistant_executor._agent_thread.message_store)
|
||||
full_conversation = [first_message] + await message_store.list_all_messages()
|
||||
session_state: dict[str, Any] = self._assistant_executor._session.state # type: ignore
|
||||
all_messages: list[Message] = list(session_state.get("memory", {}).get("messages", [])) # type: ignore
|
||||
full_conversation = [first_message, *all_messages]
|
||||
if self._final_user_message is not None:
|
||||
full_conversation.extend(self._final_user_message)
|
||||
|
||||
|
||||
@@ -1,145 +1,120 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Tests for sliding window message list."""
|
||||
"""Tests for sliding window history provider."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from agent_framework._types import Content, Message
|
||||
from agent_framework_lab_tau2._sliding_window import SlidingWindowChatMessageStore
|
||||
from agent_framework_lab_tau2._sliding_window import SlidingWindowHistoryProvider
|
||||
|
||||
|
||||
def test_initialization_empty():
|
||||
"""Test initializing with no messages."""
|
||||
sliding_window = SlidingWindowChatMessageStore(max_tokens=1000)
|
||||
|
||||
assert sliding_window.max_tokens == 1000
|
||||
assert sliding_window.system_message is None
|
||||
assert sliding_window.tool_definitions is None
|
||||
assert len(sliding_window.messages) == 0
|
||||
assert len(sliding_window.truncated_messages) == 0
|
||||
def _make_state(provider: SlidingWindowHistoryProvider, messages: list[Message] | None = None) -> dict:
|
||||
"""Helper to create a session state dict with messages pre-loaded."""
|
||||
state: dict = {}
|
||||
if messages:
|
||||
state[provider.source_id] = {"messages": list(messages)}
|
||||
return state
|
||||
|
||||
|
||||
def test_initialization_with_parameters():
|
||||
"""Test initializing with system message and tool definitions."""
|
||||
system_msg = "You are a helpful assistant"
|
||||
tool_defs = [{"name": "test_tool", "description": "A test tool"}]
|
||||
|
||||
sliding_window = SlidingWindowChatMessageStore(
|
||||
max_tokens=2000, system_message=system_msg, tool_definitions=tool_defs
|
||||
def test_initialization():
|
||||
"""Test initializing with parameters."""
|
||||
provider = SlidingWindowHistoryProvider(
|
||||
max_tokens=2000,
|
||||
system_message="You are a helpful assistant",
|
||||
tool_definitions=[{"name": "test_tool"}],
|
||||
)
|
||||
|
||||
assert sliding_window.max_tokens == 2000
|
||||
assert sliding_window.system_message == system_msg
|
||||
assert sliding_window.tool_definitions == tool_defs
|
||||
assert provider.max_tokens == 2000
|
||||
assert provider.system_message == "You are a helpful assistant"
|
||||
assert provider.tool_definitions == [{"name": "test_tool"}]
|
||||
assert provider.source_id == "memory"
|
||||
|
||||
|
||||
def test_initialization_with_messages():
|
||||
"""Test initializing with existing messages."""
|
||||
messages = [
|
||||
Message(role="user", contents=[Content.from_text(text="Hello")]),
|
||||
Message(role="assistant", contents=[Content.from_text(text="Hi there!")]),
|
||||
]
|
||||
|
||||
sliding_window = SlidingWindowChatMessageStore(messages=messages, max_tokens=1000)
|
||||
|
||||
assert len(sliding_window.messages) == 2
|
||||
assert len(sliding_window.truncated_messages) == 2
|
||||
async def test_get_messages_empty():
|
||||
"""Test getting messages from empty state."""
|
||||
provider = SlidingWindowHistoryProvider(max_tokens=1000)
|
||||
messages = await provider.get_messages(None, state={})
|
||||
assert messages == []
|
||||
|
||||
|
||||
async def test_add_messages_simple():
|
||||
"""Test adding messages without truncation."""
|
||||
sliding_window = SlidingWindowChatMessageStore(max_tokens=10000) # Large limit
|
||||
|
||||
new_messages = [
|
||||
async def test_get_messages_simple():
|
||||
"""Test getting messages without truncation."""
|
||||
provider = SlidingWindowHistoryProvider(max_tokens=10000)
|
||||
msgs = [
|
||||
Message(role="user", contents=[Content.from_text(text="What's the weather?")]),
|
||||
Message(role="assistant", contents=[Content.from_text(text="I can help with that.")]),
|
||||
]
|
||||
state = _make_state(provider, msgs)
|
||||
|
||||
await sliding_window.add_messages(new_messages)
|
||||
|
||||
messages = await sliding_window.list_messages()
|
||||
assert len(messages) == 2
|
||||
assert messages[0].text == "What's the weather?"
|
||||
assert messages[1].text == "I can help with that."
|
||||
result = await provider.get_messages(None, state=state)
|
||||
assert len(result) == 2
|
||||
assert result[0].text == "What's the weather?"
|
||||
assert result[1].text == "I can help with that."
|
||||
|
||||
|
||||
async def test_list_all_messages_vs_list_messages():
|
||||
"""Test difference between list_all_messages and list_messages."""
|
||||
sliding_window = SlidingWindowChatMessageStore(max_tokens=50) # Small limit to force truncation
|
||||
async def test_save_and_get_messages():
|
||||
"""Test saving then getting messages with truncation."""
|
||||
provider = SlidingWindowHistoryProvider(max_tokens=50)
|
||||
state: dict = {}
|
||||
|
||||
# Add many messages to trigger truncation
|
||||
messages = [
|
||||
# Save many messages
|
||||
msgs = [
|
||||
Message(role="user", contents=[Content.from_text(text=f"Message {i} with some content")]) for i in range(10)
|
||||
]
|
||||
await provider.save_messages(None, msgs, state=state)
|
||||
|
||||
await sliding_window.add_messages(messages)
|
||||
# get_messages returns truncated
|
||||
truncated = await provider.get_messages(None, state=state)
|
||||
# Full history is in session state
|
||||
all_msgs = state[provider.source_id]["messages"]
|
||||
|
||||
truncated_messages = await sliding_window.list_messages()
|
||||
all_messages = await sliding_window.list_all_messages()
|
||||
|
||||
# All messages should contain everything
|
||||
assert len(all_messages) == 10
|
||||
|
||||
# Truncated messages should be fewer due to token limit
|
||||
assert len(truncated_messages) < len(all_messages)
|
||||
assert len(all_msgs) == 10
|
||||
assert len(truncated) < len(all_msgs)
|
||||
|
||||
|
||||
def test_get_token_count_basic():
|
||||
"""Test basic token counting."""
|
||||
sliding_window = SlidingWindowChatMessageStore(max_tokens=1000)
|
||||
sliding_window.truncated_messages = [Message(role="user", contents=[Content.from_text(text="Hello")])]
|
||||
provider = SlidingWindowHistoryProvider(max_tokens=1000)
|
||||
messages = [Message(role="user", contents=[Content.from_text(text="Hello")])]
|
||||
|
||||
token_count = sliding_window.get_token_count()
|
||||
|
||||
# Should be more than 0 (exact count depends on encoding)
|
||||
token_count = provider._get_token_count(messages)
|
||||
assert token_count > 0
|
||||
|
||||
|
||||
def test_get_token_count_with_system_message():
|
||||
"""Test token counting includes system message."""
|
||||
system_msg = "You are a helpful assistant"
|
||||
sliding_window = SlidingWindowChatMessageStore(max_tokens=1000, system_message=system_msg)
|
||||
provider = SlidingWindowHistoryProvider(max_tokens=1000, system_message="You are a helpful assistant")
|
||||
|
||||
# Without messages
|
||||
token_count_empty = sliding_window.get_token_count()
|
||||
count_empty = provider._get_token_count([])
|
||||
count_with_msg = provider._get_token_count([Message(role="user", contents=[Content.from_text(text="Hello")])])
|
||||
|
||||
# Add a message
|
||||
sliding_window.truncated_messages = [Message(role="user", contents=[Content.from_text(text="Hello")])]
|
||||
token_count_with_message = sliding_window.get_token_count()
|
||||
|
||||
# With message should be more tokens
|
||||
assert token_count_with_message > token_count_empty
|
||||
assert token_count_empty > 0 # System message contributes tokens
|
||||
assert count_with_msg > count_empty
|
||||
assert count_empty > 0 # System message contributes tokens
|
||||
|
||||
|
||||
def test_get_token_count_function_call():
|
||||
"""Test token counting with function calls."""
|
||||
function_call = Content.from_function_call(call_id="call_123", name="test_function", arguments={"param": "value"})
|
||||
provider = SlidingWindowHistoryProvider(max_tokens=1000)
|
||||
|
||||
sliding_window = SlidingWindowChatMessageStore(max_tokens=1000)
|
||||
sliding_window.truncated_messages = [Message(role="assistant", contents=[function_call])]
|
||||
|
||||
token_count = sliding_window.get_token_count()
|
||||
token_count = provider._get_token_count([Message(role="assistant", contents=[function_call])])
|
||||
assert token_count > 0
|
||||
|
||||
|
||||
def test_get_token_count_function_result():
|
||||
"""Test token counting with function results."""
|
||||
function_result = Content.from_function_result(call_id="call_123", result={"success": True, "data": "result"})
|
||||
provider = SlidingWindowHistoryProvider(max_tokens=1000)
|
||||
|
||||
sliding_window = SlidingWindowChatMessageStore(max_tokens=1000)
|
||||
sliding_window.truncated_messages = [Message(role="tool", contents=[function_result])]
|
||||
|
||||
token_count = sliding_window.get_token_count()
|
||||
token_count = provider._get_token_count([Message(role="tool", contents=[function_result])])
|
||||
assert token_count > 0
|
||||
|
||||
|
||||
@patch("agent_framework_lab_tau2._sliding_window.logger")
|
||||
def test_truncate_messages_removes_old_messages(mock_logger):
|
||||
def test_truncate_removes_old_messages(mock_logger):
|
||||
"""Test that truncation removes old messages when token limit exceeded."""
|
||||
sliding_window = SlidingWindowChatMessageStore(max_tokens=20) # Very small limit
|
||||
provider = SlidingWindowHistoryProvider(max_tokens=20)
|
||||
|
||||
# Create messages that will exceed the limit
|
||||
messages = [
|
||||
Message(
|
||||
role="user",
|
||||
@@ -154,80 +129,45 @@ def test_truncate_messages_removes_old_messages(mock_logger):
|
||||
Message(role="user", contents=[Content.from_text(text="Short msg")]),
|
||||
]
|
||||
|
||||
sliding_window.truncated_messages = messages.copy()
|
||||
sliding_window.truncate_messages()
|
||||
|
||||
# Should have fewer messages after truncation
|
||||
assert len(sliding_window.truncated_messages) < len(messages)
|
||||
|
||||
# Should have logged warnings
|
||||
result = provider._truncate(list(messages))
|
||||
assert len(result) < len(messages)
|
||||
assert mock_logger.warning.called
|
||||
|
||||
|
||||
@patch("agent_framework_lab_tau2._sliding_window.logger")
|
||||
def test_truncate_messages_removes_leading_tool_messages(mock_logger):
|
||||
def test_truncate_removes_leading_tool_messages(mock_logger):
|
||||
"""Test that truncation removes leading tool messages."""
|
||||
sliding_window = SlidingWindowChatMessageStore(max_tokens=10000) # Large limit
|
||||
provider = SlidingWindowHistoryProvider(max_tokens=10000)
|
||||
|
||||
# Create messages starting with tool message
|
||||
tool_message = Message(role="tool", contents=[Content.from_function_result(call_id="call_123", result="result")])
|
||||
user_message = Message(role="user", contents=[Content.from_text(text="Hello")])
|
||||
|
||||
sliding_window.truncated_messages = [tool_message, user_message]
|
||||
sliding_window.truncate_messages()
|
||||
|
||||
# Tool message should be removed from the beginning
|
||||
assert len(sliding_window.truncated_messages) == 1
|
||||
assert sliding_window.truncated_messages[0].role == "user"
|
||||
|
||||
# Should have logged warning about removing tool message
|
||||
result = provider._truncate([tool_message, user_message])
|
||||
assert len(result) == 1
|
||||
assert result[0].role == "user"
|
||||
mock_logger.warning.assert_called()
|
||||
|
||||
|
||||
def test_estimate_any_object_token_count_dict():
|
||||
"""Test token counting for dictionary objects."""
|
||||
sliding_window = SlidingWindowChatMessageStore(max_tokens=1000)
|
||||
def test_estimate_any_object_token_count():
|
||||
"""Test token counting for various object types."""
|
||||
provider = SlidingWindowHistoryProvider(max_tokens=1000)
|
||||
|
||||
test_dict = {"key": "value", "number": 42}
|
||||
token_count = sliding_window.estimate_any_object_token_count(test_dict)
|
||||
assert provider._estimate_any_object_token_count({"key": "value"}) > 0
|
||||
assert provider._estimate_any_object_token_count("test string") > 0
|
||||
|
||||
assert token_count > 0
|
||||
|
||||
|
||||
def test_estimate_any_object_token_count_string():
|
||||
"""Test token counting for string objects."""
|
||||
sliding_window = SlidingWindowChatMessageStore(max_tokens=1000)
|
||||
|
||||
test_string = "This is a test string"
|
||||
token_count = sliding_window.estimate_any_object_token_count(test_string)
|
||||
|
||||
assert token_count > 0
|
||||
|
||||
|
||||
def test_estimate_any_object_token_count_non_serializable():
|
||||
"""Test token counting for non-JSON-serializable objects."""
|
||||
sliding_window = SlidingWindowChatMessageStore(max_tokens=1000)
|
||||
|
||||
# Create an object that can't be JSON serialized
|
||||
class CustomObject:
|
||||
# Non-serializable falls back to str()
|
||||
class Custom:
|
||||
def __str__(self):
|
||||
return "CustomObject instance"
|
||||
return "Custom instance"
|
||||
|
||||
custom_obj = CustomObject()
|
||||
token_count = sliding_window.estimate_any_object_token_count(custom_obj)
|
||||
|
||||
# Should fall back to string representation
|
||||
assert token_count > 0
|
||||
assert provider._estimate_any_object_token_count(Custom()) > 0
|
||||
|
||||
|
||||
async def test_real_world_scenario():
|
||||
"""Test a realistic conversation scenario."""
|
||||
sliding_window = SlidingWindowChatMessageStore(
|
||||
max_tokens=30,
|
||||
system_message="You are a helpful assistant", # Moderate limit
|
||||
)
|
||||
provider = SlidingWindowHistoryProvider(max_tokens=30, system_message="You are a helpful assistant")
|
||||
state: dict = {}
|
||||
|
||||
# Simulate a conversation
|
||||
conversation = [
|
||||
Message(role="user", contents=[Content.from_text(text="Hello, how are you?")]),
|
||||
Message(
|
||||
@@ -253,18 +193,13 @@ async def test_real_world_scenario():
|
||||
),
|
||||
]
|
||||
|
||||
await sliding_window.add_messages(conversation)
|
||||
await provider.save_messages(None, conversation, state=state)
|
||||
|
||||
current_messages = await sliding_window.list_messages()
|
||||
all_messages = await sliding_window.list_all_messages()
|
||||
truncated = await provider.get_messages(None, state=state)
|
||||
all_msgs = state[provider.source_id]["messages"]
|
||||
|
||||
# All messages should be preserved
|
||||
assert len(all_messages) == 6
|
||||
assert len(all_msgs) == 6
|
||||
assert len(truncated) <= 6
|
||||
|
||||
# Current messages might be truncated
|
||||
assert len(current_messages) <= 6
|
||||
|
||||
# Token count should be within or close to limit
|
||||
token_count = sliding_window.get_token_count()
|
||||
# Allow some margin since truncation happens when exceeded
|
||||
assert token_count <= sliding_window.max_tokens * 1.1
|
||||
token_count = provider._get_token_count(truncated)
|
||||
assert token_count <= provider.max_tokens * 1.1
|
||||
|
||||
Reference in New Issue
Block a user