mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
1ac68f65bf
* Initial plan * Fix: Replace alpha with linear_alpha in HybridQuery for redisvl 0.14.0 compatibility Co-authored-by: markwallace-microsoft <127216156+markwallace-microsoft@users.noreply.github.com> * Address code review: Improve test readability and add explanatory comment Co-authored-by: markwallace-microsoft <127216156+markwallace-microsoft@users.noreply.github.com> * Add CHANGELOG entry for redisvl 0.14.0 compatibility fix Co-authored-by: eavanvalkenburg <13749212+eavanvalkenburg@users.noreply.github.com> * Use AggregateHybridQuery instead of HybridQuery for backward compatibility Replace HybridQuery with AggregateHybridQuery to preserve existing functionality that works with older Redis versions. The new HybridQuery in redisvl 0.14.0 requires Redis 8.4.0+ and uses a different API, while AggregateHybridQuery maintains compatibility with the original implementation. Co-authored-by: TaoChenOSU <12570346+TaoChenOSU@users.noreply.github.com> * Fix test to use linear_alpha parameter matching _redis_search implementation The test was passing alpha as a keyword argument to _redis_search(), but the method uses linear_alpha to match the redisvl 0.14.0 AggregateHybridQuery API. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Fix pyright error: use alpha parameter matching AggregateHybridQuery API AggregateHybridQuery expects 'alpha', not 'linear_alpha'. Updated the _redis_search method parameter and the test accordingly. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: markwallace-microsoft <127216156+markwallace-microsoft@users.noreply.github.com> Co-authored-by: eavanvalkenburg <13749212+eavanvalkenburg@users.noreply.github.com> Co-authored-by: TaoChenOSU <12570346+TaoChenOSU@users.noreply.github.com> Co-authored-by: Ben Thomas <ben.thomas@microsoft.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
531 lines
23 KiB
Python
531 lines
23 KiB
Python
# Copyright (c) Microsoft. All rights reserved.
|
|
|
|
"""Tests for RedisContextProvider and RedisHistoryProvider."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
from agent_framework import AgentResponse, Message
|
|
from agent_framework._sessions import AgentSession, SessionContext
|
|
|
|
from agent_framework_redis._context_provider import RedisContextProvider
|
|
from agent_framework_redis._history_provider import RedisHistoryProvider
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Shared fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_index() -> AsyncMock:
|
|
idx = AsyncMock()
|
|
idx.create = AsyncMock()
|
|
idx.load = AsyncMock()
|
|
idx.query = AsyncMock(return_value=[])
|
|
idx.exists = AsyncMock(return_value=False)
|
|
return idx
|
|
|
|
|
|
@pytest.fixture
|
|
def patch_index_from_dict(mock_index: AsyncMock):
|
|
with patch("agent_framework_redis._context_provider.AsyncSearchIndex") as mock_cls:
|
|
mock_cls.from_dict = MagicMock(return_value=mock_index)
|
|
|
|
async def mock_from_existing(index_name: str, redis_url: str): # noqa: ARG001
|
|
mock_existing = AsyncMock()
|
|
mock_existing.schema.to_dict = MagicMock(
|
|
side_effect=lambda: mock_cls.from_dict.call_args[0][0] if mock_cls.from_dict.call_args else {}
|
|
)
|
|
return mock_existing
|
|
|
|
mock_cls.from_existing = AsyncMock(side_effect=mock_from_existing)
|
|
yield mock_cls
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_redis_client():
|
|
client = MagicMock()
|
|
client.lrange = AsyncMock(return_value=[])
|
|
client.llen = AsyncMock(return_value=0)
|
|
client.ltrim = AsyncMock()
|
|
client.delete = AsyncMock()
|
|
|
|
mock_pipeline = AsyncMock()
|
|
mock_pipeline.rpush = AsyncMock()
|
|
mock_pipeline.execute = AsyncMock()
|
|
client.pipeline.return_value.__aenter__.return_value = mock_pipeline
|
|
|
|
return client
|
|
|
|
|
|
# ===========================================================================
|
|
# RedisContextProvider tests
|
|
# ===========================================================================
|
|
|
|
|
|
class TestRedisContextProviderInit:
|
|
def test_basic_construction(self, patch_index_from_dict: MagicMock): # noqa: ARG002
|
|
provider = RedisContextProvider(source_id="ctx", user_id="u1")
|
|
assert provider.source_id == "ctx"
|
|
assert provider.user_id == "u1"
|
|
assert provider.redis_url == "redis://localhost:6379"
|
|
assert provider.index_name == "context"
|
|
assert provider.prefix == "context"
|
|
|
|
def test_custom_params(self, patch_index_from_dict: MagicMock): # noqa: ARG002
|
|
provider = RedisContextProvider(
|
|
source_id="ctx",
|
|
redis_url="redis://custom:6380",
|
|
index_name="my_idx",
|
|
prefix="my_prefix",
|
|
application_id="app1",
|
|
agent_id="agent1",
|
|
user_id="user1",
|
|
context_prompt="Custom prompt",
|
|
)
|
|
assert provider.redis_url == "redis://custom:6380"
|
|
assert provider.index_name == "my_idx"
|
|
assert provider.prefix == "my_prefix"
|
|
assert provider.application_id == "app1"
|
|
assert provider.agent_id == "agent1"
|
|
assert provider.context_prompt == "Custom prompt"
|
|
|
|
def test_default_context_prompt(self, patch_index_from_dict: MagicMock): # noqa: ARG002
|
|
provider = RedisContextProvider(source_id="ctx", user_id="u1")
|
|
assert "Memories" in provider.context_prompt
|
|
|
|
def test_invalid_vectorizer_raises(self, patch_index_from_dict: MagicMock): # noqa: ARG002
|
|
from agent_framework.exceptions import AgentException
|
|
|
|
with pytest.raises(AgentException, match="not a valid type"):
|
|
RedisContextProvider(source_id="ctx", user_id="u1", redis_vectorizer="bad") # type: ignore[arg-type]
|
|
|
|
|
|
class TestRedisContextProviderValidateFilters:
|
|
def test_no_filters_raises(self, patch_index_from_dict: MagicMock): # noqa: ARG002
|
|
provider = RedisContextProvider(source_id="ctx")
|
|
with pytest.raises(ValueError, match="(?i)at least one"):
|
|
provider._validate_filters()
|
|
|
|
def test_any_single_filter_ok(self, patch_index_from_dict: MagicMock): # noqa: ARG002
|
|
for kwargs in [{"user_id": "u"}, {"agent_id": "a"}, {"application_id": "app"}]:
|
|
provider = RedisContextProvider(source_id="ctx", **kwargs)
|
|
provider._validate_filters() # should not raise
|
|
|
|
|
|
class TestRedisContextProviderSchema:
|
|
def test_schema_has_expected_fields(self, patch_index_from_dict: MagicMock): # noqa: ARG002
|
|
provider = RedisContextProvider(source_id="ctx", user_id="u1")
|
|
schema = provider.schema_dict
|
|
field_names = [f["name"] for f in schema["fields"]]
|
|
for expected in ("role", "content", "conversation_id", "message_id", "application_id", "agent_id", "user_id"):
|
|
assert expected in field_names
|
|
assert schema["index"]["name"] == "context"
|
|
assert schema["index"]["prefix"] == "context"
|
|
|
|
def test_schema_no_vector_without_vectorizer(self, patch_index_from_dict: MagicMock): # noqa: ARG002
|
|
provider = RedisContextProvider(source_id="ctx", user_id="u1")
|
|
field_types = [f["type"] for f in provider.schema_dict["fields"]]
|
|
assert "vector" not in field_types
|
|
|
|
|
|
class TestRedisContextProviderBeforeRun:
|
|
async def test_search_results_added_to_context(
|
|
self,
|
|
mock_index: AsyncMock,
|
|
patch_index_from_dict: MagicMock, # noqa: ARG002
|
|
):
|
|
mock_index.query = AsyncMock(return_value=[{"content": "Memory A"}, {"content": "Memory B"}])
|
|
provider = RedisContextProvider(source_id="ctx", user_id="u1")
|
|
session = AgentSession(session_id="test-session")
|
|
ctx = SessionContext(input_messages=[Message(role="user", contents=["test query"])], session_id="s1")
|
|
|
|
await provider.before_run(
|
|
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
|
|
) # type: ignore[arg-type]
|
|
|
|
assert "ctx" in ctx.context_messages
|
|
msgs = ctx.context_messages["ctx"]
|
|
assert len(msgs) == 1
|
|
assert "Memory A" in msgs[0].text
|
|
assert "Memory B" in msgs[0].text
|
|
|
|
async def test_empty_input_no_search(
|
|
self,
|
|
mock_index: AsyncMock,
|
|
patch_index_from_dict: MagicMock, # noqa: ARG002
|
|
):
|
|
provider = RedisContextProvider(source_id="ctx", user_id="u1")
|
|
session = AgentSession(session_id="test-session")
|
|
ctx = SessionContext(input_messages=[Message(role="user", contents=[" "])], session_id="s1")
|
|
|
|
await provider.before_run(
|
|
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
|
|
) # type: ignore[arg-type]
|
|
|
|
mock_index.query.assert_not_called()
|
|
assert "ctx" not in ctx.context_messages
|
|
|
|
async def test_before_run_searches_without_session_id(
|
|
self,
|
|
mock_index: AsyncMock,
|
|
patch_index_from_dict: MagicMock, # noqa: ARG002
|
|
):
|
|
"""Verify that before_run performs cross-session retrieval (no session_id filter)."""
|
|
mock_index.query = AsyncMock(return_value=[{"content": "Memory"}])
|
|
provider = RedisContextProvider(source_id="ctx", user_id="u1")
|
|
session = AgentSession(session_id="test-session")
|
|
ctx = SessionContext(input_messages=[Message(role="user", contents=["test query"])], session_id="s1")
|
|
|
|
with patch.object(provider, "_redis_search", wraps=provider._redis_search) as spy:
|
|
await provider.before_run(
|
|
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
|
|
) # type: ignore[arg-type]
|
|
|
|
spy.assert_called_once()
|
|
# session_id should not be passed to _redis_search (cross-session retrieval)
|
|
assert "session_id" not in spy.call_args.kwargs
|
|
|
|
async def test_empty_results_no_messages(
|
|
self,
|
|
mock_index: AsyncMock,
|
|
patch_index_from_dict: MagicMock, # noqa: ARG002
|
|
):
|
|
mock_index.query = AsyncMock(return_value=[])
|
|
provider = RedisContextProvider(source_id="ctx", user_id="u1")
|
|
session = AgentSession(session_id="test-session")
|
|
ctx = SessionContext(input_messages=[Message(role="user", contents=["hello"])], session_id="s1")
|
|
|
|
await provider.before_run(
|
|
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
|
|
) # type: ignore[arg-type]
|
|
|
|
assert "ctx" not in ctx.context_messages
|
|
|
|
|
|
class TestRedisContextProviderAfterRun:
|
|
async def test_stores_messages(
|
|
self,
|
|
mock_index: AsyncMock,
|
|
patch_index_from_dict: MagicMock, # noqa: ARG002
|
|
):
|
|
provider = RedisContextProvider(source_id="ctx", user_id="u1")
|
|
session = AgentSession(session_id="test-session")
|
|
response = AgentResponse(messages=[Message(role="assistant", contents=["response text"])])
|
|
ctx = SessionContext(input_messages=[Message(role="user", contents=["user input"])], session_id="s1")
|
|
ctx._response = response
|
|
|
|
await provider.after_run(
|
|
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
|
|
) # type: ignore[arg-type]
|
|
|
|
mock_index.load.assert_called_once()
|
|
loaded = mock_index.load.call_args[0][0]
|
|
assert len(loaded) == 2
|
|
roles = {d["role"] for d in loaded}
|
|
assert roles == {"user", "assistant"}
|
|
|
|
async def test_skips_empty_conversations(
|
|
self,
|
|
mock_index: AsyncMock,
|
|
patch_index_from_dict: MagicMock, # noqa: ARG002
|
|
):
|
|
provider = RedisContextProvider(source_id="ctx", user_id="u1")
|
|
session = AgentSession(session_id="test-session")
|
|
ctx = SessionContext(input_messages=[Message(role="user", contents=[" "])], session_id="s1")
|
|
|
|
await provider.after_run(
|
|
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
|
|
) # type: ignore[arg-type]
|
|
|
|
mock_index.load.assert_not_called()
|
|
|
|
async def test_stores_partition_fields(
|
|
self,
|
|
mock_index: AsyncMock,
|
|
patch_index_from_dict: MagicMock, # noqa: ARG002
|
|
):
|
|
provider = RedisContextProvider(source_id="ctx", application_id="app", agent_id="ag", user_id="u1")
|
|
session = AgentSession(session_id="test-session")
|
|
ctx = SessionContext(input_messages=[Message(role="user", contents=["hello"])], session_id="s1")
|
|
|
|
await provider.after_run(
|
|
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
|
|
) # type: ignore[arg-type]
|
|
|
|
loaded = mock_index.load.call_args[0][0]
|
|
doc = loaded[0]
|
|
assert doc["application_id"] == "app"
|
|
assert doc["agent_id"] == "ag"
|
|
assert doc["user_id"] == "u1"
|
|
assert doc["conversation_id"] == "s1"
|
|
|
|
|
|
class TestRedisContextProviderContextManager:
|
|
async def test_aenter_returns_self(self, patch_index_from_dict: MagicMock): # noqa: ARG002
|
|
provider = RedisContextProvider(source_id="ctx", user_id="u1")
|
|
async with provider as p:
|
|
assert p is provider
|
|
|
|
|
|
class TestRedisContextProviderHybridQuery:
|
|
"""Test for AggregateHybridQuery parameter compatibility with redisvl 0.14.0."""
|
|
|
|
async def test_aggregate_hybrid_query_uses_alpha(
|
|
self,
|
|
mock_index: AsyncMock,
|
|
patch_index_from_dict: MagicMock, # noqa: ARG002 - fixture modifies behavior via side effects
|
|
):
|
|
"""Ensure AggregateHybridQuery is called with alpha parameter."""
|
|
from redisvl.utils.vectorize import BaseVectorizer
|
|
|
|
# Create a mock vectorizer that inherits from BaseVectorizer
|
|
mock_vectorizer = MagicMock(spec=BaseVectorizer)
|
|
mock_vectorizer.dims = 128
|
|
mock_vectorizer.dtype = "float32"
|
|
mock_vectorizer.aembed = AsyncMock(return_value=[0.1] * 128)
|
|
|
|
mock_index.query = AsyncMock(return_value=[{"content": "test result"}])
|
|
|
|
provider = RedisContextProvider(
|
|
source_id="ctx",
|
|
user_id="u1",
|
|
redis_vectorizer=mock_vectorizer,
|
|
vector_field_name="embedding",
|
|
)
|
|
|
|
# Call _redis_search with custom alpha
|
|
with patch("agent_framework_redis._context_provider.AggregateHybridQuery") as mock_hybrid_query:
|
|
mock_hybrid_query.return_value = MagicMock()
|
|
await provider._redis_search(text="test query", alpha=0.5)
|
|
|
|
# Verify AggregateHybridQuery was called with alpha parameter
|
|
mock_hybrid_query.assert_called_once()
|
|
call_kwargs = mock_hybrid_query.call_args.kwargs
|
|
assert "alpha" in call_kwargs
|
|
assert call_kwargs["alpha"] == 0.5
|
|
|
|
|
|
# ===========================================================================
|
|
# RedisHistoryProvider tests
|
|
# ===========================================================================
|
|
|
|
|
|
class TestRedisHistoryProviderInit:
|
|
def test_basic_construction(self, mock_redis_client: MagicMock):
|
|
with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url:
|
|
mock_from_url.return_value = mock_redis_client
|
|
provider = RedisHistoryProvider("memory", redis_url="redis://localhost:6379")
|
|
|
|
assert provider.source_id == "memory"
|
|
assert provider.key_prefix == "chat_messages"
|
|
assert provider.max_messages is None
|
|
assert provider.load_messages is True
|
|
assert provider.store_outputs is True
|
|
assert provider.store_inputs is True
|
|
|
|
def test_custom_params(self, mock_redis_client: MagicMock):
|
|
with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url:
|
|
mock_from_url.return_value = mock_redis_client
|
|
provider = RedisHistoryProvider(
|
|
"mem",
|
|
redis_url="redis://localhost:6379",
|
|
key_prefix="custom",
|
|
max_messages=50,
|
|
load_messages=False,
|
|
store_outputs=False,
|
|
store_inputs=False,
|
|
)
|
|
|
|
assert provider.key_prefix == "custom"
|
|
assert provider.max_messages == 50
|
|
assert provider.load_messages is False
|
|
assert provider.store_outputs is False
|
|
assert provider.store_inputs is False
|
|
|
|
def test_no_redis_url_or_credential_raises(self):
|
|
with pytest.raises(ValueError, match="Either redis_url or credential_provider must be provided"):
|
|
RedisHistoryProvider("mem")
|
|
|
|
def test_both_url_and_credential_raises(self):
|
|
mock_cred = MagicMock()
|
|
with pytest.raises(ValueError, match="mutually exclusive"):
|
|
RedisHistoryProvider(
|
|
"mem",
|
|
redis_url="redis://localhost:6379",
|
|
credential_provider=mock_cred,
|
|
host="myhost",
|
|
)
|
|
|
|
def test_credential_provider_without_host_raises(self):
|
|
mock_cred = MagicMock()
|
|
with pytest.raises(ValueError, match="host is required"):
|
|
RedisHistoryProvider("mem", credential_provider=mock_cred)
|
|
|
|
def test_credential_provider_with_host(self):
|
|
mock_cred = MagicMock()
|
|
with patch("agent_framework_redis._history_provider.redis.Redis") as mock_redis_cls:
|
|
mock_redis_cls.return_value = MagicMock()
|
|
provider = RedisHistoryProvider("mem", credential_provider=mock_cred, host="myhost")
|
|
|
|
mock_redis_cls.assert_called_once_with(
|
|
host="myhost",
|
|
port=6380,
|
|
ssl=True,
|
|
username=None,
|
|
credential_provider=mock_cred,
|
|
decode_responses=True,
|
|
)
|
|
assert provider.redis_url is None
|
|
|
|
|
|
class TestRedisHistoryProviderRedisKey:
|
|
def test_key_format(self, mock_redis_client: MagicMock):
|
|
with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url:
|
|
mock_from_url.return_value = mock_redis_client
|
|
provider = RedisHistoryProvider("mem", redis_url="redis://localhost:6379", key_prefix="msgs")
|
|
|
|
assert provider._redis_key("session-123") == "msgs:session-123"
|
|
assert provider._redis_key(None) == "msgs:default"
|
|
|
|
|
|
class TestRedisHistoryProviderGetMessages:
|
|
async def test_returns_deserialized_messages(self, mock_redis_client: MagicMock):
|
|
msg1 = Message(role="user", contents=["Hello"])
|
|
msg2 = Message(role="assistant", contents=["Hi!"])
|
|
mock_redis_client.lrange = AsyncMock(return_value=[json.dumps(msg1.to_dict()), json.dumps(msg2.to_dict())])
|
|
|
|
with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url:
|
|
mock_from_url.return_value = mock_redis_client
|
|
provider = RedisHistoryProvider("mem", redis_url="redis://localhost:6379")
|
|
|
|
messages = await provider.get_messages("s1")
|
|
assert len(messages) == 2
|
|
assert messages[0].role == "user"
|
|
assert messages[0].text == "Hello"
|
|
assert messages[1].role == "assistant"
|
|
assert messages[1].text == "Hi!"
|
|
|
|
async def test_empty_returns_empty(self, mock_redis_client: MagicMock):
|
|
mock_redis_client.lrange = AsyncMock(return_value=[])
|
|
|
|
with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url:
|
|
mock_from_url.return_value = mock_redis_client
|
|
provider = RedisHistoryProvider("mem", redis_url="redis://localhost:6379")
|
|
|
|
messages = await provider.get_messages("s1")
|
|
assert messages == []
|
|
|
|
|
|
class TestRedisHistoryProviderSaveMessages:
|
|
async def test_saves_serialized_messages(self, mock_redis_client: MagicMock):
|
|
with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url:
|
|
mock_from_url.return_value = mock_redis_client
|
|
provider = RedisHistoryProvider("mem", redis_url="redis://localhost:6379")
|
|
|
|
msgs = [Message(role="user", contents=["Hello"]), Message(role="assistant", contents=["Hi"])]
|
|
await provider.save_messages("s1", msgs)
|
|
|
|
pipeline = mock_redis_client.pipeline.return_value.__aenter__.return_value
|
|
assert pipeline.rpush.call_count == 2
|
|
pipeline.execute.assert_called_once()
|
|
|
|
async def test_empty_messages_noop(self, mock_redis_client: MagicMock):
|
|
with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url:
|
|
mock_from_url.return_value = mock_redis_client
|
|
provider = RedisHistoryProvider("mem", redis_url="redis://localhost:6379")
|
|
|
|
await provider.save_messages("s1", [])
|
|
mock_redis_client.pipeline.assert_not_called()
|
|
|
|
async def test_max_messages_trimming(self, mock_redis_client: MagicMock):
|
|
mock_redis_client.llen = AsyncMock(return_value=15)
|
|
|
|
with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url:
|
|
mock_from_url.return_value = mock_redis_client
|
|
provider = RedisHistoryProvider("mem", redis_url="redis://localhost:6379", max_messages=10)
|
|
|
|
await provider.save_messages("s1", [Message(role="user", contents=["msg"])])
|
|
|
|
mock_redis_client.ltrim.assert_called_once_with("chat_messages:s1", -10, -1)
|
|
|
|
async def test_no_trim_when_under_limit(self, mock_redis_client: MagicMock):
|
|
mock_redis_client.llen = AsyncMock(return_value=3)
|
|
|
|
with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url:
|
|
mock_from_url.return_value = mock_redis_client
|
|
provider = RedisHistoryProvider("mem", redis_url="redis://localhost:6379", max_messages=10)
|
|
|
|
await provider.save_messages("s1", [Message(role="user", contents=["msg"])])
|
|
|
|
mock_redis_client.ltrim.assert_not_called()
|
|
|
|
|
|
class TestRedisHistoryProviderClear:
|
|
async def test_clear_calls_delete(self, mock_redis_client: MagicMock):
|
|
with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url:
|
|
mock_from_url.return_value = mock_redis_client
|
|
provider = RedisHistoryProvider("mem", redis_url="redis://localhost:6379")
|
|
|
|
await provider.clear("session-1")
|
|
mock_redis_client.delete.assert_called_once_with("chat_messages:session-1")
|
|
|
|
|
|
class TestRedisHistoryProviderBeforeAfterRun:
|
|
"""Test before_run/after_run integration via BaseHistoryProvider defaults."""
|
|
|
|
async def test_before_run_loads_history(self, mock_redis_client: MagicMock):
|
|
msg = Message(role="user", contents=["old msg"])
|
|
mock_redis_client.lrange = AsyncMock(return_value=[json.dumps(msg.to_dict())])
|
|
|
|
with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url:
|
|
mock_from_url.return_value = mock_redis_client
|
|
provider = RedisHistoryProvider("mem", redis_url="redis://localhost:6379")
|
|
|
|
session = AgentSession(session_id="test")
|
|
ctx = SessionContext(input_messages=[Message(role="user", contents=["new msg"])], session_id="s1")
|
|
|
|
await provider.before_run(
|
|
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
|
|
) # type: ignore[arg-type]
|
|
|
|
assert "mem" in ctx.context_messages
|
|
assert len(ctx.context_messages["mem"]) == 1
|
|
assert ctx.context_messages["mem"][0].text == "old msg"
|
|
|
|
async def test_after_run_stores_input_and_response(self, mock_redis_client: MagicMock):
|
|
with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url:
|
|
mock_from_url.return_value = mock_redis_client
|
|
provider = RedisHistoryProvider("mem", redis_url="redis://localhost:6379")
|
|
|
|
session = AgentSession(session_id="test")
|
|
ctx = SessionContext(input_messages=[Message(role="user", contents=["hi"])], session_id="s1")
|
|
ctx._response = AgentResponse(messages=[Message(role="assistant", contents=["hello"])])
|
|
|
|
await provider.after_run(
|
|
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
|
|
) # type: ignore[arg-type]
|
|
|
|
pipeline = mock_redis_client.pipeline.return_value.__aenter__.return_value
|
|
assert pipeline.rpush.call_count == 2
|
|
pipeline.execute.assert_called_once()
|
|
|
|
async def test_after_run_skips_when_no_messages(self, mock_redis_client: MagicMock):
|
|
with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url:
|
|
mock_from_url.return_value = mock_redis_client
|
|
provider = RedisHistoryProvider(
|
|
"mem", redis_url="redis://localhost:6379", store_inputs=False, store_outputs=False
|
|
)
|
|
|
|
session = AgentSession(session_id="test")
|
|
ctx = SessionContext(input_messages=[Message(role="user", contents=["hi"])], session_id="s1")
|
|
|
|
await provider.after_run(
|
|
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
|
|
) # type: ignore[arg-type]
|
|
|
|
mock_redis_client.pipeline.assert_not_called()
|