mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: [BREAKING] Scope provider state by source_id and standardize source IDs (#3995)
* Initial plan * Add FoundryMemoryProvider and tests Co-authored-by: eavanvalkenburg <13749212+eavanvalkenburg@users.noreply.github.com> * Add sample and documentation for FoundryMemoryProvider Co-authored-by: eavanvalkenburg <13749212+eavanvalkenburg@users.noreply.github.com> * Address code review feedback for FoundryMemoryProvider Co-authored-by: eavanvalkenburg <13749212+eavanvalkenburg@users.noreply.github.com> * Address PR review comments: Add DEFAULT_SOURCE_ID, use logging.getLogger, move state to session.state Co-authored-by: eavanvalkenburg <13749212+eavanvalkenburg@users.noreply.github.com> * Fix Foundry memory ItemParam usage and exports Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Refactor provider hook state and standardize source IDs Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Support endpoint-based Foundry memory init Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Fix core README workflows link Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * updated implementation and sample * Split out Foundry memory provider changes Remove FoundryMemoryProvider implementation/tests/sample plus export and docs mentions from this branch so only non-Foundry changes remain. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Trigger CI rerun for PR #3995 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: eavanvalkenburg <13749212+eavanvalkenburg@users.noreply.github.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
a5f948c215
commit
cc98d5b6f7
@@ -13,6 +13,7 @@ from agent_framework import (
|
||||
ChatResponseUpdate,
|
||||
Content,
|
||||
FunctionInvocationLayer,
|
||||
InMemoryHistoryProvider,
|
||||
Message,
|
||||
ResponseStream,
|
||||
)
|
||||
@@ -195,7 +196,7 @@ async def main() -> None:
|
||||
print(f"Agent: {result.messages[0].text}\n")
|
||||
|
||||
# Check conversation history
|
||||
memory_state = session.state.get("memory", {})
|
||||
memory_state = session.state.get(InMemoryHistoryProvider.DEFAULT_SOURCE_ID, {})
|
||||
session_messages = memory_state.get("messages", [])
|
||||
if session_messages:
|
||||
print(f"Session contains {len(session_messages)} messages")
|
||||
|
||||
@@ -17,7 +17,9 @@ class UserInfo(BaseModel):
|
||||
|
||||
|
||||
class UserInfoMemory(BaseContextProvider):
|
||||
def __init__(self, source_id: str = "user-info-memory", *, client: SupportsChatGetResponse, **kwargs: Any):
|
||||
DEFAULT_SOURCE_ID = "user_info_memory"
|
||||
|
||||
def __init__(self, source_id: str = DEFAULT_SOURCE_ID, *, client: SupportsChatGetResponse, **kwargs: Any):
|
||||
"""Create the memory.
|
||||
|
||||
If you pass in kwargs, they will be attempted to be used to create a UserInfo object.
|
||||
@@ -39,9 +41,7 @@ class UserInfoMemory(BaseContextProvider):
|
||||
# Check if we need to extract user info from user messages
|
||||
user_messages = [msg for msg in request_messages if hasattr(msg, "role") and msg.role == "user"] # type: ignore
|
||||
|
||||
if (
|
||||
state[self.source_id]["user_info"].name is None or state[self.source_id]["user_info"].age is None
|
||||
) and user_messages:
|
||||
if (state["user_info"].name is None or state["user_info"].age is None) and user_messages:
|
||||
with suppress(Exception):
|
||||
# Use the chat client to extract structured information
|
||||
result = await self._chat_client.get_response(
|
||||
@@ -54,10 +54,10 @@ class UserInfoMemory(BaseContextProvider):
|
||||
# Update user info with extracted data
|
||||
with suppress(Exception):
|
||||
extracted = result.value
|
||||
if state[self.source_id]["user_info"].name is None and extracted.name:
|
||||
state[self.source_id]["user_info"].name = extracted.name
|
||||
if state[self.source_id]["user_info"].age is None and extracted.age:
|
||||
state[self.source_id]["user_info"].age = extracted.age
|
||||
if state["user_info"].name is None and extracted.name:
|
||||
state["user_info"].name = extracted.name
|
||||
if state["user_info"].age is None and extracted.age:
|
||||
state["user_info"].age = extracted.age
|
||||
|
||||
async def before_run(
|
||||
self,
|
||||
@@ -68,20 +68,19 @@ class UserInfoMemory(BaseContextProvider):
|
||||
state: dict[str, Any],
|
||||
) -> None:
|
||||
"""Provide user information context before each agent call."""
|
||||
if state.setdefault(self.source_id, None) is None:
|
||||
state[self.source_id] = {"user_info": UserInfo()}
|
||||
state.setdefault("user_info", UserInfo())
|
||||
|
||||
context.extend_instructions(
|
||||
self.source_id,
|
||||
"Ask the user for their name and politely decline to answer any questions until they provide it."
|
||||
if state[self.source_id]["user_info"].name is None
|
||||
else f"The user's name is {state[self.source_id]['user_info'].name}.",
|
||||
if state["user_info"].name is None
|
||||
else f"The user's name is {state['user_info'].name}.",
|
||||
)
|
||||
context.extend_instructions(
|
||||
self.source_id,
|
||||
"Ask the user for their age and politely decline to answer any questions until they provide it."
|
||||
if state[self.source_id]["user_info"].age is None
|
||||
else f"The user's age is {state[self.source_id]['user_info'].age}.",
|
||||
if state["user_info"].age is None
|
||||
else f"The user's age is {state['user_info'].age}.",
|
||||
)
|
||||
|
||||
|
||||
@@ -92,7 +91,7 @@ async def main():
|
||||
credential=AzureCliCredential(),
|
||||
)
|
||||
|
||||
context_name = "user-info-memory"
|
||||
context_name = UserInfoMemory.DEFAULT_SOURCE_ID
|
||||
|
||||
# Create the memory provider
|
||||
memory_provider = UserInfoMemory(context_name, client=client)
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Annotated
|
||||
|
||||
from agent_framework import (
|
||||
AgentContext,
|
||||
InMemoryHistoryProvider,
|
||||
tool,
|
||||
)
|
||||
from agent_framework.azure import AzureOpenAIChatClient
|
||||
@@ -50,7 +51,7 @@ async def thread_tracking_middleware(
|
||||
"""MiddlewareTypes that tracks and logs session behavior across runs."""
|
||||
session_message_count = 0
|
||||
if context.session:
|
||||
memory_state = context.session.state.get("memory", {})
|
||||
memory_state = context.session.state.get(InMemoryHistoryProvider.DEFAULT_SOURCE_ID, {})
|
||||
session_message_count = len(memory_state.get("messages", []))
|
||||
|
||||
print(f"[MiddlewareTypes pre-execution] Current input messages: {len(context.messages)}")
|
||||
@@ -62,7 +63,7 @@ async def thread_tracking_middleware(
|
||||
# Check session state after agent execution
|
||||
updated_session_message_count = 0
|
||||
if context.session:
|
||||
memory_state = context.session.state.get("memory", {})
|
||||
memory_state = context.session.state.get(InMemoryHistoryProvider.DEFAULT_SOURCE_ID, {})
|
||||
updated_session_message_count = len(memory_state.get("messages", []))
|
||||
|
||||
print(f"[MiddlewareTypes post-execution] Updated session messages: {updated_session_message_count}")
|
||||
|
||||
@@ -4,7 +4,7 @@ import asyncio
|
||||
from random import randint
|
||||
from typing import Annotated
|
||||
|
||||
from agent_framework import Agent, AgentSession, tool
|
||||
from agent_framework import Agent, AgentSession, InMemoryHistoryProvider, tool
|
||||
from agent_framework.azure import AzureOpenAIChatClient
|
||||
from azure.identity import AzureCliCredential
|
||||
from pydantic import Field
|
||||
@@ -112,7 +112,7 @@ async def example_with_existing_session_messages() -> None:
|
||||
print(f"Agent: {result1.text}")
|
||||
|
||||
# The session now contains the conversation history in state
|
||||
memory_state = session.state.get("memory", {})
|
||||
memory_state = session.state.get(InMemoryHistoryProvider.DEFAULT_SOURCE_ID, {})
|
||||
messages = memory_state.get("messages", [])
|
||||
if messages:
|
||||
print(f"Session contains {len(messages)} messages")
|
||||
|
||||
@@ -10,6 +10,7 @@ from agent_framework import (
|
||||
AgentSession,
|
||||
BaseAgent,
|
||||
Content,
|
||||
InMemoryHistoryProvider,
|
||||
Message,
|
||||
Role,
|
||||
normalize_messages,
|
||||
@@ -93,7 +94,9 @@ class EchoAgent(BaseAgent):
|
||||
if not normalized_messages:
|
||||
response_message = Message(
|
||||
role=Role.ASSISTANT,
|
||||
contents=[Content.from_text(text="Hello! I'm a custom echo agent. Send me a message and I'll echo it back.")],
|
||||
contents=[
|
||||
Content.from_text(text="Hello! I'm a custom echo agent. Send me a message and I'll echo it back.")
|
||||
],
|
||||
)
|
||||
else:
|
||||
# For simplicity, echo the last user message
|
||||
@@ -199,7 +202,7 @@ async def main() -> None:
|
||||
print(f"Agent: {result2.messages[0].text}")
|
||||
|
||||
# Check conversation history
|
||||
memory_state = session.state.get("memory", {})
|
||||
memory_state = session.state.get(InMemoryHistoryProvider.DEFAULT_SOURCE_ID, {})
|
||||
messages = memory_state.get("messages", [])
|
||||
if messages:
|
||||
print(f"\nSession contains {len(messages)} messages in history")
|
||||
|
||||
@@ -4,7 +4,7 @@ import asyncio
|
||||
from random import randint
|
||||
from typing import Annotated
|
||||
|
||||
from agent_framework import Agent, AgentSession, tool
|
||||
from agent_framework import Agent, AgentSession, InMemoryHistoryProvider, tool
|
||||
from agent_framework.openai import OpenAIChatClient
|
||||
from pydantic import Field
|
||||
|
||||
@@ -105,7 +105,7 @@ async def example_with_existing_session_messages() -> None:
|
||||
print(f"Agent: {result1.text}")
|
||||
|
||||
# The session now contains the conversation history in state
|
||||
memory_state = session.state.get("memory", {})
|
||||
memory_state = session.state.get(InMemoryHistoryProvider.DEFAULT_SOURCE_ID, {})
|
||||
messages = memory_state.get("messages", [])
|
||||
if messages:
|
||||
print(f"Session contains {len(messages)} messages")
|
||||
|
||||
@@ -59,17 +59,17 @@ async def main() -> None:
|
||||
credential=AzureCliCredential(),
|
||||
)
|
||||
|
||||
# set the same context provider, with the same source_id, for both agents to share the thread
|
||||
# set the same context provider (same default source_id) for both agents to share the thread
|
||||
writer = client.as_agent(
|
||||
instructions=("You are a concise copywriter. Provide a single, punchy marketing sentence based on the prompt."),
|
||||
name="writer",
|
||||
context_providers=[InMemoryHistoryProvider("memory")],
|
||||
context_providers=[InMemoryHistoryProvider()],
|
||||
)
|
||||
|
||||
reviewer = client.as_agent(
|
||||
instructions=("You are a thoughtful reviewer. Give brief feedback on the previous assistant message."),
|
||||
name="reviewer",
|
||||
context_providers=[InMemoryHistoryProvider("memory")],
|
||||
context_providers=[InMemoryHistoryProvider()],
|
||||
)
|
||||
|
||||
# Create the shared session
|
||||
@@ -96,7 +96,7 @@ async def main() -> None:
|
||||
|
||||
# The shared session now contains the conversation between the writer and reviewer. Print it out.
|
||||
print("=== Shared Session Conversation ===")
|
||||
memory_state = shared_session.state.get("memory", {})
|
||||
memory_state = shared_session.state.get(InMemoryHistoryProvider.DEFAULT_SOURCE_ID, {})
|
||||
for message in memory_state.get("messages", []):
|
||||
print(f"{message.author_name or message.role}: {message.text}")
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from agent_framework import AgentSession
|
||||
from agent_framework import AgentSession, InMemoryHistoryProvider
|
||||
from agent_framework.azure import AzureOpenAIResponsesClient
|
||||
from agent_framework.orchestrations import SequentialBuilder
|
||||
from azure.identity import AzureCliCredential
|
||||
@@ -109,7 +109,7 @@ async def main() -> None:
|
||||
print("\n" + "=" * 60)
|
||||
print("Full Session History")
|
||||
print("=" * 60)
|
||||
memory_state = session.state.get("memory", {})
|
||||
memory_state = session.state.get(InMemoryHistoryProvider.DEFAULT_SOURCE_ID, {})
|
||||
history = memory_state.get("messages", [])
|
||||
for i, msg in enumerate(history, start=1):
|
||||
role = msg.role if hasattr(msg.role, "value") else str(msg.role)
|
||||
|
||||
@@ -29,6 +29,7 @@ import os
|
||||
|
||||
from agent_framework import (
|
||||
InMemoryCheckpointStorage,
|
||||
InMemoryHistoryProvider,
|
||||
)
|
||||
from agent_framework.azure import AzureOpenAIResponsesClient
|
||||
from agent_framework.orchestrations import SequentialBuilder
|
||||
@@ -122,7 +123,7 @@ async def checkpointing_with_thread() -> None:
|
||||
checkpoints = await checkpoint_storage.list_checkpoints(workflow_name=workflow.name)
|
||||
print(f"\nTotal checkpoints across both turns: {len(checkpoints)}")
|
||||
|
||||
memory_state = session.state.get("memory", {})
|
||||
memory_state = session.state.get(InMemoryHistoryProvider.DEFAULT_SOURCE_ID, {})
|
||||
history = memory_state.get("messages", [])
|
||||
print(f"Messages in session history: {len(history)}")
|
||||
|
||||
|
||||
@@ -23,12 +23,14 @@ class UserInfo(BaseModel):
|
||||
class UserInfoMemory(BaseContextProvider):
|
||||
"""Context provider that extracts and remembers user info (name, age).
|
||||
|
||||
State is stored in ``session.state["user-info-memory"]`` so it survives
|
||||
State is stored in ``session.state["user_info_memory"]`` so it survives
|
||||
serialization via ``session.to_dict()`` / ``AgentSession.from_dict()``.
|
||||
"""
|
||||
|
||||
DEFAULT_SOURCE_ID = "user_info_memory"
|
||||
|
||||
def __init__(self, client: SupportsChatGetResponse):
|
||||
super().__init__("user-info-memory")
|
||||
super().__init__(self.DEFAULT_SOURCE_ID)
|
||||
self._chat_client = client
|
||||
|
||||
async def before_run(
|
||||
@@ -40,8 +42,7 @@ class UserInfoMemory(BaseContextProvider):
|
||||
state: dict[str, Any],
|
||||
) -> None:
|
||||
"""Provide user information context before each agent call."""
|
||||
my_state = state.setdefault(self.source_id, {})
|
||||
user_info = my_state.setdefault("user_info", UserInfo())
|
||||
user_info = state.setdefault("user_info", UserInfo())
|
||||
|
||||
instructions: list[str] = []
|
||||
|
||||
@@ -70,8 +71,7 @@ class UserInfoMemory(BaseContextProvider):
|
||||
state: dict[str, Any],
|
||||
) -> None:
|
||||
"""Extract user information from messages after each agent call."""
|
||||
my_state = state.setdefault(self.source_id, {})
|
||||
user_info = my_state.setdefault("user_info", UserInfo())
|
||||
user_info = state.setdefault("user_info", UserInfo())
|
||||
if user_info.name is not None and user_info.age is not None:
|
||||
return # Already have everything
|
||||
|
||||
@@ -92,7 +92,7 @@ class UserInfoMemory(BaseContextProvider):
|
||||
user_info.name = extracted.name
|
||||
if extracted and user_info.age is None and extracted.age:
|
||||
user_info.age = extracted.age
|
||||
state.setdefault(self.source_id, {})["user_info"] = user_info
|
||||
state["user_info"] = user_info
|
||||
except Exception:
|
||||
pass # Failed to extract, continue without updating
|
||||
|
||||
@@ -113,7 +113,7 @@ async def main():
|
||||
print(await agent.run("I am 20 years old", session=session))
|
||||
|
||||
# Inspect extracted user info from session state
|
||||
user_info = session.state.get("user-info-memory", {}).get("user_info", UserInfo())
|
||||
user_info = session.state.get(UserInfoMemory.DEFAULT_SOURCE_ID, {}).get("user_info", UserInfo())
|
||||
print()
|
||||
print(f"MEMORY - User Name: {user_info.name}")
|
||||
print(f"MEMORY - User Age: {user_info.age}")
|
||||
|
||||
Reference in New Issue
Block a user