mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Azure AI Search provider improvements - EmbeddingGenerator, async context manager, KB message handling (#4212)
* small updates and improvements in the azure AISearch provider * Fix mypy errors and embedding function test - Use separate variable for embeddings result to avoid mypy type reassignment error - Fix test_vectorized_query_with_embedding_function: use real async function instead of AsyncMock which falsely matches SupportsGetEmbeddings protocol Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fixes from feedback --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
2ad0caf069
commit
4530504a3d
+171
-47
@@ -13,7 +13,7 @@ import sys
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypedDict
|
||||
|
||||
from agent_framework import AGENT_FRAMEWORK_USER_AGENT, Message
|
||||
from agent_framework import AGENT_FRAMEWORK_USER_AGENT, Annotation, Content, Message, SupportsGetEmbeddings
|
||||
from agent_framework._sessions import AgentSession, BaseContextProvider, SessionContext
|
||||
from agent_framework._settings import SecretString, load_settings
|
||||
from agent_framework.azure._entra_id_authentication import AzureCredentialTypes
|
||||
@@ -47,8 +47,12 @@ if TYPE_CHECKING:
|
||||
from azure.search.documents.knowledgebases.aio import KnowledgeBaseRetrievalClient
|
||||
from azure.search.documents.knowledgebases.models import (
|
||||
KnowledgeBaseMessage,
|
||||
KnowledgeBaseMessageImageContent,
|
||||
KnowledgeBaseMessageImageContentImage,
|
||||
KnowledgeBaseMessageTextContent,
|
||||
KnowledgeBaseReference,
|
||||
KnowledgeBaseRetrievalRequest,
|
||||
KnowledgeBaseRetrievalResponse,
|
||||
KnowledgeRetrievalIntent,
|
||||
KnowledgeRetrievalSemanticIntent,
|
||||
)
|
||||
@@ -78,8 +82,12 @@ try:
|
||||
from azure.search.documents.knowledgebases.aio import KnowledgeBaseRetrievalClient
|
||||
from azure.search.documents.knowledgebases.models import (
|
||||
KnowledgeBaseMessage,
|
||||
KnowledgeBaseMessageImageContent,
|
||||
KnowledgeBaseMessageImageContentImage,
|
||||
KnowledgeBaseMessageTextContent,
|
||||
KnowledgeBaseReference,
|
||||
KnowledgeBaseRetrievalRequest,
|
||||
KnowledgeBaseRetrievalResponse,
|
||||
KnowledgeRetrievalIntent,
|
||||
KnowledgeRetrievalSemanticIntent,
|
||||
)
|
||||
@@ -154,7 +162,9 @@ class AzureAISearchContextProvider(BaseContextProvider):
|
||||
top_k: int = 5,
|
||||
semantic_configuration_name: str | None = None,
|
||||
vector_field_name: str | None = None,
|
||||
embedding_function: Callable[[str], Awaitable[list[float]]] | None = None,
|
||||
embedding_function: Callable[[str], Awaitable[list[float]]]
|
||||
| SupportsGetEmbeddings[str, list[float], Any]
|
||||
| None = None,
|
||||
context_prompt: str | None = None,
|
||||
azure_openai_resource_url: str | None = None,
|
||||
model_deployment_name: str | None = None,
|
||||
@@ -181,7 +191,7 @@ class AzureAISearchContextProvider(BaseContextProvider):
|
||||
top_k: Maximum number of documents to retrieve. Default: 5.
|
||||
semantic_configuration_name: Name of semantic configuration in the index.
|
||||
vector_field_name: Name of the vector field in the index.
|
||||
embedding_function: Async function to generate embeddings.
|
||||
embedding_function: Async function to generate embeddings or a SupportsGetEmbeddings instance.
|
||||
context_prompt: Custom prompt to prepend to retrieved context.
|
||||
azure_openai_resource_url: Azure OpenAI resource URL for Knowledge Base.
|
||||
model_deployment_name: Model deployment name in Azure OpenAI.
|
||||
@@ -309,9 +319,20 @@ class AzureAISearchContextProvider(BaseContextProvider):
|
||||
exc_tb: Any,
|
||||
) -> None:
|
||||
"""Async context manager exit - cleanup clients."""
|
||||
await self.close()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close all the open clients."""
|
||||
if self._retrieval_client is not None:
|
||||
await self._retrieval_client.close()
|
||||
self._retrieval_client = None
|
||||
self._knowledge_base_initialized = False
|
||||
if self._search_client is not None:
|
||||
await self._search_client.close()
|
||||
self._search_client = None
|
||||
if self._index_client is not None:
|
||||
await self._index_client.close()
|
||||
self._index_client = None
|
||||
|
||||
# -- Hooks pattern ---------------------------------------------------------
|
||||
|
||||
@@ -326,32 +347,23 @@ class AzureAISearchContextProvider(BaseContextProvider):
|
||||
"""Retrieve relevant context from Azure AI Search and add to session context."""
|
||||
messages_list = list(context.input_messages)
|
||||
|
||||
def get_role_value(role: str | Any) -> str:
|
||||
return role.value if hasattr(role, "value") else str(role)
|
||||
|
||||
filtered_messages = [
|
||||
msg
|
||||
for msg in messages_list
|
||||
if msg and msg.text and msg.text.strip() and get_role_value(msg.role) in ["user", "assistant"]
|
||||
msg for msg in messages_list if msg and msg.text and msg.text.strip() and msg.role in ["user", "assistant"]
|
||||
]
|
||||
if not filtered_messages:
|
||||
return
|
||||
|
||||
if self.mode == "semantic":
|
||||
query = "\n".join(msg.text for msg in filtered_messages)
|
||||
search_result_parts = await self._semantic_search(query)
|
||||
result_messages = await self._semantic_search(query)
|
||||
else:
|
||||
recent_messages = filtered_messages[-self.agentic_message_history_count :]
|
||||
search_result_parts = await self._agentic_search(recent_messages)
|
||||
result_messages = await self._agentic_search(recent_messages)
|
||||
|
||||
if not search_result_parts:
|
||||
if not result_messages:
|
||||
return
|
||||
|
||||
context_messages = [Message(role="user", text=self.context_prompt)]
|
||||
context_messages.extend([Message(role="user", text=part) for part in search_result_parts])
|
||||
context.extend_messages(self.source_id, context_messages)
|
||||
|
||||
# -- Internal methods (ported from AzureAISearchContextProvider) -----------
|
||||
context.extend_messages(self.source_id, [Message(role="user", text=self.context_prompt), *result_messages])
|
||||
|
||||
def _find_vector_fields(self, index: Any) -> list[str]:
|
||||
"""Find all fields that can store vectors."""
|
||||
@@ -432,7 +444,7 @@ class AzureAISearchContextProvider(BaseContextProvider):
|
||||
|
||||
self._auto_discovered_vector_field = True
|
||||
|
||||
async def _semantic_search(self, query: str) -> list[str]:
|
||||
async def _semantic_search(self, query: str) -> list[Message]:
|
||||
"""Perform semantic hybrid search."""
|
||||
await self._auto_discover_vector_field()
|
||||
|
||||
@@ -440,14 +452,14 @@ class AzureAISearchContextProvider(BaseContextProvider):
|
||||
if self.vector_field_name:
|
||||
vector_k = max(self.top_k, 50) if self.semantic_configuration_name else self.top_k
|
||||
if self._use_vectorizable_query:
|
||||
vector_queries = [
|
||||
VectorizableTextQuery(text=query, k_nearest_neighbors=vector_k, fields=self.vector_field_name)
|
||||
]
|
||||
vector_queries = [VectorizableTextQuery(text=query, k=vector_k, fields=self.vector_field_name)]
|
||||
elif self.embedding_function:
|
||||
query_vector = await self.embedding_function(query)
|
||||
vector_queries = [
|
||||
VectorizedQuery(vector=query_vector, k_nearest_neighbors=vector_k, fields=self.vector_field_name)
|
||||
]
|
||||
if isinstance(self.embedding_function, SupportsGetEmbeddings):
|
||||
embeddings = await self.embedding_function.get_embeddings([query]) # type: ignore[reportUnknownVariableType]
|
||||
query_vector: list[float] = embeddings[0].vector # type: ignore[reportUnknownVariableType]
|
||||
else:
|
||||
query_vector = await self.embedding_function(query)
|
||||
vector_queries = [VectorizedQuery(vector=query_vector, k=vector_k, fields=self.vector_field_name)]
|
||||
|
||||
search_params: dict[str, Any] = {"search_text": query, "top": self.top_k}
|
||||
if vector_queries:
|
||||
@@ -461,13 +473,13 @@ class AzureAISearchContextProvider(BaseContextProvider):
|
||||
raise RuntimeError("Search client is not initialized.")
|
||||
results = await self._search_client.search(**search_params) # type: ignore[reportUnknownVariableType]
|
||||
|
||||
formatted_results: list[str] = []
|
||||
result_messages: list[Message] = []
|
||||
async for doc in results: # type: ignore[reportUnknownVariableType]
|
||||
doc_id = doc.get("id") or doc.get("@search.id") # type: ignore[reportUnknownVariableType]
|
||||
doc_text: str = self._extract_document_text(doc, doc_id=doc_id) # type: ignore[reportUnknownArgumentType]
|
||||
if doc_text:
|
||||
formatted_results.append(doc_text) # type: ignore[reportUnknownArgumentType]
|
||||
return formatted_results
|
||||
result_messages.append(Message(role="user", text=doc_text)) # type: ignore[reportUnknownArgumentType]
|
||||
return result_messages
|
||||
|
||||
async def _ensure_knowledge_base(self) -> None:
|
||||
"""Ensure Knowledge Base and knowledge source are created or use existing KB."""
|
||||
@@ -550,7 +562,7 @@ class AzureAISearchContextProvider(BaseContextProvider):
|
||||
user_agent=AGENT_FRAMEWORK_USER_AGENT,
|
||||
)
|
||||
|
||||
async def _agentic_search(self, messages: list[Message]) -> list[str]:
|
||||
async def _agentic_search(self, messages: list[Message]) -> list[Message]:
|
||||
"""Perform agentic retrieval with multi-hop reasoning."""
|
||||
await self._ensure_knowledge_base()
|
||||
|
||||
@@ -577,14 +589,7 @@ class AzureAISearchContextProvider(BaseContextProvider):
|
||||
include_activity=True,
|
||||
)
|
||||
else:
|
||||
kb_messages = [
|
||||
KnowledgeBaseMessage(
|
||||
role=msg.role if hasattr(msg.role, "value") else str(msg.role),
|
||||
content=[KnowledgeBaseMessageTextContent(text=msg.text)],
|
||||
)
|
||||
for msg in messages
|
||||
if msg.text
|
||||
]
|
||||
kb_messages = self._prepare_messages_for_kb_search(messages)
|
||||
retrieval_request = KnowledgeBaseRetrievalRequest(
|
||||
messages=kb_messages,
|
||||
retrieval_reasoning_effort=reasoning_effort,
|
||||
@@ -596,17 +601,136 @@ class AzureAISearchContextProvider(BaseContextProvider):
|
||||
raise RuntimeError("Retrieval client not initialized.")
|
||||
retrieval_result = await self._retrieval_client.retrieve(retrieval_request=retrieval_request)
|
||||
|
||||
if retrieval_result.response and len(retrieval_result.response) > 0:
|
||||
assistant_message = retrieval_result.response[-1]
|
||||
if assistant_message.content:
|
||||
answer_parts: list[str] = []
|
||||
for content_item in assistant_message.content:
|
||||
if isinstance(content_item, KnowledgeBaseMessageTextContent) and content_item.text:
|
||||
answer_parts.append(content_item.text)
|
||||
if answer_parts:
|
||||
return answer_parts
|
||||
return self._parse_messages_from_kb_response(retrieval_result)
|
||||
|
||||
return ["No results found from Knowledge Base."]
|
||||
@staticmethod
|
||||
def _prepare_messages_for_kb_search(messages: list[Message]) -> list[KnowledgeBaseMessage]:
|
||||
"""Convert framework Messages to KnowledgeBaseMessages for agentic retrieval.
|
||||
|
||||
Handles text and image content types. Other content types (function calls,
|
||||
errors, etc.) are skipped.
|
||||
|
||||
Args:
|
||||
messages: Framework messages to convert.
|
||||
|
||||
Returns:
|
||||
List of KnowledgeBaseMessage objects suitable for retrieval requests.
|
||||
"""
|
||||
kb_messages: list[KnowledgeBaseMessage] = []
|
||||
for msg in messages:
|
||||
kb_content: list[KnowledgeBaseMessageTextContent | KnowledgeBaseMessageImageContent] = []
|
||||
if msg.contents:
|
||||
for content in msg.contents:
|
||||
match content.type:
|
||||
case "text" if content.text:
|
||||
kb_content.append(KnowledgeBaseMessageTextContent(text=content.text))
|
||||
case "uri" | "data" if (
|
||||
content.uri and content.media_type and content.media_type.startswith("image/")
|
||||
):
|
||||
kb_content.append(
|
||||
KnowledgeBaseMessageImageContent(
|
||||
image=KnowledgeBaseMessageImageContentImage(url=content.uri),
|
||||
)
|
||||
)
|
||||
elif msg.text:
|
||||
kb_content.append(KnowledgeBaseMessageTextContent(text=msg.text))
|
||||
if kb_content:
|
||||
kb_messages.append(KnowledgeBaseMessage(role=msg.role, content=kb_content)) # type: ignore[arg-type]
|
||||
return kb_messages
|
||||
|
||||
@staticmethod
|
||||
def _parse_references_to_annotations(references: list[KnowledgeBaseReference] | None) -> list[Annotation]:
|
||||
"""Convert Knowledge Base references to framework Annotations.
|
||||
|
||||
Captures all available fields from each reference subtype: URLs, doc keys,
|
||||
reranker scores, source data, and the raw reference object itself.
|
||||
|
||||
Args:
|
||||
references: The references from a Knowledge Base retrieval response.
|
||||
|
||||
Returns:
|
||||
List of citation Annotations.
|
||||
"""
|
||||
if not references:
|
||||
return []
|
||||
annotations: list[Annotation] = []
|
||||
for ref in references:
|
||||
url: str | None = None
|
||||
for attr in ("url", "blob_url", "doc_url", "web_url"):
|
||||
url = getattr(ref, attr, None)
|
||||
if url:
|
||||
break
|
||||
|
||||
annotation = Annotation(
|
||||
type="citation",
|
||||
url=url or "",
|
||||
title=getattr(ref, "title", None) or ref.id,
|
||||
)
|
||||
|
||||
extra: dict[str, Any] = {
|
||||
"reference_id": ref.id,
|
||||
"reference_type": getattr(ref, "type", None),
|
||||
"activity_source": ref.activity_source,
|
||||
}
|
||||
if ref.reranker_score is not None:
|
||||
extra["reranker_score"] = ref.reranker_score
|
||||
if ref.source_data:
|
||||
extra["source_data"] = ref.source_data
|
||||
doc_key = getattr(ref, "doc_key", None)
|
||||
if doc_key:
|
||||
extra["doc_key"] = doc_key
|
||||
if ref.additional_properties:
|
||||
extra["sdk_additional_properties"] = ref.additional_properties
|
||||
sensitivity_info = getattr(ref, "search_sensitivity_label_info", None)
|
||||
if sensitivity_info:
|
||||
extra["sensitivity_label"] = {
|
||||
"display_name": sensitivity_info.display_name,
|
||||
"sensitivity_label_id": sensitivity_info.sensitivity_label_id,
|
||||
"is_encrypted": sensitivity_info.is_encrypted,
|
||||
}
|
||||
|
||||
annotation["additional_properties"] = extra
|
||||
annotation["raw_representation"] = ref
|
||||
annotations.append(annotation)
|
||||
return annotations
|
||||
|
||||
@staticmethod
|
||||
def _parse_messages_from_kb_response(retrieval_result: KnowledgeBaseRetrievalResponse) -> list[Message]:
|
||||
"""Convert a Knowledge Base retrieval response to framework Messages.
|
||||
|
||||
Each KnowledgeBaseMessage becomes a Message. References from the response
|
||||
are converted to Annotations and attached to content items.
|
||||
|
||||
Args:
|
||||
retrieval_result: The full retrieval response including messages and references.
|
||||
|
||||
Returns:
|
||||
List of Messages, or a single default Message if no results found.
|
||||
"""
|
||||
if not retrieval_result.response:
|
||||
return [Message(role="assistant", text="No results found from Knowledge Base.")]
|
||||
|
||||
annotations = AzureAISearchContextProvider._parse_references_to_annotations(retrieval_result.references)
|
||||
|
||||
result_messages: list[Message] = []
|
||||
for kb_msg in retrieval_result.response:
|
||||
if not kb_msg.content:
|
||||
continue
|
||||
contents: list[Content] = []
|
||||
for item in kb_msg.content:
|
||||
if isinstance(item, KnowledgeBaseMessageTextContent) and item.text:
|
||||
contents.append(Content.from_text(item.text))
|
||||
elif isinstance(item, KnowledgeBaseMessageImageContent) and item.image and item.image.url:
|
||||
contents.append(Content.from_uri(uri=item.image.url, media_type="image/png"))
|
||||
if contents:
|
||||
if annotations:
|
||||
for c in contents:
|
||||
c.annotations = annotations
|
||||
result_messages.append(Message(role=kb_msg.role or "assistant", contents=contents))
|
||||
|
||||
if not result_messages:
|
||||
return [Message(role="assistant", text="No results found from Knowledge Base.")]
|
||||
return result_messages
|
||||
|
||||
def _extract_document_text(self, doc: dict[str, Any], doc_id: str | None = None) -> str:
|
||||
"""Extract readable text from a search document with optional citation."""
|
||||
|
||||
@@ -6,7 +6,7 @@ from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from agent_framework import Message
|
||||
from agent_framework import Content, Message
|
||||
from agent_framework._sessions import AgentSession, SessionContext
|
||||
from agent_framework.exceptions import SettingNotFoundError
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
@@ -720,7 +720,7 @@ class TestSemanticSearch:
|
||||
|
||||
results = await provider._semantic_search("test query")
|
||||
assert len(results) == 1
|
||||
assert "result text" in results[0]
|
||||
assert "result text" in results[0].text
|
||||
call_kwargs = mock_client.search.call_args[1]
|
||||
assert call_kwargs["search_text"] == "test query"
|
||||
|
||||
@@ -746,7 +746,11 @@ class TestSemanticSearch:
|
||||
provider = _make_provider()
|
||||
provider._use_vectorizable_query = False
|
||||
provider.vector_field_name = "embedding"
|
||||
provider.embedding_function = AsyncMock(return_value=[0.1, 0.2, 0.3])
|
||||
|
||||
async def _embed(query: str) -> list[float]:
|
||||
return [0.1, 0.2, 0.3]
|
||||
|
||||
provider.embedding_function = _embed
|
||||
mock_client = AsyncMock()
|
||||
|
||||
async def _search(**kwargs):
|
||||
@@ -757,7 +761,6 @@ class TestSemanticSearch:
|
||||
|
||||
results = await provider._semantic_search("embed query")
|
||||
assert len(results) == 1
|
||||
provider.embedding_function.assert_awaited_once_with("embed query")
|
||||
call_kwargs = mock_client.search.call_args[1]
|
||||
assert "vector_queries" in call_kwargs
|
||||
|
||||
@@ -1100,9 +1103,11 @@ class TestAgenticSearch:
|
||||
mock_content = Mock()
|
||||
mock_content.text = "Answer text"
|
||||
mock_message = Mock()
|
||||
mock_message.role = "assistant"
|
||||
mock_message.content = [mock_content]
|
||||
mock_result = Mock()
|
||||
mock_result.response = [mock_message]
|
||||
mock_result.references = None
|
||||
|
||||
mock_retrieval = AsyncMock()
|
||||
mock_retrieval.retrieve = AsyncMock(return_value=mock_result)
|
||||
@@ -1115,7 +1120,9 @@ class TestAgenticSearch:
|
||||
):
|
||||
results = await provider._agentic_search([Message(role="user", contents=["test query"])])
|
||||
|
||||
assert results == ["Answer text"]
|
||||
assert len(results) == 1
|
||||
assert results[0].text == "Answer text"
|
||||
assert results[0].role == "assistant"
|
||||
|
||||
async def test_non_minimal_reasoning_uses_messages(self) -> None:
|
||||
provider = _make_provider()
|
||||
@@ -1126,9 +1133,11 @@ class TestAgenticSearch:
|
||||
mock_content = Mock()
|
||||
mock_content.text = "Medium answer"
|
||||
mock_message = Mock()
|
||||
mock_message.role = "assistant"
|
||||
mock_message.content = [mock_content]
|
||||
mock_result = Mock()
|
||||
mock_result.response = [mock_message]
|
||||
mock_result.references = None
|
||||
|
||||
mock_retrieval = AsyncMock()
|
||||
mock_retrieval.retrieve = AsyncMock(return_value=mock_result)
|
||||
@@ -1143,7 +1152,8 @@ class TestAgenticSearch:
|
||||
Message(role="assistant", contents=["answer"]),
|
||||
])
|
||||
|
||||
assert results == ["Medium answer"]
|
||||
assert len(results) == 1
|
||||
assert results[0].text == "Medium answer"
|
||||
mock_retrieval.retrieve.assert_awaited_once()
|
||||
|
||||
async def test_no_response_returns_default_message(self) -> None:
|
||||
@@ -1154,13 +1164,15 @@ class TestAgenticSearch:
|
||||
|
||||
mock_result = Mock()
|
||||
mock_result.response = []
|
||||
mock_result.references = None
|
||||
|
||||
mock_retrieval = AsyncMock()
|
||||
mock_retrieval.retrieve = AsyncMock(return_value=mock_result)
|
||||
provider._retrieval_client = mock_retrieval
|
||||
|
||||
results = await provider._agentic_search([Message(role="user", contents=["query"])])
|
||||
assert results == ["No results found from Knowledge Base."]
|
||||
assert len(results) == 1
|
||||
assert results[0].text == "No results found from Knowledge Base."
|
||||
|
||||
async def test_empty_content_returns_default_message(self) -> None:
|
||||
provider = _make_provider()
|
||||
@@ -1172,13 +1184,15 @@ class TestAgenticSearch:
|
||||
mock_message.content = None
|
||||
mock_result = Mock()
|
||||
mock_result.response = [mock_message]
|
||||
mock_result.references = None
|
||||
|
||||
mock_retrieval = AsyncMock()
|
||||
mock_retrieval.retrieve = AsyncMock(return_value=mock_result)
|
||||
provider._retrieval_client = mock_retrieval
|
||||
|
||||
results = await provider._agentic_search([Message(role="user", contents=["query"])])
|
||||
assert results == ["No results found from Knowledge Base."]
|
||||
assert len(results) == 1
|
||||
assert results[0].text == "No results found from Knowledge Base."
|
||||
|
||||
async def test_answer_synthesis_output_mode(self) -> None:
|
||||
provider = _make_provider()
|
||||
@@ -1190,9 +1204,11 @@ class TestAgenticSearch:
|
||||
mock_content = Mock()
|
||||
mock_content.text = "Synthesized answer"
|
||||
mock_message = Mock()
|
||||
mock_message.role = "assistant"
|
||||
mock_message.content = [mock_content]
|
||||
mock_result = Mock()
|
||||
mock_result.response = [mock_message]
|
||||
mock_result.references = None
|
||||
|
||||
mock_retrieval = AsyncMock()
|
||||
mock_retrieval.retrieve = AsyncMock(return_value=mock_result)
|
||||
@@ -1204,7 +1220,8 @@ class TestAgenticSearch:
|
||||
):
|
||||
results = await provider._agentic_search([Message(role="user", contents=["query"])])
|
||||
|
||||
assert results == ["Synthesized answer"]
|
||||
assert len(results) == 1
|
||||
assert results[0].text == "Synthesized answer"
|
||||
|
||||
async def test_content_without_text_excluded(self) -> None:
|
||||
provider = _make_provider()
|
||||
@@ -1217,9 +1234,11 @@ class TestAgenticSearch:
|
||||
mock_content_no_text = Mock()
|
||||
mock_content_no_text.text = None
|
||||
mock_message = Mock()
|
||||
mock_message.role = "assistant"
|
||||
mock_message.content = [mock_content_no_text, mock_content_with_text]
|
||||
mock_result = Mock()
|
||||
mock_result.response = [mock_message]
|
||||
mock_result.references = None
|
||||
|
||||
mock_retrieval = AsyncMock()
|
||||
mock_retrieval.retrieve = AsyncMock(return_value=mock_result)
|
||||
@@ -1231,7 +1250,8 @@ class TestAgenticSearch:
|
||||
):
|
||||
results = await provider._agentic_search([Message(role="user", contents=["query"])])
|
||||
|
||||
assert results == ["Good content"]
|
||||
assert len(results) == 1
|
||||
assert results[0].text == "Good content"
|
||||
|
||||
async def test_none_response_returns_default_message(self) -> None:
|
||||
provider = _make_provider()
|
||||
@@ -1241,13 +1261,355 @@ class TestAgenticSearch:
|
||||
|
||||
mock_result = Mock()
|
||||
mock_result.response = None
|
||||
mock_result.references = None
|
||||
|
||||
mock_retrieval = AsyncMock()
|
||||
mock_retrieval.retrieve = AsyncMock(return_value=mock_result)
|
||||
provider._retrieval_client = mock_retrieval
|
||||
|
||||
results = await provider._agentic_search([Message(role="user", contents=["query"])])
|
||||
assert results == ["No results found from Knowledge Base."]
|
||||
assert len(results) == 1
|
||||
assert results[0].text == "No results found from Knowledge Base."
|
||||
|
||||
|
||||
# -- before_run: agentic mode --------------------------------------------------
|
||||
|
||||
|
||||
# -- _prepare_messages_for_kb_search / _parse_content_from_kb_response --------
|
||||
|
||||
|
||||
class TestPrepareMessagesForKbSearch:
|
||||
"""Tests for _prepare_messages_for_kb_search."""
|
||||
|
||||
def test_text_only_messages(self) -> None:
|
||||
messages = [
|
||||
Message(role="user", contents=["hello"]),
|
||||
Message(role="assistant", contents=["world"]),
|
||||
]
|
||||
result = AzureAISearchContextProvider._prepare_messages_for_kb_search(messages)
|
||||
assert len(result) == 2
|
||||
assert result[0].role == "user"
|
||||
assert result[1].role == "assistant"
|
||||
# Verify content is KnowledgeBaseMessageTextContent
|
||||
from azure.search.documents.knowledgebases.models import KnowledgeBaseMessageTextContent
|
||||
|
||||
assert isinstance(result[0].content[0], KnowledgeBaseMessageTextContent)
|
||||
assert result[0].content[0].text == "hello"
|
||||
|
||||
def test_image_uri_content(self) -> None:
|
||||
from agent_framework import Content
|
||||
|
||||
img = Content.from_uri(uri="https://example.com/photo.png", media_type="image/png")
|
||||
messages = [Message(role="user", contents=[img])]
|
||||
result = AzureAISearchContextProvider._prepare_messages_for_kb_search(messages)
|
||||
assert len(result) == 1
|
||||
from azure.search.documents.knowledgebases.models import KnowledgeBaseMessageImageContent
|
||||
|
||||
assert isinstance(result[0].content[0], KnowledgeBaseMessageImageContent)
|
||||
assert result[0].content[0].image.url == "https://example.com/photo.png"
|
||||
|
||||
def test_mixed_text_and_image_content(self) -> None:
|
||||
from agent_framework import Content
|
||||
|
||||
text = Content.from_text("describe this image")
|
||||
img = Content.from_uri(uri="https://example.com/img.jpg", media_type="image/jpeg")
|
||||
messages = [Message(role="user", contents=[text, img])]
|
||||
result = AzureAISearchContextProvider._prepare_messages_for_kb_search(messages)
|
||||
assert len(result) == 1
|
||||
assert len(result[0].content) == 2
|
||||
|
||||
def test_skips_non_text_non_image_content(self) -> None:
|
||||
from agent_framework import Content
|
||||
|
||||
error = Content.from_error(message="oops")
|
||||
messages = [Message(role="user", contents=[error])]
|
||||
result = AzureAISearchContextProvider._prepare_messages_for_kb_search(messages)
|
||||
assert len(result) == 0 # message had no usable content
|
||||
|
||||
def test_skips_empty_text(self) -> None:
|
||||
from agent_framework import Content
|
||||
|
||||
empty = Content.from_text("")
|
||||
messages = [Message(role="user", contents=[empty])]
|
||||
result = AzureAISearchContextProvider._prepare_messages_for_kb_search(messages)
|
||||
assert len(result) == 0
|
||||
|
||||
def test_fallback_to_msg_text_when_no_contents(self) -> None:
|
||||
msg = Message(role="user", text="fallback text")
|
||||
result = AzureAISearchContextProvider._prepare_messages_for_kb_search([msg])
|
||||
assert len(result) == 1
|
||||
assert result[0].content[0].text == "fallback text"
|
||||
|
||||
def test_data_uri_image(self) -> None:
|
||||
from agent_framework import Content
|
||||
|
||||
img = Content.from_data(data=b"\x89PNG", media_type="image/png")
|
||||
messages = [Message(role="user", contents=[img])]
|
||||
result = AzureAISearchContextProvider._prepare_messages_for_kb_search(messages)
|
||||
assert len(result) == 1
|
||||
from azure.search.documents.knowledgebases.models import KnowledgeBaseMessageImageContent
|
||||
|
||||
assert isinstance(result[0].content[0], KnowledgeBaseMessageImageContent)
|
||||
|
||||
def test_non_image_uri_skipped(self) -> None:
|
||||
from agent_framework import Content
|
||||
|
||||
pdf = Content.from_uri(uri="https://example.com/doc.pdf", media_type="application/pdf")
|
||||
messages = [Message(role="user", contents=[pdf])]
|
||||
result = AzureAISearchContextProvider._prepare_messages_for_kb_search(messages)
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
class TestParseReferencesToAnnotations:
|
||||
"""Tests for _parse_references_to_annotations."""
|
||||
|
||||
def test_none_references(self) -> None:
|
||||
result = AzureAISearchContextProvider._parse_references_to_annotations(None)
|
||||
assert result == []
|
||||
|
||||
def test_empty_references(self) -> None:
|
||||
result = AzureAISearchContextProvider._parse_references_to_annotations([])
|
||||
assert result == []
|
||||
|
||||
def test_search_index_reference_captures_doc_key(self) -> None:
|
||||
from azure.search.documents.knowledgebases.models import KnowledgeBaseSearchIndexReference
|
||||
|
||||
ref = KnowledgeBaseSearchIndexReference(id="ref-1", activity_source=0, doc_key="doc-1")
|
||||
result = AzureAISearchContextProvider._parse_references_to_annotations([ref])
|
||||
assert len(result) == 1
|
||||
assert result[0]["type"] == "citation"
|
||||
assert result[0]["title"] == "ref-1"
|
||||
extra = result[0]["additional_properties"]
|
||||
assert extra["reference_id"] == "ref-1"
|
||||
assert extra["reference_type"] == "searchIndex"
|
||||
assert extra["activity_source"] == 0
|
||||
assert extra["doc_key"] == "doc-1"
|
||||
|
||||
def test_web_reference_with_url_and_title(self) -> None:
|
||||
from azure.search.documents.knowledgebases.models import KnowledgeBaseWebReference
|
||||
|
||||
ref = KnowledgeBaseWebReference(
|
||||
id="ref-2", activity_source=0, url="https://example.com/page", title="Example Page"
|
||||
)
|
||||
result = AzureAISearchContextProvider._parse_references_to_annotations([ref])
|
||||
assert len(result) == 1
|
||||
assert result[0]["url"] == "https://example.com/page"
|
||||
assert result[0]["title"] == "Example Page"
|
||||
assert result[0]["additional_properties"]["reference_type"] == "web"
|
||||
|
||||
def test_blob_reference_extracts_blob_url(self) -> None:
|
||||
from azure.search.documents.knowledgebases.models import KnowledgeBaseAzureBlobReference
|
||||
|
||||
ref = KnowledgeBaseAzureBlobReference(
|
||||
id="ref-3", activity_source=0, blob_url="https://storage.blob.core.windows.net/doc.pdf"
|
||||
)
|
||||
result = AzureAISearchContextProvider._parse_references_to_annotations([ref])
|
||||
assert result[0]["url"] == "https://storage.blob.core.windows.net/doc.pdf"
|
||||
assert result[0]["additional_properties"]["reference_type"] == "azureBlob"
|
||||
|
||||
def test_source_data_and_reranker_score(self) -> None:
|
||||
from azure.search.documents.knowledgebases.models import KnowledgeBaseSearchIndexReference
|
||||
|
||||
ref = KnowledgeBaseSearchIndexReference(
|
||||
id="ref-4", activity_source=0, source_data={"chunk": "some text"}, reranker_score=0.95
|
||||
)
|
||||
result = AzureAISearchContextProvider._parse_references_to_annotations([ref])
|
||||
extra = result[0]["additional_properties"]
|
||||
assert extra["source_data"] == {"chunk": "some text"}
|
||||
assert extra["reranker_score"] == 0.95
|
||||
|
||||
def test_raw_representation_stores_original_ref(self) -> None:
|
||||
from azure.search.documents.knowledgebases.models import KnowledgeBaseSearchIndexReference
|
||||
|
||||
ref = KnowledgeBaseSearchIndexReference(id="ref-5", activity_source=0)
|
||||
result = AzureAISearchContextProvider._parse_references_to_annotations([ref])
|
||||
assert result[0]["raw_representation"] is ref
|
||||
|
||||
def test_remote_sharepoint_captures_sensitivity_label(self) -> None:
|
||||
from azure.search.documents.knowledgebases.models import (
|
||||
KnowledgeBaseRemoteSharePointReference,
|
||||
SharePointSensitivityLabelInfo,
|
||||
)
|
||||
|
||||
label = SharePointSensitivityLabelInfo(
|
||||
display_name="Confidential", sensitivity_label_id="lbl-1", is_encrypted=True
|
||||
)
|
||||
ref = KnowledgeBaseRemoteSharePointReference(
|
||||
id="ref-6", activity_source=0, web_url="https://sp.example.com/doc", search_sensitivity_label_info=label
|
||||
)
|
||||
result = AzureAISearchContextProvider._parse_references_to_annotations([ref])
|
||||
assert result[0]["url"] == "https://sp.example.com/doc"
|
||||
sl = result[0]["additional_properties"]["sensitivity_label"]
|
||||
assert sl["display_name"] == "Confidential"
|
||||
assert sl["sensitivity_label_id"] == "lbl-1"
|
||||
assert sl["is_encrypted"] is True
|
||||
|
||||
def test_multiple_references(self) -> None:
|
||||
from azure.search.documents.knowledgebases.models import (
|
||||
KnowledgeBaseSearchIndexReference,
|
||||
KnowledgeBaseWebReference,
|
||||
)
|
||||
|
||||
refs = [
|
||||
KnowledgeBaseSearchIndexReference(id="ref-a", activity_source=0),
|
||||
KnowledgeBaseWebReference(id="ref-b", activity_source=1, url="https://example.com"),
|
||||
]
|
||||
result = AzureAISearchContextProvider._parse_references_to_annotations(refs)
|
||||
assert len(result) == 2
|
||||
assert result[0]["additional_properties"]["activity_source"] == 0
|
||||
assert result[1]["additional_properties"]["activity_source"] == 1
|
||||
|
||||
|
||||
class TestParseMessagesFromKbResponse:
|
||||
"""Tests for _parse_messages_from_kb_response."""
|
||||
|
||||
def test_converts_all_messages(self) -> None:
|
||||
from azure.search.documents.knowledgebases.models import (
|
||||
KnowledgeBaseMessage,
|
||||
KnowledgeBaseMessageTextContent,
|
||||
KnowledgeBaseRetrievalResponse,
|
||||
)
|
||||
|
||||
response = KnowledgeBaseRetrievalResponse(
|
||||
response=[
|
||||
KnowledgeBaseMessage(role="user", content=[KnowledgeBaseMessageTextContent(text="q")]),
|
||||
KnowledgeBaseMessage(role="assistant", content=[KnowledgeBaseMessageTextContent(text="answer")]),
|
||||
],
|
||||
references=None,
|
||||
)
|
||||
result = AzureAISearchContextProvider._parse_messages_from_kb_response(response)
|
||||
assert len(result) == 2
|
||||
assert result[0].role == "user"
|
||||
assert result[0].text == "q"
|
||||
assert result[1].role == "assistant"
|
||||
assert result[1].text == "answer"
|
||||
|
||||
def test_none_response_returns_default(self) -> None:
|
||||
from azure.search.documents.knowledgebases.models import KnowledgeBaseRetrievalResponse
|
||||
|
||||
response = KnowledgeBaseRetrievalResponse(response=None, references=None)
|
||||
result = AzureAISearchContextProvider._parse_messages_from_kb_response(response)
|
||||
assert len(result) == 1
|
||||
assert result[0].text == "No results found from Knowledge Base."
|
||||
|
||||
def test_empty_response_returns_default(self) -> None:
|
||||
from azure.search.documents.knowledgebases.models import KnowledgeBaseRetrievalResponse
|
||||
|
||||
response = KnowledgeBaseRetrievalResponse(response=[], references=None)
|
||||
result = AzureAISearchContextProvider._parse_messages_from_kb_response(response)
|
||||
assert len(result) == 1
|
||||
assert result[0].text == "No results found from Knowledge Base."
|
||||
|
||||
def test_image_content(self) -> None:
|
||||
from azure.search.documents.knowledgebases.models import (
|
||||
KnowledgeBaseMessage,
|
||||
KnowledgeBaseMessageImageContent,
|
||||
KnowledgeBaseMessageImageContentImage,
|
||||
KnowledgeBaseRetrievalResponse,
|
||||
)
|
||||
|
||||
response = KnowledgeBaseRetrievalResponse(
|
||||
response=[
|
||||
KnowledgeBaseMessage(
|
||||
role="assistant",
|
||||
content=[
|
||||
KnowledgeBaseMessageImageContent(
|
||||
image=KnowledgeBaseMessageImageContentImage(url="https://img.example.com/a.png")
|
||||
)
|
||||
],
|
||||
),
|
||||
],
|
||||
references=None,
|
||||
)
|
||||
result = AzureAISearchContextProvider._parse_messages_from_kb_response(response)
|
||||
assert len(result) == 1
|
||||
assert result[0].contents[0].type == "uri"
|
||||
assert result[0].contents[0].uri == "https://img.example.com/a.png"
|
||||
|
||||
def test_mixed_text_and_image_content(self) -> None:
|
||||
from azure.search.documents.knowledgebases.models import (
|
||||
KnowledgeBaseMessage,
|
||||
KnowledgeBaseMessageImageContent,
|
||||
KnowledgeBaseMessageImageContentImage,
|
||||
KnowledgeBaseMessageTextContent,
|
||||
KnowledgeBaseRetrievalResponse,
|
||||
)
|
||||
|
||||
response = KnowledgeBaseRetrievalResponse(
|
||||
response=[
|
||||
KnowledgeBaseMessage(
|
||||
role="assistant",
|
||||
content=[
|
||||
KnowledgeBaseMessageTextContent(text="description"),
|
||||
KnowledgeBaseMessageImageContent(
|
||||
image=KnowledgeBaseMessageImageContentImage(url="https://img.example.com/b.png")
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
references=None,
|
||||
)
|
||||
result = AzureAISearchContextProvider._parse_messages_from_kb_response(response)
|
||||
assert len(result) == 1
|
||||
assert len(result[0].contents) == 2
|
||||
assert result[0].contents[0].type == "text"
|
||||
assert result[0].contents[1].type == "uri"
|
||||
|
||||
def test_references_become_annotations(self) -> None:
|
||||
from azure.search.documents.knowledgebases.models import (
|
||||
KnowledgeBaseMessage,
|
||||
KnowledgeBaseMessageTextContent,
|
||||
KnowledgeBaseRetrievalResponse,
|
||||
KnowledgeBaseWebReference,
|
||||
)
|
||||
|
||||
response = KnowledgeBaseRetrievalResponse(
|
||||
response=[
|
||||
KnowledgeBaseMessage(role="assistant", content=[KnowledgeBaseMessageTextContent(text="answer")]),
|
||||
],
|
||||
references=[
|
||||
KnowledgeBaseWebReference(
|
||||
id="ref-1", activity_source=0, url="https://example.com", title="Example"
|
||||
),
|
||||
],
|
||||
)
|
||||
result = AzureAISearchContextProvider._parse_messages_from_kb_response(response)
|
||||
assert len(result) == 1
|
||||
annotations = result[0].contents[0].annotations
|
||||
assert annotations is not None
|
||||
assert len(annotations) == 1
|
||||
assert annotations[0]["type"] == "citation"
|
||||
assert annotations[0]["url"] == "https://example.com"
|
||||
assert annotations[0]["title"] == "Example"
|
||||
|
||||
def test_multiple_messages_with_references(self) -> None:
|
||||
from azure.search.documents.knowledgebases.models import (
|
||||
KnowledgeBaseMessage,
|
||||
KnowledgeBaseMessageTextContent,
|
||||
KnowledgeBaseRetrievalResponse,
|
||||
KnowledgeBaseSearchIndexReference,
|
||||
)
|
||||
|
||||
response = KnowledgeBaseRetrievalResponse(
|
||||
response=[
|
||||
KnowledgeBaseMessage(role="user", content=[KnowledgeBaseMessageTextContent(text="q")]),
|
||||
KnowledgeBaseMessage(
|
||||
role="assistant",
|
||||
content=[
|
||||
KnowledgeBaseMessageTextContent(text="part1"),
|
||||
KnowledgeBaseMessageTextContent(text="part2"),
|
||||
],
|
||||
),
|
||||
],
|
||||
references=[KnowledgeBaseSearchIndexReference(id="doc-1", activity_source=0)],
|
||||
)
|
||||
result = AzureAISearchContextProvider._parse_messages_from_kb_response(response)
|
||||
assert len(result) == 2
|
||||
# All content items get annotations
|
||||
for msg in result:
|
||||
for c in msg.contents:
|
||||
assert c.annotations is not None
|
||||
assert len(c.annotations) == 1
|
||||
|
||||
|
||||
# -- before_run: agentic mode --------------------------------------------------
|
||||
@@ -1266,9 +1628,11 @@ class TestBeforeRunAgentic:
|
||||
mock_content = Mock()
|
||||
mock_content.text = "agentic result"
|
||||
mock_message = Mock()
|
||||
mock_message.role = "assistant"
|
||||
mock_message.content = [mock_content]
|
||||
mock_result = Mock()
|
||||
mock_result.response = [mock_message]
|
||||
mock_result.references = None
|
||||
|
||||
mock_retrieval = AsyncMock()
|
||||
mock_retrieval.retrieve = AsyncMock(return_value=mock_result)
|
||||
|
||||
+3
@@ -135,6 +135,9 @@ async def main() -> None:
|
||||
async for chunk in agent.run(user_input, stream=True):
|
||||
if chunk.text:
|
||||
print(chunk.text, end="", flush=True)
|
||||
for content in chunk.contents:
|
||||
if content.annotations:
|
||||
print(f"\n[Sources: {content.annotations}]", end="", flush=True)
|
||||
|
||||
print("\n")
|
||||
|
||||
|
||||
+21
-3
@@ -4,7 +4,7 @@ import asyncio
|
||||
import os
|
||||
|
||||
from agent_framework import Agent
|
||||
from agent_framework.azure import AzureAIAgentClient, AzureAISearchContextProvider
|
||||
from agent_framework.azure import AzureAIAgentClient, AzureAISearchContextProvider, AzureOpenAIEmbeddingClient
|
||||
from azure.identity.aio import AzureCliCredential
|
||||
from dotenv import load_dotenv
|
||||
|
||||
@@ -30,6 +30,8 @@ Prerequisites:
|
||||
- AZURE_SEARCH_INDEX_NAME: Your search index name
|
||||
- AZURE_AI_PROJECT_ENDPOINT: Your Azure AI Foundry project endpoint
|
||||
- AZURE_AI_MODEL_DEPLOYMENT_NAME: Your model deployment name (e.g., "gpt-4o")
|
||||
- AZURE_OPENAI_EMBEDDING_MODEL_ID: (Optional) Your embedding model for hybrid search (e.g., "text-embedding-3-small")
|
||||
- AZURE_OPENAI_ENDPOINT: (Optional) Your Azure OpenAI resource URL, required if using an OpenAI embedding model for hybrid search
|
||||
"""
|
||||
|
||||
# Sample queries to demonstrate RAG
|
||||
@@ -43,12 +45,24 @@ USER_INPUTS = [
|
||||
async def main() -> None:
|
||||
"""Main function demonstrating Azure AI Search semantic mode."""
|
||||
|
||||
credential = AzureCliCredential()
|
||||
|
||||
# Get configuration from environment
|
||||
search_endpoint = os.environ["AZURE_SEARCH_ENDPOINT"]
|
||||
search_key = os.environ.get("AZURE_SEARCH_API_KEY")
|
||||
index_name = os.environ["AZURE_SEARCH_INDEX_NAME"]
|
||||
project_endpoint = os.environ["AZURE_AI_PROJECT_ENDPOINT"]
|
||||
model_deployment = os.environ.get("AZURE_AI_MODEL_DEPLOYMENT_NAME", "gpt-4o")
|
||||
openai_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
|
||||
embedding_model = os.environ.get("AZURE_OPENAI_EMBEDDING_MODEL_ID", "text-embedding-3-small")
|
||||
|
||||
embedding_client = None
|
||||
if openai_endpoint and embedding_model:
|
||||
embedding_client = AzureOpenAIEmbeddingClient(
|
||||
endpoint=openai_endpoint,
|
||||
deployment_name=embedding_model,
|
||||
credential=credential,
|
||||
)
|
||||
|
||||
# Create Azure AI Search context provider with semantic mode (recommended, fast)
|
||||
print("Using SEMANTIC mode (hybrid search + semantic ranking, fast)\n")
|
||||
@@ -57,9 +71,13 @@ async def main() -> None:
|
||||
endpoint=search_endpoint,
|
||||
index_name=index_name,
|
||||
api_key=search_key, # Use api_key for API key auth, or credential for managed identity
|
||||
credential=AzureCliCredential() if not search_key else None,
|
||||
credential=credential if not search_key else None,
|
||||
mode="semantic", # Default mode
|
||||
top_k=3, # Retrieve top 3 most relevant documents
|
||||
embedding_function=embedding_client, # Provide embedding function for hybrid search
|
||||
vector_field_name="DescriptionVector"
|
||||
if embedding_client
|
||||
else None, # Set vector field for hybrid search if using embeddings
|
||||
)
|
||||
|
||||
# Create agent with search context provider
|
||||
@@ -68,7 +86,7 @@ async def main() -> None:
|
||||
AzureAIAgentClient(
|
||||
project_endpoint=project_endpoint,
|
||||
model_deployment_name=model_deployment,
|
||||
credential=AzureCliCredential(),
|
||||
credential=credential,
|
||||
) as client,
|
||||
Agent(
|
||||
client=client,
|
||||
|
||||
Reference in New Issue
Block a user