diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py index f16560e517..d85dc95111 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py @@ -1,7 +1,9 @@ # Copyright (c) Microsoft. All rights reserved. +import ast import json import os +import re import sys from collections.abc import AsyncIterable, MutableMapping, MutableSequence, Sequence from typing import Any, ClassVar, TypeVar @@ -427,7 +429,9 @@ class AzureAIAgentClient(BaseChatClient): # and remove until here. return thread_id - def _extract_url_citations(self, message_delta_chunk: MessageDeltaChunk) -> list[CitationAnnotation]: + def _extract_url_citations( + self, message_delta_chunk: MessageDeltaChunk, azure_search_tool_calls: list[dict[str, Any]] + ) -> list[CitationAnnotation]: """Extract URL citations from MessageDeltaChunk.""" url_citations: list[CitationAnnotation] = [] @@ -446,10 +450,15 @@ class AzureAIAgentClient(BaseChatClient): ) ] - # Create CitationAnnotation from AzureAI annotation + # Extract real URL from Azure AI Search tool calls + real_url = self._get_real_url_from_citation_reference( + annotation.url_citation.url, azure_search_tool_calls + ) + + # Create CitationAnnotation with real URL citation = CitationAnnotation( title=getattr(annotation.url_citation, "title", None), - url=annotation.url_citation.url, + url=real_url, snippet=None, annotated_regions=annotated_regions, raw_representation=annotation, @@ -458,11 +467,54 @@ class AzureAIAgentClient(BaseChatClient): return url_citations + def _get_real_url_from_citation_reference( + self, citation_url: str, azure_search_tool_calls: list[dict[str, Any]] + ) -> str: + """Extract real URL from Azure AI Search tool calls based on citation reference. + + Args: + citation_url: Citation reference URL (e.g., "doc_0", "#doc_1", or full URL with doc_N) + azure_search_tool_calls: List of captured Azure AI Search tool calls + + Returns: + Real document URL if found, otherwise original citation_url + """ + # Extract document index from citation URL (e.g., "doc_0" -> 0) + match = re.search(r"doc_(\d+)", citation_url) + if not match: + return citation_url + + doc_index = int(match.group(1)) + + # Get Azure AI Search tool calls + if not azure_search_tool_calls: + return citation_url + + try: + # Extract URLs from the most recent Azure AI Search tool call + tool_call = azure_search_tool_calls[-1] # Most recent call + output_str = tool_call["azure_ai_search"]["output"] + + # Parse the tool call output to get URLs + output_data = ast.literal_eval(output_str) + all_urls = output_data["metadata"]["get_urls"] + + # Return the URL at the specified index, if it exists + if 0 <= doc_index < len(all_urls): + return str(all_urls[doc_index]) + + except (KeyError, IndexError, TypeError, ValueError, SyntaxError) as ex: + logger.debug(f"Failed to extract real URL for {citation_url}: {ex}") + + return citation_url + async def _process_stream( self, stream: AsyncAgentRunStream[AsyncAgentEventHandler[Any]] | AsyncAgentEventHandler[Any], thread_id: str ) -> AsyncIterable[ChatResponseUpdate]: """Process events from the stream iterator and yield ChatResponseUpdate objects.""" response_id: str | None = None + # Track Azure Search tool calls for this stream only + azure_search_tool_calls: list[dict[str, Any]] = [] response_stream = await stream.__aenter__() if isinstance(stream, AsyncAgentRunStream) else stream # type: ignore[no-untyped-call] try: async for event_type, event_data, _ in response_stream: # type: ignore @@ -472,7 +524,7 @@ class AzureAIAgentClient(BaseChatClient): role = Role.USER if event_data.delta.role == MessageRole.USER else Role.ASSISTANT # Extract URL citations from the delta chunk - url_citations = self._extract_url_citations(event_data) + url_citations = self._extract_url_citations(event_data, azure_search_tool_calls) # Create contents with citations if any exist citation_content: list[Contents] = [] @@ -545,6 +597,10 @@ class AzureAIAgentClient(BaseChatClient): case AgentStreamEvent.THREAD_RUN_STEP_CREATED: response_id = event_data.run_id case AgentStreamEvent.THREAD_RUN_COMPLETED | AgentStreamEvent.THREAD_RUN_STEP_COMPLETED: + # Capture Azure AI Search tool calls when steps complete + if event_type == AgentStreamEvent.THREAD_RUN_STEP_COMPLETED: + self._capture_azure_search_tool_calls(event_data, azure_search_tool_calls) + if event_data.usage: usage_content = UsageContent( UsageDetails( @@ -623,6 +679,29 @@ class AzureAIAgentClient(BaseChatClient): if isinstance(stream, AsyncAgentRunStream): await stream.__aexit__(None, None, None) # type: ignore[no-untyped-call] + def _capture_azure_search_tool_calls( + self, step_data: RunStep, azure_search_tool_calls: list[dict[str, Any]] + ) -> None: + """Capture Azure AI Search tool call data from completed steps.""" + try: + if ( + hasattr(step_data, "step_details") + and hasattr(step_data.step_details, "tool_calls") + and step_data.step_details.tool_calls + ): + for tool_call in step_data.step_details.tool_calls: + if hasattr(tool_call, "type") and tool_call.type == "azure_ai_search": + # Store the complete tool call as a dictionary + tool_call_dict = { + "id": getattr(tool_call, "id", None), + "type": tool_call.type, + "azure_ai_search": getattr(tool_call, "azure_ai_search", None), + } + azure_search_tool_calls.append(tool_call_dict) + logger.debug(f"Captured Azure AI Search tool call: {tool_call_dict['id']}") + except Exception as ex: + logger.debug(f"Failed to capture Azure AI Search tool call: {ex}") + def _create_function_call_contents(self, event_data: ThreadRun, response_id: str | None) -> list[Contents]: """Create function call contents from a tool action event.""" if isinstance(event_data, ThreadRun) and event_data.required_action is not None: diff --git a/python/packages/azure-ai/tests/test_azure_ai_agent_client.py b/python/packages/azure-ai/tests/test_azure_ai_agent_client.py index 555d27d560..d839eca376 100644 --- a/python/packages/azure-ai/tests/test_azure_ai_agent_client.py +++ b/python/packages/azure-ai/tests/test_azure_ai_agent_client.py @@ -3,7 +3,7 @@ import json import os from pathlib import Path -from typing import Annotated +from typing import Annotated, Any from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -92,6 +92,7 @@ def create_test_azure_ai_chat_client( client._agent_created = False client._should_close_client = False client._agent_definition = None + client._azure_search_tool_calls = [] # Add the new instance variable client.additional_properties = {} client.middleware = None @@ -1335,8 +1336,8 @@ def test_azure_ai_chat_client_extract_url_citations_with_citations(mock_agents_c mock_chunk = MagicMock(spec=MessageDeltaChunk) mock_chunk.delta = mock_delta - # Call the method - citations = chat_client._extract_url_citations(mock_chunk) # type: ignore + # Call the method with empty azure_search_tool_calls + citations = chat_client._extract_url_citations(mock_chunk, []) # type: ignore # Verify results assert len(citations) == 1 @@ -1804,3 +1805,166 @@ async def test_azure_ai_chat_client_no_cleanup_when_agent_not_created_by_client( # Verify agent was NOT deleted mock_agents_client.delete_agent.assert_not_called() assert chat_client.agent_id == "existing-agent-id" + + +def test_azure_ai_chat_client_capture_azure_search_tool_calls(mock_agents_client: MagicMock) -> None: + """Test _capture_azure_search_tool_calls method.""" + chat_client = create_test_azure_ai_chat_client(mock_agents_client) + + # Mock Azure AI Search tool call + mock_tool_call = MagicMock() + mock_tool_call.type = "azure_ai_search" + mock_tool_call.id = "call_123" + mock_tool_call.azure_ai_search = {"input": "test query", "output": "test output"} + + # Mock step data + mock_step_data = MagicMock() + mock_step_data.step_details.tool_calls = [mock_tool_call] + + # Call the method with a list to capture tool calls + azure_search_tool_calls: list[dict[str, Any]] = [] + chat_client._capture_azure_search_tool_calls(mock_step_data, azure_search_tool_calls) # type: ignore + + # Verify tool call was captured + assert len(azure_search_tool_calls) == 1 + captured_tool_call = azure_search_tool_calls[0] + assert captured_tool_call["type"] == "azure_ai_search" + assert captured_tool_call["id"] == "call_123" + assert captured_tool_call["azure_ai_search"] == {"input": "test query", "output": "test output"} + + +def test_azure_ai_chat_client_get_real_url_from_citation_reference_no_tool_calls( + mock_agents_client: MagicMock, +) -> None: + """Test _get_real_url_from_citation_reference with no tool calls.""" + chat_client = create_test_azure_ai_chat_client(mock_agents_client) + + # No tool calls - pass empty list + result = chat_client._get_real_url_from_citation_reference("doc_1", []) # type: ignore + assert result == "doc_1" + + +def test_azure_ai_chat_client_get_real_url_from_citation_reference_invalid_output( + mock_agents_client: MagicMock, +) -> None: + """Test _get_real_url_from_citation_reference with invalid output format.""" + chat_client = create_test_azure_ai_chat_client(mock_agents_client) + + # Tool call with invalid output format + azure_search_tool_calls = [ + {"id": "call_123", "type": "azure_ai_search", "azure_ai_search": {"output": "invalid_json_format"}} + ] + + result = chat_client._get_real_url_from_citation_reference("doc_1", azure_search_tool_calls) # type: ignore + assert result == "doc_1" + + +async def test_azure_ai_chat_client_context_manager(mock_agents_client: MagicMock) -> None: + """Test AzureAIAgentClient as async context manager.""" + chat_client = create_test_azure_ai_chat_client(mock_agents_client) + + # Mock close method to avoid actual cleanup + chat_client.close = AsyncMock() + + async with chat_client as client: + assert client is chat_client + + # Verify close was called on exit + chat_client.close.assert_called_once() + + +async def test_azure_ai_chat_client_close_method(mock_agents_client: MagicMock) -> None: + """Test AzureAIAgentClient close method.""" + chat_client = create_test_azure_ai_chat_client(mock_agents_client) + + # Mock cleanup methods + chat_client._cleanup_agent_if_needed = AsyncMock() + chat_client._close_client_if_needed = AsyncMock() + + await chat_client.close() + + # Verify cleanup methods were called + chat_client._cleanup_agent_if_needed.assert_called_once() + chat_client._close_client_if_needed.assert_called_once() + + +def test_azure_ai_chat_client_extract_url_citations_with_azure_search_enhanced_url( + mock_agents_client: MagicMock, +) -> None: + """Test _extract_url_citations with Azure AI Search URL enhancement.""" + chat_client = create_test_azure_ai_chat_client(mock_agents_client) + + # Add Azure Search tool calls for URL enhancement + azure_search_tool_calls = [ + { + "id": "call_123", + "type": "azure_ai_search", + "azure_ai_search": { + "output": str({ + "metadata": {"get_urls": ["https://real-example.com/doc1", "https://real-example.com/doc2"]} + }) + }, + } + ] + + # Create mock URL citation with doc reference + mock_url_citation = MagicMock() + mock_url_citation.url = "doc_1" + mock_url_citation.title = "Test Title" + + mock_annotation = MagicMock(spec=MessageDeltaTextUrlCitationAnnotation) + mock_annotation.url_citation = mock_url_citation + mock_annotation.start_index = 10 + mock_annotation.end_index = 20 + + mock_text = MagicMock() + mock_text.annotations = [mock_annotation] + + mock_text_content = MagicMock(spec=MessageDeltaTextContent) + mock_text_content.text = mock_text + + mock_delta = MagicMock() + mock_delta.content = [mock_text_content] + + mock_chunk = MagicMock(spec=MessageDeltaChunk) + mock_chunk.delta = mock_delta + + citations = chat_client._extract_url_citations(mock_chunk, azure_search_tool_calls) # type: ignore + + # Verify real URL was used + assert len(citations) == 1 + citation = citations[0] + assert citation.url == "https://real-example.com/doc2" # doc_1 maps to index 1 + + +def test_azure_ai_chat_client_init_with_auto_created_agents_client( + azure_ai_unit_test_env: dict[str, str], mock_azure_credential: MagicMock +) -> None: + """Test AzureAIAgentClient initialization when it creates its own AgentsClient.""" + + # Mock the AgentsClient constructor + with patch("agent_framework_azure_ai._chat_client.AgentsClient") as mock_agents_client_class: + mock_agents_client_instance = MagicMock() + mock_agents_client_class.return_value = mock_agents_client_instance + + # Create client without providing agents_client - should create its own + client = AzureAIAgentClient( + agents_client=None, # This will trigger creation of AgentsClient + agent_id="test-agent", + project_endpoint=azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"], + model_deployment_name=azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"], + async_credential=mock_azure_credential, + ) + + # Verify AgentsClient was created with correct parameters + mock_agents_client_class.assert_called_once_with( + endpoint=azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"], + credential=mock_azure_credential, + user_agent="agent-framework-python/0.0.0", + ) + + # Verify client properties are set correctly + assert client.agents_client is mock_agents_client_instance + assert client.agent_id == "test-agent" + assert client.credential is mock_azure_credential + assert client._should_close_client is True # Should close since we created it # type: ignore[attr-defined] diff --git a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_azure_ai_search.py b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_azure_ai_search.py index 32fff1fbda..34d4913651 100644 --- a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_azure_ai_search.py +++ b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_azure_ai_search.py @@ -103,11 +103,11 @@ async def main() -> None: print() - # Display collected citations + # Display collected citation if citations: - print("\n\nCitations:") + print("\n\nCitation:") for i, citation in enumerate(citations, 1): - print(f"[{i}] Reference: {citation.url}") + print(f"[{i}] {citation.url}") print("\n" + "=" * 50 + "\n") print("Hotel search conversation completed!")