mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Add Entity State Providers for DurableTask Package (#2981)
* Add Entity State Providers * address comments * Fix tests * Fix tests * Revert unrelated changes and remove thread_id * Revert unrelated files
This commit is contained in:
committed by
GitHub
Unverified
parent
87a38bc7da
commit
a02527f00a
@@ -40,6 +40,7 @@ from ._durable_agent_state import (
|
||||
DurableAgentStateUsage,
|
||||
DurableAgentStateUsageContent,
|
||||
)
|
||||
from ._entities import AgentEntity, AgentEntityStateProviderMixin
|
||||
from ._models import RunRequest, serialize_response_format
|
||||
|
||||
__all__ = [
|
||||
@@ -54,6 +55,8 @@ __all__ = [
|
||||
"WAIT_FOR_RESPONSE_FIELD",
|
||||
"WAIT_FOR_RESPONSE_HEADER",
|
||||
"AgentCallbackContext",
|
||||
"AgentEntity",
|
||||
"AgentEntityStateProviderMixin",
|
||||
"AgentResponseCallbackProtocol",
|
||||
"ApiResponseFields",
|
||||
"ContentTypes",
|
||||
|
||||
@@ -82,7 +82,7 @@ def _parse_created_at(value: Any) -> datetime:
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
logger.error(
|
||||
logger.warning(
|
||||
f"Invalid or missing created_at value in durable agent state; defaulting to current UTC time, {value}",
|
||||
stack_info=True,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,351 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Durable Task entity implementations for Microsoft Agent Framework."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from collections.abc import AsyncIterable
|
||||
from typing import Any, cast
|
||||
|
||||
from agent_framework import (
|
||||
AgentProtocol,
|
||||
AgentRunResponse,
|
||||
AgentRunResponseUpdate,
|
||||
ChatMessage,
|
||||
ErrorContent,
|
||||
Role,
|
||||
get_logger,
|
||||
)
|
||||
from durabletask.entities import DurableEntity
|
||||
|
||||
from ._callbacks import AgentCallbackContext, AgentResponseCallbackProtocol
|
||||
from ._durable_agent_state import (
|
||||
DurableAgentState,
|
||||
DurableAgentStateEntry,
|
||||
DurableAgentStateRequest,
|
||||
DurableAgentStateResponse,
|
||||
)
|
||||
from ._models import RunRequest
|
||||
|
||||
logger = get_logger("agent_framework.durabletask.entities")
|
||||
|
||||
|
||||
class AgentEntityStateProviderMixin:
|
||||
"""Mixin implementing durable agent state caching + (de)serialization + persistence.
|
||||
|
||||
Concrete classes must implement:
|
||||
- _get_state_dict(): fetch raw persisted state dict (default should be {})
|
||||
- _set_state_dict(): persist raw state dict
|
||||
- _get_thread_id_from_entity(): fetch the thread ID from the underlying context
|
||||
"""
|
||||
|
||||
_state_cache: DurableAgentState | None = None
|
||||
|
||||
def _get_state_dict(self) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _set_state_dict(self, state: dict[str, Any]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_thread_id_from_entity(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def thread_id(self) -> str:
|
||||
return self._get_thread_id_from_entity()
|
||||
|
||||
@property
|
||||
def state(self) -> DurableAgentState:
|
||||
if self._state_cache is None:
|
||||
raw_state = self._get_state_dict()
|
||||
self._state_cache = DurableAgentState.from_dict(raw_state) if raw_state else DurableAgentState()
|
||||
return self._state_cache
|
||||
|
||||
@state.setter
|
||||
def state(self, value: DurableAgentState) -> None:
|
||||
self._state_cache = value
|
||||
self.persist_state()
|
||||
|
||||
def persist_state(self) -> None:
|
||||
"""Persist the current state to the underlying storage provider."""
|
||||
if self._state_cache is None:
|
||||
self._state_cache = DurableAgentState()
|
||||
self._set_state_dict(self._state_cache.to_dict())
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Clear conversation history by resetting state to a fresh DurableAgentState."""
|
||||
self._state_cache = DurableAgentState()
|
||||
self.persist_state()
|
||||
logger.debug("[AgentEntityStateProviderMixin.reset] State reset complete")
|
||||
|
||||
|
||||
class AgentEntity:
|
||||
"""Platform-agnostic agent execution logic.
|
||||
|
||||
This class encapsulates the core logic for executing an agent within a durable entity context.
|
||||
"""
|
||||
|
||||
agent: AgentProtocol
|
||||
callback: AgentResponseCallbackProtocol | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent: AgentProtocol,
|
||||
callback: AgentResponseCallbackProtocol | None = None,
|
||||
*,
|
||||
state_provider: AgentEntityStateProviderMixin,
|
||||
) -> None:
|
||||
self.agent = agent
|
||||
self.callback = callback
|
||||
self._state_provider = state_provider
|
||||
|
||||
logger.debug("[AgentEntity] Initialized with agent type: %s", type(agent).__name__)
|
||||
|
||||
@property
|
||||
def state(self) -> DurableAgentState:
|
||||
return self._state_provider.state
|
||||
|
||||
@state.setter
|
||||
def state(self, value: DurableAgentState) -> None:
|
||||
self._state_provider.state = value
|
||||
|
||||
def persist_state(self) -> None:
|
||||
self._state_provider.persist_state()
|
||||
|
||||
def reset(self) -> None:
|
||||
self._state_provider.reset()
|
||||
|
||||
def _is_error_response(self, entry: DurableAgentStateEntry) -> bool:
|
||||
"""Check if a conversation history entry is an error response."""
|
||||
if isinstance(entry, DurableAgentStateResponse):
|
||||
return entry.is_error
|
||||
return False
|
||||
|
||||
async def run(
|
||||
self,
|
||||
request: RunRequest | dict[str, Any] | str,
|
||||
) -> AgentRunResponse:
|
||||
"""Execute the agent with a message."""
|
||||
if isinstance(request, str):
|
||||
run_request = RunRequest(message=request, role=Role.USER)
|
||||
elif isinstance(request, dict):
|
||||
run_request = RunRequest.from_dict(request)
|
||||
else:
|
||||
run_request = request
|
||||
|
||||
message = run_request.message
|
||||
thread_id = self._state_provider.thread_id
|
||||
correlation_id = run_request.correlation_id
|
||||
if not thread_id:
|
||||
raise ValueError("Entity State Provider must provide a thread_id")
|
||||
if not correlation_id:
|
||||
raise ValueError("RunRequest must include a correlation_id")
|
||||
response_format = run_request.response_format
|
||||
enable_tool_calls = run_request.enable_tool_calls
|
||||
|
||||
logger.debug("[AgentEntity.run] Received Message: %s", run_request)
|
||||
|
||||
state_request = DurableAgentStateRequest.from_run_request(run_request)
|
||||
self.state.data.conversation_history.append(state_request)
|
||||
|
||||
try:
|
||||
chat_messages: list[ChatMessage] = [
|
||||
m.to_chat_message()
|
||||
for entry in self.state.data.conversation_history
|
||||
if not self._is_error_response(entry)
|
||||
for m in entry.messages
|
||||
]
|
||||
|
||||
run_kwargs: dict[str, Any] = {"messages": chat_messages}
|
||||
if not enable_tool_calls:
|
||||
run_kwargs["tools"] = None
|
||||
if response_format:
|
||||
run_kwargs["response_format"] = response_format
|
||||
|
||||
agent_run_response: AgentRunResponse = await self._invoke_agent(
|
||||
run_kwargs=run_kwargs,
|
||||
correlation_id=correlation_id,
|
||||
thread_id=thread_id,
|
||||
request_message=message,
|
||||
)
|
||||
|
||||
state_response = DurableAgentStateResponse.from_run_response(correlation_id, agent_run_response)
|
||||
self.state.data.conversation_history.append(state_response)
|
||||
self.persist_state()
|
||||
|
||||
return agent_run_response
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception("[AgentEntity.run] Agent execution failed.")
|
||||
|
||||
error_message = ChatMessage(
|
||||
role=Role.ASSISTANT, contents=[ErrorContent(message=str(exc), error_code=type(exc).__name__)]
|
||||
)
|
||||
error_response = AgentRunResponse(messages=[error_message])
|
||||
|
||||
error_state_response = DurableAgentStateResponse.from_run_response(correlation_id, error_response)
|
||||
error_state_response.is_error = True
|
||||
self.state.data.conversation_history.append(error_state_response)
|
||||
self.persist_state()
|
||||
|
||||
return error_response
|
||||
|
||||
async def _invoke_agent(
|
||||
self,
|
||||
run_kwargs: dict[str, Any],
|
||||
correlation_id: str,
|
||||
thread_id: str,
|
||||
request_message: str,
|
||||
) -> AgentRunResponse:
|
||||
"""Execute the agent, preferring streaming when available."""
|
||||
callback_context: AgentCallbackContext | None = None
|
||||
if self.callback is not None:
|
||||
callback_context = self._build_callback_context(
|
||||
correlation_id=correlation_id,
|
||||
thread_id=thread_id,
|
||||
request_message=request_message,
|
||||
)
|
||||
|
||||
run_stream_callable = getattr(self.agent, "run_stream", None)
|
||||
if callable(run_stream_callable):
|
||||
try:
|
||||
stream_candidate = run_stream_callable(**run_kwargs)
|
||||
if inspect.isawaitable(stream_candidate):
|
||||
stream_candidate = await stream_candidate
|
||||
|
||||
return await self._consume_stream(
|
||||
stream=cast(AsyncIterable[AgentRunResponseUpdate], stream_candidate),
|
||||
callback_context=callback_context,
|
||||
)
|
||||
except TypeError as type_error:
|
||||
if "__aiter__" not in str(type_error):
|
||||
raise
|
||||
logger.debug(
|
||||
"run_stream returned a non-async result; falling back to run(): %s",
|
||||
type_error,
|
||||
)
|
||||
except Exception as stream_error:
|
||||
logger.warning(
|
||||
"run_stream failed; falling back to run(): %s",
|
||||
stream_error,
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
logger.debug("Agent does not expose run_stream; falling back to run().")
|
||||
|
||||
agent_run_response = await self._invoke_non_stream(run_kwargs)
|
||||
await self._notify_final_response(agent_run_response, callback_context)
|
||||
return agent_run_response
|
||||
|
||||
async def _consume_stream(
|
||||
self,
|
||||
stream: AsyncIterable[AgentRunResponseUpdate],
|
||||
callback_context: AgentCallbackContext | None = None,
|
||||
) -> AgentRunResponse:
|
||||
"""Consume streaming responses and build the final AgentRunResponse."""
|
||||
updates: list[AgentRunResponseUpdate] = []
|
||||
|
||||
async for update in stream:
|
||||
updates.append(update)
|
||||
await self._notify_stream_update(update, callback_context)
|
||||
|
||||
if updates:
|
||||
response = AgentRunResponse.from_agent_run_response_updates(updates)
|
||||
else:
|
||||
logger.debug("[AgentEntity] No streaming updates received; creating empty response")
|
||||
response = AgentRunResponse(messages=[])
|
||||
|
||||
await self._notify_final_response(response, callback_context)
|
||||
return response
|
||||
|
||||
async def _invoke_non_stream(self, run_kwargs: dict[str, Any]) -> AgentRunResponse:
|
||||
"""Invoke the agent without streaming support."""
|
||||
run_callable = getattr(self.agent, "run", None)
|
||||
if run_callable is None or not callable(run_callable):
|
||||
raise AttributeError("Agent does not implement run() method")
|
||||
|
||||
result = run_callable(**run_kwargs)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
|
||||
if not isinstance(result, AgentRunResponse):
|
||||
raise TypeError(f"Agent run() must return an AgentRunResponse instance; received {type(result).__name__}")
|
||||
|
||||
return result
|
||||
|
||||
async def _notify_stream_update(
|
||||
self,
|
||||
update: AgentRunResponseUpdate,
|
||||
context: AgentCallbackContext | None,
|
||||
) -> None:
|
||||
"""Invoke the streaming callback if one is registered."""
|
||||
if self.callback is None or context is None:
|
||||
return
|
||||
|
||||
try:
|
||||
callback_result = self.callback.on_streaming_response_update(update, context)
|
||||
if inspect.isawaitable(callback_result):
|
||||
await callback_result
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"[AgentEntity] Streaming callback raised an exception: %s",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
async def _notify_final_response(
|
||||
self,
|
||||
response: AgentRunResponse,
|
||||
context: AgentCallbackContext | None,
|
||||
) -> None:
|
||||
"""Invoke the final response callback if one is registered."""
|
||||
if self.callback is None or context is None:
|
||||
return
|
||||
|
||||
try:
|
||||
callback_result = self.callback.on_agent_response(response, context)
|
||||
if inspect.isawaitable(callback_result):
|
||||
await callback_result
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"[AgentEntity] Response callback raised an exception: %s",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def _build_callback_context(
|
||||
self,
|
||||
correlation_id: str,
|
||||
thread_id: str,
|
||||
request_message: str,
|
||||
) -> AgentCallbackContext:
|
||||
"""Create the callback context provided to consumers."""
|
||||
agent_name = getattr(self.agent, "name", None) or type(self.agent).__name__
|
||||
return AgentCallbackContext(
|
||||
agent_name=agent_name,
|
||||
correlation_id=correlation_id,
|
||||
thread_id=thread_id,
|
||||
request_message=request_message,
|
||||
)
|
||||
|
||||
|
||||
class DurableTaskEntityStateProvider(DurableEntity, AgentEntityStateProviderMixin):
|
||||
"""DurableTask Durable Entity state provider for AgentEntity.
|
||||
|
||||
This class utilizes the Durable Entity context from `durabletask` package
|
||||
to get and set the state of the agent entity.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def _get_state_dict(self) -> dict[str, Any]:
|
||||
raw = self.get_state(dict, default={})
|
||||
return cast(dict[str, Any], raw)
|
||||
|
||||
def _set_state_dict(self, state: dict[str, Any]) -> None:
|
||||
self.set_state(state)
|
||||
|
||||
def _get_thread_id_from_entity(self) -> str:
|
||||
return self.entity_context.entity_id.key
|
||||
@@ -101,7 +101,6 @@ class RunRequest:
|
||||
role: The role of the message sender (user, system, or assistant)
|
||||
response_format: Optional Pydantic BaseModel type describing the structured response format
|
||||
enable_tool_calls: Whether to enable tool calls for this request
|
||||
thread_id: Optional thread ID for tracking
|
||||
correlation_id: Optional correlation ID for tracking the response to this specific request
|
||||
created_at: Optional timestamp when the request was created
|
||||
orchestration_id: Optional ID of the orchestration that initiated this request
|
||||
@@ -112,7 +111,6 @@ class RunRequest:
|
||||
role: Role = Role.USER
|
||||
response_format: type[BaseModel] | None = None
|
||||
enable_tool_calls: bool = True
|
||||
thread_id: str | None = None
|
||||
correlation_id: str | None = None
|
||||
created_at: datetime | None = None
|
||||
orchestration_id: str | None = None
|
||||
@@ -124,7 +122,6 @@ class RunRequest:
|
||||
role: Role | str | None = Role.USER,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
enable_tool_calls: bool = True,
|
||||
thread_id: str | None = None,
|
||||
correlation_id: str | None = None,
|
||||
created_at: datetime | None = None,
|
||||
orchestration_id: str | None = None,
|
||||
@@ -134,7 +131,6 @@ class RunRequest:
|
||||
self.response_format = response_format
|
||||
self.request_response_format = request_response_format
|
||||
self.enable_tool_calls = enable_tool_calls
|
||||
self.thread_id = thread_id
|
||||
self.correlation_id = correlation_id
|
||||
self.created_at = created_at
|
||||
self.orchestration_id = orchestration_id
|
||||
@@ -161,8 +157,6 @@ class RunRequest:
|
||||
}
|
||||
if self.response_format:
|
||||
result["response_format"] = serialize_response_format(self.response_format)
|
||||
if self.thread_id:
|
||||
result["thread_id"] = self.thread_id
|
||||
if self.correlation_id:
|
||||
result["correlationId"] = self.correlation_id
|
||||
if self.created_at:
|
||||
@@ -188,7 +182,6 @@ class RunRequest:
|
||||
role=cls.coerce_role(data.get("role")),
|
||||
response_format=_deserialize_response_format(data.get("response_format")),
|
||||
enable_tool_calls=data.get("enable_tool_calls", True),
|
||||
thread_id=data.get("thread_id"),
|
||||
correlation_id=data.get("correlationId"),
|
||||
created_at=created_at,
|
||||
orchestration_id=data.get("orchestrationId"),
|
||||
|
||||
@@ -0,0 +1,695 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Unit tests for AgentEntity.
|
||||
|
||||
Run with: pytest tests/test_entities.py -v
|
||||
"""
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import datetime
|
||||
from typing import Any, TypeVar
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
from agent_framework import AgentRunResponse, AgentRunResponseUpdate, ChatMessage, ErrorContent, Role
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agent_framework_durabletask import (
|
||||
AgentEntity,
|
||||
AgentEntityStateProviderMixin,
|
||||
DurableAgentState,
|
||||
DurableAgentStateData,
|
||||
DurableAgentStateMessage,
|
||||
DurableAgentStateRequest,
|
||||
DurableAgentStateTextContent,
|
||||
RunRequest,
|
||||
)
|
||||
from agent_framework_durabletask._entities import DurableTaskEntityStateProvider
|
||||
|
||||
TState = TypeVar("TState")
|
||||
|
||||
|
||||
class MockEntityContext:
|
||||
"""Minimal durabletask EntityContext shim for tests."""
|
||||
|
||||
def __init__(self, initial_state: Any = None) -> None:
|
||||
self._state = initial_state
|
||||
|
||||
def get_state(
|
||||
self,
|
||||
intended_type: type[TState] | None = None,
|
||||
default: TState | None = None,
|
||||
) -> Any:
|
||||
del intended_type
|
||||
if self._state is None:
|
||||
return default
|
||||
return self._state
|
||||
|
||||
def set_state(self, new_state: Any) -> None:
|
||||
self._state = new_state
|
||||
|
||||
|
||||
class _InMemoryStateProvider(AgentEntityStateProviderMixin):
|
||||
"""Test-only state provider for AgentEntity."""
|
||||
|
||||
def __init__(self, *, thread_id: str, initial_state: dict[str, Any] | None = None) -> None:
|
||||
self._thread_id = thread_id
|
||||
self._state_dict: dict[str, Any] = initial_state or {}
|
||||
|
||||
def _get_state_dict(self) -> dict[str, Any]:
|
||||
return self._state_dict
|
||||
|
||||
def _set_state_dict(self, state: dict[str, Any]) -> None:
|
||||
self._state_dict = state
|
||||
|
||||
def _get_thread_id_from_entity(self) -> str:
|
||||
return self._thread_id
|
||||
|
||||
|
||||
def _make_entity(agent: Any, callback: Any = None, *, thread_id: str = "test-thread") -> AgentEntity:
|
||||
return AgentEntity(agent, callback=callback, state_provider=_InMemoryStateProvider(thread_id=thread_id))
|
||||
|
||||
|
||||
def _role_value(chat_message: DurableAgentStateMessage) -> str:
|
||||
"""Helper to extract the string role from a ChatMessage."""
|
||||
role = getattr(chat_message, "role", None)
|
||||
role_value = getattr(role, "value", role)
|
||||
if role_value is None:
|
||||
return ""
|
||||
return str(role_value)
|
||||
|
||||
|
||||
def _agent_response(text: str | None) -> AgentRunResponse:
|
||||
"""Create an AgentRunResponse with a single assistant message."""
|
||||
message = (
|
||||
ChatMessage(role="assistant", text=text) if text is not None else ChatMessage(role="assistant", contents=[])
|
||||
)
|
||||
return AgentRunResponse(messages=[message])
|
||||
|
||||
|
||||
class RecordingCallback:
|
||||
"""Callback implementation capturing streaming and final responses for assertions."""
|
||||
|
||||
def __init__(self):
|
||||
self.stream_mock = AsyncMock()
|
||||
self.response_mock = AsyncMock()
|
||||
|
||||
async def on_streaming_response_update(
|
||||
self,
|
||||
update: AgentRunResponseUpdate,
|
||||
context: Any,
|
||||
) -> None:
|
||||
await self.stream_mock(update, context)
|
||||
|
||||
async def on_agent_response(self, response: AgentRunResponse, context: Any) -> None:
|
||||
await self.response_mock(response, context)
|
||||
|
||||
|
||||
class EntityStructuredResponse(BaseModel):
|
||||
answer: float
|
||||
|
||||
|
||||
class TestAgentEntityInit:
|
||||
"""Test suite for AgentEntity initialization."""
|
||||
|
||||
def test_init_creates_entity(self) -> None:
|
||||
"""Test that AgentEntity initializes correctly."""
|
||||
mock_agent = Mock()
|
||||
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
assert entity.agent == mock_agent
|
||||
assert len(entity.state.data.conversation_history) == 0
|
||||
assert entity.state.data.extension_data is None
|
||||
assert entity.state.schema_version == DurableAgentState.SCHEMA_VERSION
|
||||
|
||||
def test_init_stores_agent_reference(self) -> None:
|
||||
"""Test that the agent reference is stored correctly."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.name = "TestAgent"
|
||||
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
assert entity.agent.name == "TestAgent"
|
||||
|
||||
def test_init_with_different_agent_types(self) -> None:
|
||||
"""Test initialization with different agent types."""
|
||||
agent1 = Mock()
|
||||
agent1.__class__.__name__ = "AzureOpenAIAgent"
|
||||
|
||||
agent2 = Mock()
|
||||
agent2.__class__.__name__ = "CustomAgent"
|
||||
|
||||
entity1 = _make_entity(agent1)
|
||||
entity2 = _make_entity(agent2)
|
||||
|
||||
assert entity1.agent.__class__.__name__ == "AzureOpenAIAgent"
|
||||
assert entity2.agent.__class__.__name__ == "CustomAgent"
|
||||
|
||||
|
||||
class TestDurableTaskEntityStateProvider:
|
||||
"""Tests for DurableTaskEntityStateProvider wrapper behavior and persistence wiring."""
|
||||
|
||||
def _make_durabletask_entity_provider(
|
||||
self,
|
||||
agent: Any,
|
||||
*,
|
||||
initial_state: dict[str, Any] | None = None,
|
||||
) -> tuple[DurableTaskEntityStateProvider, MockEntityContext]:
|
||||
"""Create a DurableTaskEntityStateProvider wired to an in-memory durabletask context."""
|
||||
entity = DurableTaskEntityStateProvider()
|
||||
ctx = MockEntityContext(initial_state)
|
||||
# DurableEntity provides this hook; required for get_state/set_state to work in unit tests.
|
||||
entity._initialize_entity_context(ctx) # type: ignore[attr-defined]
|
||||
return entity, ctx
|
||||
|
||||
def test_reset_persists_cleared_state(self) -> None:
|
||||
mock_agent = Mock()
|
||||
|
||||
existing_state = {
|
||||
"schemaVersion": "1.0.0",
|
||||
"data": {
|
||||
"conversationHistory": [
|
||||
{
|
||||
"$type": "request",
|
||||
"correlationId": "corr-existing-1",
|
||||
"createdAt": "2024-01-01T00:00:00Z",
|
||||
"messages": [{"role": "user", "contents": [{"$type": "text", "text": "msg1"}]}],
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
entity, ctx = self._make_durabletask_entity_provider(mock_agent, initial_state=existing_state)
|
||||
|
||||
entity.reset()
|
||||
|
||||
persisted = ctx.get_state(dict, default={})
|
||||
assert isinstance(persisted, dict)
|
||||
assert persisted["data"]["conversationHistory"] == []
|
||||
|
||||
|
||||
class TestAgentEntityRunAgent:
|
||||
"""Test suite for the run_agent operation."""
|
||||
|
||||
async def test_run_executes_agent(self) -> None:
|
||||
"""Test that run executes the agent."""
|
||||
mock_agent = Mock()
|
||||
mock_response = _agent_response("Test response")
|
||||
mock_agent.run = AsyncMock(return_value=mock_response)
|
||||
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
result = await entity.run({
|
||||
"message": "Test message",
|
||||
"correlationId": "corr-entity-1",
|
||||
})
|
||||
|
||||
# Verify agent.run was called
|
||||
mock_agent.run.assert_called_once()
|
||||
_, kwargs = mock_agent.run.call_args
|
||||
sent_messages: list[Any] = kwargs.get("messages")
|
||||
assert len(sent_messages) == 1
|
||||
sent_message = sent_messages[0]
|
||||
assert isinstance(sent_message, ChatMessage)
|
||||
assert getattr(sent_message, "text", None) == "Test message"
|
||||
assert getattr(sent_message.role, "value", sent_message.role) == "user"
|
||||
|
||||
# Verify result
|
||||
assert isinstance(result, AgentRunResponse)
|
||||
assert result.text == "Test response"
|
||||
|
||||
async def test_run_agent_streaming_callbacks_invoked(self) -> None:
|
||||
"""Ensure streaming updates trigger callbacks and run() is not used."""
|
||||
updates = [
|
||||
AgentRunResponseUpdate(text="Hello"),
|
||||
AgentRunResponseUpdate(text=" world"),
|
||||
]
|
||||
|
||||
async def update_generator() -> AsyncIterator[AgentRunResponseUpdate]:
|
||||
for update in updates:
|
||||
yield update
|
||||
|
||||
mock_agent = Mock()
|
||||
mock_agent.name = "StreamingAgent"
|
||||
mock_agent.run_stream = Mock(return_value=update_generator())
|
||||
mock_agent.run = AsyncMock(side_effect=AssertionError("run() should not be called when streaming succeeds"))
|
||||
|
||||
callback = RecordingCallback()
|
||||
entity = _make_entity(mock_agent, callback=callback, thread_id="session-1")
|
||||
|
||||
result = await entity.run(
|
||||
{
|
||||
"message": "Tell me something",
|
||||
"correlationId": "corr-stream-1",
|
||||
},
|
||||
)
|
||||
|
||||
assert isinstance(result, AgentRunResponse)
|
||||
assert "Hello" in result.text
|
||||
assert callback.stream_mock.await_count == len(updates)
|
||||
assert callback.response_mock.await_count == 1
|
||||
mock_agent.run.assert_not_called()
|
||||
|
||||
# Validate callback arguments
|
||||
stream_calls = callback.stream_mock.await_args_list
|
||||
for expected_update, recorded_call in zip(updates, stream_calls, strict=True):
|
||||
assert recorded_call.args[0] is expected_update
|
||||
context = recorded_call.args[1]
|
||||
assert context.agent_name == "StreamingAgent"
|
||||
assert context.correlation_id == "corr-stream-1"
|
||||
assert context.thread_id == "session-1"
|
||||
assert context.request_message == "Tell me something"
|
||||
|
||||
final_call = callback.response_mock.await_args
|
||||
assert final_call is not None
|
||||
final_response, final_context = final_call.args
|
||||
assert final_context.agent_name == "StreamingAgent"
|
||||
assert final_context.correlation_id == "corr-stream-1"
|
||||
assert final_context.thread_id == "session-1"
|
||||
assert final_context.request_message == "Tell me something"
|
||||
assert getattr(final_response, "text", "").strip()
|
||||
|
||||
async def test_run_agent_final_callback_without_streaming(self) -> None:
|
||||
"""Ensure the final callback fires even when streaming is unavailable."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.name = "NonStreamingAgent"
|
||||
mock_agent.run_stream = None
|
||||
agent_response = _agent_response("Final response")
|
||||
mock_agent.run = AsyncMock(return_value=agent_response)
|
||||
|
||||
callback = RecordingCallback()
|
||||
entity = _make_entity(mock_agent, callback=callback, thread_id="session-2")
|
||||
|
||||
result = await entity.run(
|
||||
{
|
||||
"message": "Hi",
|
||||
"correlationId": "corr-final-1",
|
||||
},
|
||||
)
|
||||
|
||||
assert isinstance(result, AgentRunResponse)
|
||||
assert result.text == "Final response"
|
||||
assert callback.stream_mock.await_count == 0
|
||||
assert callback.response_mock.await_count == 1
|
||||
|
||||
final_call = callback.response_mock.await_args
|
||||
assert final_call is not None
|
||||
assert final_call.args[0] is agent_response
|
||||
final_context = final_call.args[1]
|
||||
assert final_context.agent_name == "NonStreamingAgent"
|
||||
assert final_context.correlation_id == "corr-final-1"
|
||||
assert final_context.thread_id == "session-2"
|
||||
assert final_context.request_message == "Hi"
|
||||
|
||||
async def test_run_agent_updates_conversation_history(self) -> None:
|
||||
"""Test that run_agent updates the conversation history."""
|
||||
mock_agent = Mock()
|
||||
mock_response = _agent_response("Agent response")
|
||||
mock_agent.run = AsyncMock(return_value=mock_response)
|
||||
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
await entity.run({"message": "User message", "correlationId": "corr-entity-2"})
|
||||
|
||||
# Should have 2 entries: user message + assistant response
|
||||
user_history = entity.state.data.conversation_history[0].messages
|
||||
assistant_history = entity.state.data.conversation_history[1].messages
|
||||
|
||||
assert len(user_history) == 1
|
||||
|
||||
user_msg = user_history[0]
|
||||
assert _role_value(user_msg) == "user"
|
||||
assert user_msg.text == "User message"
|
||||
|
||||
assistant_msg = assistant_history[0]
|
||||
assert _role_value(assistant_msg) == "assistant"
|
||||
assert assistant_msg.text == "Agent response"
|
||||
|
||||
async def test_run_agent_increments_message_count(self) -> None:
|
||||
"""Test that run_agent increments the message count."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
|
||||
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
assert len(entity.state.data.conversation_history) == 0
|
||||
|
||||
await entity.run({"message": "Message 1", "correlationId": "corr-entity-3a"})
|
||||
assert len(entity.state.data.conversation_history) == 2
|
||||
|
||||
await entity.run({"message": "Message 2", "correlationId": "corr-entity-3b"})
|
||||
assert len(entity.state.data.conversation_history) == 4
|
||||
|
||||
await entity.run({"message": "Message 3", "correlationId": "corr-entity-3c"})
|
||||
assert len(entity.state.data.conversation_history) == 6
|
||||
|
||||
async def test_run_requires_entity_thread_id(self) -> None:
|
||||
"""Test that AgentEntity.run rejects missing entity thread identifiers."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
|
||||
|
||||
entity = _make_entity(mock_agent, thread_id="")
|
||||
|
||||
with pytest.raises(ValueError, match="thread_id"):
|
||||
await entity.run({"message": "Message", "correlationId": "corr-entity-5"})
|
||||
|
||||
async def test_run_agent_multiple_conversations(self) -> None:
|
||||
"""Test that run_agent maintains history across multiple messages."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
|
||||
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
# Send multiple messages
|
||||
await entity.run({"message": "Message 1", "correlationId": "corr-entity-8a"})
|
||||
await entity.run({"message": "Message 2", "correlationId": "corr-entity-8b"})
|
||||
await entity.run({"message": "Message 3", "correlationId": "corr-entity-8c"})
|
||||
|
||||
history = entity.state.data.conversation_history
|
||||
assert len(history) == 6
|
||||
assert entity.state.message_count == 6
|
||||
|
||||
|
||||
class TestAgentEntityReset:
|
||||
"""Test suite for the reset operation."""
|
||||
|
||||
def test_reset_clears_conversation_history(self) -> None:
|
||||
"""Test that reset clears the conversation history."""
|
||||
mock_agent = Mock()
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
# Add some history with proper DurableAgentStateEntry objects
|
||||
entity.state.data.conversation_history = [
|
||||
DurableAgentStateRequest(
|
||||
correlation_id="test-1",
|
||||
created_at=datetime.now(),
|
||||
messages=[
|
||||
DurableAgentStateMessage(
|
||||
role="user",
|
||||
contents=[DurableAgentStateTextContent(text="msg1")],
|
||||
)
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
entity.reset()
|
||||
|
||||
assert entity.state.data.conversation_history == []
|
||||
|
||||
def test_reset_with_extension_data(self) -> None:
|
||||
"""Test that reset works when entity has extension data."""
|
||||
mock_agent = Mock()
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
# Set up some initial state with conversation history
|
||||
entity.state.data = DurableAgentStateData(conversation_history=[], extension_data={"some_key": "some_value"})
|
||||
|
||||
entity.reset()
|
||||
|
||||
assert len(entity.state.data.conversation_history) == 0
|
||||
|
||||
def test_reset_clears_message_count(self) -> None:
|
||||
"""Test that reset clears the message count."""
|
||||
mock_agent = Mock()
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
entity.reset()
|
||||
|
||||
assert len(entity.state.data.conversation_history) == 0
|
||||
|
||||
async def test_reset_after_conversation(self) -> None:
|
||||
"""Test reset after a full conversation."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
|
||||
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
# Have a conversation
|
||||
await entity.run({"message": "Message 1", "correlationId": "corr-entity-10a"})
|
||||
await entity.run({"message": "Message 2", "correlationId": "corr-entity-10b"})
|
||||
|
||||
# Verify state before reset
|
||||
assert entity.state.message_count == 4
|
||||
assert len(entity.state.data.conversation_history) == 4
|
||||
|
||||
# Reset
|
||||
entity.reset()
|
||||
|
||||
# Verify state after reset
|
||||
assert entity.state.message_count == 0
|
||||
assert len(entity.state.data.conversation_history) == 0
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Test suite for error handling in entities."""
|
||||
|
||||
async def test_run_agent_handles_agent_exception(self) -> None:
|
||||
"""Test that run_agent handles agent exceptions."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(side_effect=Exception("Agent failed"))
|
||||
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
result = await entity.run({"message": "Message", "correlationId": "corr-entity-error-1"})
|
||||
|
||||
assert isinstance(result, AgentRunResponse)
|
||||
assert len(result.messages) == 1
|
||||
content = result.messages[0].contents[0]
|
||||
assert isinstance(content, ErrorContent)
|
||||
assert "Agent failed" in (content.message or "")
|
||||
assert content.error_code == "Exception"
|
||||
|
||||
async def test_run_agent_handles_value_error(self) -> None:
|
||||
"""Test that run_agent handles ValueError instances."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(side_effect=ValueError("Invalid input"))
|
||||
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
result = await entity.run({"message": "Message", "correlationId": "corr-entity-error-2"})
|
||||
|
||||
assert isinstance(result, AgentRunResponse)
|
||||
assert len(result.messages) == 1
|
||||
content = result.messages[0].contents[0]
|
||||
assert isinstance(content, ErrorContent)
|
||||
assert content.error_code == "ValueError"
|
||||
assert "Invalid input" in str(content.message)
|
||||
|
||||
async def test_run_agent_handles_timeout_error(self) -> None:
|
||||
"""Test that run_agent handles TimeoutError instances."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(side_effect=TimeoutError("Request timeout"))
|
||||
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
result = await entity.run({"message": "Message", "correlationId": "corr-entity-error-3"})
|
||||
|
||||
assert isinstance(result, AgentRunResponse)
|
||||
assert len(result.messages) == 1
|
||||
content = result.messages[0].contents[0]
|
||||
assert isinstance(content, ErrorContent)
|
||||
assert content.error_code == "TimeoutError"
|
||||
|
||||
async def test_run_agent_preserves_message_on_error(self) -> None:
|
||||
"""Test that run_agent preserves message information on error."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(side_effect=Exception("Error"))
|
||||
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
result = await entity.run(
|
||||
{"message": "Test message", "correlationId": "corr-entity-error-4"},
|
||||
)
|
||||
|
||||
# Even on error, message info should be preserved
|
||||
assert isinstance(result, AgentRunResponse)
|
||||
assert len(result.messages) == 1
|
||||
content = result.messages[0].contents[0]
|
||||
assert isinstance(content, ErrorContent)
|
||||
|
||||
|
||||
class TestConversationHistory:
|
||||
"""Test suite for conversation history tracking."""
|
||||
|
||||
async def test_conversation_history_has_timestamps(self) -> None:
|
||||
"""Test that conversation history entries include timestamps."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
|
||||
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
await entity.run({"message": "Message", "correlationId": "corr-entity-history-1"})
|
||||
|
||||
# Check both user and assistant messages have timestamps
|
||||
for entry in entity.state.data.conversation_history:
|
||||
timestamp = entry.created_at
|
||||
assert timestamp is not None
|
||||
# Verify timestamp is in ISO format
|
||||
datetime.fromisoformat(str(timestamp))
|
||||
|
||||
async def test_conversation_history_ordering(self) -> None:
|
||||
"""Test that conversation history maintains the correct order."""
|
||||
mock_agent = Mock()
|
||||
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
# Send multiple messages with different responses
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response 1"))
|
||||
await entity.run(
|
||||
{"message": "Message 1", "correlationId": "corr-entity-history-2a"},
|
||||
)
|
||||
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response 2"))
|
||||
await entity.run(
|
||||
{"message": "Message 2", "correlationId": "corr-entity-history-2b"},
|
||||
)
|
||||
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response 3"))
|
||||
await entity.run(
|
||||
{"message": "Message 3", "correlationId": "corr-entity-history-2c"},
|
||||
)
|
||||
|
||||
# Verify order
|
||||
history = entity.state.data.conversation_history
|
||||
# Each conversation turn creates 2 entries: request and response
|
||||
assert history[0].messages[0].text == "Message 1" # Request 1
|
||||
assert history[1].messages[0].text == "Response 1" # Response 1
|
||||
assert history[2].messages[0].text == "Message 2" # Request 2
|
||||
assert history[3].messages[0].text == "Response 2" # Response 2
|
||||
assert history[4].messages[0].text == "Message 3" # Request 3
|
||||
assert history[5].messages[0].text == "Response 3" # Response 3
|
||||
|
||||
async def test_conversation_history_role_alternation(self) -> None:
|
||||
"""Test that conversation history alternates between user and assistant roles."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
|
||||
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
await entity.run(
|
||||
{"message": "Message 1", "correlationId": "corr-entity-history-3a"},
|
||||
)
|
||||
await entity.run(
|
||||
{"message": "Message 2", "correlationId": "corr-entity-history-3b"},
|
||||
)
|
||||
|
||||
# Check role alternation
|
||||
history = entity.state.data.conversation_history
|
||||
# Each conversation turn creates 2 entries: request and response
|
||||
assert history[0].messages[0].role == "user" # Request 1
|
||||
assert history[1].messages[0].role == "assistant" # Response 1
|
||||
assert history[2].messages[0].role == "user" # Request 2
|
||||
assert history[3].messages[0].role == "assistant" # Response 2
|
||||
|
||||
|
||||
class TestRunRequestSupport:
|
||||
"""Test suite for RunRequest support in entities."""
|
||||
|
||||
async def test_run_agent_with_run_request_object(self) -> None:
|
||||
"""Test run_agent with a RunRequest object."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
|
||||
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
request = RunRequest(
|
||||
message="Test message",
|
||||
role=Role.USER,
|
||||
enable_tool_calls=True,
|
||||
correlation_id="corr-runreq-1",
|
||||
)
|
||||
|
||||
result = await entity.run(request)
|
||||
|
||||
assert isinstance(result, AgentRunResponse)
|
||||
assert result.text == "Response"
|
||||
|
||||
async def test_run_agent_with_dict_request(self) -> None:
|
||||
"""Test run_agent with a dictionary request."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
|
||||
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
request_dict = {
|
||||
"message": "Test message",
|
||||
"role": "system",
|
||||
"enable_tool_calls": False,
|
||||
"correlationId": "corr-runreq-2",
|
||||
}
|
||||
|
||||
result = await entity.run(request_dict)
|
||||
|
||||
assert isinstance(result, AgentRunResponse)
|
||||
assert result.text == "Response"
|
||||
|
||||
async def test_run_agent_with_string_raises_without_correlation(self) -> None:
|
||||
"""Test that run_agent rejects legacy string input without correlation ID."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
|
||||
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await entity.run("Simple message")
|
||||
|
||||
async def test_run_agent_stores_role_in_history(self) -> None:
|
||||
"""Test that run_agent stores the role in conversation history."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
|
||||
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
# Send as system role
|
||||
request = RunRequest(
|
||||
message="System message",
|
||||
role=Role.SYSTEM,
|
||||
correlation_id="corr-runreq-3",
|
||||
)
|
||||
|
||||
await entity.run(request)
|
||||
|
||||
# Check that system role was stored
|
||||
history = entity.state.data.conversation_history
|
||||
assert history[0].messages[0].role == "system"
|
||||
assert history[0].messages[0].text == "System message"
|
||||
|
||||
async def test_run_agent_with_response_format(self) -> None:
|
||||
"""Test run_agent with a JSON response format."""
|
||||
mock_agent = Mock()
|
||||
# Return JSON response
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response('{"answer": 42}'))
|
||||
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
request = RunRequest(
|
||||
message="What is the answer?",
|
||||
response_format=EntityStructuredResponse,
|
||||
correlation_id="corr-runreq-4",
|
||||
)
|
||||
|
||||
result = await entity.run(request)
|
||||
|
||||
assert isinstance(result, AgentRunResponse)
|
||||
assert result.text == '{"answer": 42}'
|
||||
assert result.value is None
|
||||
|
||||
async def test_run_agent_disable_tool_calls(self) -> None:
|
||||
"""Test run_agent with tool calls disabled."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
|
||||
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
request = RunRequest(message="Test", enable_tool_calls=False, correlation_id="corr-runreq-5")
|
||||
|
||||
result = await entity.run(request)
|
||||
|
||||
assert isinstance(result, AgentRunResponse)
|
||||
# Agent should have been called (tool disabling is framework-dependent)
|
||||
mock_agent.run.assert_called_once()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "--tb=short"])
|
||||
@@ -18,20 +18,18 @@ class TestRunRequest:
|
||||
|
||||
def test_init_with_defaults(self) -> None:
|
||||
"""Test RunRequest initialization with defaults."""
|
||||
request = RunRequest(message="Hello", thread_id="thread-default")
|
||||
request = RunRequest(message="Hello")
|
||||
|
||||
assert request.message == "Hello"
|
||||
assert request.role == Role.USER
|
||||
assert request.response_format is None
|
||||
assert request.enable_tool_calls is True
|
||||
assert request.thread_id == "thread-default"
|
||||
|
||||
def test_init_with_all_fields(self) -> None:
|
||||
"""Test RunRequest initialization with all fields."""
|
||||
schema = ModuleStructuredResponse
|
||||
request = RunRequest(
|
||||
message="Hello",
|
||||
thread_id="thread-123",
|
||||
role=Role.SYSTEM,
|
||||
response_format=schema,
|
||||
enable_tool_calls=False,
|
||||
@@ -41,31 +39,29 @@ class TestRunRequest:
|
||||
assert request.role == Role.SYSTEM
|
||||
assert request.response_format is schema
|
||||
assert request.enable_tool_calls is False
|
||||
assert request.thread_id == "thread-123"
|
||||
|
||||
def test_init_coerces_string_role(self) -> None:
|
||||
"""Ensure string role values are coerced into Role instances."""
|
||||
request = RunRequest(message="Hello", thread_id="thread-str-role", role="system") # type: ignore[arg-type]
|
||||
request = RunRequest(message="Hello", role="system") # type: ignore[arg-type]
|
||||
|
||||
assert request.role == Role.SYSTEM
|
||||
|
||||
def test_to_dict_with_defaults(self) -> None:
|
||||
"""Test to_dict with default values."""
|
||||
request = RunRequest(message="Test message", thread_id="thread-to-dict")
|
||||
request = RunRequest(message="Test message")
|
||||
data = request.to_dict()
|
||||
|
||||
assert data["message"] == "Test message"
|
||||
assert data["enable_tool_calls"] is True
|
||||
assert data["role"] == "user"
|
||||
assert "response_format" not in data or data["response_format"] is None
|
||||
assert data["thread_id"] == "thread-to-dict"
|
||||
assert "thread_id" not in data
|
||||
|
||||
def test_to_dict_with_all_fields(self) -> None:
|
||||
"""Test to_dict with all fields."""
|
||||
schema = ModuleStructuredResponse
|
||||
request = RunRequest(
|
||||
message="Hello",
|
||||
thread_id="thread-456",
|
||||
role=Role.ASSISTANT,
|
||||
response_format=schema,
|
||||
enable_tool_calls=False,
|
||||
@@ -78,17 +74,22 @@ class TestRunRequest:
|
||||
assert data["response_format"]["module"] == schema.__module__
|
||||
assert data["response_format"]["qualname"] == schema.__qualname__
|
||||
assert data["enable_tool_calls"] is False
|
||||
assert data["thread_id"] == "thread-456"
|
||||
assert "thread_id" not in data
|
||||
|
||||
def test_from_dict_with_defaults(self) -> None:
|
||||
"""Test from_dict with minimal data."""
|
||||
data = {"message": "Hello", "thread_id": "thread-from-dict"}
|
||||
data = {"message": "Hello"}
|
||||
request = RunRequest.from_dict(data)
|
||||
|
||||
assert request.message == "Hello"
|
||||
assert request.role == Role.USER
|
||||
assert request.enable_tool_calls is True
|
||||
assert request.thread_id == "thread-from-dict"
|
||||
|
||||
def test_from_dict_ignores_thread_id_field(self) -> None:
|
||||
"""Ensure legacy thread_id input does not break RunRequest parsing."""
|
||||
request = RunRequest.from_dict({"message": "Hello", "thread_id": "ignored"})
|
||||
|
||||
assert request.message == "Hello"
|
||||
|
||||
def test_from_dict_with_all_fields(self) -> None:
|
||||
"""Test from_dict with all fields."""
|
||||
@@ -101,7 +102,6 @@ class TestRunRequest:
|
||||
"qualname": ModuleStructuredResponse.__qualname__,
|
||||
},
|
||||
"enable_tool_calls": False,
|
||||
"thread_id": "thread-789",
|
||||
}
|
||||
request = RunRequest.from_dict(data)
|
||||
|
||||
@@ -109,11 +109,10 @@ class TestRunRequest:
|
||||
assert request.role == Role.SYSTEM
|
||||
assert request.response_format is ModuleStructuredResponse
|
||||
assert request.enable_tool_calls is False
|
||||
assert request.thread_id == "thread-789"
|
||||
|
||||
def test_from_dict_with_unknown_role_preserves_value(self) -> None:
|
||||
"""Test from_dict keeps custom roles intact."""
|
||||
data = {"message": "Test", "role": "reviewer", "thread_id": "thread-with-custom-role"}
|
||||
data = {"message": "Test", "role": "reviewer"}
|
||||
request = RunRequest.from_dict(data)
|
||||
|
||||
assert request.role.value == "reviewer"
|
||||
@@ -121,18 +120,15 @@ class TestRunRequest:
|
||||
|
||||
def test_from_dict_empty_message(self) -> None:
|
||||
"""Test from_dict with empty message."""
|
||||
data = {"thread_id": "thread-empty"}
|
||||
request = RunRequest.from_dict(data)
|
||||
request = RunRequest.from_dict({})
|
||||
|
||||
assert request.message == ""
|
||||
assert request.role == Role.USER
|
||||
assert request.thread_id == "thread-empty"
|
||||
|
||||
def test_round_trip_dict_conversion(self) -> None:
|
||||
"""Test round-trip to_dict and from_dict."""
|
||||
original = RunRequest(
|
||||
message="Test message",
|
||||
thread_id="thread-123",
|
||||
role=Role.SYSTEM,
|
||||
response_format=ModuleStructuredResponse,
|
||||
enable_tool_calls=False,
|
||||
@@ -145,13 +141,11 @@ class TestRunRequest:
|
||||
assert restored.role == original.role
|
||||
assert restored.response_format is ModuleStructuredResponse
|
||||
assert restored.enable_tool_calls == original.enable_tool_calls
|
||||
assert restored.thread_id == original.thread_id
|
||||
|
||||
def test_round_trip_with_pydantic_response_format(self) -> None:
|
||||
"""Ensure Pydantic response formats serialize and deserialize properly."""
|
||||
original = RunRequest(
|
||||
message="Structured",
|
||||
thread_id="thread-pydantic",
|
||||
response_format=ModuleStructuredResponse,
|
||||
)
|
||||
|
||||
@@ -166,14 +160,14 @@ class TestRunRequest:
|
||||
|
||||
def test_init_with_correlationId(self) -> None:
|
||||
"""Test RunRequest initialization with correlationId."""
|
||||
request = RunRequest(message="Test message", thread_id="thread-corr-init", correlation_id="corr-123")
|
||||
request = RunRequest(message="Test message", correlation_id="corr-123")
|
||||
|
||||
assert request.message == "Test message"
|
||||
assert request.correlation_id == "corr-123"
|
||||
|
||||
def test_to_dict_with_correlationId(self) -> None:
|
||||
"""Test to_dict includes correlationId."""
|
||||
request = RunRequest(message="Test", thread_id="thread-corr-to-dict", correlation_id="corr-456")
|
||||
request = RunRequest(message="Test", correlation_id="corr-456")
|
||||
data = request.to_dict()
|
||||
|
||||
assert data["message"] == "Test"
|
||||
@@ -181,18 +175,16 @@ class TestRunRequest:
|
||||
|
||||
def test_from_dict_with_correlationId(self) -> None:
|
||||
"""Test from_dict with correlationId."""
|
||||
data = {"message": "Test", "correlationId": "corr-789", "thread_id": "thread-corr-from-dict"}
|
||||
data = {"message": "Test", "correlationId": "corr-789"}
|
||||
request = RunRequest.from_dict(data)
|
||||
|
||||
assert request.message == "Test"
|
||||
assert request.correlation_id == "corr-789"
|
||||
assert request.thread_id == "thread-corr-from-dict"
|
||||
|
||||
def test_round_trip_with_correlationId(self) -> None:
|
||||
"""Test round-trip to_dict and from_dict with correlationId."""
|
||||
original = RunRequest(
|
||||
message="Test message",
|
||||
thread_id="thread-123",
|
||||
role=Role.SYSTEM,
|
||||
correlation_id="corr-123",
|
||||
)
|
||||
@@ -203,13 +195,11 @@ class TestRunRequest:
|
||||
assert restored.message == original.message
|
||||
assert restored.role == original.role
|
||||
assert restored.correlation_id == original.correlation_id
|
||||
assert restored.thread_id == original.thread_id
|
||||
|
||||
def test_init_with_orchestration_id(self) -> None:
|
||||
"""Test RunRequest initialization with orchestration_id."""
|
||||
request = RunRequest(
|
||||
message="Test message",
|
||||
thread_id="thread-orch-init",
|
||||
orchestration_id="orch-123",
|
||||
)
|
||||
|
||||
@@ -220,7 +210,6 @@ class TestRunRequest:
|
||||
"""Test to_dict includes orchestrationId."""
|
||||
request = RunRequest(
|
||||
message="Test",
|
||||
thread_id="thread-orch-to-dict",
|
||||
orchestration_id="orch-456",
|
||||
)
|
||||
data = request.to_dict()
|
||||
@@ -232,7 +221,6 @@ class TestRunRequest:
|
||||
"""Test to_dict excludes orchestrationId when not set."""
|
||||
request = RunRequest(
|
||||
message="Test",
|
||||
thread_id="thread-orch-none",
|
||||
)
|
||||
data = request.to_dict()
|
||||
|
||||
@@ -243,19 +231,16 @@ class TestRunRequest:
|
||||
data = {
|
||||
"message": "Test",
|
||||
"orchestrationId": "orch-789",
|
||||
"thread_id": "thread-orch-from-dict",
|
||||
}
|
||||
request = RunRequest.from_dict(data)
|
||||
|
||||
assert request.message == "Test"
|
||||
assert request.orchestration_id == "orch-789"
|
||||
assert request.thread_id == "thread-orch-from-dict"
|
||||
|
||||
def test_round_trip_with_orchestration_id(self) -> None:
|
||||
"""Test round-trip to_dict and from_dict with orchestration_id."""
|
||||
original = RunRequest(
|
||||
message="Test message",
|
||||
thread_id="thread-123",
|
||||
role=Role.SYSTEM,
|
||||
correlation_id="corr-123",
|
||||
orchestration_id="orch-123",
|
||||
@@ -268,7 +253,6 @@ class TestRunRequest:
|
||||
assert restored.role == original.role
|
||||
assert restored.correlation_id == original.correlation_id
|
||||
assert restored.orchestration_id == original.orchestration_id
|
||||
assert restored.thread_id == original.thread_id
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user