mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Enhance Azure AI Search Citations with Complete URL Information (#2066)
* add get_url to raw rep for absolute path url * fixes * add real url to citation annotation * small fix * project client + openapi fix * openapi sample revert * tool call list fix
This commit is contained in:
committed by
GitHub
Unverified
parent
6d890e46ed
commit
21dceca482
@@ -1,7 +1,9 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import ast
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from collections.abc import AsyncIterable, MutableMapping, MutableSequence, Sequence
|
||||
from typing import Any, ClassVar, TypeVar
|
||||
@@ -427,7 +429,9 @@ class AzureAIAgentClient(BaseChatClient):
|
||||
# and remove until here.
|
||||
return thread_id
|
||||
|
||||
def _extract_url_citations(self, message_delta_chunk: MessageDeltaChunk) -> list[CitationAnnotation]:
|
||||
def _extract_url_citations(
|
||||
self, message_delta_chunk: MessageDeltaChunk, azure_search_tool_calls: list[dict[str, Any]]
|
||||
) -> list[CitationAnnotation]:
|
||||
"""Extract URL citations from MessageDeltaChunk."""
|
||||
url_citations: list[CitationAnnotation] = []
|
||||
|
||||
@@ -446,10 +450,15 @@ class AzureAIAgentClient(BaseChatClient):
|
||||
)
|
||||
]
|
||||
|
||||
# Create CitationAnnotation from AzureAI annotation
|
||||
# Extract real URL from Azure AI Search tool calls
|
||||
real_url = self._get_real_url_from_citation_reference(
|
||||
annotation.url_citation.url, azure_search_tool_calls
|
||||
)
|
||||
|
||||
# Create CitationAnnotation with real URL
|
||||
citation = CitationAnnotation(
|
||||
title=getattr(annotation.url_citation, "title", None),
|
||||
url=annotation.url_citation.url,
|
||||
url=real_url,
|
||||
snippet=None,
|
||||
annotated_regions=annotated_regions,
|
||||
raw_representation=annotation,
|
||||
@@ -458,11 +467,54 @@ class AzureAIAgentClient(BaseChatClient):
|
||||
|
||||
return url_citations
|
||||
|
||||
def _get_real_url_from_citation_reference(
|
||||
self, citation_url: str, azure_search_tool_calls: list[dict[str, Any]]
|
||||
) -> str:
|
||||
"""Extract real URL from Azure AI Search tool calls based on citation reference.
|
||||
|
||||
Args:
|
||||
citation_url: Citation reference URL (e.g., "doc_0", "#doc_1", or full URL with doc_N)
|
||||
azure_search_tool_calls: List of captured Azure AI Search tool calls
|
||||
|
||||
Returns:
|
||||
Real document URL if found, otherwise original citation_url
|
||||
"""
|
||||
# Extract document index from citation URL (e.g., "doc_0" -> 0)
|
||||
match = re.search(r"doc_(\d+)", citation_url)
|
||||
if not match:
|
||||
return citation_url
|
||||
|
||||
doc_index = int(match.group(1))
|
||||
|
||||
# Get Azure AI Search tool calls
|
||||
if not azure_search_tool_calls:
|
||||
return citation_url
|
||||
|
||||
try:
|
||||
# Extract URLs from the most recent Azure AI Search tool call
|
||||
tool_call = azure_search_tool_calls[-1] # Most recent call
|
||||
output_str = tool_call["azure_ai_search"]["output"]
|
||||
|
||||
# Parse the tool call output to get URLs
|
||||
output_data = ast.literal_eval(output_str)
|
||||
all_urls = output_data["metadata"]["get_urls"]
|
||||
|
||||
# Return the URL at the specified index, if it exists
|
||||
if 0 <= doc_index < len(all_urls):
|
||||
return str(all_urls[doc_index])
|
||||
|
||||
except (KeyError, IndexError, TypeError, ValueError, SyntaxError) as ex:
|
||||
logger.debug(f"Failed to extract real URL for {citation_url}: {ex}")
|
||||
|
||||
return citation_url
|
||||
|
||||
async def _process_stream(
|
||||
self, stream: AsyncAgentRunStream[AsyncAgentEventHandler[Any]] | AsyncAgentEventHandler[Any], thread_id: str
|
||||
) -> AsyncIterable[ChatResponseUpdate]:
|
||||
"""Process events from the stream iterator and yield ChatResponseUpdate objects."""
|
||||
response_id: str | None = None
|
||||
# Track Azure Search tool calls for this stream only
|
||||
azure_search_tool_calls: list[dict[str, Any]] = []
|
||||
response_stream = await stream.__aenter__() if isinstance(stream, AsyncAgentRunStream) else stream # type: ignore[no-untyped-call]
|
||||
try:
|
||||
async for event_type, event_data, _ in response_stream: # type: ignore
|
||||
@@ -472,7 +524,7 @@ class AzureAIAgentClient(BaseChatClient):
|
||||
role = Role.USER if event_data.delta.role == MessageRole.USER else Role.ASSISTANT
|
||||
|
||||
# Extract URL citations from the delta chunk
|
||||
url_citations = self._extract_url_citations(event_data)
|
||||
url_citations = self._extract_url_citations(event_data, azure_search_tool_calls)
|
||||
|
||||
# Create contents with citations if any exist
|
||||
citation_content: list[Contents] = []
|
||||
@@ -545,6 +597,10 @@ class AzureAIAgentClient(BaseChatClient):
|
||||
case AgentStreamEvent.THREAD_RUN_STEP_CREATED:
|
||||
response_id = event_data.run_id
|
||||
case AgentStreamEvent.THREAD_RUN_COMPLETED | AgentStreamEvent.THREAD_RUN_STEP_COMPLETED:
|
||||
# Capture Azure AI Search tool calls when steps complete
|
||||
if event_type == AgentStreamEvent.THREAD_RUN_STEP_COMPLETED:
|
||||
self._capture_azure_search_tool_calls(event_data, azure_search_tool_calls)
|
||||
|
||||
if event_data.usage:
|
||||
usage_content = UsageContent(
|
||||
UsageDetails(
|
||||
@@ -623,6 +679,29 @@ class AzureAIAgentClient(BaseChatClient):
|
||||
if isinstance(stream, AsyncAgentRunStream):
|
||||
await stream.__aexit__(None, None, None) # type: ignore[no-untyped-call]
|
||||
|
||||
def _capture_azure_search_tool_calls(
|
||||
self, step_data: RunStep, azure_search_tool_calls: list[dict[str, Any]]
|
||||
) -> None:
|
||||
"""Capture Azure AI Search tool call data from completed steps."""
|
||||
try:
|
||||
if (
|
||||
hasattr(step_data, "step_details")
|
||||
and hasattr(step_data.step_details, "tool_calls")
|
||||
and step_data.step_details.tool_calls
|
||||
):
|
||||
for tool_call in step_data.step_details.tool_calls:
|
||||
if hasattr(tool_call, "type") and tool_call.type == "azure_ai_search":
|
||||
# Store the complete tool call as a dictionary
|
||||
tool_call_dict = {
|
||||
"id": getattr(tool_call, "id", None),
|
||||
"type": tool_call.type,
|
||||
"azure_ai_search": getattr(tool_call, "azure_ai_search", None),
|
||||
}
|
||||
azure_search_tool_calls.append(tool_call_dict)
|
||||
logger.debug(f"Captured Azure AI Search tool call: {tool_call_dict['id']}")
|
||||
except Exception as ex:
|
||||
logger.debug(f"Failed to capture Azure AI Search tool call: {ex}")
|
||||
|
||||
def _create_function_call_contents(self, event_data: ThreadRun, response_id: str | None) -> list[Contents]:
|
||||
"""Create function call contents from a tool action event."""
|
||||
if isinstance(event_data, ThreadRun) and event_data.required_action is not None:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
from typing import Annotated, Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -92,6 +92,7 @@ def create_test_azure_ai_chat_client(
|
||||
client._agent_created = False
|
||||
client._should_close_client = False
|
||||
client._agent_definition = None
|
||||
client._azure_search_tool_calls = [] # Add the new instance variable
|
||||
client.additional_properties = {}
|
||||
client.middleware = None
|
||||
|
||||
@@ -1335,8 +1336,8 @@ def test_azure_ai_chat_client_extract_url_citations_with_citations(mock_agents_c
|
||||
mock_chunk = MagicMock(spec=MessageDeltaChunk)
|
||||
mock_chunk.delta = mock_delta
|
||||
|
||||
# Call the method
|
||||
citations = chat_client._extract_url_citations(mock_chunk) # type: ignore
|
||||
# Call the method with empty azure_search_tool_calls
|
||||
citations = chat_client._extract_url_citations(mock_chunk, []) # type: ignore
|
||||
|
||||
# Verify results
|
||||
assert len(citations) == 1
|
||||
@@ -1804,3 +1805,166 @@ async def test_azure_ai_chat_client_no_cleanup_when_agent_not_created_by_client(
|
||||
# Verify agent was NOT deleted
|
||||
mock_agents_client.delete_agent.assert_not_called()
|
||||
assert chat_client.agent_id == "existing-agent-id"
|
||||
|
||||
|
||||
def test_azure_ai_chat_client_capture_azure_search_tool_calls(mock_agents_client: MagicMock) -> None:
|
||||
"""Test _capture_azure_search_tool_calls method."""
|
||||
chat_client = create_test_azure_ai_chat_client(mock_agents_client)
|
||||
|
||||
# Mock Azure AI Search tool call
|
||||
mock_tool_call = MagicMock()
|
||||
mock_tool_call.type = "azure_ai_search"
|
||||
mock_tool_call.id = "call_123"
|
||||
mock_tool_call.azure_ai_search = {"input": "test query", "output": "test output"}
|
||||
|
||||
# Mock step data
|
||||
mock_step_data = MagicMock()
|
||||
mock_step_data.step_details.tool_calls = [mock_tool_call]
|
||||
|
||||
# Call the method with a list to capture tool calls
|
||||
azure_search_tool_calls: list[dict[str, Any]] = []
|
||||
chat_client._capture_azure_search_tool_calls(mock_step_data, azure_search_tool_calls) # type: ignore
|
||||
|
||||
# Verify tool call was captured
|
||||
assert len(azure_search_tool_calls) == 1
|
||||
captured_tool_call = azure_search_tool_calls[0]
|
||||
assert captured_tool_call["type"] == "azure_ai_search"
|
||||
assert captured_tool_call["id"] == "call_123"
|
||||
assert captured_tool_call["azure_ai_search"] == {"input": "test query", "output": "test output"}
|
||||
|
||||
|
||||
def test_azure_ai_chat_client_get_real_url_from_citation_reference_no_tool_calls(
|
||||
mock_agents_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test _get_real_url_from_citation_reference with no tool calls."""
|
||||
chat_client = create_test_azure_ai_chat_client(mock_agents_client)
|
||||
|
||||
# No tool calls - pass empty list
|
||||
result = chat_client._get_real_url_from_citation_reference("doc_1", []) # type: ignore
|
||||
assert result == "doc_1"
|
||||
|
||||
|
||||
def test_azure_ai_chat_client_get_real_url_from_citation_reference_invalid_output(
|
||||
mock_agents_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test _get_real_url_from_citation_reference with invalid output format."""
|
||||
chat_client = create_test_azure_ai_chat_client(mock_agents_client)
|
||||
|
||||
# Tool call with invalid output format
|
||||
azure_search_tool_calls = [
|
||||
{"id": "call_123", "type": "azure_ai_search", "azure_ai_search": {"output": "invalid_json_format"}}
|
||||
]
|
||||
|
||||
result = chat_client._get_real_url_from_citation_reference("doc_1", azure_search_tool_calls) # type: ignore
|
||||
assert result == "doc_1"
|
||||
|
||||
|
||||
async def test_azure_ai_chat_client_context_manager(mock_agents_client: MagicMock) -> None:
|
||||
"""Test AzureAIAgentClient as async context manager."""
|
||||
chat_client = create_test_azure_ai_chat_client(mock_agents_client)
|
||||
|
||||
# Mock close method to avoid actual cleanup
|
||||
chat_client.close = AsyncMock()
|
||||
|
||||
async with chat_client as client:
|
||||
assert client is chat_client
|
||||
|
||||
# Verify close was called on exit
|
||||
chat_client.close.assert_called_once()
|
||||
|
||||
|
||||
async def test_azure_ai_chat_client_close_method(mock_agents_client: MagicMock) -> None:
|
||||
"""Test AzureAIAgentClient close method."""
|
||||
chat_client = create_test_azure_ai_chat_client(mock_agents_client)
|
||||
|
||||
# Mock cleanup methods
|
||||
chat_client._cleanup_agent_if_needed = AsyncMock()
|
||||
chat_client._close_client_if_needed = AsyncMock()
|
||||
|
||||
await chat_client.close()
|
||||
|
||||
# Verify cleanup methods were called
|
||||
chat_client._cleanup_agent_if_needed.assert_called_once()
|
||||
chat_client._close_client_if_needed.assert_called_once()
|
||||
|
||||
|
||||
def test_azure_ai_chat_client_extract_url_citations_with_azure_search_enhanced_url(
|
||||
mock_agents_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test _extract_url_citations with Azure AI Search URL enhancement."""
|
||||
chat_client = create_test_azure_ai_chat_client(mock_agents_client)
|
||||
|
||||
# Add Azure Search tool calls for URL enhancement
|
||||
azure_search_tool_calls = [
|
||||
{
|
||||
"id": "call_123",
|
||||
"type": "azure_ai_search",
|
||||
"azure_ai_search": {
|
||||
"output": str({
|
||||
"metadata": {"get_urls": ["https://real-example.com/doc1", "https://real-example.com/doc2"]}
|
||||
})
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
# Create mock URL citation with doc reference
|
||||
mock_url_citation = MagicMock()
|
||||
mock_url_citation.url = "doc_1"
|
||||
mock_url_citation.title = "Test Title"
|
||||
|
||||
mock_annotation = MagicMock(spec=MessageDeltaTextUrlCitationAnnotation)
|
||||
mock_annotation.url_citation = mock_url_citation
|
||||
mock_annotation.start_index = 10
|
||||
mock_annotation.end_index = 20
|
||||
|
||||
mock_text = MagicMock()
|
||||
mock_text.annotations = [mock_annotation]
|
||||
|
||||
mock_text_content = MagicMock(spec=MessageDeltaTextContent)
|
||||
mock_text_content.text = mock_text
|
||||
|
||||
mock_delta = MagicMock()
|
||||
mock_delta.content = [mock_text_content]
|
||||
|
||||
mock_chunk = MagicMock(spec=MessageDeltaChunk)
|
||||
mock_chunk.delta = mock_delta
|
||||
|
||||
citations = chat_client._extract_url_citations(mock_chunk, azure_search_tool_calls) # type: ignore
|
||||
|
||||
# Verify real URL was used
|
||||
assert len(citations) == 1
|
||||
citation = citations[0]
|
||||
assert citation.url == "https://real-example.com/doc2" # doc_1 maps to index 1
|
||||
|
||||
|
||||
def test_azure_ai_chat_client_init_with_auto_created_agents_client(
|
||||
azure_ai_unit_test_env: dict[str, str], mock_azure_credential: MagicMock
|
||||
) -> None:
|
||||
"""Test AzureAIAgentClient initialization when it creates its own AgentsClient."""
|
||||
|
||||
# Mock the AgentsClient constructor
|
||||
with patch("agent_framework_azure_ai._chat_client.AgentsClient") as mock_agents_client_class:
|
||||
mock_agents_client_instance = MagicMock()
|
||||
mock_agents_client_class.return_value = mock_agents_client_instance
|
||||
|
||||
# Create client without providing agents_client - should create its own
|
||||
client = AzureAIAgentClient(
|
||||
agents_client=None, # This will trigger creation of AgentsClient
|
||||
agent_id="test-agent",
|
||||
project_endpoint=azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"],
|
||||
model_deployment_name=azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"],
|
||||
async_credential=mock_azure_credential,
|
||||
)
|
||||
|
||||
# Verify AgentsClient was created with correct parameters
|
||||
mock_agents_client_class.assert_called_once_with(
|
||||
endpoint=azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"],
|
||||
credential=mock_azure_credential,
|
||||
user_agent="agent-framework-python/0.0.0",
|
||||
)
|
||||
|
||||
# Verify client properties are set correctly
|
||||
assert client.agents_client is mock_agents_client_instance
|
||||
assert client.agent_id == "test-agent"
|
||||
assert client.credential is mock_azure_credential
|
||||
assert client._should_close_client is True # Should close since we created it # type: ignore[attr-defined]
|
||||
|
||||
+3
-3
@@ -103,11 +103,11 @@ async def main() -> None:
|
||||
|
||||
print()
|
||||
|
||||
# Display collected citations
|
||||
# Display collected citation
|
||||
if citations:
|
||||
print("\n\nCitations:")
|
||||
print("\n\nCitation:")
|
||||
for i, citation in enumerate(citations, 1):
|
||||
print(f"[{i}] Reference: {citation.url}")
|
||||
print(f"[{i}] {citation.url}")
|
||||
|
||||
print("\n" + "=" * 50 + "\n")
|
||||
print("Hotel search conversation completed!")
|
||||
|
||||
Reference in New Issue
Block a user