Python: Azure AI Search provider improvements - EmbeddingGenerator, async context manager, KB message handling (#4212)

* small updates and improvements in the azure AISearch provider

* Fix mypy errors and embedding function test

- Use separate variable for embeddings result to avoid mypy type reassignment error
- Fix test_vectorized_query_with_embedding_function: use real async function
  instead of AsyncMock which falsely matches SupportsGetEmbeddings protocol

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* fixes from feedback

---------

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Eduard van Valkenburg
2026-02-25 07:47:26 +01:00
committed by GitHub
Unverified
parent 2ad0caf069
commit 4530504a3d
4 changed files with 570 additions and 61 deletions
@@ -13,7 +13,7 @@ import sys
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypedDict
from agent_framework import AGENT_FRAMEWORK_USER_AGENT, Message
from agent_framework import AGENT_FRAMEWORK_USER_AGENT, Annotation, Content, Message, SupportsGetEmbeddings
from agent_framework._sessions import AgentSession, BaseContextProvider, SessionContext
from agent_framework._settings import SecretString, load_settings
from agent_framework.azure._entra_id_authentication import AzureCredentialTypes
@@ -47,8 +47,12 @@ if TYPE_CHECKING:
from azure.search.documents.knowledgebases.aio import KnowledgeBaseRetrievalClient
from azure.search.documents.knowledgebases.models import (
KnowledgeBaseMessage,
KnowledgeBaseMessageImageContent,
KnowledgeBaseMessageImageContentImage,
KnowledgeBaseMessageTextContent,
KnowledgeBaseReference,
KnowledgeBaseRetrievalRequest,
KnowledgeBaseRetrievalResponse,
KnowledgeRetrievalIntent,
KnowledgeRetrievalSemanticIntent,
)
@@ -78,8 +82,12 @@ try:
from azure.search.documents.knowledgebases.aio import KnowledgeBaseRetrievalClient
from azure.search.documents.knowledgebases.models import (
KnowledgeBaseMessage,
KnowledgeBaseMessageImageContent,
KnowledgeBaseMessageImageContentImage,
KnowledgeBaseMessageTextContent,
KnowledgeBaseReference,
KnowledgeBaseRetrievalRequest,
KnowledgeBaseRetrievalResponse,
KnowledgeRetrievalIntent,
KnowledgeRetrievalSemanticIntent,
)
@@ -154,7 +162,9 @@ class AzureAISearchContextProvider(BaseContextProvider):
top_k: int = 5,
semantic_configuration_name: str | None = None,
vector_field_name: str | None = None,
embedding_function: Callable[[str], Awaitable[list[float]]] | None = None,
embedding_function: Callable[[str], Awaitable[list[float]]]
| SupportsGetEmbeddings[str, list[float], Any]
| None = None,
context_prompt: str | None = None,
azure_openai_resource_url: str | None = None,
model_deployment_name: str | None = None,
@@ -181,7 +191,7 @@ class AzureAISearchContextProvider(BaseContextProvider):
top_k: Maximum number of documents to retrieve. Default: 5.
semantic_configuration_name: Name of semantic configuration in the index.
vector_field_name: Name of the vector field in the index.
embedding_function: Async function to generate embeddings.
embedding_function: Async function to generate embeddings or a SupportsGetEmbeddings instance.
context_prompt: Custom prompt to prepend to retrieved context.
azure_openai_resource_url: Azure OpenAI resource URL for Knowledge Base.
model_deployment_name: Model deployment name in Azure OpenAI.
@@ -309,9 +319,20 @@ class AzureAISearchContextProvider(BaseContextProvider):
exc_tb: Any,
) -> None:
"""Async context manager exit - cleanup clients."""
await self.close()
async def close(self) -> None:
"""Close all the open clients."""
if self._retrieval_client is not None:
await self._retrieval_client.close()
self._retrieval_client = None
self._knowledge_base_initialized = False
if self._search_client is not None:
await self._search_client.close()
self._search_client = None
if self._index_client is not None:
await self._index_client.close()
self._index_client = None
# -- Hooks pattern ---------------------------------------------------------
@@ -326,32 +347,23 @@ class AzureAISearchContextProvider(BaseContextProvider):
"""Retrieve relevant context from Azure AI Search and add to session context."""
messages_list = list(context.input_messages)
def get_role_value(role: str | Any) -> str:
return role.value if hasattr(role, "value") else str(role)
filtered_messages = [
msg
for msg in messages_list
if msg and msg.text and msg.text.strip() and get_role_value(msg.role) in ["user", "assistant"]
msg for msg in messages_list if msg and msg.text and msg.text.strip() and msg.role in ["user", "assistant"]
]
if not filtered_messages:
return
if self.mode == "semantic":
query = "\n".join(msg.text for msg in filtered_messages)
search_result_parts = await self._semantic_search(query)
result_messages = await self._semantic_search(query)
else:
recent_messages = filtered_messages[-self.agentic_message_history_count :]
search_result_parts = await self._agentic_search(recent_messages)
result_messages = await self._agentic_search(recent_messages)
if not search_result_parts:
if not result_messages:
return
context_messages = [Message(role="user", text=self.context_prompt)]
context_messages.extend([Message(role="user", text=part) for part in search_result_parts])
context.extend_messages(self.source_id, context_messages)
# -- Internal methods (ported from AzureAISearchContextProvider) -----------
context.extend_messages(self.source_id, [Message(role="user", text=self.context_prompt), *result_messages])
def _find_vector_fields(self, index: Any) -> list[str]:
"""Find all fields that can store vectors."""
@@ -432,7 +444,7 @@ class AzureAISearchContextProvider(BaseContextProvider):
self._auto_discovered_vector_field = True
async def _semantic_search(self, query: str) -> list[str]:
async def _semantic_search(self, query: str) -> list[Message]:
"""Perform semantic hybrid search."""
await self._auto_discover_vector_field()
@@ -440,14 +452,14 @@ class AzureAISearchContextProvider(BaseContextProvider):
if self.vector_field_name:
vector_k = max(self.top_k, 50) if self.semantic_configuration_name else self.top_k
if self._use_vectorizable_query:
vector_queries = [
VectorizableTextQuery(text=query, k_nearest_neighbors=vector_k, fields=self.vector_field_name)
]
vector_queries = [VectorizableTextQuery(text=query, k=vector_k, fields=self.vector_field_name)]
elif self.embedding_function:
query_vector = await self.embedding_function(query)
vector_queries = [
VectorizedQuery(vector=query_vector, k_nearest_neighbors=vector_k, fields=self.vector_field_name)
]
if isinstance(self.embedding_function, SupportsGetEmbeddings):
embeddings = await self.embedding_function.get_embeddings([query]) # type: ignore[reportUnknownVariableType]
query_vector: list[float] = embeddings[0].vector # type: ignore[reportUnknownVariableType]
else:
query_vector = await self.embedding_function(query)
vector_queries = [VectorizedQuery(vector=query_vector, k=vector_k, fields=self.vector_field_name)]
search_params: dict[str, Any] = {"search_text": query, "top": self.top_k}
if vector_queries:
@@ -461,13 +473,13 @@ class AzureAISearchContextProvider(BaseContextProvider):
raise RuntimeError("Search client is not initialized.")
results = await self._search_client.search(**search_params) # type: ignore[reportUnknownVariableType]
formatted_results: list[str] = []
result_messages: list[Message] = []
async for doc in results: # type: ignore[reportUnknownVariableType]
doc_id = doc.get("id") or doc.get("@search.id") # type: ignore[reportUnknownVariableType]
doc_text: str = self._extract_document_text(doc, doc_id=doc_id) # type: ignore[reportUnknownArgumentType]
if doc_text:
formatted_results.append(doc_text) # type: ignore[reportUnknownArgumentType]
return formatted_results
result_messages.append(Message(role="user", text=doc_text)) # type: ignore[reportUnknownArgumentType]
return result_messages
async def _ensure_knowledge_base(self) -> None:
"""Ensure Knowledge Base and knowledge source are created or use existing KB."""
@@ -550,7 +562,7 @@ class AzureAISearchContextProvider(BaseContextProvider):
user_agent=AGENT_FRAMEWORK_USER_AGENT,
)
async def _agentic_search(self, messages: list[Message]) -> list[str]:
async def _agentic_search(self, messages: list[Message]) -> list[Message]:
"""Perform agentic retrieval with multi-hop reasoning."""
await self._ensure_knowledge_base()
@@ -577,14 +589,7 @@ class AzureAISearchContextProvider(BaseContextProvider):
include_activity=True,
)
else:
kb_messages = [
KnowledgeBaseMessage(
role=msg.role if hasattr(msg.role, "value") else str(msg.role),
content=[KnowledgeBaseMessageTextContent(text=msg.text)],
)
for msg in messages
if msg.text
]
kb_messages = self._prepare_messages_for_kb_search(messages)
retrieval_request = KnowledgeBaseRetrievalRequest(
messages=kb_messages,
retrieval_reasoning_effort=reasoning_effort,
@@ -596,17 +601,136 @@ class AzureAISearchContextProvider(BaseContextProvider):
raise RuntimeError("Retrieval client not initialized.")
retrieval_result = await self._retrieval_client.retrieve(retrieval_request=retrieval_request)
if retrieval_result.response and len(retrieval_result.response) > 0:
assistant_message = retrieval_result.response[-1]
if assistant_message.content:
answer_parts: list[str] = []
for content_item in assistant_message.content:
if isinstance(content_item, KnowledgeBaseMessageTextContent) and content_item.text:
answer_parts.append(content_item.text)
if answer_parts:
return answer_parts
return self._parse_messages_from_kb_response(retrieval_result)
return ["No results found from Knowledge Base."]
@staticmethod
def _prepare_messages_for_kb_search(messages: list[Message]) -> list[KnowledgeBaseMessage]:
"""Convert framework Messages to KnowledgeBaseMessages for agentic retrieval.
Handles text and image content types. Other content types (function calls,
errors, etc.) are skipped.
Args:
messages: Framework messages to convert.
Returns:
List of KnowledgeBaseMessage objects suitable for retrieval requests.
"""
kb_messages: list[KnowledgeBaseMessage] = []
for msg in messages:
kb_content: list[KnowledgeBaseMessageTextContent | KnowledgeBaseMessageImageContent] = []
if msg.contents:
for content in msg.contents:
match content.type:
case "text" if content.text:
kb_content.append(KnowledgeBaseMessageTextContent(text=content.text))
case "uri" | "data" if (
content.uri and content.media_type and content.media_type.startswith("image/")
):
kb_content.append(
KnowledgeBaseMessageImageContent(
image=KnowledgeBaseMessageImageContentImage(url=content.uri),
)
)
elif msg.text:
kb_content.append(KnowledgeBaseMessageTextContent(text=msg.text))
if kb_content:
kb_messages.append(KnowledgeBaseMessage(role=msg.role, content=kb_content)) # type: ignore[arg-type]
return kb_messages
@staticmethod
def _parse_references_to_annotations(references: list[KnowledgeBaseReference] | None) -> list[Annotation]:
"""Convert Knowledge Base references to framework Annotations.
Captures all available fields from each reference subtype: URLs, doc keys,
reranker scores, source data, and the raw reference object itself.
Args:
references: The references from a Knowledge Base retrieval response.
Returns:
List of citation Annotations.
"""
if not references:
return []
annotations: list[Annotation] = []
for ref in references:
url: str | None = None
for attr in ("url", "blob_url", "doc_url", "web_url"):
url = getattr(ref, attr, None)
if url:
break
annotation = Annotation(
type="citation",
url=url or "",
title=getattr(ref, "title", None) or ref.id,
)
extra: dict[str, Any] = {
"reference_id": ref.id,
"reference_type": getattr(ref, "type", None),
"activity_source": ref.activity_source,
}
if ref.reranker_score is not None:
extra["reranker_score"] = ref.reranker_score
if ref.source_data:
extra["source_data"] = ref.source_data
doc_key = getattr(ref, "doc_key", None)
if doc_key:
extra["doc_key"] = doc_key
if ref.additional_properties:
extra["sdk_additional_properties"] = ref.additional_properties
sensitivity_info = getattr(ref, "search_sensitivity_label_info", None)
if sensitivity_info:
extra["sensitivity_label"] = {
"display_name": sensitivity_info.display_name,
"sensitivity_label_id": sensitivity_info.sensitivity_label_id,
"is_encrypted": sensitivity_info.is_encrypted,
}
annotation["additional_properties"] = extra
annotation["raw_representation"] = ref
annotations.append(annotation)
return annotations
@staticmethod
def _parse_messages_from_kb_response(retrieval_result: KnowledgeBaseRetrievalResponse) -> list[Message]:
"""Convert a Knowledge Base retrieval response to framework Messages.
Each KnowledgeBaseMessage becomes a Message. References from the response
are converted to Annotations and attached to content items.
Args:
retrieval_result: The full retrieval response including messages and references.
Returns:
List of Messages, or a single default Message if no results found.
"""
if not retrieval_result.response:
return [Message(role="assistant", text="No results found from Knowledge Base.")]
annotations = AzureAISearchContextProvider._parse_references_to_annotations(retrieval_result.references)
result_messages: list[Message] = []
for kb_msg in retrieval_result.response:
if not kb_msg.content:
continue
contents: list[Content] = []
for item in kb_msg.content:
if isinstance(item, KnowledgeBaseMessageTextContent) and item.text:
contents.append(Content.from_text(item.text))
elif isinstance(item, KnowledgeBaseMessageImageContent) and item.image and item.image.url:
contents.append(Content.from_uri(uri=item.image.url, media_type="image/png"))
if contents:
if annotations:
for c in contents:
c.annotations = annotations
result_messages.append(Message(role=kb_msg.role or "assistant", contents=contents))
if not result_messages:
return [Message(role="assistant", text="No results found from Knowledge Base.")]
return result_messages
def _extract_document_text(self, doc: dict[str, Any], doc_id: str | None = None) -> str:
"""Extract readable text from a search document with optional citation."""
@@ -6,7 +6,7 @@ from types import SimpleNamespace
from unittest.mock import AsyncMock, Mock, patch
import pytest
from agent_framework import Message
from agent_framework import Content, Message
from agent_framework._sessions import AgentSession, SessionContext
from agent_framework.exceptions import SettingNotFoundError
from azure.core.credentials import AzureKeyCredential
@@ -720,7 +720,7 @@ class TestSemanticSearch:
results = await provider._semantic_search("test query")
assert len(results) == 1
assert "result text" in results[0]
assert "result text" in results[0].text
call_kwargs = mock_client.search.call_args[1]
assert call_kwargs["search_text"] == "test query"
@@ -746,7 +746,11 @@ class TestSemanticSearch:
provider = _make_provider()
provider._use_vectorizable_query = False
provider.vector_field_name = "embedding"
provider.embedding_function = AsyncMock(return_value=[0.1, 0.2, 0.3])
async def _embed(query: str) -> list[float]:
return [0.1, 0.2, 0.3]
provider.embedding_function = _embed
mock_client = AsyncMock()
async def _search(**kwargs):
@@ -757,7 +761,6 @@ class TestSemanticSearch:
results = await provider._semantic_search("embed query")
assert len(results) == 1
provider.embedding_function.assert_awaited_once_with("embed query")
call_kwargs = mock_client.search.call_args[1]
assert "vector_queries" in call_kwargs
@@ -1100,9 +1103,11 @@ class TestAgenticSearch:
mock_content = Mock()
mock_content.text = "Answer text"
mock_message = Mock()
mock_message.role = "assistant"
mock_message.content = [mock_content]
mock_result = Mock()
mock_result.response = [mock_message]
mock_result.references = None
mock_retrieval = AsyncMock()
mock_retrieval.retrieve = AsyncMock(return_value=mock_result)
@@ -1115,7 +1120,9 @@ class TestAgenticSearch:
):
results = await provider._agentic_search([Message(role="user", contents=["test query"])])
assert results == ["Answer text"]
assert len(results) == 1
assert results[0].text == "Answer text"
assert results[0].role == "assistant"
async def test_non_minimal_reasoning_uses_messages(self) -> None:
provider = _make_provider()
@@ -1126,9 +1133,11 @@ class TestAgenticSearch:
mock_content = Mock()
mock_content.text = "Medium answer"
mock_message = Mock()
mock_message.role = "assistant"
mock_message.content = [mock_content]
mock_result = Mock()
mock_result.response = [mock_message]
mock_result.references = None
mock_retrieval = AsyncMock()
mock_retrieval.retrieve = AsyncMock(return_value=mock_result)
@@ -1143,7 +1152,8 @@ class TestAgenticSearch:
Message(role="assistant", contents=["answer"]),
])
assert results == ["Medium answer"]
assert len(results) == 1
assert results[0].text == "Medium answer"
mock_retrieval.retrieve.assert_awaited_once()
async def test_no_response_returns_default_message(self) -> None:
@@ -1154,13 +1164,15 @@ class TestAgenticSearch:
mock_result = Mock()
mock_result.response = []
mock_result.references = None
mock_retrieval = AsyncMock()
mock_retrieval.retrieve = AsyncMock(return_value=mock_result)
provider._retrieval_client = mock_retrieval
results = await provider._agentic_search([Message(role="user", contents=["query"])])
assert results == ["No results found from Knowledge Base."]
assert len(results) == 1
assert results[0].text == "No results found from Knowledge Base."
async def test_empty_content_returns_default_message(self) -> None:
provider = _make_provider()
@@ -1172,13 +1184,15 @@ class TestAgenticSearch:
mock_message.content = None
mock_result = Mock()
mock_result.response = [mock_message]
mock_result.references = None
mock_retrieval = AsyncMock()
mock_retrieval.retrieve = AsyncMock(return_value=mock_result)
provider._retrieval_client = mock_retrieval
results = await provider._agentic_search([Message(role="user", contents=["query"])])
assert results == ["No results found from Knowledge Base."]
assert len(results) == 1
assert results[0].text == "No results found from Knowledge Base."
async def test_answer_synthesis_output_mode(self) -> None:
provider = _make_provider()
@@ -1190,9 +1204,11 @@ class TestAgenticSearch:
mock_content = Mock()
mock_content.text = "Synthesized answer"
mock_message = Mock()
mock_message.role = "assistant"
mock_message.content = [mock_content]
mock_result = Mock()
mock_result.response = [mock_message]
mock_result.references = None
mock_retrieval = AsyncMock()
mock_retrieval.retrieve = AsyncMock(return_value=mock_result)
@@ -1204,7 +1220,8 @@ class TestAgenticSearch:
):
results = await provider._agentic_search([Message(role="user", contents=["query"])])
assert results == ["Synthesized answer"]
assert len(results) == 1
assert results[0].text == "Synthesized answer"
async def test_content_without_text_excluded(self) -> None:
provider = _make_provider()
@@ -1217,9 +1234,11 @@ class TestAgenticSearch:
mock_content_no_text = Mock()
mock_content_no_text.text = None
mock_message = Mock()
mock_message.role = "assistant"
mock_message.content = [mock_content_no_text, mock_content_with_text]
mock_result = Mock()
mock_result.response = [mock_message]
mock_result.references = None
mock_retrieval = AsyncMock()
mock_retrieval.retrieve = AsyncMock(return_value=mock_result)
@@ -1231,7 +1250,8 @@ class TestAgenticSearch:
):
results = await provider._agentic_search([Message(role="user", contents=["query"])])
assert results == ["Good content"]
assert len(results) == 1
assert results[0].text == "Good content"
async def test_none_response_returns_default_message(self) -> None:
provider = _make_provider()
@@ -1241,13 +1261,355 @@ class TestAgenticSearch:
mock_result = Mock()
mock_result.response = None
mock_result.references = None
mock_retrieval = AsyncMock()
mock_retrieval.retrieve = AsyncMock(return_value=mock_result)
provider._retrieval_client = mock_retrieval
results = await provider._agentic_search([Message(role="user", contents=["query"])])
assert results == ["No results found from Knowledge Base."]
assert len(results) == 1
assert results[0].text == "No results found from Knowledge Base."
# -- before_run: agentic mode --------------------------------------------------
# -- _prepare_messages_for_kb_search / _parse_content_from_kb_response --------
class TestPrepareMessagesForKbSearch:
"""Tests for _prepare_messages_for_kb_search."""
def test_text_only_messages(self) -> None:
messages = [
Message(role="user", contents=["hello"]),
Message(role="assistant", contents=["world"]),
]
result = AzureAISearchContextProvider._prepare_messages_for_kb_search(messages)
assert len(result) == 2
assert result[0].role == "user"
assert result[1].role == "assistant"
# Verify content is KnowledgeBaseMessageTextContent
from azure.search.documents.knowledgebases.models import KnowledgeBaseMessageTextContent
assert isinstance(result[0].content[0], KnowledgeBaseMessageTextContent)
assert result[0].content[0].text == "hello"
def test_image_uri_content(self) -> None:
from agent_framework import Content
img = Content.from_uri(uri="https://example.com/photo.png", media_type="image/png")
messages = [Message(role="user", contents=[img])]
result = AzureAISearchContextProvider._prepare_messages_for_kb_search(messages)
assert len(result) == 1
from azure.search.documents.knowledgebases.models import KnowledgeBaseMessageImageContent
assert isinstance(result[0].content[0], KnowledgeBaseMessageImageContent)
assert result[0].content[0].image.url == "https://example.com/photo.png"
def test_mixed_text_and_image_content(self) -> None:
from agent_framework import Content
text = Content.from_text("describe this image")
img = Content.from_uri(uri="https://example.com/img.jpg", media_type="image/jpeg")
messages = [Message(role="user", contents=[text, img])]
result = AzureAISearchContextProvider._prepare_messages_for_kb_search(messages)
assert len(result) == 1
assert len(result[0].content) == 2
def test_skips_non_text_non_image_content(self) -> None:
from agent_framework import Content
error = Content.from_error(message="oops")
messages = [Message(role="user", contents=[error])]
result = AzureAISearchContextProvider._prepare_messages_for_kb_search(messages)
assert len(result) == 0 # message had no usable content
def test_skips_empty_text(self) -> None:
from agent_framework import Content
empty = Content.from_text("")
messages = [Message(role="user", contents=[empty])]
result = AzureAISearchContextProvider._prepare_messages_for_kb_search(messages)
assert len(result) == 0
def test_fallback_to_msg_text_when_no_contents(self) -> None:
msg = Message(role="user", text="fallback text")
result = AzureAISearchContextProvider._prepare_messages_for_kb_search([msg])
assert len(result) == 1
assert result[0].content[0].text == "fallback text"
def test_data_uri_image(self) -> None:
from agent_framework import Content
img = Content.from_data(data=b"\x89PNG", media_type="image/png")
messages = [Message(role="user", contents=[img])]
result = AzureAISearchContextProvider._prepare_messages_for_kb_search(messages)
assert len(result) == 1
from azure.search.documents.knowledgebases.models import KnowledgeBaseMessageImageContent
assert isinstance(result[0].content[0], KnowledgeBaseMessageImageContent)
def test_non_image_uri_skipped(self) -> None:
from agent_framework import Content
pdf = Content.from_uri(uri="https://example.com/doc.pdf", media_type="application/pdf")
messages = [Message(role="user", contents=[pdf])]
result = AzureAISearchContextProvider._prepare_messages_for_kb_search(messages)
assert len(result) == 0
class TestParseReferencesToAnnotations:
"""Tests for _parse_references_to_annotations."""
def test_none_references(self) -> None:
result = AzureAISearchContextProvider._parse_references_to_annotations(None)
assert result == []
def test_empty_references(self) -> None:
result = AzureAISearchContextProvider._parse_references_to_annotations([])
assert result == []
def test_search_index_reference_captures_doc_key(self) -> None:
from azure.search.documents.knowledgebases.models import KnowledgeBaseSearchIndexReference
ref = KnowledgeBaseSearchIndexReference(id="ref-1", activity_source=0, doc_key="doc-1")
result = AzureAISearchContextProvider._parse_references_to_annotations([ref])
assert len(result) == 1
assert result[0]["type"] == "citation"
assert result[0]["title"] == "ref-1"
extra = result[0]["additional_properties"]
assert extra["reference_id"] == "ref-1"
assert extra["reference_type"] == "searchIndex"
assert extra["activity_source"] == 0
assert extra["doc_key"] == "doc-1"
def test_web_reference_with_url_and_title(self) -> None:
from azure.search.documents.knowledgebases.models import KnowledgeBaseWebReference
ref = KnowledgeBaseWebReference(
id="ref-2", activity_source=0, url="https://example.com/page", title="Example Page"
)
result = AzureAISearchContextProvider._parse_references_to_annotations([ref])
assert len(result) == 1
assert result[0]["url"] == "https://example.com/page"
assert result[0]["title"] == "Example Page"
assert result[0]["additional_properties"]["reference_type"] == "web"
def test_blob_reference_extracts_blob_url(self) -> None:
from azure.search.documents.knowledgebases.models import KnowledgeBaseAzureBlobReference
ref = KnowledgeBaseAzureBlobReference(
id="ref-3", activity_source=0, blob_url="https://storage.blob.core.windows.net/doc.pdf"
)
result = AzureAISearchContextProvider._parse_references_to_annotations([ref])
assert result[0]["url"] == "https://storage.blob.core.windows.net/doc.pdf"
assert result[0]["additional_properties"]["reference_type"] == "azureBlob"
def test_source_data_and_reranker_score(self) -> None:
from azure.search.documents.knowledgebases.models import KnowledgeBaseSearchIndexReference
ref = KnowledgeBaseSearchIndexReference(
id="ref-4", activity_source=0, source_data={"chunk": "some text"}, reranker_score=0.95
)
result = AzureAISearchContextProvider._parse_references_to_annotations([ref])
extra = result[0]["additional_properties"]
assert extra["source_data"] == {"chunk": "some text"}
assert extra["reranker_score"] == 0.95
def test_raw_representation_stores_original_ref(self) -> None:
from azure.search.documents.knowledgebases.models import KnowledgeBaseSearchIndexReference
ref = KnowledgeBaseSearchIndexReference(id="ref-5", activity_source=0)
result = AzureAISearchContextProvider._parse_references_to_annotations([ref])
assert result[0]["raw_representation"] is ref
def test_remote_sharepoint_captures_sensitivity_label(self) -> None:
from azure.search.documents.knowledgebases.models import (
KnowledgeBaseRemoteSharePointReference,
SharePointSensitivityLabelInfo,
)
label = SharePointSensitivityLabelInfo(
display_name="Confidential", sensitivity_label_id="lbl-1", is_encrypted=True
)
ref = KnowledgeBaseRemoteSharePointReference(
id="ref-6", activity_source=0, web_url="https://sp.example.com/doc", search_sensitivity_label_info=label
)
result = AzureAISearchContextProvider._parse_references_to_annotations([ref])
assert result[0]["url"] == "https://sp.example.com/doc"
sl = result[0]["additional_properties"]["sensitivity_label"]
assert sl["display_name"] == "Confidential"
assert sl["sensitivity_label_id"] == "lbl-1"
assert sl["is_encrypted"] is True
def test_multiple_references(self) -> None:
from azure.search.documents.knowledgebases.models import (
KnowledgeBaseSearchIndexReference,
KnowledgeBaseWebReference,
)
refs = [
KnowledgeBaseSearchIndexReference(id="ref-a", activity_source=0),
KnowledgeBaseWebReference(id="ref-b", activity_source=1, url="https://example.com"),
]
result = AzureAISearchContextProvider._parse_references_to_annotations(refs)
assert len(result) == 2
assert result[0]["additional_properties"]["activity_source"] == 0
assert result[1]["additional_properties"]["activity_source"] == 1
class TestParseMessagesFromKbResponse:
"""Tests for _parse_messages_from_kb_response."""
def test_converts_all_messages(self) -> None:
from azure.search.documents.knowledgebases.models import (
KnowledgeBaseMessage,
KnowledgeBaseMessageTextContent,
KnowledgeBaseRetrievalResponse,
)
response = KnowledgeBaseRetrievalResponse(
response=[
KnowledgeBaseMessage(role="user", content=[KnowledgeBaseMessageTextContent(text="q")]),
KnowledgeBaseMessage(role="assistant", content=[KnowledgeBaseMessageTextContent(text="answer")]),
],
references=None,
)
result = AzureAISearchContextProvider._parse_messages_from_kb_response(response)
assert len(result) == 2
assert result[0].role == "user"
assert result[0].text == "q"
assert result[1].role == "assistant"
assert result[1].text == "answer"
def test_none_response_returns_default(self) -> None:
from azure.search.documents.knowledgebases.models import KnowledgeBaseRetrievalResponse
response = KnowledgeBaseRetrievalResponse(response=None, references=None)
result = AzureAISearchContextProvider._parse_messages_from_kb_response(response)
assert len(result) == 1
assert result[0].text == "No results found from Knowledge Base."
def test_empty_response_returns_default(self) -> None:
from azure.search.documents.knowledgebases.models import KnowledgeBaseRetrievalResponse
response = KnowledgeBaseRetrievalResponse(response=[], references=None)
result = AzureAISearchContextProvider._parse_messages_from_kb_response(response)
assert len(result) == 1
assert result[0].text == "No results found from Knowledge Base."
def test_image_content(self) -> None:
from azure.search.documents.knowledgebases.models import (
KnowledgeBaseMessage,
KnowledgeBaseMessageImageContent,
KnowledgeBaseMessageImageContentImage,
KnowledgeBaseRetrievalResponse,
)
response = KnowledgeBaseRetrievalResponse(
response=[
KnowledgeBaseMessage(
role="assistant",
content=[
KnowledgeBaseMessageImageContent(
image=KnowledgeBaseMessageImageContentImage(url="https://img.example.com/a.png")
)
],
),
],
references=None,
)
result = AzureAISearchContextProvider._parse_messages_from_kb_response(response)
assert len(result) == 1
assert result[0].contents[0].type == "uri"
assert result[0].contents[0].uri == "https://img.example.com/a.png"
def test_mixed_text_and_image_content(self) -> None:
from azure.search.documents.knowledgebases.models import (
KnowledgeBaseMessage,
KnowledgeBaseMessageImageContent,
KnowledgeBaseMessageImageContentImage,
KnowledgeBaseMessageTextContent,
KnowledgeBaseRetrievalResponse,
)
response = KnowledgeBaseRetrievalResponse(
response=[
KnowledgeBaseMessage(
role="assistant",
content=[
KnowledgeBaseMessageTextContent(text="description"),
KnowledgeBaseMessageImageContent(
image=KnowledgeBaseMessageImageContentImage(url="https://img.example.com/b.png")
),
],
),
],
references=None,
)
result = AzureAISearchContextProvider._parse_messages_from_kb_response(response)
assert len(result) == 1
assert len(result[0].contents) == 2
assert result[0].contents[0].type == "text"
assert result[0].contents[1].type == "uri"
def test_references_become_annotations(self) -> None:
from azure.search.documents.knowledgebases.models import (
KnowledgeBaseMessage,
KnowledgeBaseMessageTextContent,
KnowledgeBaseRetrievalResponse,
KnowledgeBaseWebReference,
)
response = KnowledgeBaseRetrievalResponse(
response=[
KnowledgeBaseMessage(role="assistant", content=[KnowledgeBaseMessageTextContent(text="answer")]),
],
references=[
KnowledgeBaseWebReference(
id="ref-1", activity_source=0, url="https://example.com", title="Example"
),
],
)
result = AzureAISearchContextProvider._parse_messages_from_kb_response(response)
assert len(result) == 1
annotations = result[0].contents[0].annotations
assert annotations is not None
assert len(annotations) == 1
assert annotations[0]["type"] == "citation"
assert annotations[0]["url"] == "https://example.com"
assert annotations[0]["title"] == "Example"
def test_multiple_messages_with_references(self) -> None:
from azure.search.documents.knowledgebases.models import (
KnowledgeBaseMessage,
KnowledgeBaseMessageTextContent,
KnowledgeBaseRetrievalResponse,
KnowledgeBaseSearchIndexReference,
)
response = KnowledgeBaseRetrievalResponse(
response=[
KnowledgeBaseMessage(role="user", content=[KnowledgeBaseMessageTextContent(text="q")]),
KnowledgeBaseMessage(
role="assistant",
content=[
KnowledgeBaseMessageTextContent(text="part1"),
KnowledgeBaseMessageTextContent(text="part2"),
],
),
],
references=[KnowledgeBaseSearchIndexReference(id="doc-1", activity_source=0)],
)
result = AzureAISearchContextProvider._parse_messages_from_kb_response(response)
assert len(result) == 2
# All content items get annotations
for msg in result:
for c in msg.contents:
assert c.annotations is not None
assert len(c.annotations) == 1
# -- before_run: agentic mode --------------------------------------------------
@@ -1266,9 +1628,11 @@ class TestBeforeRunAgentic:
mock_content = Mock()
mock_content.text = "agentic result"
mock_message = Mock()
mock_message.role = "assistant"
mock_message.content = [mock_content]
mock_result = Mock()
mock_result.response = [mock_message]
mock_result.references = None
mock_retrieval = AsyncMock()
mock_retrieval.retrieve = AsyncMock(return_value=mock_result)
@@ -135,6 +135,9 @@ async def main() -> None:
async for chunk in agent.run(user_input, stream=True):
if chunk.text:
print(chunk.text, end="", flush=True)
for content in chunk.contents:
if content.annotations:
print(f"\n[Sources: {content.annotations}]", end="", flush=True)
print("\n")
@@ -4,7 +4,7 @@ import asyncio
import os
from agent_framework import Agent
from agent_framework.azure import AzureAIAgentClient, AzureAISearchContextProvider
from agent_framework.azure import AzureAIAgentClient, AzureAISearchContextProvider, AzureOpenAIEmbeddingClient
from azure.identity.aio import AzureCliCredential
from dotenv import load_dotenv
@@ -30,6 +30,8 @@ Prerequisites:
- AZURE_SEARCH_INDEX_NAME: Your search index name
- AZURE_AI_PROJECT_ENDPOINT: Your Azure AI Foundry project endpoint
- AZURE_AI_MODEL_DEPLOYMENT_NAME: Your model deployment name (e.g., "gpt-4o")
- AZURE_OPENAI_EMBEDDING_MODEL_ID: (Optional) Your embedding model for hybrid search (e.g., "text-embedding-3-small")
- AZURE_OPENAI_ENDPOINT: (Optional) Your Azure OpenAI resource URL, required if using an OpenAI embedding model for hybrid search
"""
# Sample queries to demonstrate RAG
@@ -43,12 +45,24 @@ USER_INPUTS = [
async def main() -> None:
"""Main function demonstrating Azure AI Search semantic mode."""
credential = AzureCliCredential()
# Get configuration from environment
search_endpoint = os.environ["AZURE_SEARCH_ENDPOINT"]
search_key = os.environ.get("AZURE_SEARCH_API_KEY")
index_name = os.environ["AZURE_SEARCH_INDEX_NAME"]
project_endpoint = os.environ["AZURE_AI_PROJECT_ENDPOINT"]
model_deployment = os.environ.get("AZURE_AI_MODEL_DEPLOYMENT_NAME", "gpt-4o")
openai_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
embedding_model = os.environ.get("AZURE_OPENAI_EMBEDDING_MODEL_ID", "text-embedding-3-small")
embedding_client = None
if openai_endpoint and embedding_model:
embedding_client = AzureOpenAIEmbeddingClient(
endpoint=openai_endpoint,
deployment_name=embedding_model,
credential=credential,
)
# Create Azure AI Search context provider with semantic mode (recommended, fast)
print("Using SEMANTIC mode (hybrid search + semantic ranking, fast)\n")
@@ -57,9 +71,13 @@ async def main() -> None:
endpoint=search_endpoint,
index_name=index_name,
api_key=search_key, # Use api_key for API key auth, or credential for managed identity
credential=AzureCliCredential() if not search_key else None,
credential=credential if not search_key else None,
mode="semantic", # Default mode
top_k=3, # Retrieve top 3 most relevant documents
embedding_function=embedding_client, # Provide embedding function for hybrid search
vector_field_name="DescriptionVector"
if embedding_client
else None, # Set vector field for hybrid search if using embeddings
)
# Create agent with search context provider
@@ -68,7 +86,7 @@ async def main() -> None:
AzureAIAgentClient(
project_endpoint=project_endpoint,
model_deployment_name=model_deployment,
credential=AzureCliCredential(),
credential=credential,
) as client,
Agent(
client=client,