mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: [BREAKING] Scope provider state by source_id and standardize source IDs (#3995)
* Initial plan * Add FoundryMemoryProvider and tests Co-authored-by: eavanvalkenburg <13749212+eavanvalkenburg@users.noreply.github.com> * Add sample and documentation for FoundryMemoryProvider Co-authored-by: eavanvalkenburg <13749212+eavanvalkenburg@users.noreply.github.com> * Address code review feedback for FoundryMemoryProvider Co-authored-by: eavanvalkenburg <13749212+eavanvalkenburg@users.noreply.github.com> * Address PR review comments: Add DEFAULT_SOURCE_ID, use logging.getLogger, move state to session.state Co-authored-by: eavanvalkenburg <13749212+eavanvalkenburg@users.noreply.github.com> * Fix Foundry memory ItemParam usage and exports Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Refactor provider hook state and standardize source IDs Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Support endpoint-based Foundry memory init Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Fix core README workflows link Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * updated implementation and sample * Split out Foundry memory provider changes Remove FoundryMemoryProvider implementation/tests/sample plus export and docs mentions from this branch so only non-Foundry changes remain. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Trigger CI rerun for PR #3995 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: eavanvalkenburg <13749212+eavanvalkenburg@users.noreply.github.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
a5f948c215
commit
cc98d5b6f7
@@ -220,7 +220,7 @@ if __name__ == "__main__":
|
||||
- [Getting Started with Agents](../../samples/02-agents): Basic agent creation and tool usage
|
||||
- [Chat Client Examples](../../samples/02-agents/chat_client): Direct chat client usage patterns
|
||||
- [Azure AI Integration](https://github.com/microsoft/agent-framework/tree/main/python/packages/azure-ai): Azure AI integration
|
||||
- [.NET Workflows Samples](https://github.com/microsoft/agent-framework/tree/main/dotnet/samples/GettingStarted/Workflows): Advanced multi-agent patterns (.NET)
|
||||
- [.NET Workflows Samples](../../../dotnet/samples/GettingStarted/Workflows): Advanced multi-agent patterns (.NET)
|
||||
|
||||
## Agent Framework Documentation
|
||||
|
||||
|
||||
@@ -420,13 +420,18 @@ class BaseAgent(SerializationMixin):
|
||||
session: The conversation session.
|
||||
context: The invocation context with response populated.
|
||||
"""
|
||||
state = session.state if session else {}
|
||||
provider_session = session
|
||||
if provider_session is None and self.context_providers:
|
||||
provider_session = AgentSession()
|
||||
|
||||
for provider in reversed(self.context_providers):
|
||||
if provider_session is None:
|
||||
raise RuntimeError("Provider session must be available when context providers are configured.")
|
||||
await provider.after_run(
|
||||
agent=self, # type: ignore[arg-type]
|
||||
session=session, # type: ignore[arg-type]
|
||||
session=provider_session,
|
||||
context=context,
|
||||
state=state,
|
||||
state=provider_session.state.setdefault(provider.source_id, {}),
|
||||
)
|
||||
|
||||
def as_tool(
|
||||
@@ -988,10 +993,14 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
and not opts.get("store")
|
||||
and not (getattr(self.client, "STORES_BY_DEFAULT", False) and opts.get("store") is not False)
|
||||
):
|
||||
self.context_providers.append(InMemoryHistoryProvider("memory"))
|
||||
self.context_providers.append(InMemoryHistoryProvider())
|
||||
|
||||
active_session = session
|
||||
if active_session is None and self.context_providers:
|
||||
active_session = AgentSession()
|
||||
|
||||
session_context, chat_options = await self._prepare_session_and_messages(
|
||||
session=session,
|
||||
session=active_session,
|
||||
input_messages=input_messages,
|
||||
options=opts,
|
||||
)
|
||||
@@ -1018,7 +1027,9 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
# Build options dict from run() options merged with provided options
|
||||
run_opts: dict[str, Any] = {
|
||||
"model_id": opts.pop("model_id", None),
|
||||
"conversation_id": session.service_session_id if session else opts.pop("conversation_id", None),
|
||||
"conversation_id": active_session.service_session_id
|
||||
if active_session
|
||||
else opts.pop("conversation_id", None),
|
||||
"allow_multiple_tool_calls": opts.pop("allow_multiple_tool_calls", None),
|
||||
"additional_function_arguments": opts.pop("additional_function_arguments", None),
|
||||
"frequency_penalty": opts.pop("frequency_penalty", None),
|
||||
@@ -1046,12 +1057,12 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
|
||||
# Ensure session is forwarded in kwargs for tool invocation
|
||||
finalize_kwargs = dict(kwargs)
|
||||
finalize_kwargs["session"] = session
|
||||
finalize_kwargs["session"] = active_session
|
||||
# Filter chat_options from kwargs to prevent duplicate keyword argument
|
||||
filtered_kwargs = {k: v for k, v in finalize_kwargs.items() if k != "chat_options"}
|
||||
|
||||
return {
|
||||
"session": session,
|
||||
"session": active_session,
|
||||
"session_context": session_context,
|
||||
"input_messages": input_messages,
|
||||
"session_messages": session_messages,
|
||||
@@ -1129,23 +1140,28 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
else:
|
||||
chat_options = {}
|
||||
|
||||
provider_session = session
|
||||
if provider_session is None and self.context_providers:
|
||||
provider_session = AgentSession()
|
||||
|
||||
session_context = SessionContext(
|
||||
session_id=session.session_id if session else None,
|
||||
service_session_id=session.service_session_id if session else None,
|
||||
session_id=provider_session.session_id if provider_session else None,
|
||||
service_session_id=provider_session.service_session_id if provider_session else None,
|
||||
input_messages=input_messages or [],
|
||||
options=options or {},
|
||||
)
|
||||
|
||||
# Run before_run providers (forward order, skip BaseHistoryProvider with load_messages=False)
|
||||
state = session.state if session else {}
|
||||
for provider in self.context_providers:
|
||||
if isinstance(provider, BaseHistoryProvider) and not provider.load_messages:
|
||||
continue
|
||||
if provider_session is None:
|
||||
raise RuntimeError("Provider session must be available when context providers are configured.")
|
||||
await provider.before_run(
|
||||
agent=self, # type: ignore[arg-type]
|
||||
session=session, # type: ignore[arg-type]
|
||||
session=provider_session,
|
||||
context=session_context,
|
||||
state=state,
|
||||
state=provider_session.state.setdefault(provider.source_id, {}),
|
||||
)
|
||||
|
||||
# Merge provider-contributed tools into chat_options
|
||||
|
||||
@@ -16,7 +16,7 @@ import copy
|
||||
import uuid
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, ClassVar
|
||||
|
||||
from ._types import AgentResponse, Message
|
||||
|
||||
@@ -310,7 +310,8 @@ class BaseContextProvider:
|
||||
agent: The agent running this invocation.
|
||||
session: The current session.
|
||||
context: The invocation context - add messages/instructions/tools here.
|
||||
state: The session's mutable state dict.
|
||||
state: The provider-scoped mutable state dict for this provider.
|
||||
Full cross-provider state remains available at ``session.state``.
|
||||
"""
|
||||
|
||||
async def after_run(
|
||||
@@ -330,7 +331,8 @@ class BaseContextProvider:
|
||||
agent: The agent that ran this invocation.
|
||||
session: The current session.
|
||||
context: The invocation context with response populated.
|
||||
state: The session's mutable state dict.
|
||||
state: The provider-scoped mutable state dict for this provider.
|
||||
Full cross-provider state remains available at ``session.state``.
|
||||
"""
|
||||
|
||||
|
||||
@@ -520,25 +522,56 @@ class AgentSession:
|
||||
class InMemoryHistoryProvider(BaseHistoryProvider):
|
||||
"""Built-in history provider that stores messages in session.state.
|
||||
|
||||
Messages are stored in ``state[source_id]["messages"]`` as a list of
|
||||
Messages are stored in ``state["messages"]`` as a list of
|
||||
``Message`` objects. Serialization to/from dicts is handled by
|
||||
``AgentSession.to_dict()``/``from_dict()`` using ``SerializationProtocol``.
|
||||
|
||||
This provider holds no instance state — all data lives in the session's
|
||||
state dict, passed as a named ``state`` parameter to ``get_messages``/``save_messages``.
|
||||
|
||||
This is the default provider auto-added by the agent when no providers
|
||||
are configured and ``conversation_id`` or ``store=True`` is set.
|
||||
This is the default provider auto-added by the agent for local sessions
|
||||
when no providers are configured and service-side storage is not requested.
|
||||
"""
|
||||
|
||||
DEFAULT_SOURCE_ID: ClassVar[str] = "in_memory"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source_id: str | None = None,
|
||||
*,
|
||||
load_messages: bool = True,
|
||||
store_inputs: bool = True,
|
||||
store_context_messages: bool = False,
|
||||
store_context_from: set[str] | None = None,
|
||||
store_outputs: bool = True,
|
||||
) -> None:
|
||||
"""Initialize the in-memory history provider.
|
||||
|
||||
Args:
|
||||
source_id: Unique identifier for this provider instance.
|
||||
Defaults to DEFAULT_SOURCE_ID when not provided.
|
||||
load_messages: Whether to load messages before invocation.
|
||||
store_inputs: Whether to store input messages.
|
||||
store_context_messages: Whether to store context from other providers.
|
||||
store_context_from: If set, only store context from these source_ids.
|
||||
store_outputs: Whether to store response messages.
|
||||
"""
|
||||
super().__init__(
|
||||
source_id=source_id or self.DEFAULT_SOURCE_ID,
|
||||
load_messages=load_messages,
|
||||
store_inputs=store_inputs,
|
||||
store_context_messages=store_context_messages,
|
||||
store_context_from=store_context_from,
|
||||
store_outputs=store_outputs,
|
||||
)
|
||||
|
||||
async def get_messages(
|
||||
self, session_id: str | None, *, state: dict[str, Any] | None = None, **kwargs: Any
|
||||
) -> list[Message]:
|
||||
"""Retrieve messages from session state."""
|
||||
if state is None:
|
||||
return []
|
||||
my_state = state.get(self.source_id, {})
|
||||
return list(my_state.get("messages", []))
|
||||
return list(state.get("messages", []))
|
||||
|
||||
async def save_messages(
|
||||
self,
|
||||
@@ -551,6 +584,5 @@ class InMemoryHistoryProvider(BaseHistoryProvider):
|
||||
"""Persist messages to session state."""
|
||||
if state is None:
|
||||
return
|
||||
my_state = state.setdefault(self.source_id, {})
|
||||
existing = my_state.get("messages", [])
|
||||
my_state["messages"] = [*existing, *messages]
|
||||
existing = state.get("messages", [])
|
||||
state["messages"] = [*existing, *messages]
|
||||
|
||||
@@ -121,7 +121,7 @@ class WorkflowAgent(BaseAgent):
|
||||
|
||||
resolved_context_providers = list(context_providers) if context_providers is not None else []
|
||||
if not resolved_context_providers:
|
||||
resolved_context_providers.append(InMemoryHistoryProvider("memory"))
|
||||
resolved_context_providers.append(InMemoryHistoryProvider())
|
||||
|
||||
super().__init__(
|
||||
id=id,
|
||||
@@ -237,23 +237,27 @@ class WorkflowAgent(BaseAgent):
|
||||
An AgentResponse representing the workflow execution results.
|
||||
"""
|
||||
input_messages = normalize_messages_input(messages)
|
||||
provider_session = session
|
||||
if provider_session is None and self.context_providers:
|
||||
provider_session = AgentSession()
|
||||
|
||||
# run the context providers with the session
|
||||
session_context = SessionContext(
|
||||
session_id=session.session_id if session else None,
|
||||
service_session_id=session.service_session_id if session else None,
|
||||
session_id=provider_session.session_id if provider_session else None,
|
||||
service_session_id=provider_session.service_session_id if provider_session else None,
|
||||
input_messages=input_messages or [],
|
||||
options={},
|
||||
)
|
||||
state = session.state if session else {}
|
||||
for provider in self.context_providers:
|
||||
if isinstance(provider, BaseHistoryProvider) and not provider.load_messages:
|
||||
continue
|
||||
if provider_session is None:
|
||||
raise RuntimeError("Provider session must be available when context providers are configured.")
|
||||
await provider.before_run(
|
||||
agent=self, # type: ignore[arg-type]
|
||||
session=session, # type: ignore[arg-type]
|
||||
session=provider_session,
|
||||
context=session_context,
|
||||
state=state,
|
||||
state=provider_session.state.setdefault(provider.source_id, {}),
|
||||
)
|
||||
# combine the messages
|
||||
session_messages: list[Message] = session_context.get_messages(include_input=True)
|
||||
@@ -266,7 +270,7 @@ class WorkflowAgent(BaseAgent):
|
||||
output_events.append(event)
|
||||
|
||||
result = self._convert_workflow_events_to_agent_response(response_id, output_events)
|
||||
await self._run_after_providers(session=session, context=session_context)
|
||||
await self._run_after_providers(session=provider_session, context=session_context)
|
||||
return result
|
||||
|
||||
async def _run_stream_impl(
|
||||
@@ -293,23 +297,27 @@ class WorkflowAgent(BaseAgent):
|
||||
AgentResponseUpdate objects representing the workflow execution progress.
|
||||
"""
|
||||
input_messages = normalize_messages_input(messages)
|
||||
provider_session = session
|
||||
if provider_session is None and self.context_providers:
|
||||
provider_session = AgentSession()
|
||||
|
||||
# run the context providers with the session
|
||||
session_context = SessionContext(
|
||||
session_id=session.session_id if session else None,
|
||||
service_session_id=session.service_session_id if session else None,
|
||||
session_id=provider_session.session_id if provider_session else None,
|
||||
service_session_id=provider_session.service_session_id if provider_session else None,
|
||||
input_messages=input_messages or [],
|
||||
options={},
|
||||
)
|
||||
state = session.state if session else {}
|
||||
for provider in self.context_providers:
|
||||
if isinstance(provider, BaseHistoryProvider) and not provider.load_messages:
|
||||
continue
|
||||
if provider_session is None:
|
||||
raise RuntimeError("Provider session must be available when context providers are configured.")
|
||||
await provider.before_run(
|
||||
agent=self, # type: ignore[arg-type]
|
||||
session=session, # type: ignore[arg-type]
|
||||
session=provider_session,
|
||||
context=session_context,
|
||||
state=state,
|
||||
state=provider_session.state.setdefault(provider.source_id, {}),
|
||||
)
|
||||
# combine the messages
|
||||
|
||||
@@ -320,7 +328,7 @@ class WorkflowAgent(BaseAgent):
|
||||
updates = self._convert_workflow_event_to_agent_response_updates(response_id, event)
|
||||
for update in updates:
|
||||
yield update
|
||||
await self._run_after_providers(session=session, context=session_context)
|
||||
await self._run_after_providers(session=provider_session, context=session_context)
|
||||
|
||||
async def _run_core(
|
||||
self,
|
||||
|
||||
@@ -107,10 +107,10 @@ async def test_chat_client_agent_create_session(client: SupportsChatGetResponse)
|
||||
async def test_chat_client_agent_prepare_session_and_messages(client: SupportsChatGetResponse) -> None:
|
||||
from agent_framework._sessions import InMemoryHistoryProvider
|
||||
|
||||
agent = Agent(client=client, context_providers=[InMemoryHistoryProvider("memory")])
|
||||
agent = Agent(client=client, context_providers=[InMemoryHistoryProvider()])
|
||||
message = Message(role="user", text="Hello")
|
||||
session = AgentSession()
|
||||
session.state["memory"] = {"messages": [message]}
|
||||
session.state[InMemoryHistoryProvider.DEFAULT_SOURCE_ID] = {"messages": [message]}
|
||||
|
||||
session_context, _ = await agent._prepare_session_and_messages( # type: ignore[reportPrivateUsage]
|
||||
session=session,
|
||||
@@ -267,6 +267,8 @@ async def test_chat_client_agent_update_session_id_streaming_does_not_use_respon
|
||||
|
||||
|
||||
async def test_chat_client_agent_update_session_messages(client: SupportsChatGetResponse) -> None:
|
||||
from agent_framework._sessions import InMemoryHistoryProvider
|
||||
|
||||
agent = Agent(client=client)
|
||||
session = agent.create_session()
|
||||
|
||||
@@ -275,7 +277,7 @@ async def test_chat_client_agent_update_session_messages(client: SupportsChatGet
|
||||
|
||||
assert session.service_session_id is None
|
||||
|
||||
chat_messages: list[Message] = session.state.get("memory", {}).get("messages", [])
|
||||
chat_messages: list[Message] = session.state.get(InMemoryHistoryProvider.DEFAULT_SOURCE_ID, {}).get("messages", [])
|
||||
|
||||
assert chat_messages is not None
|
||||
assert len(chat_messages) == 2
|
||||
|
||||
@@ -27,6 +27,7 @@ from agent_framework import (
|
||||
chat_middleware,
|
||||
function_middleware,
|
||||
)
|
||||
from agent_framework._sessions import InMemoryHistoryProvider
|
||||
|
||||
from .conftest import MockBaseChatClient, MockChatClient
|
||||
|
||||
@@ -1416,8 +1417,10 @@ class TestChatAgentSessionBehavior:
|
||||
async def process(self, context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
# Capture state before next() call
|
||||
thread_messages = []
|
||||
if context.session and context.session.state.get("memory"):
|
||||
thread_messages = context.session.state.get("memory", {}).get("messages", [])
|
||||
if context.session and context.session.state.get(InMemoryHistoryProvider.DEFAULT_SOURCE_ID):
|
||||
thread_messages = context.session.state.get(InMemoryHistoryProvider.DEFAULT_SOURCE_ID, {}).get(
|
||||
"messages", []
|
||||
)
|
||||
|
||||
before_state = {
|
||||
"before_next": True,
|
||||
@@ -1432,8 +1435,10 @@ class TestChatAgentSessionBehavior:
|
||||
|
||||
# Capture state after next() call
|
||||
thread_messages_after = []
|
||||
if context.session and context.session.state.get("memory"):
|
||||
thread_messages_after = context.session.state.get("memory", {}).get("messages", [])
|
||||
if context.session and context.session.state.get(InMemoryHistoryProvider.DEFAULT_SOURCE_ID):
|
||||
thread_messages_after = context.session.state.get(
|
||||
InMemoryHistoryProvider.DEFAULT_SOURCE_ID, {}
|
||||
).get("messages", [])
|
||||
|
||||
after_state = {
|
||||
"before_next": False,
|
||||
|
||||
@@ -359,30 +359,50 @@ class TestAgentSession:
|
||||
|
||||
class TestInMemoryHistoryProvider:
|
||||
async def test_empty_state_returns_no_messages(self) -> None:
|
||||
provider = InMemoryHistoryProvider("memory")
|
||||
provider = InMemoryHistoryProvider()
|
||||
session = AgentSession()
|
||||
ctx = SessionContext(session_id="s1", input_messages=[])
|
||||
await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
|
||||
assert ctx.context_messages.get("memory", []) == []
|
||||
await provider.before_run( # type: ignore[arg-type]
|
||||
agent=None,
|
||||
session=session,
|
||||
context=ctx,
|
||||
state=session.state.setdefault(provider.source_id, {}),
|
||||
)
|
||||
assert ctx.context_messages.get(provider.source_id, []) == []
|
||||
|
||||
async def test_stores_and_loads_messages(self) -> None:
|
||||
from agent_framework import AgentResponse
|
||||
|
||||
provider = InMemoryHistoryProvider("memory")
|
||||
provider = InMemoryHistoryProvider()
|
||||
session = AgentSession()
|
||||
|
||||
# First run: send input, get response
|
||||
input_msg = Message(role="user", contents=["hello"])
|
||||
resp_msg = Message(role="assistant", contents=["hi there"])
|
||||
ctx1 = SessionContext(session_id="s1", input_messages=[input_msg])
|
||||
await provider.before_run(agent=None, session=session, context=ctx1, state=session.state) # type: ignore[arg-type]
|
||||
await provider.before_run( # type: ignore[arg-type]
|
||||
agent=None,
|
||||
session=session,
|
||||
context=ctx1,
|
||||
state=session.state.setdefault(provider.source_id, {}),
|
||||
)
|
||||
ctx1._response = AgentResponse(messages=[resp_msg])
|
||||
await provider.after_run(agent=None, session=session, context=ctx1, state=session.state) # type: ignore[arg-type]
|
||||
await provider.after_run( # type: ignore[arg-type]
|
||||
agent=None,
|
||||
session=session,
|
||||
context=ctx1,
|
||||
state=session.state.setdefault(provider.source_id, {}),
|
||||
)
|
||||
|
||||
# Second run: should load previous messages
|
||||
ctx2 = SessionContext(session_id="s1", input_messages=[Message(role="user", contents=["again"])])
|
||||
await provider.before_run(agent=None, session=session, context=ctx2, state=session.state) # type: ignore[arg-type]
|
||||
loaded = ctx2.context_messages.get("memory", [])
|
||||
await provider.before_run( # type: ignore[arg-type]
|
||||
agent=None,
|
||||
session=session,
|
||||
context=ctx2,
|
||||
state=session.state.setdefault(provider.source_id, {}),
|
||||
)
|
||||
loaded = ctx2.context_messages.get(provider.source_id, [])
|
||||
assert len(loaded) == 2
|
||||
assert loaded[0].text == "hello"
|
||||
assert loaded[1].text == "hi there"
|
||||
@@ -390,17 +410,27 @@ class TestInMemoryHistoryProvider:
|
||||
async def test_state_is_serializable(self) -> None:
|
||||
from agent_framework import AgentResponse
|
||||
|
||||
provider = InMemoryHistoryProvider("memory")
|
||||
provider = InMemoryHistoryProvider()
|
||||
session = AgentSession()
|
||||
|
||||
input_msg = Message(role="user", contents=["test"])
|
||||
ctx = SessionContext(session_id="s1", input_messages=[input_msg])
|
||||
await provider.before_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
|
||||
await provider.before_run( # type: ignore[arg-type]
|
||||
agent=None,
|
||||
session=session,
|
||||
context=ctx,
|
||||
state=session.state.setdefault(provider.source_id, {}),
|
||||
)
|
||||
ctx._response = AgentResponse(messages=[Message(role="assistant", contents=["reply"])])
|
||||
await provider.after_run(agent=None, session=session, context=ctx, state=session.state) # type: ignore[arg-type]
|
||||
await provider.after_run( # type: ignore[arg-type]
|
||||
agent=None,
|
||||
session=session,
|
||||
context=ctx,
|
||||
state=session.state.setdefault(provider.source_id, {}),
|
||||
)
|
||||
|
||||
# State contains Message objects (not dicts)
|
||||
assert isinstance(session.state["memory"]["messages"][0], Message)
|
||||
assert isinstance(session.state[provider.source_id]["messages"][0], Message)
|
||||
|
||||
# to_dict() serializes them via SerializationProtocol
|
||||
session_dict = session.to_dict()
|
||||
@@ -409,9 +439,9 @@ class TestInMemoryHistoryProvider:
|
||||
|
||||
# Round-trip through session serialization restores Message objects
|
||||
restored = AgentSession.from_dict(json.loads(json_str))
|
||||
assert isinstance(restored.state["memory"]["messages"][0], Message)
|
||||
assert restored.state["memory"]["messages"][0].text == "test"
|
||||
assert restored.state["memory"]["messages"][1].text == "reply"
|
||||
assert isinstance(restored.state[provider.source_id]["messages"][0], Message)
|
||||
assert restored.state[provider.source_id]["messages"][0].text == "test"
|
||||
assert restored.state[provider.source_id]["messages"][1].text == "reply"
|
||||
|
||||
async def test_source_id_attribution(self) -> None:
|
||||
provider = InMemoryHistoryProvider("custom-source")
|
||||
|
||||
Reference in New Issue
Block a user