Python: Enhance Azure AI Search Citations with Complete URL Information (#2066)

* add get_url to raw rep for absolute path url

* fixes

* add real url to citation annotation

* small fix

* project client + openapi fix

* openapi sample revert

* tool call list fix
This commit is contained in:
Giles Odigwe
2025-11-13 11:23:11 -08:00
committed by GitHub
Unverified
parent 6d890e46ed
commit 21dceca482
3 changed files with 253 additions and 10 deletions
@@ -1,7 +1,9 @@
# Copyright (c) Microsoft. All rights reserved.
import ast
import json
import os
import re
import sys
from collections.abc import AsyncIterable, MutableMapping, MutableSequence, Sequence
from typing import Any, ClassVar, TypeVar
@@ -427,7 +429,9 @@ class AzureAIAgentClient(BaseChatClient):
# and remove until here.
return thread_id
def _extract_url_citations(self, message_delta_chunk: MessageDeltaChunk) -> list[CitationAnnotation]:
def _extract_url_citations(
self, message_delta_chunk: MessageDeltaChunk, azure_search_tool_calls: list[dict[str, Any]]
) -> list[CitationAnnotation]:
"""Extract URL citations from MessageDeltaChunk."""
url_citations: list[CitationAnnotation] = []
@@ -446,10 +450,15 @@ class AzureAIAgentClient(BaseChatClient):
)
]
# Create CitationAnnotation from AzureAI annotation
# Extract real URL from Azure AI Search tool calls
real_url = self._get_real_url_from_citation_reference(
annotation.url_citation.url, azure_search_tool_calls
)
# Create CitationAnnotation with real URL
citation = CitationAnnotation(
title=getattr(annotation.url_citation, "title", None),
url=annotation.url_citation.url,
url=real_url,
snippet=None,
annotated_regions=annotated_regions,
raw_representation=annotation,
@@ -458,11 +467,54 @@ class AzureAIAgentClient(BaseChatClient):
return url_citations
def _get_real_url_from_citation_reference(
self, citation_url: str, azure_search_tool_calls: list[dict[str, Any]]
) -> str:
"""Extract real URL from Azure AI Search tool calls based on citation reference.
Args:
citation_url: Citation reference URL (e.g., "doc_0", "#doc_1", or full URL with doc_N)
azure_search_tool_calls: List of captured Azure AI Search tool calls
Returns:
Real document URL if found, otherwise original citation_url
"""
# Extract document index from citation URL (e.g., "doc_0" -> 0)
match = re.search(r"doc_(\d+)", citation_url)
if not match:
return citation_url
doc_index = int(match.group(1))
# Get Azure AI Search tool calls
if not azure_search_tool_calls:
return citation_url
try:
# Extract URLs from the most recent Azure AI Search tool call
tool_call = azure_search_tool_calls[-1] # Most recent call
output_str = tool_call["azure_ai_search"]["output"]
# Parse the tool call output to get URLs
output_data = ast.literal_eval(output_str)
all_urls = output_data["metadata"]["get_urls"]
# Return the URL at the specified index, if it exists
if 0 <= doc_index < len(all_urls):
return str(all_urls[doc_index])
except (KeyError, IndexError, TypeError, ValueError, SyntaxError) as ex:
logger.debug(f"Failed to extract real URL for {citation_url}: {ex}")
return citation_url
async def _process_stream(
self, stream: AsyncAgentRunStream[AsyncAgentEventHandler[Any]] | AsyncAgentEventHandler[Any], thread_id: str
) -> AsyncIterable[ChatResponseUpdate]:
"""Process events from the stream iterator and yield ChatResponseUpdate objects."""
response_id: str | None = None
# Track Azure Search tool calls for this stream only
azure_search_tool_calls: list[dict[str, Any]] = []
response_stream = await stream.__aenter__() if isinstance(stream, AsyncAgentRunStream) else stream # type: ignore[no-untyped-call]
try:
async for event_type, event_data, _ in response_stream: # type: ignore
@@ -472,7 +524,7 @@ class AzureAIAgentClient(BaseChatClient):
role = Role.USER if event_data.delta.role == MessageRole.USER else Role.ASSISTANT
# Extract URL citations from the delta chunk
url_citations = self._extract_url_citations(event_data)
url_citations = self._extract_url_citations(event_data, azure_search_tool_calls)
# Create contents with citations if any exist
citation_content: list[Contents] = []
@@ -545,6 +597,10 @@ class AzureAIAgentClient(BaseChatClient):
case AgentStreamEvent.THREAD_RUN_STEP_CREATED:
response_id = event_data.run_id
case AgentStreamEvent.THREAD_RUN_COMPLETED | AgentStreamEvent.THREAD_RUN_STEP_COMPLETED:
# Capture Azure AI Search tool calls when steps complete
if event_type == AgentStreamEvent.THREAD_RUN_STEP_COMPLETED:
self._capture_azure_search_tool_calls(event_data, azure_search_tool_calls)
if event_data.usage:
usage_content = UsageContent(
UsageDetails(
@@ -623,6 +679,29 @@ class AzureAIAgentClient(BaseChatClient):
if isinstance(stream, AsyncAgentRunStream):
await stream.__aexit__(None, None, None) # type: ignore[no-untyped-call]
def _capture_azure_search_tool_calls(
self, step_data: RunStep, azure_search_tool_calls: list[dict[str, Any]]
) -> None:
"""Capture Azure AI Search tool call data from completed steps."""
try:
if (
hasattr(step_data, "step_details")
and hasattr(step_data.step_details, "tool_calls")
and step_data.step_details.tool_calls
):
for tool_call in step_data.step_details.tool_calls:
if hasattr(tool_call, "type") and tool_call.type == "azure_ai_search":
# Store the complete tool call as a dictionary
tool_call_dict = {
"id": getattr(tool_call, "id", None),
"type": tool_call.type,
"azure_ai_search": getattr(tool_call, "azure_ai_search", None),
}
azure_search_tool_calls.append(tool_call_dict)
logger.debug(f"Captured Azure AI Search tool call: {tool_call_dict['id']}")
except Exception as ex:
logger.debug(f"Failed to capture Azure AI Search tool call: {ex}")
def _create_function_call_contents(self, event_data: ThreadRun, response_id: str | None) -> list[Contents]:
"""Create function call contents from a tool action event."""
if isinstance(event_data, ThreadRun) and event_data.required_action is not None:
@@ -3,7 +3,7 @@
import json
import os
from pathlib import Path
from typing import Annotated
from typing import Annotated, Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@@ -92,6 +92,7 @@ def create_test_azure_ai_chat_client(
client._agent_created = False
client._should_close_client = False
client._agent_definition = None
client._azure_search_tool_calls = [] # Add the new instance variable
client.additional_properties = {}
client.middleware = None
@@ -1335,8 +1336,8 @@ def test_azure_ai_chat_client_extract_url_citations_with_citations(mock_agents_c
mock_chunk = MagicMock(spec=MessageDeltaChunk)
mock_chunk.delta = mock_delta
# Call the method
citations = chat_client._extract_url_citations(mock_chunk) # type: ignore
# Call the method with empty azure_search_tool_calls
citations = chat_client._extract_url_citations(mock_chunk, []) # type: ignore
# Verify results
assert len(citations) == 1
@@ -1804,3 +1805,166 @@ async def test_azure_ai_chat_client_no_cleanup_when_agent_not_created_by_client(
# Verify agent was NOT deleted
mock_agents_client.delete_agent.assert_not_called()
assert chat_client.agent_id == "existing-agent-id"
def test_azure_ai_chat_client_capture_azure_search_tool_calls(mock_agents_client: MagicMock) -> None:
"""Test _capture_azure_search_tool_calls method."""
chat_client = create_test_azure_ai_chat_client(mock_agents_client)
# Mock Azure AI Search tool call
mock_tool_call = MagicMock()
mock_tool_call.type = "azure_ai_search"
mock_tool_call.id = "call_123"
mock_tool_call.azure_ai_search = {"input": "test query", "output": "test output"}
# Mock step data
mock_step_data = MagicMock()
mock_step_data.step_details.tool_calls = [mock_tool_call]
# Call the method with a list to capture tool calls
azure_search_tool_calls: list[dict[str, Any]] = []
chat_client._capture_azure_search_tool_calls(mock_step_data, azure_search_tool_calls) # type: ignore
# Verify tool call was captured
assert len(azure_search_tool_calls) == 1
captured_tool_call = azure_search_tool_calls[0]
assert captured_tool_call["type"] == "azure_ai_search"
assert captured_tool_call["id"] == "call_123"
assert captured_tool_call["azure_ai_search"] == {"input": "test query", "output": "test output"}
def test_azure_ai_chat_client_get_real_url_from_citation_reference_no_tool_calls(
mock_agents_client: MagicMock,
) -> None:
"""Test _get_real_url_from_citation_reference with no tool calls."""
chat_client = create_test_azure_ai_chat_client(mock_agents_client)
# No tool calls - pass empty list
result = chat_client._get_real_url_from_citation_reference("doc_1", []) # type: ignore
assert result == "doc_1"
def test_azure_ai_chat_client_get_real_url_from_citation_reference_invalid_output(
mock_agents_client: MagicMock,
) -> None:
"""Test _get_real_url_from_citation_reference with invalid output format."""
chat_client = create_test_azure_ai_chat_client(mock_agents_client)
# Tool call with invalid output format
azure_search_tool_calls = [
{"id": "call_123", "type": "azure_ai_search", "azure_ai_search": {"output": "invalid_json_format"}}
]
result = chat_client._get_real_url_from_citation_reference("doc_1", azure_search_tool_calls) # type: ignore
assert result == "doc_1"
async def test_azure_ai_chat_client_context_manager(mock_agents_client: MagicMock) -> None:
"""Test AzureAIAgentClient as async context manager."""
chat_client = create_test_azure_ai_chat_client(mock_agents_client)
# Mock close method to avoid actual cleanup
chat_client.close = AsyncMock()
async with chat_client as client:
assert client is chat_client
# Verify close was called on exit
chat_client.close.assert_called_once()
async def test_azure_ai_chat_client_close_method(mock_agents_client: MagicMock) -> None:
"""Test AzureAIAgentClient close method."""
chat_client = create_test_azure_ai_chat_client(mock_agents_client)
# Mock cleanup methods
chat_client._cleanup_agent_if_needed = AsyncMock()
chat_client._close_client_if_needed = AsyncMock()
await chat_client.close()
# Verify cleanup methods were called
chat_client._cleanup_agent_if_needed.assert_called_once()
chat_client._close_client_if_needed.assert_called_once()
def test_azure_ai_chat_client_extract_url_citations_with_azure_search_enhanced_url(
mock_agents_client: MagicMock,
) -> None:
"""Test _extract_url_citations with Azure AI Search URL enhancement."""
chat_client = create_test_azure_ai_chat_client(mock_agents_client)
# Add Azure Search tool calls for URL enhancement
azure_search_tool_calls = [
{
"id": "call_123",
"type": "azure_ai_search",
"azure_ai_search": {
"output": str({
"metadata": {"get_urls": ["https://real-example.com/doc1", "https://real-example.com/doc2"]}
})
},
}
]
# Create mock URL citation with doc reference
mock_url_citation = MagicMock()
mock_url_citation.url = "doc_1"
mock_url_citation.title = "Test Title"
mock_annotation = MagicMock(spec=MessageDeltaTextUrlCitationAnnotation)
mock_annotation.url_citation = mock_url_citation
mock_annotation.start_index = 10
mock_annotation.end_index = 20
mock_text = MagicMock()
mock_text.annotations = [mock_annotation]
mock_text_content = MagicMock(spec=MessageDeltaTextContent)
mock_text_content.text = mock_text
mock_delta = MagicMock()
mock_delta.content = [mock_text_content]
mock_chunk = MagicMock(spec=MessageDeltaChunk)
mock_chunk.delta = mock_delta
citations = chat_client._extract_url_citations(mock_chunk, azure_search_tool_calls) # type: ignore
# Verify real URL was used
assert len(citations) == 1
citation = citations[0]
assert citation.url == "https://real-example.com/doc2" # doc_1 maps to index 1
def test_azure_ai_chat_client_init_with_auto_created_agents_client(
azure_ai_unit_test_env: dict[str, str], mock_azure_credential: MagicMock
) -> None:
"""Test AzureAIAgentClient initialization when it creates its own AgentsClient."""
# Mock the AgentsClient constructor
with patch("agent_framework_azure_ai._chat_client.AgentsClient") as mock_agents_client_class:
mock_agents_client_instance = MagicMock()
mock_agents_client_class.return_value = mock_agents_client_instance
# Create client without providing agents_client - should create its own
client = AzureAIAgentClient(
agents_client=None, # This will trigger creation of AgentsClient
agent_id="test-agent",
project_endpoint=azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"],
model_deployment_name=azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"],
async_credential=mock_azure_credential,
)
# Verify AgentsClient was created with correct parameters
mock_agents_client_class.assert_called_once_with(
endpoint=azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"],
credential=mock_azure_credential,
user_agent="agent-framework-python/0.0.0",
)
# Verify client properties are set correctly
assert client.agents_client is mock_agents_client_instance
assert client.agent_id == "test-agent"
assert client.credential is mock_azure_credential
assert client._should_close_client is True # Should close since we created it # type: ignore[attr-defined]
@@ -103,11 +103,11 @@ async def main() -> None:
print()
# Display collected citations
# Display collected citation
if citations:
print("\n\nCitations:")
print("\n\nCitation:")
for i, citation in enumerate(citations, 1):
print(f"[{i}] Reference: {citation.url}")
print(f"[{i}] {citation.url}")
print("\n" + "=" * 50 + "\n")
print("Hotel search conversation completed!")