mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: [Feature Branch] Fixed "store" parameter handling (#2069)
* Fixed store parameter handling * Small fix
This commit is contained in:
committed by
GitHub
Unverified
parent
476fbbefc3
commit
c3ef6475a2
@@ -26,11 +26,7 @@ from azure.ai.projects.models import (
|
||||
)
|
||||
from azure.core.credentials_async import AsyncTokenCredential
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
from openai.types.responses.parsed_response import (
|
||||
ParsedResponse,
|
||||
)
|
||||
from openai.types.responses.response import Response as OpenAIResponse
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pydantic import ValidationError
|
||||
|
||||
from ._shared import AzureAISettings
|
||||
|
||||
@@ -249,18 +245,6 @@ class AzureAIClient(OpenAIBaseResponsesClient):
|
||||
|
||||
return {"name": agent_name, "version": self.agent_version, "type": "agent_reference"}
|
||||
|
||||
async def _get_conversation_id_or_create(self, run_options: dict[str, Any]) -> str:
|
||||
# Since "conversation" property is used, remove "previous_response_id" from options
|
||||
# Use global conversation_id as fallback
|
||||
conversation_id = run_options.pop("previous_response_id", self.conversation_id)
|
||||
|
||||
if conversation_id:
|
||||
return conversation_id
|
||||
|
||||
# Create a new conversation with messages
|
||||
created_conversation = await self.client.conversations.create()
|
||||
return created_conversation.id
|
||||
|
||||
async def _close_client_if_needed(self) -> None:
|
||||
"""Close project_client session if we created it."""
|
||||
if self._should_close_client:
|
||||
@@ -288,16 +272,11 @@ class AzureAIClient(OpenAIBaseResponsesClient):
|
||||
async def prepare_options(
|
||||
self, messages: MutableSequence[ChatMessage], chat_options: ChatOptions
|
||||
) -> dict[str, Any]:
|
||||
chat_options.store = bool(chat_options.store or chat_options.store is None)
|
||||
prepared_messages, instructions = self._prepare_input(messages)
|
||||
run_options = await super().prepare_options(prepared_messages, chat_options)
|
||||
agent_reference = await self._get_agent_reference_or_create(run_options, instructions)
|
||||
|
||||
store = run_options.get("store", False)
|
||||
|
||||
if store:
|
||||
conversation_id = await self._get_conversation_id_or_create(run_options)
|
||||
run_options["conversation"] = conversation_id
|
||||
|
||||
run_options["extra_body"] = {"agent": agent_reference}
|
||||
|
||||
# Remove properties that are not supported on request level
|
||||
@@ -313,10 +292,6 @@ class AzureAIClient(OpenAIBaseResponsesClient):
|
||||
"""Initialize OpenAI client asynchronously."""
|
||||
self.client = await self.project_client.get_openai_client() # type: ignore
|
||||
|
||||
def get_conversation_id(self, response: OpenAIResponse | ParsedResponse[BaseModel], store: bool) -> str | None:
|
||||
"""Get the conversation ID from the response if store is True."""
|
||||
return response.conversation.id if response.conversation and store else None
|
||||
|
||||
def _update_agent_name(self, agent_name: str | None) -> None:
|
||||
"""Update the agent name in the chat client.
|
||||
|
||||
|
||||
@@ -193,34 +193,6 @@ async def test_azure_ai_client_get_agent_reference_missing_model(
|
||||
await client._get_agent_reference_or_create({}, None) # type: ignore
|
||||
|
||||
|
||||
async def test_azure_ai_client_get_conversation_id_or_create_existing(
|
||||
mock_project_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test _get_conversation_id_or_create when conversation_id is already provided."""
|
||||
client = create_test_azure_ai_client(mock_project_client, conversation_id="existing-conversation")
|
||||
|
||||
conversation_id = await client._get_conversation_id_or_create({}) # type: ignore
|
||||
|
||||
assert conversation_id == "existing-conversation"
|
||||
|
||||
|
||||
async def test_azure_ai_client_get_conversation_id_or_create_new(
|
||||
mock_project_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test _get_conversation_id_or_create when creating a new conversation."""
|
||||
client = create_test_azure_ai_client(mock_project_client)
|
||||
|
||||
# Mock conversation creation response
|
||||
mock_conversation = MagicMock()
|
||||
mock_conversation.id = "new-conversation-123"
|
||||
client.client.conversations.create = AsyncMock(return_value=mock_conversation)
|
||||
|
||||
conversation_id = await client._get_conversation_id_or_create({}) # type: ignore
|
||||
|
||||
assert conversation_id == "new-conversation-123"
|
||||
client.client.conversations.create.assert_called_once()
|
||||
|
||||
|
||||
async def test_azure_ai_client_prepare_input_with_system_messages(
|
||||
mock_project_client: MagicMock,
|
||||
) -> None:
|
||||
@@ -279,34 +251,6 @@ async def test_azure_ai_client_prepare_options_basic(mock_project_client: MagicM
|
||||
assert run_options["extra_body"]["agent"]["name"] == "test-agent"
|
||||
|
||||
|
||||
async def test_azure_ai_client_prepare_options_with_store(mock_project_client: MagicMock) -> None:
|
||||
"""Test prepare_options with store=True creates conversation."""
|
||||
client = create_test_azure_ai_client(mock_project_client, agent_name="test-agent", agent_version="1.0")
|
||||
|
||||
# Mock conversation creation
|
||||
mock_conversation = MagicMock()
|
||||
mock_conversation.id = "new-conversation-456"
|
||||
client.client.conversations.create = AsyncMock(return_value=mock_conversation)
|
||||
|
||||
messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")])]
|
||||
chat_options = ChatOptions(store=True)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
client.__class__.__bases__[0], "prepare_options", return_value={"model": "test-model", "store": True}
|
||||
),
|
||||
patch.object(
|
||||
client,
|
||||
"_get_agent_reference_or_create",
|
||||
return_value={"name": "test-agent", "version": "1.0", "type": "agent_reference"},
|
||||
),
|
||||
):
|
||||
run_options = await client.prepare_options(messages, chat_options)
|
||||
|
||||
assert "conversation" in run_options
|
||||
assert run_options["conversation"] == "new-conversation-456"
|
||||
|
||||
|
||||
async def test_azure_ai_client_initialize_client(mock_project_client: MagicMock) -> None:
|
||||
"""Test initialize_client method."""
|
||||
client = create_test_azure_ai_client(mock_project_client)
|
||||
@@ -320,27 +264,6 @@ async def test_azure_ai_client_initialize_client(mock_project_client: MagicMock)
|
||||
mock_project_client.get_openai_client.assert_called_once()
|
||||
|
||||
|
||||
def test_azure_ai_client_get_conversation_id_from_response(mock_project_client: MagicMock) -> None:
|
||||
"""Test get_conversation_id method."""
|
||||
client = create_test_azure_ai_client(mock_project_client)
|
||||
|
||||
# Test with conversation and store=True
|
||||
mock_response = MagicMock()
|
||||
mock_response.conversation.id = "test-conversation-123"
|
||||
|
||||
conversation_id = client.get_conversation_id(mock_response, store=True)
|
||||
assert conversation_id == "test-conversation-123"
|
||||
|
||||
# Test with store=False
|
||||
conversation_id = client.get_conversation_id(mock_response, store=False)
|
||||
assert conversation_id is None
|
||||
|
||||
# Test with no conversation
|
||||
mock_response.conversation = None
|
||||
conversation_id = client.get_conversation_id(mock_response, store=True)
|
||||
assert conversation_id is None
|
||||
|
||||
|
||||
def test_azure_ai_client_update_agent_name(mock_project_client: MagicMock) -> None:
|
||||
"""Test _update_agent_name method."""
|
||||
client = create_test_azure_ai_client(mock_project_client)
|
||||
|
||||
@@ -564,10 +564,6 @@ class BaseChatClient(SerializationMixin, ABC):
|
||||
|
||||
# Validate that store is True when conversation_id is set
|
||||
if chat_options.conversation_id is not None and chat_options.store is not True:
|
||||
logger.warning(
|
||||
"When conversation_id is set, store must be True for service-managed threads. "
|
||||
"Automatically setting store=True."
|
||||
)
|
||||
chat_options.store = True
|
||||
|
||||
if chat_options.instructions:
|
||||
@@ -663,10 +659,6 @@ class BaseChatClient(SerializationMixin, ABC):
|
||||
|
||||
# Validate that store is True when conversation_id is set
|
||||
if chat_options.conversation_id is not None and chat_options.store is not True:
|
||||
logger.warning(
|
||||
"When conversation_id is set, store must be True for service-managed threads. "
|
||||
"Automatically setting store=True."
|
||||
)
|
||||
chat_options.store = True
|
||||
|
||||
if chat_options.instructions:
|
||||
|
||||
@@ -1630,7 +1630,7 @@ def _handle_function_calls_response(
|
||||
# this runs in every but the first run
|
||||
# we need to keep track of all function call messages
|
||||
fcc_messages.extend(response.messages)
|
||||
if getattr(kwargs.get("chat_options"), "store", False):
|
||||
if response.conversation_id is not None:
|
||||
prepped_messages.clear()
|
||||
prepped_messages.append(result_message)
|
||||
else:
|
||||
@@ -1833,7 +1833,7 @@ def _handle_function_calls_streaming_response(
|
||||
# this runs in every but the first run
|
||||
# we need to keep track of all function call messages
|
||||
fcc_messages.extend(response.messages)
|
||||
if getattr(kwargs.get("chat_options"), "store", False):
|
||||
if response.conversation_id is not None:
|
||||
prepped_messages.clear()
|
||||
prepped_messages.append(result_message)
|
||||
else:
|
||||
|
||||
@@ -302,6 +302,8 @@ class OpenAIBaseResponsesClient(OpenAIBase, BaseChatClient):
|
||||
if never_require_approvals := tool.approval_mode.get("never_require_approval"):
|
||||
mcp["require_approval"] = {"never": {"tool_names": list(never_require_approvals)}}
|
||||
|
||||
return mcp
|
||||
|
||||
async def prepare_options(
|
||||
self, messages: MutableSequence[ChatMessage], chat_options: ChatOptions
|
||||
) -> dict[str, Any]:
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from random import randint
|
||||
from typing import Annotated
|
||||
|
||||
from agent_framework import ChatAgent
|
||||
from agent_framework.azure import AzureAIClient
|
||||
from azure.ai.projects.aio import AIProjectClient
|
||||
from azure.identity.aio import AzureCliCredential
|
||||
from pydantic import Field
|
||||
|
||||
"""
|
||||
Azure AI Agent with Existing Conversation Example
|
||||
|
||||
This sample demonstrates working with pre-existing conversation
|
||||
by providing conversation ID for reuse patterns.
|
||||
"""
|
||||
|
||||
|
||||
def get_weather(
|
||||
location: Annotated[str, Field(description="The location to get the weather for.")],
|
||||
) -> str:
|
||||
"""Get the weather for a given location."""
|
||||
conditions = ["sunny", "cloudy", "rainy", "stormy"]
|
||||
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
# Create the client
|
||||
async with (
|
||||
AzureCliCredential() as credential,
|
||||
AIProjectClient(endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"], credential=credential) as project_client,
|
||||
):
|
||||
openai_client = await project_client.get_openai_client() # type: ignore
|
||||
|
||||
# Create a conversation that will persist
|
||||
created_conversation = await openai_client.conversations.create()
|
||||
|
||||
try:
|
||||
async with ChatAgent(
|
||||
chat_client=AzureAIClient(project_client=project_client),
|
||||
instructions="You are a helpful weather agent.",
|
||||
tools=get_weather,
|
||||
store=True,
|
||||
) as agent:
|
||||
thread = agent.get_new_thread(service_thread_id=created_conversation.id)
|
||||
assert thread.is_initialized
|
||||
result = await agent.run("What's the weather like in Tokyo?", thread=thread)
|
||||
print(f"Result: {result}\n")
|
||||
finally:
|
||||
# Clean up the conversation manually
|
||||
await openai_client.conversations.delete(conversation_id=created_conversation.id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -71,19 +71,19 @@ async def example_with_thread_persistence_in_memory() -> None:
|
||||
# First conversation
|
||||
query1 = "What's the weather like in Tokyo?"
|
||||
print(f"User: {query1}")
|
||||
result1 = await agent.run(query1, thread=thread)
|
||||
result1 = await agent.run(query1, thread=thread, store=False)
|
||||
print(f"Agent: {result1.text}")
|
||||
|
||||
# Second conversation using the same thread - maintains context
|
||||
query2 = "How about London?"
|
||||
print(f"\nUser: {query2}")
|
||||
result2 = await agent.run(query2, thread=thread)
|
||||
result2 = await agent.run(query2, thread=thread, store=False)
|
||||
print(f"Agent: {result2.text}")
|
||||
|
||||
# Third conversation - agent should remember both previous cities
|
||||
query3 = "Which of the cities I asked about has better weather?"
|
||||
print(f"\nUser: {query3}")
|
||||
result3 = await agent.run(query3, thread=thread)
|
||||
result3 = await agent.run(query3, thread=thread, store=False)
|
||||
print(f"Agent: {result3.text}")
|
||||
print("Note: The agent remembers context from previous messages in the same thread.\n")
|
||||
|
||||
@@ -91,7 +91,7 @@ async def example_with_thread_persistence_in_memory() -> None:
|
||||
async def example_with_existing_thread_id() -> None:
|
||||
"""
|
||||
Example showing how to work with an existing thread ID from the service.
|
||||
In this example, messages are stored on the server using Azure AI conversation state.
|
||||
In this example, messages are stored on the server.
|
||||
"""
|
||||
print("=== Existing Thread ID Example ===")
|
||||
|
||||
@@ -111,8 +111,7 @@ async def example_with_existing_thread_id() -> None:
|
||||
|
||||
query1 = "What's the weather in Paris?"
|
||||
print(f"User: {query1}")
|
||||
# Enable Azure AI conversation state by setting `store` parameter to True
|
||||
result1 = await agent.run(query1, thread=thread, store=True)
|
||||
result1 = await agent.run(query1, thread=thread)
|
||||
print(f"Agent: {result1.text}")
|
||||
|
||||
# The thread ID is set after the first response
|
||||
@@ -134,7 +133,7 @@ async def example_with_existing_thread_id() -> None:
|
||||
|
||||
query2 = "What was the last city I asked about?"
|
||||
print(f"User: {query2}")
|
||||
result2 = await agent.run(query2, thread=thread, store=True)
|
||||
result2 = await agent.run(query2, thread=thread)
|
||||
print(f"Agent: {result2.text}")
|
||||
print("Note: The agent continues the conversation from the previous thread by using thread ID.\n")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user