Python: [Feature Branch] Fixed "store" parameter handling (#2069)

* Fixed store parameter handling

* Small fix
This commit is contained in:
Dmytro Struk
2025-11-10 18:24:32 -08:00
committed by GitHub
Unverified
parent 476fbbefc3
commit c3ef6475a2
7 changed files with 12 additions and 179 deletions
@@ -26,11 +26,7 @@ from azure.ai.projects.models import (
)
from azure.core.credentials_async import AsyncTokenCredential
from azure.core.exceptions import ResourceNotFoundError
from openai.types.responses.parsed_response import (
ParsedResponse,
)
from openai.types.responses.response import Response as OpenAIResponse
from pydantic import BaseModel, ValidationError
from pydantic import ValidationError
from ._shared import AzureAISettings
@@ -249,18 +245,6 @@ class AzureAIClient(OpenAIBaseResponsesClient):
return {"name": agent_name, "version": self.agent_version, "type": "agent_reference"}
async def _get_conversation_id_or_create(self, run_options: dict[str, Any]) -> str:
# Since "conversation" property is used, remove "previous_response_id" from options
# Use global conversation_id as fallback
conversation_id = run_options.pop("previous_response_id", self.conversation_id)
if conversation_id:
return conversation_id
# Create a new conversation with messages
created_conversation = await self.client.conversations.create()
return created_conversation.id
async def _close_client_if_needed(self) -> None:
"""Close project_client session if we created it."""
if self._should_close_client:
@@ -288,16 +272,11 @@ class AzureAIClient(OpenAIBaseResponsesClient):
async def prepare_options(
self, messages: MutableSequence[ChatMessage], chat_options: ChatOptions
) -> dict[str, Any]:
chat_options.store = bool(chat_options.store or chat_options.store is None)
prepared_messages, instructions = self._prepare_input(messages)
run_options = await super().prepare_options(prepared_messages, chat_options)
agent_reference = await self._get_agent_reference_or_create(run_options, instructions)
store = run_options.get("store", False)
if store:
conversation_id = await self._get_conversation_id_or_create(run_options)
run_options["conversation"] = conversation_id
run_options["extra_body"] = {"agent": agent_reference}
# Remove properties that are not supported on request level
@@ -313,10 +292,6 @@ class AzureAIClient(OpenAIBaseResponsesClient):
"""Initialize OpenAI client asynchronously."""
self.client = await self.project_client.get_openai_client() # type: ignore
def get_conversation_id(self, response: OpenAIResponse | ParsedResponse[BaseModel], store: bool) -> str | None:
"""Get the conversation ID from the response if store is True."""
return response.conversation.id if response.conversation and store else None
def _update_agent_name(self, agent_name: str | None) -> None:
"""Update the agent name in the chat client.
@@ -193,34 +193,6 @@ async def test_azure_ai_client_get_agent_reference_missing_model(
await client._get_agent_reference_or_create({}, None) # type: ignore
async def test_azure_ai_client_get_conversation_id_or_create_existing(
mock_project_client: MagicMock,
) -> None:
"""Test _get_conversation_id_or_create when conversation_id is already provided."""
client = create_test_azure_ai_client(mock_project_client, conversation_id="existing-conversation")
conversation_id = await client._get_conversation_id_or_create({}) # type: ignore
assert conversation_id == "existing-conversation"
async def test_azure_ai_client_get_conversation_id_or_create_new(
mock_project_client: MagicMock,
) -> None:
"""Test _get_conversation_id_or_create when creating a new conversation."""
client = create_test_azure_ai_client(mock_project_client)
# Mock conversation creation response
mock_conversation = MagicMock()
mock_conversation.id = "new-conversation-123"
client.client.conversations.create = AsyncMock(return_value=mock_conversation)
conversation_id = await client._get_conversation_id_or_create({}) # type: ignore
assert conversation_id == "new-conversation-123"
client.client.conversations.create.assert_called_once()
async def test_azure_ai_client_prepare_input_with_system_messages(
mock_project_client: MagicMock,
) -> None:
@@ -279,34 +251,6 @@ async def test_azure_ai_client_prepare_options_basic(mock_project_client: MagicM
assert run_options["extra_body"]["agent"]["name"] == "test-agent"
async def test_azure_ai_client_prepare_options_with_store(mock_project_client: MagicMock) -> None:
"""Test prepare_options with store=True creates conversation."""
client = create_test_azure_ai_client(mock_project_client, agent_name="test-agent", agent_version="1.0")
# Mock conversation creation
mock_conversation = MagicMock()
mock_conversation.id = "new-conversation-456"
client.client.conversations.create = AsyncMock(return_value=mock_conversation)
messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")])]
chat_options = ChatOptions(store=True)
with (
patch.object(
client.__class__.__bases__[0], "prepare_options", return_value={"model": "test-model", "store": True}
),
patch.object(
client,
"_get_agent_reference_or_create",
return_value={"name": "test-agent", "version": "1.0", "type": "agent_reference"},
),
):
run_options = await client.prepare_options(messages, chat_options)
assert "conversation" in run_options
assert run_options["conversation"] == "new-conversation-456"
async def test_azure_ai_client_initialize_client(mock_project_client: MagicMock) -> None:
"""Test initialize_client method."""
client = create_test_azure_ai_client(mock_project_client)
@@ -320,27 +264,6 @@ async def test_azure_ai_client_initialize_client(mock_project_client: MagicMock)
mock_project_client.get_openai_client.assert_called_once()
def test_azure_ai_client_get_conversation_id_from_response(mock_project_client: MagicMock) -> None:
"""Test get_conversation_id method."""
client = create_test_azure_ai_client(mock_project_client)
# Test with conversation and store=True
mock_response = MagicMock()
mock_response.conversation.id = "test-conversation-123"
conversation_id = client.get_conversation_id(mock_response, store=True)
assert conversation_id == "test-conversation-123"
# Test with store=False
conversation_id = client.get_conversation_id(mock_response, store=False)
assert conversation_id is None
# Test with no conversation
mock_response.conversation = None
conversation_id = client.get_conversation_id(mock_response, store=True)
assert conversation_id is None
def test_azure_ai_client_update_agent_name(mock_project_client: MagicMock) -> None:
"""Test _update_agent_name method."""
client = create_test_azure_ai_client(mock_project_client)
@@ -564,10 +564,6 @@ class BaseChatClient(SerializationMixin, ABC):
# Validate that store is True when conversation_id is set
if chat_options.conversation_id is not None and chat_options.store is not True:
logger.warning(
"When conversation_id is set, store must be True for service-managed threads. "
"Automatically setting store=True."
)
chat_options.store = True
if chat_options.instructions:
@@ -663,10 +659,6 @@ class BaseChatClient(SerializationMixin, ABC):
# Validate that store is True when conversation_id is set
if chat_options.conversation_id is not None and chat_options.store is not True:
logger.warning(
"When conversation_id is set, store must be True for service-managed threads. "
"Automatically setting store=True."
)
chat_options.store = True
if chat_options.instructions:
@@ -1630,7 +1630,7 @@ def _handle_function_calls_response(
# this runs in every but the first run
# we need to keep track of all function call messages
fcc_messages.extend(response.messages)
if getattr(kwargs.get("chat_options"), "store", False):
if response.conversation_id is not None:
prepped_messages.clear()
prepped_messages.append(result_message)
else:
@@ -1833,7 +1833,7 @@ def _handle_function_calls_streaming_response(
# this runs in every but the first run
# we need to keep track of all function call messages
fcc_messages.extend(response.messages)
if getattr(kwargs.get("chat_options"), "store", False):
if response.conversation_id is not None:
prepped_messages.clear()
prepped_messages.append(result_message)
else:
@@ -302,6 +302,8 @@ class OpenAIBaseResponsesClient(OpenAIBase, BaseChatClient):
if never_require_approvals := tool.approval_mode.get("never_require_approval"):
mcp["require_approval"] = {"never": {"tool_names": list(never_require_approvals)}}
return mcp
async def prepare_options(
self, messages: MutableSequence[ChatMessage], chat_options: ChatOptions
) -> dict[str, Any]:
@@ -1,58 +0,0 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import os
from random import randint
from typing import Annotated
from agent_framework import ChatAgent
from agent_framework.azure import AzureAIClient
from azure.ai.projects.aio import AIProjectClient
from azure.identity.aio import AzureCliCredential
from pydantic import Field
"""
Azure AI Agent with Existing Conversation Example
This sample demonstrates working with pre-existing conversation
by providing conversation ID for reuse patterns.
"""
def get_weather(
location: Annotated[str, Field(description="The location to get the weather for.")],
) -> str:
"""Get the weather for a given location."""
conditions = ["sunny", "cloudy", "rainy", "stormy"]
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
async def main() -> None:
# Create the client
async with (
AzureCliCredential() as credential,
AIProjectClient(endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"], credential=credential) as project_client,
):
openai_client = await project_client.get_openai_client() # type: ignore
# Create a conversation that will persist
created_conversation = await openai_client.conversations.create()
try:
async with ChatAgent(
chat_client=AzureAIClient(project_client=project_client),
instructions="You are a helpful weather agent.",
tools=get_weather,
store=True,
) as agent:
thread = agent.get_new_thread(service_thread_id=created_conversation.id)
assert thread.is_initialized
result = await agent.run("What's the weather like in Tokyo?", thread=thread)
print(f"Result: {result}\n")
finally:
# Clean up the conversation manually
await openai_client.conversations.delete(conversation_id=created_conversation.id)
if __name__ == "__main__":
asyncio.run(main())
@@ -71,19 +71,19 @@ async def example_with_thread_persistence_in_memory() -> None:
# First conversation
query1 = "What's the weather like in Tokyo?"
print(f"User: {query1}")
result1 = await agent.run(query1, thread=thread)
result1 = await agent.run(query1, thread=thread, store=False)
print(f"Agent: {result1.text}")
# Second conversation using the same thread - maintains context
query2 = "How about London?"
print(f"\nUser: {query2}")
result2 = await agent.run(query2, thread=thread)
result2 = await agent.run(query2, thread=thread, store=False)
print(f"Agent: {result2.text}")
# Third conversation - agent should remember both previous cities
query3 = "Which of the cities I asked about has better weather?"
print(f"\nUser: {query3}")
result3 = await agent.run(query3, thread=thread)
result3 = await agent.run(query3, thread=thread, store=False)
print(f"Agent: {result3.text}")
print("Note: The agent remembers context from previous messages in the same thread.\n")
@@ -91,7 +91,7 @@ async def example_with_thread_persistence_in_memory() -> None:
async def example_with_existing_thread_id() -> None:
"""
Example showing how to work with an existing thread ID from the service.
In this example, messages are stored on the server using Azure AI conversation state.
In this example, messages are stored on the server.
"""
print("=== Existing Thread ID Example ===")
@@ -111,8 +111,7 @@ async def example_with_existing_thread_id() -> None:
query1 = "What's the weather in Paris?"
print(f"User: {query1}")
# Enable Azure AI conversation state by setting `store` parameter to True
result1 = await agent.run(query1, thread=thread, store=True)
result1 = await agent.run(query1, thread=thread)
print(f"Agent: {result1.text}")
# The thread ID is set after the first response
@@ -134,7 +133,7 @@ async def example_with_existing_thread_id() -> None:
query2 = "What was the last city I asked about?"
print(f"User: {query2}")
result2 = await agent.run(query2, thread=thread, store=True)
result2 = await agent.run(query2, thread=thread)
print(f"Agent: {result2.text}")
print("Note: The agent continues the conversation from the previous thread by using thread ID.\n")