Files
agent-framework/python/packages/redis/tests/test_providers.py
Copilot 1ac68f65bf Python: Fix RedisContextProvider for redisvl 0.14.0 by using AggregateHybridQuery (#3954)
* Initial plan

* Fix: Replace alpha with linear_alpha in HybridQuery for redisvl 0.14.0 compatibility

Co-authored-by: markwallace-microsoft <127216156+markwallace-microsoft@users.noreply.github.com>

* Address code review: Improve test readability and add explanatory comment

Co-authored-by: markwallace-microsoft <127216156+markwallace-microsoft@users.noreply.github.com>

* Add CHANGELOG entry for redisvl 0.14.0 compatibility fix

Co-authored-by: eavanvalkenburg <13749212+eavanvalkenburg@users.noreply.github.com>

* Use AggregateHybridQuery instead of HybridQuery for backward compatibility

Replace HybridQuery with AggregateHybridQuery to preserve existing functionality that works with older Redis versions. The new HybridQuery in redisvl 0.14.0 requires Redis 8.4.0+ and uses a different API, while AggregateHybridQuery maintains compatibility with the original implementation.

Co-authored-by: TaoChenOSU <12570346+TaoChenOSU@users.noreply.github.com>

* Fix test to use linear_alpha parameter matching _redis_search implementation

The test was passing alpha as a keyword argument to _redis_search(), but the
method uses linear_alpha to match the redisvl 0.14.0 AggregateHybridQuery API.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Fix pyright error: use alpha parameter matching AggregateHybridQuery API

AggregateHybridQuery expects 'alpha', not 'linear_alpha'. Updated the
_redis_search method parameter and the test accordingly.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: markwallace-microsoft <127216156+markwallace-microsoft@users.noreply.github.com>
Co-authored-by: eavanvalkenburg <13749212+eavanvalkenburg@users.noreply.github.com>
Co-authored-by: TaoChenOSU <12570346+TaoChenOSU@users.noreply.github.com>
Co-authored-by: Ben Thomas <ben.thomas@microsoft.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-03-06 01:34:35 +00:00

531 lines
23 KiB
Python

# Copyright (c) Microsoft. All rights reserved.
"""Tests for RedisContextProvider and RedisHistoryProvider."""
from __future__ import annotations
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agent_framework import AgentResponse, Message
from agent_framework._sessions import AgentSession, SessionContext
from agent_framework_redis._context_provider import RedisContextProvider
from agent_framework_redis._history_provider import RedisHistoryProvider
# ---------------------------------------------------------------------------
# Shared fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def mock_index() -> AsyncMock:
idx = AsyncMock()
idx.create = AsyncMock()
idx.load = AsyncMock()
idx.query = AsyncMock(return_value=[])
idx.exists = AsyncMock(return_value=False)
return idx
@pytest.fixture
def patch_index_from_dict(mock_index: AsyncMock):
with patch("agent_framework_redis._context_provider.AsyncSearchIndex") as mock_cls:
mock_cls.from_dict = MagicMock(return_value=mock_index)
async def mock_from_existing(index_name: str, redis_url: str): # noqa: ARG001
mock_existing = AsyncMock()
mock_existing.schema.to_dict = MagicMock(
side_effect=lambda: mock_cls.from_dict.call_args[0][0] if mock_cls.from_dict.call_args else {}
)
return mock_existing
mock_cls.from_existing = AsyncMock(side_effect=mock_from_existing)
yield mock_cls
@pytest.fixture
def mock_redis_client():
client = MagicMock()
client.lrange = AsyncMock(return_value=[])
client.llen = AsyncMock(return_value=0)
client.ltrim = AsyncMock()
client.delete = AsyncMock()
mock_pipeline = AsyncMock()
mock_pipeline.rpush = AsyncMock()
mock_pipeline.execute = AsyncMock()
client.pipeline.return_value.__aenter__.return_value = mock_pipeline
return client
# ===========================================================================
# RedisContextProvider tests
# ===========================================================================
class TestRedisContextProviderInit:
def test_basic_construction(self, patch_index_from_dict: MagicMock): # noqa: ARG002
provider = RedisContextProvider(source_id="ctx", user_id="u1")
assert provider.source_id == "ctx"
assert provider.user_id == "u1"
assert provider.redis_url == "redis://localhost:6379"
assert provider.index_name == "context"
assert provider.prefix == "context"
def test_custom_params(self, patch_index_from_dict: MagicMock): # noqa: ARG002
provider = RedisContextProvider(
source_id="ctx",
redis_url="redis://custom:6380",
index_name="my_idx",
prefix="my_prefix",
application_id="app1",
agent_id="agent1",
user_id="user1",
context_prompt="Custom prompt",
)
assert provider.redis_url == "redis://custom:6380"
assert provider.index_name == "my_idx"
assert provider.prefix == "my_prefix"
assert provider.application_id == "app1"
assert provider.agent_id == "agent1"
assert provider.context_prompt == "Custom prompt"
def test_default_context_prompt(self, patch_index_from_dict: MagicMock): # noqa: ARG002
provider = RedisContextProvider(source_id="ctx", user_id="u1")
assert "Memories" in provider.context_prompt
def test_invalid_vectorizer_raises(self, patch_index_from_dict: MagicMock): # noqa: ARG002
from agent_framework.exceptions import AgentException
with pytest.raises(AgentException, match="not a valid type"):
RedisContextProvider(source_id="ctx", user_id="u1", redis_vectorizer="bad") # type: ignore[arg-type]
class TestRedisContextProviderValidateFilters:
def test_no_filters_raises(self, patch_index_from_dict: MagicMock): # noqa: ARG002
provider = RedisContextProvider(source_id="ctx")
with pytest.raises(ValueError, match="(?i)at least one"):
provider._validate_filters()
def test_any_single_filter_ok(self, patch_index_from_dict: MagicMock): # noqa: ARG002
for kwargs in [{"user_id": "u"}, {"agent_id": "a"}, {"application_id": "app"}]:
provider = RedisContextProvider(source_id="ctx", **kwargs)
provider._validate_filters() # should not raise
class TestRedisContextProviderSchema:
def test_schema_has_expected_fields(self, patch_index_from_dict: MagicMock): # noqa: ARG002
provider = RedisContextProvider(source_id="ctx", user_id="u1")
schema = provider.schema_dict
field_names = [f["name"] for f in schema["fields"]]
for expected in ("role", "content", "conversation_id", "message_id", "application_id", "agent_id", "user_id"):
assert expected in field_names
assert schema["index"]["name"] == "context"
assert schema["index"]["prefix"] == "context"
def test_schema_no_vector_without_vectorizer(self, patch_index_from_dict: MagicMock): # noqa: ARG002
provider = RedisContextProvider(source_id="ctx", user_id="u1")
field_types = [f["type"] for f in provider.schema_dict["fields"]]
assert "vector" not in field_types
class TestRedisContextProviderBeforeRun:
async def test_search_results_added_to_context(
self,
mock_index: AsyncMock,
patch_index_from_dict: MagicMock, # noqa: ARG002
):
mock_index.query = AsyncMock(return_value=[{"content": "Memory A"}, {"content": "Memory B"}])
provider = RedisContextProvider(source_id="ctx", user_id="u1")
session = AgentSession(session_id="test-session")
ctx = SessionContext(input_messages=[Message(role="user", contents=["test query"])], session_id="s1")
await provider.before_run(
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
) # type: ignore[arg-type]
assert "ctx" in ctx.context_messages
msgs = ctx.context_messages["ctx"]
assert len(msgs) == 1
assert "Memory A" in msgs[0].text
assert "Memory B" in msgs[0].text
async def test_empty_input_no_search(
self,
mock_index: AsyncMock,
patch_index_from_dict: MagicMock, # noqa: ARG002
):
provider = RedisContextProvider(source_id="ctx", user_id="u1")
session = AgentSession(session_id="test-session")
ctx = SessionContext(input_messages=[Message(role="user", contents=[" "])], session_id="s1")
await provider.before_run(
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
) # type: ignore[arg-type]
mock_index.query.assert_not_called()
assert "ctx" not in ctx.context_messages
async def test_before_run_searches_without_session_id(
self,
mock_index: AsyncMock,
patch_index_from_dict: MagicMock, # noqa: ARG002
):
"""Verify that before_run performs cross-session retrieval (no session_id filter)."""
mock_index.query = AsyncMock(return_value=[{"content": "Memory"}])
provider = RedisContextProvider(source_id="ctx", user_id="u1")
session = AgentSession(session_id="test-session")
ctx = SessionContext(input_messages=[Message(role="user", contents=["test query"])], session_id="s1")
with patch.object(provider, "_redis_search", wraps=provider._redis_search) as spy:
await provider.before_run(
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
) # type: ignore[arg-type]
spy.assert_called_once()
# session_id should not be passed to _redis_search (cross-session retrieval)
assert "session_id" not in spy.call_args.kwargs
async def test_empty_results_no_messages(
self,
mock_index: AsyncMock,
patch_index_from_dict: MagicMock, # noqa: ARG002
):
mock_index.query = AsyncMock(return_value=[])
provider = RedisContextProvider(source_id="ctx", user_id="u1")
session = AgentSession(session_id="test-session")
ctx = SessionContext(input_messages=[Message(role="user", contents=["hello"])], session_id="s1")
await provider.before_run(
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
) # type: ignore[arg-type]
assert "ctx" not in ctx.context_messages
class TestRedisContextProviderAfterRun:
async def test_stores_messages(
self,
mock_index: AsyncMock,
patch_index_from_dict: MagicMock, # noqa: ARG002
):
provider = RedisContextProvider(source_id="ctx", user_id="u1")
session = AgentSession(session_id="test-session")
response = AgentResponse(messages=[Message(role="assistant", contents=["response text"])])
ctx = SessionContext(input_messages=[Message(role="user", contents=["user input"])], session_id="s1")
ctx._response = response
await provider.after_run(
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
) # type: ignore[arg-type]
mock_index.load.assert_called_once()
loaded = mock_index.load.call_args[0][0]
assert len(loaded) == 2
roles = {d["role"] for d in loaded}
assert roles == {"user", "assistant"}
async def test_skips_empty_conversations(
self,
mock_index: AsyncMock,
patch_index_from_dict: MagicMock, # noqa: ARG002
):
provider = RedisContextProvider(source_id="ctx", user_id="u1")
session = AgentSession(session_id="test-session")
ctx = SessionContext(input_messages=[Message(role="user", contents=[" "])], session_id="s1")
await provider.after_run(
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
) # type: ignore[arg-type]
mock_index.load.assert_not_called()
async def test_stores_partition_fields(
self,
mock_index: AsyncMock,
patch_index_from_dict: MagicMock, # noqa: ARG002
):
provider = RedisContextProvider(source_id="ctx", application_id="app", agent_id="ag", user_id="u1")
session = AgentSession(session_id="test-session")
ctx = SessionContext(input_messages=[Message(role="user", contents=["hello"])], session_id="s1")
await provider.after_run(
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
) # type: ignore[arg-type]
loaded = mock_index.load.call_args[0][0]
doc = loaded[0]
assert doc["application_id"] == "app"
assert doc["agent_id"] == "ag"
assert doc["user_id"] == "u1"
assert doc["conversation_id"] == "s1"
class TestRedisContextProviderContextManager:
async def test_aenter_returns_self(self, patch_index_from_dict: MagicMock): # noqa: ARG002
provider = RedisContextProvider(source_id="ctx", user_id="u1")
async with provider as p:
assert p is provider
class TestRedisContextProviderHybridQuery:
"""Test for AggregateHybridQuery parameter compatibility with redisvl 0.14.0."""
async def test_aggregate_hybrid_query_uses_alpha(
self,
mock_index: AsyncMock,
patch_index_from_dict: MagicMock, # noqa: ARG002 - fixture modifies behavior via side effects
):
"""Ensure AggregateHybridQuery is called with alpha parameter."""
from redisvl.utils.vectorize import BaseVectorizer
# Create a mock vectorizer that inherits from BaseVectorizer
mock_vectorizer = MagicMock(spec=BaseVectorizer)
mock_vectorizer.dims = 128
mock_vectorizer.dtype = "float32"
mock_vectorizer.aembed = AsyncMock(return_value=[0.1] * 128)
mock_index.query = AsyncMock(return_value=[{"content": "test result"}])
provider = RedisContextProvider(
source_id="ctx",
user_id="u1",
redis_vectorizer=mock_vectorizer,
vector_field_name="embedding",
)
# Call _redis_search with custom alpha
with patch("agent_framework_redis._context_provider.AggregateHybridQuery") as mock_hybrid_query:
mock_hybrid_query.return_value = MagicMock()
await provider._redis_search(text="test query", alpha=0.5)
# Verify AggregateHybridQuery was called with alpha parameter
mock_hybrid_query.assert_called_once()
call_kwargs = mock_hybrid_query.call_args.kwargs
assert "alpha" in call_kwargs
assert call_kwargs["alpha"] == 0.5
# ===========================================================================
# RedisHistoryProvider tests
# ===========================================================================
class TestRedisHistoryProviderInit:
def test_basic_construction(self, mock_redis_client: MagicMock):
with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url:
mock_from_url.return_value = mock_redis_client
provider = RedisHistoryProvider("memory", redis_url="redis://localhost:6379")
assert provider.source_id == "memory"
assert provider.key_prefix == "chat_messages"
assert provider.max_messages is None
assert provider.load_messages is True
assert provider.store_outputs is True
assert provider.store_inputs is True
def test_custom_params(self, mock_redis_client: MagicMock):
with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url:
mock_from_url.return_value = mock_redis_client
provider = RedisHistoryProvider(
"mem",
redis_url="redis://localhost:6379",
key_prefix="custom",
max_messages=50,
load_messages=False,
store_outputs=False,
store_inputs=False,
)
assert provider.key_prefix == "custom"
assert provider.max_messages == 50
assert provider.load_messages is False
assert provider.store_outputs is False
assert provider.store_inputs is False
def test_no_redis_url_or_credential_raises(self):
with pytest.raises(ValueError, match="Either redis_url or credential_provider must be provided"):
RedisHistoryProvider("mem")
def test_both_url_and_credential_raises(self):
mock_cred = MagicMock()
with pytest.raises(ValueError, match="mutually exclusive"):
RedisHistoryProvider(
"mem",
redis_url="redis://localhost:6379",
credential_provider=mock_cred,
host="myhost",
)
def test_credential_provider_without_host_raises(self):
mock_cred = MagicMock()
with pytest.raises(ValueError, match="host is required"):
RedisHistoryProvider("mem", credential_provider=mock_cred)
def test_credential_provider_with_host(self):
mock_cred = MagicMock()
with patch("agent_framework_redis._history_provider.redis.Redis") as mock_redis_cls:
mock_redis_cls.return_value = MagicMock()
provider = RedisHistoryProvider("mem", credential_provider=mock_cred, host="myhost")
mock_redis_cls.assert_called_once_with(
host="myhost",
port=6380,
ssl=True,
username=None,
credential_provider=mock_cred,
decode_responses=True,
)
assert provider.redis_url is None
class TestRedisHistoryProviderRedisKey:
def test_key_format(self, mock_redis_client: MagicMock):
with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url:
mock_from_url.return_value = mock_redis_client
provider = RedisHistoryProvider("mem", redis_url="redis://localhost:6379", key_prefix="msgs")
assert provider._redis_key("session-123") == "msgs:session-123"
assert provider._redis_key(None) == "msgs:default"
class TestRedisHistoryProviderGetMessages:
async def test_returns_deserialized_messages(self, mock_redis_client: MagicMock):
msg1 = Message(role="user", contents=["Hello"])
msg2 = Message(role="assistant", contents=["Hi!"])
mock_redis_client.lrange = AsyncMock(return_value=[json.dumps(msg1.to_dict()), json.dumps(msg2.to_dict())])
with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url:
mock_from_url.return_value = mock_redis_client
provider = RedisHistoryProvider("mem", redis_url="redis://localhost:6379")
messages = await provider.get_messages("s1")
assert len(messages) == 2
assert messages[0].role == "user"
assert messages[0].text == "Hello"
assert messages[1].role == "assistant"
assert messages[1].text == "Hi!"
async def test_empty_returns_empty(self, mock_redis_client: MagicMock):
mock_redis_client.lrange = AsyncMock(return_value=[])
with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url:
mock_from_url.return_value = mock_redis_client
provider = RedisHistoryProvider("mem", redis_url="redis://localhost:6379")
messages = await provider.get_messages("s1")
assert messages == []
class TestRedisHistoryProviderSaveMessages:
async def test_saves_serialized_messages(self, mock_redis_client: MagicMock):
with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url:
mock_from_url.return_value = mock_redis_client
provider = RedisHistoryProvider("mem", redis_url="redis://localhost:6379")
msgs = [Message(role="user", contents=["Hello"]), Message(role="assistant", contents=["Hi"])]
await provider.save_messages("s1", msgs)
pipeline = mock_redis_client.pipeline.return_value.__aenter__.return_value
assert pipeline.rpush.call_count == 2
pipeline.execute.assert_called_once()
async def test_empty_messages_noop(self, mock_redis_client: MagicMock):
with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url:
mock_from_url.return_value = mock_redis_client
provider = RedisHistoryProvider("mem", redis_url="redis://localhost:6379")
await provider.save_messages("s1", [])
mock_redis_client.pipeline.assert_not_called()
async def test_max_messages_trimming(self, mock_redis_client: MagicMock):
mock_redis_client.llen = AsyncMock(return_value=15)
with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url:
mock_from_url.return_value = mock_redis_client
provider = RedisHistoryProvider("mem", redis_url="redis://localhost:6379", max_messages=10)
await provider.save_messages("s1", [Message(role="user", contents=["msg"])])
mock_redis_client.ltrim.assert_called_once_with("chat_messages:s1", -10, -1)
async def test_no_trim_when_under_limit(self, mock_redis_client: MagicMock):
mock_redis_client.llen = AsyncMock(return_value=3)
with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url:
mock_from_url.return_value = mock_redis_client
provider = RedisHistoryProvider("mem", redis_url="redis://localhost:6379", max_messages=10)
await provider.save_messages("s1", [Message(role="user", contents=["msg"])])
mock_redis_client.ltrim.assert_not_called()
class TestRedisHistoryProviderClear:
async def test_clear_calls_delete(self, mock_redis_client: MagicMock):
with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url:
mock_from_url.return_value = mock_redis_client
provider = RedisHistoryProvider("mem", redis_url="redis://localhost:6379")
await provider.clear("session-1")
mock_redis_client.delete.assert_called_once_with("chat_messages:session-1")
class TestRedisHistoryProviderBeforeAfterRun:
"""Test before_run/after_run integration via BaseHistoryProvider defaults."""
async def test_before_run_loads_history(self, mock_redis_client: MagicMock):
msg = Message(role="user", contents=["old msg"])
mock_redis_client.lrange = AsyncMock(return_value=[json.dumps(msg.to_dict())])
with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url:
mock_from_url.return_value = mock_redis_client
provider = RedisHistoryProvider("mem", redis_url="redis://localhost:6379")
session = AgentSession(session_id="test")
ctx = SessionContext(input_messages=[Message(role="user", contents=["new msg"])], session_id="s1")
await provider.before_run(
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
) # type: ignore[arg-type]
assert "mem" in ctx.context_messages
assert len(ctx.context_messages["mem"]) == 1
assert ctx.context_messages["mem"][0].text == "old msg"
async def test_after_run_stores_input_and_response(self, mock_redis_client: MagicMock):
with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url:
mock_from_url.return_value = mock_redis_client
provider = RedisHistoryProvider("mem", redis_url="redis://localhost:6379")
session = AgentSession(session_id="test")
ctx = SessionContext(input_messages=[Message(role="user", contents=["hi"])], session_id="s1")
ctx._response = AgentResponse(messages=[Message(role="assistant", contents=["hello"])])
await provider.after_run(
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
) # type: ignore[arg-type]
pipeline = mock_redis_client.pipeline.return_value.__aenter__.return_value
assert pipeline.rpush.call_count == 2
pipeline.execute.assert_called_once()
async def test_after_run_skips_when_no_messages(self, mock_redis_client: MagicMock):
with patch("agent_framework_redis._history_provider.redis.from_url") as mock_from_url:
mock_from_url.return_value = mock_redis_client
provider = RedisHistoryProvider(
"mem", redis_url="redis://localhost:6379", store_inputs=False, store_outputs=False
)
session = AgentSession(session_id="test")
ctx = SessionContext(input_messages=[Message(role="user", contents=["hi"])], session_id="s1")
await provider.after_run(
agent=None, session=session, context=ctx, state=session.state.setdefault(provider.source_id, {})
) # type: ignore[arg-type]
mock_redis_client.pipeline.assert_not_called()