mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: [BREAKING] cleanup of thread API and serialization (#893)
* cleanup of threads and serialization * fix for sliding window * fix redis test * updated from comments * updated context provider and threads * updated lock * add asyncio default * fix redis tests * fix tests * fix tests * renamed to invoking * fixed tests * fix for instructions
This commit is contained in:
committed by
GitHub
Unverified
parent
bf5931932e
commit
10d10364a9
@@ -9,6 +9,7 @@ This folder contains examples demonstrating different ways to create and use age
|
||||
| [`azure_ai_basic.py`](azure_ai_basic.py) | The simplest way to create an agent using `ChatAgent` with `AzureAIAgentClient`. It automatically handles all configuration using environment variables. |
|
||||
| [`azure_ai_with_explicit_settings.py`](azure_ai_with_explicit_settings.py) | Shows how to create an agent with explicitly configured `AzureAIAgentClient` settings, including project endpoint, model deployment, credentials, and agent name. |
|
||||
| [`azure_ai_with_existing_agent.py`](azure_ai_with_existing_agent.py) | Shows how to work with a pre-existing agent by providing the agent ID to the Azure AI chat client. This example also demonstrates proper cleanup of manually created agents. |
|
||||
| [`azure_ai_with_existing_thread.py`](azure_ai_with_existing_thread.py) | Shows how to work with a pre-existing thread by providing the thread ID to the Azure AI chat client. This example also demonstrates proper cleanup of manually created threads. |
|
||||
| [`azure_ai_with_function_tools.py`](azure_ai_with_function_tools.py) | Demonstrates how to use function tools with agents. Shows both agent-level tools (defined when creating the agent) and query-level tools (provided with specific queries). |
|
||||
| [`azure_ai_with_code_interpreter.py`](azure_ai_with_code_interpreter.py) | Shows how to use the HostedCodeInterpreterTool with Azure AI agents to write and execute Python code. Includes helper methods for accessing code interpreter data from response chunks. |
|
||||
| [`azure_ai_with_local_mcp.py`](azure_ai_with_local_mcp.py) | Shows how to integrate Azure AI agents with Model Context Protocol (MCP) servers for enhanced functionality and tool integration. Demonstrates both agent-level and run-level tool configuration. |
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from random import randint
|
||||
from typing import Annotated
|
||||
|
||||
from agent_framework import ChatAgent
|
||||
from agent_framework.azure import AzureAIAgentClient
|
||||
from azure.ai.projects.aio import AIProjectClient
|
||||
from azure.identity.aio import AzureCliCredential
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
def get_weather(
|
||||
location: Annotated[str, Field(description="The location to get the weather for.")],
|
||||
) -> str:
|
||||
"""Get the weather for a given location."""
|
||||
conditions = ["sunny", "cloudy", "rainy", "stormy"]
|
||||
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
print("=== Azure AI Chat Client with Existing Thread ===")
|
||||
|
||||
# Create the client
|
||||
async with (
|
||||
AzureCliCredential() as credential,
|
||||
AIProjectClient(endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"], credential=credential) as client,
|
||||
):
|
||||
# Create an thread that will persist
|
||||
created_thread = await client.agents.threads.create()
|
||||
|
||||
try:
|
||||
async with ChatAgent(
|
||||
# passing in the client is optional here, so if you take the agent_id from the portal
|
||||
# you can use it directly without the two lines above.
|
||||
chat_client=AzureAIAgentClient(project_client=client),
|
||||
instructions="You are a helpful weather agent.",
|
||||
tools=get_weather,
|
||||
) as agent:
|
||||
thread = agent.get_new_thread(service_thread_id=created_thread.id)
|
||||
assert thread.is_initialized
|
||||
result = await agent.run("What's the weather like in Tokyo?", thread=thread)
|
||||
print(f"Result: {result}\n")
|
||||
finally:
|
||||
# Clean up the thread manually
|
||||
await client.agents.threads.delete(created_thread.id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -4,7 +4,7 @@ import asyncio
|
||||
from random import randint
|
||||
from typing import Annotated
|
||||
|
||||
from agent_framework import AgentThread, ChatAgent, ChatMessageList
|
||||
from agent_framework import AgentThread, ChatAgent, ChatMessageStore
|
||||
from agent_framework.azure import AzureOpenAIChatClient
|
||||
from azure.identity import AzureCliCredential
|
||||
from pydantic import Field
|
||||
@@ -125,7 +125,7 @@ async def example_with_existing_thread_messages() -> None:
|
||||
|
||||
# You can also create a new thread from existing messages
|
||||
messages = await thread.message_store.list_messages() if thread.message_store else []
|
||||
new_thread = AgentThread(message_store=ChatMessageList(messages))
|
||||
new_thread = AgentThread(message_store=ChatMessageStore(messages))
|
||||
|
||||
query3 = "How does the Paris weather compare to London?"
|
||||
print(f"User: {query3}")
|
||||
|
||||
@@ -101,8 +101,7 @@ class EchoAgent(BaseAgent):
|
||||
|
||||
# Notify the thread of new messages if provided
|
||||
if thread is not None:
|
||||
await self._notify_thread_of_new_messages(thread, normalized_messages)
|
||||
await self._notify_thread_of_new_messages(thread, response_message)
|
||||
await self._notify_thread_of_new_messages(thread, normalized_messages, response_message)
|
||||
|
||||
return AgentRunResponse(messages=[response_message])
|
||||
|
||||
@@ -136,10 +135,6 @@ class EchoAgent(BaseAgent):
|
||||
else:
|
||||
response_text = f"{self.echo_prefix}[Non-text message received]"
|
||||
|
||||
# Notify the thread of input messages if provided
|
||||
if thread is not None:
|
||||
await self._notify_thread_of_new_messages(thread, normalized_messages)
|
||||
|
||||
# Simulate streaming by yielding the response word by word
|
||||
words = response_text.split()
|
||||
for i, word in enumerate(words):
|
||||
@@ -157,7 +152,7 @@ class EchoAgent(BaseAgent):
|
||||
# Notify the thread of the complete response if provided
|
||||
if thread is not None:
|
||||
complete_response = ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text=response_text)])
|
||||
await self._notify_thread_of_new_messages(thread, complete_response)
|
||||
await self._notify_thread_of_new_messages(thread, normalized_messages, complete_response)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
|
||||
@@ -4,7 +4,7 @@ import asyncio
|
||||
from random import randint
|
||||
from typing import Annotated
|
||||
|
||||
from agent_framework import AgentThread, ChatAgent, ChatMessageList
|
||||
from agent_framework import AgentThread, ChatAgent, ChatMessageStore
|
||||
from agent_framework.openai import OpenAIChatClient
|
||||
from pydantic import Field
|
||||
|
||||
@@ -119,7 +119,7 @@ async def example_with_existing_thread_messages() -> None:
|
||||
# You can also create a new thread from existing messages
|
||||
messages = await thread.message_store.list_messages() if thread.message_store else []
|
||||
|
||||
new_thread = AgentThread(message_store=ChatMessageList(messages))
|
||||
new_thread = AgentThread(message_store=ChatMessageStore(messages))
|
||||
|
||||
query3 = "How does the Paris weather compare to London?"
|
||||
print(f"User: {query3}")
|
||||
|
||||
@@ -63,7 +63,7 @@ The provider supports both full‑text only and hybrid vector search:
|
||||
|
||||
`redis_basics.py` walks through three scenarios:
|
||||
|
||||
1. Standalone provider usage: adds messages and retrieves context via `model_invoking`.
|
||||
1. Standalone provider usage: adds messages and retrieves context via `invoking`.
|
||||
2. Agent integration: teaches the agent a preference and verifies it is remembered across turns.
|
||||
3. Agent + tool: calls a sample tool (flight search) and then asks the agent to recall details remembered from the tool output.
|
||||
|
||||
@@ -108,5 +108,3 @@ You should see the agent responses and, when using embeddings, context retrieved
|
||||
- Ensure at least one of `application_id`, `agent_id`, `user_id`, or `thread_id` is set; the provider requires a scope.
|
||||
- If using embeddings, verify `OPENAI_API_KEY` is set and reachable.
|
||||
- Make sure Redis exposes RediSearch (Redis Stack image or managed service with search enabled).
|
||||
|
||||
|
||||
|
||||
@@ -27,21 +27,17 @@ Run:
|
||||
python redis_basics.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from agent_framework import ChatMessage, Role
|
||||
from agent_framework_redis._provider import RedisProvider
|
||||
from agent_framework.openai import OpenAIChatClient
|
||||
from redisvl.utils.vectorize import OpenAITextVectorizer
|
||||
from agent_framework_redis._provider import RedisProvider
|
||||
from redisvl.extensions.cache.embeddings import EmbeddingsCache
|
||||
from redisvl.utils.vectorize import OpenAITextVectorizer
|
||||
|
||||
|
||||
def search_flights(
|
||||
origin_airport_code: str,
|
||||
destination_airport_code: str,
|
||||
detailed: bool = False
|
||||
) -> str:
|
||||
def search_flights(origin_airport_code: str, destination_airport_code: str, detailed: bool = False) -> str:
|
||||
"""Simulated flight-search tool to demonstrate tool memory.
|
||||
|
||||
The agent can call this function, and the returned details can be stored
|
||||
@@ -50,9 +46,27 @@ def search_flights(
|
||||
"""
|
||||
# Minimal static catalog used to simulate a tool's structured output
|
||||
flights = {
|
||||
("JFK", "LAX"): {"airline": "SkyJet", "duration": "6h 15m", "price": 325, "cabin": "Economy", "baggage": "1 checked bag"},
|
||||
("SFO", "SEA"): {"airline": "Pacific Air", "duration": "2h 5m", "price": 129, "cabin": "Economy", "baggage": "Carry-on only"},
|
||||
("LHR", "DXB"): {"airline": "EuroWings", "duration": "6h 50m", "price": 499, "cabin": "Business", "baggage": "2 bags included"},
|
||||
("JFK", "LAX"): {
|
||||
"airline": "SkyJet",
|
||||
"duration": "6h 15m",
|
||||
"price": 325,
|
||||
"cabin": "Economy",
|
||||
"baggage": "1 checked bag",
|
||||
},
|
||||
("SFO", "SEA"): {
|
||||
"airline": "Pacific Air",
|
||||
"duration": "2h 5m",
|
||||
"price": 129,
|
||||
"cabin": "Economy",
|
||||
"baggage": "Carry-on only",
|
||||
},
|
||||
("LHR", "DXB"): {
|
||||
"airline": "EuroWings",
|
||||
"duration": "6h 50m",
|
||||
"price": 499,
|
||||
"cabin": "Business",
|
||||
"baggage": "2 bags included",
|
||||
},
|
||||
}
|
||||
|
||||
route = (origin_airport_code.upper(), destination_airport_code.upper())
|
||||
@@ -97,7 +111,7 @@ async def main() -> None:
|
||||
)
|
||||
# The provider manages persistence and retrieval. application_id/agent_id/user_id
|
||||
# scope data for multi-tenant separation; thread_id (set later) narrows to a
|
||||
# specific conversation.
|
||||
# specific conversation.
|
||||
provider = RedisProvider(
|
||||
redis_url="redis://localhost:6379",
|
||||
index_name="redis_basics",
|
||||
@@ -109,7 +123,7 @@ async def main() -> None:
|
||||
vector_algorithm="hnsw",
|
||||
vector_distance_metric="cosine",
|
||||
)
|
||||
|
||||
|
||||
# Build sample chat messages to persist to Redis
|
||||
messages = [
|
||||
ChatMessage(role=Role.USER, text="runA CONVO: User Message"),
|
||||
@@ -121,14 +135,12 @@ async def main() -> None:
|
||||
# Threads are logical boundaries used by the provider to group and retrieve
|
||||
# conversation-specific context.
|
||||
await provider.thread_created(thread_id="runA")
|
||||
await provider.messages_adding(thread_id="runA", new_messages=messages)
|
||||
await provider.invoked(request_messages=messages)
|
||||
|
||||
# Retrieve relevant memories for a hypothetical model call. The provider uses
|
||||
# the current request messages as the retrieval query and returns context to
|
||||
# be injected into the model's instructions.
|
||||
ctx = await provider.model_invoking([
|
||||
ChatMessage(role=Role.SYSTEM, text="B: Assistant Message")
|
||||
])
|
||||
ctx = await provider.invoking([ChatMessage(role=Role.SYSTEM, text="B: Assistant Message")])
|
||||
|
||||
# Inspect retrieved memories that would be injected into instructions
|
||||
# (Debug-only output so you can verify retrieval works as expected.)
|
||||
@@ -167,13 +179,14 @@ async def main() -> None:
|
||||
# Create agent wired to the Redis context provider. The provider automatically
|
||||
# persists conversational details and surfaces relevant context on each turn.
|
||||
agent = client.create_agent(
|
||||
name="MemoryEnhancedAssistant",
|
||||
instructions=(
|
||||
"You are a helpful assistant. Personalize replies using provided context. "
|
||||
"Before answering, always check for stored context"
|
||||
),
|
||||
tools=[],
|
||||
context_providers=provider)
|
||||
name="MemoryEnhancedAssistant",
|
||||
instructions=(
|
||||
"You are a helpful assistant. Personalize replies using provided context. "
|
||||
"Before answering, always check for stored context"
|
||||
),
|
||||
tools=[],
|
||||
context_providers=provider,
|
||||
)
|
||||
|
||||
# Teach a user preference; the agent writes this to the provider's memory
|
||||
query = "Remember that I enjoy glugenflorgle"
|
||||
@@ -201,20 +214,21 @@ async def main() -> None:
|
||||
prefix="context_3",
|
||||
application_id="matrix_of_kermits",
|
||||
agent_id="agent_kermit",
|
||||
user_id="kermit"
|
||||
user_id="kermit",
|
||||
)
|
||||
|
||||
# Create agent exposing the flight search tool. Tool outputs are captured by the
|
||||
# provider and become retrievable context for later turns.
|
||||
client = OpenAIChatClient(ai_model_id=os.getenv("OPENAI_CHAT_MODEL_ID"), api_key=os.getenv("OPENAI_API_KEY"))
|
||||
agent = client.create_agent(
|
||||
name="MemoryEnhancedAssistant",
|
||||
instructions=(
|
||||
"You are a helpful assistant. Personalize replies using provided context. "
|
||||
"Before answering, always check for stored context"
|
||||
),
|
||||
tools=search_flights,
|
||||
context_providers=provider)
|
||||
name="MemoryEnhancedAssistant",
|
||||
instructions=(
|
||||
"You are a helpful assistant. Personalize replies using provided context. "
|
||||
"Before answering, always check for stored context"
|
||||
),
|
||||
tools=search_flights,
|
||||
context_providers=provider,
|
||||
)
|
||||
# Invoke the tool; outputs become part of memory/context
|
||||
query = "Are there any flights from new york city (jfk) to la? Give me details"
|
||||
result = await agent.run(query)
|
||||
@@ -229,5 +243,6 @@ async def main() -> None:
|
||||
# Drop / delete the provider index in Redis
|
||||
await provider.redis_index.delete()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
"""Redis Context Provider: Basic usage and agent integration
|
||||
|
||||
This example demonstrates how to use the Redis ChatMessageStore to persist
|
||||
This example demonstrates how to use the Redis ChatMessageStoreProtocol to persist
|
||||
conversational details. Pass it as a constructor argument to create_agent.
|
||||
|
||||
Requirements:
|
||||
@@ -14,15 +14,14 @@ Run:
|
||||
python redis_conversation.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from agent_framework_redis._provider import RedisProvider
|
||||
from agent_framework_redis._chat_message_store import RedisChatMessageStore
|
||||
from agent_framework.openai import OpenAIChatClient
|
||||
from redisvl.utils.vectorize import OpenAITextVectorizer
|
||||
from agent_framework_redis._chat_message_store import RedisChatMessageStore
|
||||
from agent_framework_redis._provider import RedisProvider
|
||||
from redisvl.extensions.cache.embeddings import EmbeddingsCache
|
||||
|
||||
from redisvl.utils.vectorize import OpenAITextVectorizer
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
@@ -65,15 +64,15 @@ async def main() -> None:
|
||||
# Create agent wired to the Redis context provider. The provider automatically
|
||||
# persists conversational details and surfaces relevant context on each turn.
|
||||
agent = client.create_agent(
|
||||
name="MemoryEnhancedAssistant",
|
||||
instructions=(
|
||||
"You are a helpful assistant. Personalize replies using provided context. "
|
||||
"Before answering, always check for stored context"
|
||||
),
|
||||
tools=[],
|
||||
context_providers=provider,
|
||||
chat_message_store_factory=chat_message_store_factory,
|
||||
)
|
||||
name="MemoryEnhancedAssistant",
|
||||
instructions=(
|
||||
"You are a helpful assistant. Personalize replies using provided context. "
|
||||
"Before answering, always check for stored context"
|
||||
),
|
||||
tools=[],
|
||||
context_providers=provider,
|
||||
chat_message_store_factory=chat_message_store_factory,
|
||||
)
|
||||
|
||||
# Teach a user preference; the agent writes this to the provider's memory
|
||||
query = "Remember that I enjoy gumbo"
|
||||
@@ -109,5 +108,6 @@ async def main() -> None:
|
||||
# Drop / delete the provider index in Redis
|
||||
await provider.redis_index.delete()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -0,0 +1,120 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
from collections.abc import MutableSequence, Sequence
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import ChatAgent, ChatClientProtocol, ChatMessage, ChatOptions, Context, ContextProvider
|
||||
from agent_framework.azure import AzureAIAgentClient
|
||||
from azure.identity.aio import AzureCliCredential
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class UserInfo(BaseModel):
|
||||
name: str | None = None
|
||||
age: int | None = None
|
||||
|
||||
|
||||
class UserInfoMemory(ContextProvider):
|
||||
def __init__(self, chat_client: ChatClientProtocol, user_info: UserInfo | None = None, **kwargs: Any):
|
||||
"""Create the memory.
|
||||
|
||||
If you pass in kwargs, they will be attempted to be used to create a UserInfo object.
|
||||
"""
|
||||
|
||||
self._chat_client = chat_client
|
||||
if user_info:
|
||||
self.user_info = user_info
|
||||
elif kwargs:
|
||||
self.user_info = UserInfo.model_validate(kwargs)
|
||||
else:
|
||||
self.user_info = UserInfo()
|
||||
|
||||
async def invoked(
|
||||
self,
|
||||
request_messages: ChatMessage | Sequence[ChatMessage],
|
||||
response_messages: ChatMessage | Sequence[ChatMessage] | None = None,
|
||||
invoke_exception: Exception | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Extract user information from messages after each agent call."""
|
||||
# Check if we need to extract user info from user messages
|
||||
user_messages = [msg for msg in request_messages if hasattr(msg, "role") and msg.role.value == "user"] # type: ignore
|
||||
|
||||
if (self.user_info.name is None or self.user_info.age is None) and user_messages:
|
||||
try:
|
||||
# Use the chat client to extract structured information
|
||||
result = await self._chat_client.get_response(
|
||||
messages=request_messages, # type: ignore
|
||||
chat_options=ChatOptions(
|
||||
instructions="Extract the user's name and age from the message if present. If not present return nulls.",
|
||||
response_format=UserInfo,
|
||||
),
|
||||
)
|
||||
|
||||
# Update user info with extracted data
|
||||
if result.value:
|
||||
if self.user_info.name is None and result.value.name:
|
||||
self.user_info.name = result.value.name
|
||||
if self.user_info.age is None and result.value.age:
|
||||
self.user_info.age = result.value.age
|
||||
|
||||
except Exception:
|
||||
pass # Failed to extract, continue without updating
|
||||
|
||||
async def invoking(self, messages: ChatMessage | MutableSequence[ChatMessage], **kwargs: Any) -> Context:
|
||||
"""Provide user information context before each agent call."""
|
||||
instructions: list[str] = []
|
||||
|
||||
if self.user_info.name is None:
|
||||
instructions.append(
|
||||
"Ask the user for their name and politely decline to answer any questions until they provide it."
|
||||
)
|
||||
else:
|
||||
instructions.append(f"The user's name is {self.user_info.name}.")
|
||||
|
||||
if self.user_info.age is None:
|
||||
instructions.append(
|
||||
"Ask the user for their age and politely decline to answer any questions until they provide it."
|
||||
)
|
||||
else:
|
||||
instructions.append(f"The user's age is {self.user_info.age}.")
|
||||
|
||||
# Return context with additional instructions
|
||||
return Context(instructions=" ".join(instructions))
|
||||
|
||||
def serialize(self) -> str:
|
||||
"""Serialize the user info for thread persistence."""
|
||||
return self.user_info.model_dump_json()
|
||||
|
||||
|
||||
async def main():
|
||||
async with AzureCliCredential() as credential:
|
||||
chat_client = AzureAIAgentClient(async_credential=credential)
|
||||
|
||||
# Create the memory provider
|
||||
memory_provider = UserInfoMemory(chat_client)
|
||||
|
||||
# Create the agent with memory
|
||||
async with ChatAgent(
|
||||
chat_client=chat_client,
|
||||
instructions="You are a friendly assistant. Always address the user by their name.",
|
||||
context_providers=memory_provider,
|
||||
) as agent:
|
||||
# Create a new thread for the conversation
|
||||
thread = agent.get_new_thread()
|
||||
|
||||
print(await agent.run("Hello, what is the square root of 9?", thread=thread))
|
||||
print(await agent.run("My name is Ruaidhrí", thread=thread))
|
||||
print(await agent.run("I am 20 years old", thread=thread))
|
||||
|
||||
# Access the memory component via the thread's get_service method and inspect the memories
|
||||
user_info_memory = thread.context_provider.providers[0] # type: ignore
|
||||
if user_info_memory:
|
||||
print()
|
||||
print(f"MEMORY - User Name: {user_info_memory.user_info.name}") # type: ignore
|
||||
print(f"MEMORY - User Age: {user_info_memory.user_info.age}") # type: ignore
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -4,7 +4,7 @@ import asyncio
|
||||
from collections.abc import Collection
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import ChatMessage, ChatMessageStore
|
||||
from agent_framework import ChatMessage, ChatMessageStoreProtocol
|
||||
from agent_framework.openai import OpenAIChatClient
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -15,7 +15,7 @@ class CustomStoreState(BaseModel):
|
||||
messages: list[ChatMessage]
|
||||
|
||||
|
||||
class CustomChatMessageStore(ChatMessageStore):
|
||||
class CustomChatMessageStore(ChatMessageStoreProtocol):
|
||||
"""Implementation of custom chat message store.
|
||||
In real applications, this can be an implementation of relational database or vector store."""
|
||||
|
||||
|
||||
@@ -5,48 +5,47 @@ import os
|
||||
from uuid import uuid4
|
||||
|
||||
from agent_framework import AgentThread
|
||||
from agent_framework._threads import deserialize_thread_state
|
||||
from agent_framework.openai import OpenAIChatClient
|
||||
from agent_framework.redis import RedisChatMessageStore
|
||||
|
||||
|
||||
async def example_basic_redis_store() -> None:
|
||||
async def example_manual_memory_store() -> None:
|
||||
"""Basic example of using Redis chat message store."""
|
||||
print("=== Basic Redis Chat Message Store Example ===")
|
||||
|
||||
|
||||
# Create Redis store with auto-generated thread ID
|
||||
redis_store = RedisChatMessageStore(
|
||||
redis_url="redis://localhost:6379",
|
||||
# thread_id will be auto-generated if not provided
|
||||
)
|
||||
|
||||
|
||||
print(f"Created store with thread ID: {redis_store.thread_id}")
|
||||
|
||||
|
||||
# Create thread with Redis store
|
||||
thread = AgentThread(message_store=redis_store)
|
||||
|
||||
|
||||
# Create agent
|
||||
agent = OpenAIChatClient().create_agent(
|
||||
name="RedisBot",
|
||||
instructions="You are a helpful assistant that remembers our conversation using Redis.",
|
||||
)
|
||||
|
||||
|
||||
# Have a conversation
|
||||
print("\n--- Starting conversation ---")
|
||||
query1 = "Hello! My name is Alice and I love pizza."
|
||||
print(f"User: {query1}")
|
||||
response1 = await agent.run(query1, thread=thread)
|
||||
print(f"Agent: {response1.text}")
|
||||
|
||||
|
||||
query2 = "What do you remember about me?"
|
||||
print(f"User: {query2}")
|
||||
response2 = await agent.run(query2, thread=thread)
|
||||
print(f"Agent: {response2.text}")
|
||||
|
||||
|
||||
# Show messages are stored in Redis
|
||||
messages = await redis_store.list_messages()
|
||||
print(f"\nTotal messages in Redis: {len(messages)}")
|
||||
|
||||
|
||||
# Cleanup
|
||||
await redis_store.clear()
|
||||
await redis_store.aclose()
|
||||
@@ -56,51 +55,51 @@ async def example_basic_redis_store() -> None:
|
||||
async def example_user_session_management() -> None:
|
||||
"""Example of managing user sessions with Redis."""
|
||||
print("=== User Session Management Example ===")
|
||||
|
||||
|
||||
user_id = "alice_123"
|
||||
session_id = f"session_{uuid4()}"
|
||||
|
||||
|
||||
# Create Redis store for specific user session
|
||||
def create_user_session_store():
|
||||
return RedisChatMessageStore(
|
||||
redis_url="redis://localhost:6379",
|
||||
thread_id=f"user_{user_id}_{session_id}",
|
||||
max_messages=10 # Keep only last 10 messages
|
||||
max_messages=10, # Keep only last 10 messages
|
||||
)
|
||||
|
||||
|
||||
# Create agent with factory pattern
|
||||
agent = OpenAIChatClient().create_agent(
|
||||
name="SessionBot",
|
||||
instructions="You are a helpful assistant. Keep track of user preferences.",
|
||||
chat_message_store_factory=create_user_session_store,
|
||||
)
|
||||
|
||||
|
||||
# Start conversation
|
||||
thread = agent.get_new_thread()
|
||||
|
||||
|
||||
print(f"Started session for user {user_id}")
|
||||
if hasattr(thread.message_store, 'thread_id'):
|
||||
if hasattr(thread.message_store, "thread_id"):
|
||||
print(f"Thread ID: {thread.message_store.thread_id}") # type: ignore[union-attr]
|
||||
|
||||
|
||||
# Simulate conversation
|
||||
queries = [
|
||||
"Hi, I'm Alice and I prefer vegetarian food.",
|
||||
"What restaurants would you recommend?",
|
||||
"I also love Italian cuisine.",
|
||||
"Can you remember my food preferences?"
|
||||
"Can you remember my food preferences?",
|
||||
]
|
||||
|
||||
|
||||
for i, query in enumerate(queries, 1):
|
||||
print(f"\n--- Message {i} ---")
|
||||
print(f"User: {query}")
|
||||
response = await agent.run(query, thread=thread)
|
||||
print(f"Agent: {response.text}")
|
||||
|
||||
|
||||
# Show persistent storage
|
||||
if thread.message_store:
|
||||
messages = await thread.message_store.list_messages() # type: ignore[union-attr]
|
||||
print(f"\nMessages stored for user {user_id}: {len(messages)}")
|
||||
|
||||
|
||||
# Cleanup
|
||||
if thread.message_store:
|
||||
await thread.message_store.clear() # type: ignore[union-attr]
|
||||
@@ -111,58 +110,58 @@ async def example_user_session_management() -> None:
|
||||
async def example_conversation_persistence() -> None:
|
||||
"""Example of conversation persistence across application restarts."""
|
||||
print("=== Conversation Persistence Example ===")
|
||||
|
||||
|
||||
conversation_id = "persistent_chat_001"
|
||||
|
||||
|
||||
# Phase 1: Start conversation
|
||||
print("--- Phase 1: Starting conversation ---")
|
||||
store1 = RedisChatMessageStore(
|
||||
redis_url="redis://localhost:6379",
|
||||
thread_id=conversation_id,
|
||||
)
|
||||
|
||||
|
||||
thread1 = AgentThread(message_store=store1)
|
||||
agent = OpenAIChatClient().create_agent(
|
||||
name="PersistentBot",
|
||||
instructions="You are a helpful assistant. Remember our conversation history.",
|
||||
)
|
||||
|
||||
|
||||
# Start conversation
|
||||
query1 = "Hello! I'm working on a Python project about machine learning."
|
||||
print(f"User: {query1}")
|
||||
response1 = await agent.run(query1, thread=thread1)
|
||||
print(f"Agent: {response1.text}")
|
||||
|
||||
|
||||
query2 = "I'm specifically interested in neural networks."
|
||||
print(f"User: {query2}")
|
||||
response2 = await agent.run(query2, thread=thread1)
|
||||
print(f"Agent: {response2.text}")
|
||||
|
||||
|
||||
print(f"Stored {len(await store1.list_messages())} messages in Redis")
|
||||
await store1.aclose()
|
||||
|
||||
|
||||
# Phase 2: Resume conversation (simulating app restart)
|
||||
print("\n--- Phase 2: Resuming conversation (after 'restart') ---")
|
||||
store2 = RedisChatMessageStore(
|
||||
redis_url="redis://localhost:6379",
|
||||
thread_id=conversation_id, # Same thread ID
|
||||
)
|
||||
|
||||
|
||||
thread2 = AgentThread(message_store=store2)
|
||||
|
||||
|
||||
# Continue conversation - agent should remember context
|
||||
query3 = "What was I working on before?"
|
||||
print(f"User: {query3}")
|
||||
response3 = await agent.run(query3, thread=thread2)
|
||||
print(f"Agent: {response3.text}")
|
||||
|
||||
|
||||
query4 = "Can you suggest some Python libraries for neural networks?"
|
||||
print(f"User: {query4}")
|
||||
response4 = await agent.run(query4, thread=thread2)
|
||||
print(f"Agent: {response4.text}")
|
||||
|
||||
|
||||
print(f"Total messages after resuming: {len(await store2.list_messages())}")
|
||||
|
||||
|
||||
# Cleanup
|
||||
await store2.clear()
|
||||
await store2.aclose()
|
||||
@@ -172,52 +171,49 @@ async def example_conversation_persistence() -> None:
|
||||
async def example_thread_serialization() -> None:
|
||||
"""Example of thread state serialization and deserialization."""
|
||||
print("=== Thread Serialization Example ===")
|
||||
|
||||
|
||||
# Create initial thread with Redis store
|
||||
original_store = RedisChatMessageStore(
|
||||
redis_url="redis://localhost:6379",
|
||||
thread_id="serialization_test",
|
||||
max_messages=50,
|
||||
)
|
||||
|
||||
|
||||
original_thread = AgentThread(message_store=original_store)
|
||||
|
||||
|
||||
agent = OpenAIChatClient().create_agent(
|
||||
name="SerializationBot",
|
||||
instructions="You are a helpful assistant.",
|
||||
)
|
||||
|
||||
|
||||
# Have initial conversation
|
||||
print("--- Initial conversation ---")
|
||||
query1 = "Hello! I'm testing serialization."
|
||||
print(f"User: {query1}")
|
||||
response1 = await agent.run(query1, thread=original_thread)
|
||||
print(f"Agent: {response1.text}")
|
||||
|
||||
|
||||
# Serialize thread state
|
||||
serialized_thread = await original_thread.serialize()
|
||||
print(f"\nSerialized thread state: {serialized_thread}")
|
||||
|
||||
|
||||
# Close original connection
|
||||
await original_store.aclose()
|
||||
|
||||
|
||||
# Deserialize thread state (simulating loading from database/file)
|
||||
print("\n--- Deserializing thread state ---")
|
||||
|
||||
|
||||
# Create a new thread with the same Redis store type
|
||||
# This ensures the correct store type is used for deserialization
|
||||
restored_store = RedisChatMessageStore(redis_url="redis://localhost:6379")
|
||||
restored_thread = AgentThread(message_store=restored_store)
|
||||
|
||||
# Deserialize the thread state into the properly typed thread
|
||||
await deserialize_thread_state(restored_thread, serialized_thread)
|
||||
|
||||
restored_thread = await AgentThread.deserialize(serialized_thread, message_store=restored_store)
|
||||
|
||||
# Continue conversation with restored thread
|
||||
query2 = "Do you remember what I said about testing?"
|
||||
print(f"User: {query2}")
|
||||
response2 = await agent.run(query2, thread=restored_thread)
|
||||
print(f"Agent: {response2.text}")
|
||||
|
||||
|
||||
# Cleanup
|
||||
if restored_thread.message_store:
|
||||
await restored_thread.message_store.clear() # type: ignore[union-attr]
|
||||
@@ -228,20 +224,20 @@ async def example_thread_serialization() -> None:
|
||||
async def example_message_limits() -> None:
|
||||
"""Example of automatic message trimming with limits."""
|
||||
print("=== Message Limits Example ===")
|
||||
|
||||
|
||||
# Create store with small message limit
|
||||
store = RedisChatMessageStore(
|
||||
redis_url="redis://localhost:6379",
|
||||
thread_id="limits_test",
|
||||
max_messages=3, # Keep only 3 most recent messages
|
||||
)
|
||||
|
||||
|
||||
thread = AgentThread(message_store=store)
|
||||
agent = OpenAIChatClient().create_agent(
|
||||
name="LimitBot",
|
||||
instructions="You are a helpful assistant with limited memory.",
|
||||
)
|
||||
|
||||
|
||||
# Send multiple messages to test trimming
|
||||
messages = [
|
||||
"Message 1: Hello!",
|
||||
@@ -250,22 +246,22 @@ async def example_message_limits() -> None:
|
||||
"Message 4: Tell me a joke.",
|
||||
"Message 5: This should trigger trimming.",
|
||||
]
|
||||
|
||||
|
||||
for i, query in enumerate(messages, 1):
|
||||
print(f"\n--- Sending message {i} ---")
|
||||
print(f"User: {query}")
|
||||
response = await agent.run(query, thread=thread)
|
||||
print(f"Agent: {response.text}")
|
||||
|
||||
|
||||
stored_messages = await store.list_messages()
|
||||
print(f"Messages in store: {len(stored_messages)}")
|
||||
if len(stored_messages) > 0:
|
||||
print(f"Oldest message: {stored_messages[0].text[:30]}...")
|
||||
|
||||
|
||||
# Final check
|
||||
final_messages = await store.list_messages()
|
||||
print(f"\nFinal message count: {len(final_messages)} (should be <= 6: 3 messages × 2 per exchange)")
|
||||
|
||||
|
||||
# Cleanup
|
||||
await store.clear()
|
||||
await store.aclose()
|
||||
@@ -280,12 +276,12 @@ async def main() -> None:
|
||||
print("- Redis server running on localhost:6379")
|
||||
print("- OPENAI_API_KEY environment variable set")
|
||||
print("=" * 50)
|
||||
|
||||
|
||||
# Check prerequisites
|
||||
if not os.getenv("OPENAI_API_KEY"):
|
||||
print("ERROR: OPENAI_API_KEY environment variable not set")
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
# Test Redis connection
|
||||
test_store = RedisChatMessageStore(redis_url="redis://localhost:6379")
|
||||
@@ -298,17 +294,17 @@ async def main() -> None:
|
||||
print(f"ERROR: Cannot connect to Redis: {e}")
|
||||
print("Please ensure Redis is running on localhost:6379")
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
# Run all examples
|
||||
await example_basic_redis_store()
|
||||
await example_manual_memory_store()
|
||||
await example_user_session_management()
|
||||
await example_conversation_persistence()
|
||||
await example_thread_serialization()
|
||||
await example_message_limits()
|
||||
|
||||
|
||||
print("All examples completed successfully!")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error running examples: {e}")
|
||||
raise
|
||||
|
||||
Reference in New Issue
Block a user