mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Add orchestration ID to durable agent entity state and code refactor (#2484)
* Initial plan * Add orchestration ID to durable agent entity state for Python Co-authored-by: larohra <41490930+larohra@users.noreply.github.com> * Fix type safety checks * Fix tests --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: larohra <41490930+larohra@users.noreply.github.com> Co-authored-by: Laveesh Rohra <larohra@microsoft.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
0d5d10d24a
commit
cb343dd707
@@ -562,6 +562,7 @@ class AgentFunctionApp(DFAppBase):
|
||||
logger.debug("[MCP Tool Trigger] Received invocation for agent: %s", agent_name)
|
||||
return await self._handle_mcp_tool_invocation(agent_name=agent_name, context=context, client=client)
|
||||
|
||||
_ = mcp_tool_handler
|
||||
logger.debug("[AgentFunctionApp] Registered MCP tool trigger for agent: %s", agent_name)
|
||||
|
||||
async def _handle_mcp_tool_invocation(
|
||||
@@ -587,15 +588,17 @@ class AgentFunctionApp(DFAppBase):
|
||||
|
||||
# Parse JSON context string
|
||||
try:
|
||||
parsed_context = json.loads(context)
|
||||
parsed_context: Any = json.loads(context)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid MCP context format: {e}") from e
|
||||
|
||||
parsed_context = cast(Mapping[str, Any], parsed_context) if isinstance(parsed_context, dict) else {}
|
||||
|
||||
# Extract arguments from MCP context
|
||||
arguments = parsed_context.get("arguments", {}) if isinstance(parsed_context, dict) else {}
|
||||
arguments: dict[str, Any] = parsed_context.get("arguments", {})
|
||||
|
||||
# Validate required 'query' argument
|
||||
query = arguments.get("query")
|
||||
query: Any = arguments.get("query")
|
||||
if not query or not isinstance(query, str):
|
||||
raise ValueError("MCP Tool invocation is missing required 'query' argument of type string.")
|
||||
|
||||
@@ -951,10 +954,9 @@ class AgentFunctionApp(DFAppBase):
|
||||
"""Create a lowercase header mapping from the incoming request."""
|
||||
headers: dict[str, str] = {}
|
||||
raw_headers = req.headers
|
||||
if isinstance(raw_headers, Mapping):
|
||||
for key, value in raw_headers.items():
|
||||
if value is not None:
|
||||
headers[str(key).lower()] = str(value)
|
||||
for key, value in cast(Mapping[str, str], raw_headers).items():
|
||||
headers[key.lower()] = value
|
||||
|
||||
return headers
|
||||
|
||||
@staticmethod
|
||||
|
||||
+154
-120
@@ -32,7 +32,7 @@ from __future__ import annotations
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from agent_framework import (
|
||||
AgentRunResponse,
|
||||
@@ -74,6 +74,130 @@ def _parse_created_at(value: Any) -> datetime:
|
||||
return datetime.now(tz=timezone.utc)
|
||||
|
||||
|
||||
def _parse_messages(data: dict[str, Any]) -> list[DurableAgentStateMessage]:
|
||||
"""Parse messages from a dictionary, converting dicts to DurableAgentStateMessage objects.
|
||||
|
||||
Args:
|
||||
data: Dictionary containing a 'messages' key with a list of message data
|
||||
|
||||
Returns:
|
||||
List of DurableAgentStateMessage objects
|
||||
"""
|
||||
messages: list[DurableAgentStateMessage] = []
|
||||
raw_messages: list[Any] = data.get("messages", [])
|
||||
for raw_msg in raw_messages:
|
||||
if isinstance(raw_msg, dict):
|
||||
messages.append(DurableAgentStateMessage.from_dict(cast(dict[str, Any], raw_msg)))
|
||||
elif isinstance(raw_msg, DurableAgentStateMessage):
|
||||
messages.append(raw_msg)
|
||||
return messages
|
||||
|
||||
|
||||
def _parse_history_entries(data_dict: dict[str, Any]) -> list[DurableAgentStateEntry]:
|
||||
"""Parse conversation history entries from a dictionary.
|
||||
|
||||
Args:
|
||||
data_dict: Dictionary containing a 'conversationHistory' key with a list of entry data
|
||||
|
||||
Returns:
|
||||
List of DurableAgentStateEntry objects (requests and responses)
|
||||
"""
|
||||
history_data: list[Any] = data_dict.get("conversationHistory", [])
|
||||
deserialized_history: list[DurableAgentStateEntry] = []
|
||||
for raw_entry in history_data:
|
||||
if isinstance(raw_entry, dict):
|
||||
entry_dict = cast(dict[str, Any], raw_entry)
|
||||
entry_type = entry_dict.get("$type") or entry_dict.get("json_type")
|
||||
if entry_type == DurableAgentStateEntryJsonType.RESPONSE:
|
||||
deserialized_history.append(DurableAgentStateResponse.from_dict(entry_dict))
|
||||
elif entry_type == DurableAgentStateEntryJsonType.REQUEST:
|
||||
deserialized_history.append(DurableAgentStateRequest.from_dict(entry_dict))
|
||||
else:
|
||||
deserialized_history.append(DurableAgentStateEntry.from_dict(entry_dict))
|
||||
elif isinstance(raw_entry, DurableAgentStateEntry):
|
||||
deserialized_history.append(raw_entry)
|
||||
return deserialized_history
|
||||
|
||||
|
||||
def _parse_contents(data: dict[str, Any]) -> list[DurableAgentStateContent]:
|
||||
"""Parse content items from a dictionary.
|
||||
|
||||
Args:
|
||||
data: Dictionary containing a 'contents' key with a list of content data
|
||||
|
||||
Returns:
|
||||
List of DurableAgentStateContent objects
|
||||
"""
|
||||
contents: list[DurableAgentStateContent] = []
|
||||
raw_contents: list[Any] = data.get("contents", [])
|
||||
for raw_content in raw_contents:
|
||||
if isinstance(raw_content, dict):
|
||||
content_dict = cast(dict[str, Any], raw_content)
|
||||
content_type: str | None = content_dict.get("$type")
|
||||
if content_type == DurableAgentStateTextContent.type:
|
||||
contents.append(DurableAgentStateTextContent(text=content_dict.get("text")))
|
||||
elif content_type == DurableAgentStateDataContent.type:
|
||||
contents.append(
|
||||
DurableAgentStateDataContent(
|
||||
uri=str(content_dict.get("uri", "")),
|
||||
media_type=content_dict.get("mediaType"),
|
||||
)
|
||||
)
|
||||
elif content_type == DurableAgentStateErrorContent.type:
|
||||
contents.append(
|
||||
DurableAgentStateErrorContent(
|
||||
message=content_dict.get("message"),
|
||||
error_code=content_dict.get("errorCode"),
|
||||
details=content_dict.get("details"),
|
||||
)
|
||||
)
|
||||
elif content_type == DurableAgentStateFunctionCallContent.type:
|
||||
contents.append(
|
||||
DurableAgentStateFunctionCallContent(
|
||||
call_id=str(content_dict.get("callId", "")),
|
||||
name=str(content_dict.get("name", "")),
|
||||
arguments=content_dict.get("arguments", {}),
|
||||
)
|
||||
)
|
||||
elif content_type == DurableAgentStateFunctionResultContent.type:
|
||||
contents.append(
|
||||
DurableAgentStateFunctionResultContent(
|
||||
call_id=str(content_dict.get("callId", "")),
|
||||
result=content_dict.get("result"),
|
||||
)
|
||||
)
|
||||
elif content_type == DurableAgentStateHostedFileContent.type:
|
||||
contents.append(DurableAgentStateHostedFileContent(file_id=str(content_dict.get("fileId", ""))))
|
||||
elif content_type == DurableAgentStateHostedVectorStoreContent.type:
|
||||
contents.append(
|
||||
DurableAgentStateHostedVectorStoreContent(
|
||||
vector_store_id=str(content_dict.get("vectorStoreId", ""))
|
||||
)
|
||||
)
|
||||
elif content_type == DurableAgentStateTextReasoningContent.type:
|
||||
contents.append(DurableAgentStateTextReasoningContent(text=content_dict.get("text")))
|
||||
elif content_type == DurableAgentStateUriContent.type:
|
||||
contents.append(
|
||||
DurableAgentStateUriContent(
|
||||
uri=str(content_dict.get("uri", "")),
|
||||
media_type=str(content_dict.get("mediaType", "")),
|
||||
)
|
||||
)
|
||||
elif content_type == DurableAgentStateUsageContent.type:
|
||||
usage_data = content_dict.get("usage")
|
||||
if usage_data and isinstance(usage_data, dict):
|
||||
contents.append(
|
||||
DurableAgentStateUsageContent(
|
||||
usage=DurableAgentStateUsage.from_dict(cast(dict[str, Any], usage_data))
|
||||
)
|
||||
)
|
||||
elif content_type == DurableAgentStateUnknownContent.type:
|
||||
contents.append(DurableAgentStateUnknownContent(content=content_dict.get("content", {})))
|
||||
elif isinstance(raw_content, DurableAgentStateContent):
|
||||
contents.append(raw_content)
|
||||
return contents
|
||||
|
||||
|
||||
class DurableAgentStateContent:
|
||||
"""Base class for all content types in durable agent state messages.
|
||||
|
||||
@@ -197,25 +321,8 @@ class DurableAgentStateData:
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data_dict: dict[str, Any]) -> DurableAgentStateData:
|
||||
# Restore the conversation history - deserialize entries from dicts to objects
|
||||
history_data = data_dict.get("conversationHistory", [])
|
||||
deserialized_history: list[DurableAgentStateEntry] = []
|
||||
for entry_dict in history_data:
|
||||
if isinstance(entry_dict, dict):
|
||||
# Deserialize based on $type discriminator
|
||||
entry_type = entry_dict.get("$type") or entry_dict.get("json_type")
|
||||
if entry_type == DurableAgentStateEntryJsonType.RESPONSE:
|
||||
deserialized_history.append(DurableAgentStateResponse.from_dict(entry_dict))
|
||||
elif entry_type == DurableAgentStateEntryJsonType.REQUEST:
|
||||
deserialized_history.append(DurableAgentStateRequest.from_dict(entry_dict))
|
||||
else:
|
||||
deserialized_history.append(DurableAgentStateEntry.from_dict(entry_dict))
|
||||
else:
|
||||
# Already an object
|
||||
deserialized_history.append(entry_dict)
|
||||
|
||||
return cls(
|
||||
conversation_history=deserialized_history,
|
||||
conversation_history=_parse_history_entries(data_dict),
|
||||
extension_data=data_dict.get("extensionData"),
|
||||
)
|
||||
|
||||
@@ -227,7 +334,7 @@ class DurableAgentState:
|
||||
in Azure Durable Entities. It maintains the conversation history as a sequence of request
|
||||
and response entries, each with their messages, timestamps, and metadata.
|
||||
|
||||
The state follows a versioned schema (currently 1.0.0) that defines the structure for:
|
||||
The state follows a versioned schema (see SCHEMA_VERSION class constant) that defines the structure for:
|
||||
- Request entries: User/system messages with optional response format specifications
|
||||
- Response entries: Assistant messages with token usage information
|
||||
- Messages: Individual chat messages with role, content items, and timestamps
|
||||
@@ -235,7 +342,7 @@ class DurableAgentState:
|
||||
|
||||
State is serialized to JSON with this structure:
|
||||
{
|
||||
"schemaVersion": "1.0.0",
|
||||
"schemaVersion": "<SCHEMA_VERSION>",
|
||||
"data": {
|
||||
"conversationHistory": [
|
||||
{"$type": "request", "correlationId": "...", "createdAt": "...", "messages": [...]},
|
||||
@@ -246,17 +353,20 @@ class DurableAgentState:
|
||||
|
||||
Attributes:
|
||||
data: Container for conversation history and optional extension data
|
||||
schema_version: Schema version string (defaults to "1.0.0")
|
||||
schema_version: Schema version string (defaults to SCHEMA_VERSION)
|
||||
"""
|
||||
|
||||
data: DurableAgentStateData
|
||||
schema_version: str = "1.0.0"
|
||||
# Durable Agent Schema version
|
||||
SCHEMA_VERSION: str = "1.1.0"
|
||||
|
||||
def __init__(self, schema_version: str = "1.0.0"):
|
||||
data: DurableAgentStateData
|
||||
schema_version: str = SCHEMA_VERSION
|
||||
|
||||
def __init__(self, schema_version: str = SCHEMA_VERSION):
|
||||
"""Initialize a new durable agent state.
|
||||
|
||||
Args:
|
||||
schema_version: Schema version to use (defaults to "1.0.0")
|
||||
schema_version: Schema version to use (defaults to SCHEMA_VERSION)
|
||||
"""
|
||||
self.data = DurableAgentStateData()
|
||||
self.schema_version = schema_version
|
||||
@@ -283,7 +393,7 @@ class DurableAgentState:
|
||||
logger.warning("Resetting state as it is incompatible with the current schema, all history will be lost")
|
||||
return cls()
|
||||
|
||||
instance = cls(schema_version=state.get("schemaVersion", "1.0.0"))
|
||||
instance = cls(schema_version=state.get("schemaVersion", DurableAgentState.SCHEMA_VERSION))
|
||||
instance.data = DurableAgentStateData.from_dict(state.get("data", {}))
|
||||
|
||||
return instance
|
||||
@@ -325,7 +435,7 @@ class DurableAgentState:
|
||||
if entry.correlation_id == correlation_id and isinstance(entry, DurableAgentStateResponse):
|
||||
# Found the entry, extract response data
|
||||
# Get the text content from assistant messages only
|
||||
content = "\n".join(message.text for message in entry.messages if message.text is not None)
|
||||
content = "\n".join(message.text for message in entry.messages if message.text)
|
||||
|
||||
return {"content": content, "message_count": self.message_count, "correlationId": correlation_id}
|
||||
return None
|
||||
@@ -388,28 +498,17 @@ class DurableAgentStateEntry:
|
||||
self.extension_data = extension_data
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
# Ensure createdAt is never null
|
||||
created_at_value = self.created_at
|
||||
if created_at_value is None:
|
||||
created_at_value = datetime.now(tz=timezone.utc)
|
||||
|
||||
return {
|
||||
"$type": self.json_type,
|
||||
"correlationId": self.correlation_id,
|
||||
"createdAt": created_at_value.isoformat() if isinstance(created_at_value, datetime) else created_at_value,
|
||||
"createdAt": self.created_at.isoformat(),
|
||||
"messages": [m.to_dict() for m in self.messages],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> DurableAgentStateEntry:
|
||||
created_at = _parse_created_at(data.get("created_at"))
|
||||
|
||||
messages = []
|
||||
for msg_dict in data.get("messages", []):
|
||||
if isinstance(msg_dict, dict):
|
||||
messages.append(DurableAgentStateMessage.from_dict(msg_dict))
|
||||
else:
|
||||
messages.append(msg_dict)
|
||||
messages = _parse_messages(data)
|
||||
|
||||
return cls(
|
||||
json_type=DurableAgentStateEntryJsonType(data.get("$type", "entry")),
|
||||
@@ -430,6 +529,7 @@ class DurableAgentStateRequest(DurableAgentStateEntry):
|
||||
Attributes:
|
||||
response_type: Expected response type ("text" or "json")
|
||||
response_schema: JSON schema for structured responses (when response_type is "json")
|
||||
orchestration_id: ID of the orchestration that initiated this request (if any)
|
||||
correlationId: Unique identifier linking this request to its response
|
||||
created_at: Timestamp when the request was created
|
||||
messages: List of messages included in this request
|
||||
@@ -438,6 +538,7 @@ class DurableAgentStateRequest(DurableAgentStateEntry):
|
||||
|
||||
response_type: str | None = None
|
||||
response_schema: dict[str, Any] | None = None
|
||||
orchestration_id: str | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -447,6 +548,7 @@ class DurableAgentStateRequest(DurableAgentStateEntry):
|
||||
extension_data: dict[str, Any] | None = None,
|
||||
response_type: str | None = None,
|
||||
response_schema: dict[str, Any] | None = None,
|
||||
orchestration_id: str | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
json_type=DurableAgentStateEntryJsonType.REQUEST,
|
||||
@@ -457,9 +559,12 @@ class DurableAgentStateRequest(DurableAgentStateEntry):
|
||||
)
|
||||
self.response_type = response_type
|
||||
self.response_schema = response_schema
|
||||
self.orchestration_id = orchestration_id
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
data = super().to_dict()
|
||||
if self.orchestration_id is not None:
|
||||
data["orchestrationId"] = self.orchestration_id
|
||||
if self.response_type is not None:
|
||||
data["responseType"] = self.response_type
|
||||
if self.response_schema is not None:
|
||||
@@ -469,13 +574,7 @@ class DurableAgentStateRequest(DurableAgentStateEntry):
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> DurableAgentStateRequest:
|
||||
created_at = _parse_created_at(data.get("created_at"))
|
||||
|
||||
messages = []
|
||||
for msg_dict in data.get("messages", []):
|
||||
if isinstance(msg_dict, dict):
|
||||
messages.append(DurableAgentStateMessage.from_dict(msg_dict))
|
||||
else:
|
||||
messages.append(msg_dict)
|
||||
messages = _parse_messages(data)
|
||||
|
||||
return cls(
|
||||
correlation_id=data.get("correlationId", ""),
|
||||
@@ -484,6 +583,7 @@ class DurableAgentStateRequest(DurableAgentStateEntry):
|
||||
extension_data=data.get("extensionData"),
|
||||
response_type=data.get("responseType"),
|
||||
response_schema=data.get("responseSchema"),
|
||||
orchestration_id=data.get("orchestrationId"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -495,6 +595,7 @@ class DurableAgentStateRequest(DurableAgentStateEntry):
|
||||
created_at=datetime.now(tz=timezone.utc),
|
||||
response_type=request.request_response_format,
|
||||
response_schema=serialize_response_format(request.response_format),
|
||||
orchestration_id=request.orchestration_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -545,20 +646,12 @@ class DurableAgentStateResponse(DurableAgentStateEntry):
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> DurableAgentStateResponse:
|
||||
created_at = _parse_created_at(data.get("created_at"))
|
||||
|
||||
messages = []
|
||||
for msg_dict in data.get("messages", []):
|
||||
if isinstance(msg_dict, dict):
|
||||
messages.append(DurableAgentStateMessage.from_dict(msg_dict))
|
||||
else:
|
||||
messages.append(msg_dict)
|
||||
messages = _parse_messages(data)
|
||||
|
||||
usage_dict = data.get("usage")
|
||||
usage = None
|
||||
usage: DurableAgentStateUsage | None = None
|
||||
if usage_dict and isinstance(usage_dict, dict):
|
||||
usage = DurableAgentStateUsage.from_dict(usage_dict)
|
||||
elif usage_dict:
|
||||
usage = usage_dict
|
||||
usage = DurableAgentStateUsage.from_dict(cast(dict[str, Any], usage_dict))
|
||||
|
||||
return cls(
|
||||
correlation_id=data.get("correlationId", ""),
|
||||
@@ -639,68 +732,9 @@ class DurableAgentStateMessage:
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> DurableAgentStateMessage:
|
||||
contents: list[DurableAgentStateContent] = []
|
||||
for content_dict in data.get("contents", []):
|
||||
if isinstance(content_dict, dict):
|
||||
content_type = content_dict.get("$type")
|
||||
if content_type == DurableAgentStateTextContent.type:
|
||||
contents.append(DurableAgentStateTextContent(text=content_dict.get("text")))
|
||||
elif content_type == DurableAgentStateDataContent.type:
|
||||
contents.append(
|
||||
DurableAgentStateDataContent(
|
||||
uri=content_dict.get("uri", ""), media_type=content_dict.get("mediaType")
|
||||
)
|
||||
)
|
||||
elif content_type == DurableAgentStateErrorContent.type:
|
||||
contents.append(
|
||||
DurableAgentStateErrorContent(
|
||||
message=content_dict.get("message"),
|
||||
error_code=content_dict.get("errorCode"),
|
||||
details=content_dict.get("details"),
|
||||
)
|
||||
)
|
||||
elif content_type == DurableAgentStateFunctionCallContent.type:
|
||||
contents.append(
|
||||
DurableAgentStateFunctionCallContent(
|
||||
call_id=content_dict.get("callId", ""),
|
||||
name=content_dict.get("name", ""),
|
||||
arguments=content_dict.get("arguments", {}),
|
||||
)
|
||||
)
|
||||
elif content_type == DurableAgentStateFunctionResultContent.type:
|
||||
contents.append(
|
||||
DurableAgentStateFunctionResultContent(
|
||||
call_id=content_dict.get("callId", ""), result=content_dict.get("result")
|
||||
)
|
||||
)
|
||||
elif content_type == DurableAgentStateHostedFileContent.type:
|
||||
contents.append(DurableAgentStateHostedFileContent(file_id=content_dict.get("fileId", "")))
|
||||
elif content_type == DurableAgentStateHostedVectorStoreContent.type:
|
||||
contents.append(
|
||||
DurableAgentStateHostedVectorStoreContent(vector_store_id=content_dict.get("vectorStoreId", ""))
|
||||
)
|
||||
elif content_type == DurableAgentStateTextReasoningContent.type:
|
||||
contents.append(DurableAgentStateTextReasoningContent(text=content_dict.get("text")))
|
||||
elif content_type == DurableAgentStateUriContent.type:
|
||||
contents.append(
|
||||
DurableAgentStateUriContent(
|
||||
uri=content_dict.get("uri", ""), media_type=content_dict.get("mediaType", "")
|
||||
)
|
||||
)
|
||||
elif content_type == DurableAgentStateUsageContent.type:
|
||||
usage_data = content_dict.get("usage")
|
||||
if usage_data and isinstance(usage_data, dict):
|
||||
contents.append(
|
||||
DurableAgentStateUsageContent(usage=DurableAgentStateUsage.from_dict(usage_data))
|
||||
)
|
||||
elif content_type == DurableAgentStateUnknownContent.type:
|
||||
contents.append(DurableAgentStateUnknownContent(content=content_dict.get("content", {})))
|
||||
else:
|
||||
contents.append(content_dict) # type: ignore
|
||||
|
||||
return cls(
|
||||
role=data.get("role", ""),
|
||||
contents=contents,
|
||||
contents=_parse_contents(data),
|
||||
author_name=data.get("authorName"),
|
||||
created_at=_parse_created_at(data.get("createdAt")),
|
||||
extension_data=data.get("extensionData"),
|
||||
@@ -709,7 +743,7 @@ class DurableAgentStateMessage:
|
||||
@property
|
||||
def text(self) -> str:
|
||||
"""Extract text from the contents list."""
|
||||
text_parts = []
|
||||
text_parts: list[str] = []
|
||||
for content in self.contents:
|
||||
if isinstance(content, DurableAgentStateTextContent):
|
||||
text_parts.append(content.text or "")
|
||||
|
||||
@@ -287,6 +287,7 @@ class RunRequest:
|
||||
thread_id: Optional thread ID for tracking
|
||||
correlation_id: Optional correlation ID for tracking the response to this specific request
|
||||
created_at: Optional timestamp when the request was created
|
||||
orchestration_id: Optional ID of the orchestration that initiated this request
|
||||
"""
|
||||
|
||||
message: str
|
||||
@@ -297,6 +298,7 @@ class RunRequest:
|
||||
thread_id: str | None = None
|
||||
correlation_id: str | None = None
|
||||
created_at: str | None = None
|
||||
orchestration_id: str | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -308,6 +310,7 @@ class RunRequest:
|
||||
thread_id: str | None = None,
|
||||
correlation_id: str | None = None,
|
||||
created_at: str | None = None,
|
||||
orchestration_id: str | None = None,
|
||||
) -> None:
|
||||
self.message = message
|
||||
self.role = self.coerce_role(role)
|
||||
@@ -317,6 +320,7 @@ class RunRequest:
|
||||
self.thread_id = thread_id
|
||||
self.correlation_id = correlation_id
|
||||
self.created_at = created_at
|
||||
self.orchestration_id = orchestration_id
|
||||
|
||||
@staticmethod
|
||||
def coerce_role(value: Role | str | None) -> Role:
|
||||
@@ -346,6 +350,8 @@ class RunRequest:
|
||||
result["correlationId"] = self.correlation_id
|
||||
if self.created_at:
|
||||
result["created_at"] = self.created_at
|
||||
if self.orchestration_id:
|
||||
result["orchestrationId"] = self.orchestration_id
|
||||
|
||||
return result
|
||||
|
||||
@@ -361,4 +367,5 @@ class RunRequest:
|
||||
thread_id=data.get("thread_id"),
|
||||
correlation_id=data.get("correlationId"),
|
||||
created_at=data.get("created_at"),
|
||||
orchestration_id=data.get("orchestrationId"),
|
||||
)
|
||||
|
||||
@@ -272,12 +272,14 @@ class DurableAIAgent(AgentProtocol):
|
||||
)
|
||||
|
||||
# Prepare the request using RunRequest model
|
||||
# Include the orchestration's instance_id so it can be stored in the agent's entity state
|
||||
run_request = RunRequest(
|
||||
message=message_str,
|
||||
enable_tool_calls=enable_tool_calls,
|
||||
correlation_id=correlation_id,
|
||||
thread_id=session_id.key,
|
||||
response_format=response_format,
|
||||
orchestration_id=self.context.instance_id,
|
||||
)
|
||||
|
||||
logger.debug("[DurableAIAgent] Calling entity %s with message: %s", entity_id, message_str[:100])
|
||||
|
||||
@@ -79,7 +79,7 @@ class TestAgentEntityInit:
|
||||
assert entity.agent == mock_agent
|
||||
assert len(entity.state.data.conversation_history) == 0
|
||||
assert entity.state.data.extension_data is None
|
||||
assert entity.state.schema_version == "1.0.0"
|
||||
assert entity.state.schema_version == DurableAgentState.SCHEMA_VERSION
|
||||
|
||||
def test_init_stores_agent_reference(self) -> None:
|
||||
"""Test that the agent reference is stored correctly."""
|
||||
@@ -124,8 +124,7 @@ class TestAgentEntityRunAgent:
|
||||
# Verify agent.run was called
|
||||
mock_agent.run.assert_called_once()
|
||||
_, kwargs = mock_agent.run.call_args
|
||||
sent_messages = kwargs.get("messages")
|
||||
assert isinstance(sent_messages, list)
|
||||
sent_messages: list[Any] = kwargs.get("messages")
|
||||
assert len(sent_messages) == 1
|
||||
sent_message = sent_messages[0]
|
||||
assert isinstance(sent_message, ChatMessage)
|
||||
@@ -910,5 +909,98 @@ class TestRunRequestSupport:
|
||||
assert text_found, f"Response text not found in message: {message}"
|
||||
|
||||
|
||||
class TestDurableAgentStateRequestOrchestrationId:
|
||||
"""Test suite for DurableAgentStateRequest orchestration_id field."""
|
||||
|
||||
def test_request_with_orchestration_id(self) -> None:
|
||||
"""Test creating a request with an orchestration_id."""
|
||||
request = DurableAgentStateRequest(
|
||||
correlation_id="corr-123",
|
||||
created_at=datetime.now(),
|
||||
messages=[
|
||||
DurableAgentStateMessage(
|
||||
role="user",
|
||||
contents=[DurableAgentStateTextContent(text="test")],
|
||||
)
|
||||
],
|
||||
orchestration_id="orch-456",
|
||||
)
|
||||
|
||||
assert request.orchestration_id == "orch-456"
|
||||
|
||||
def test_request_to_dict_includes_orchestration_id(self) -> None:
|
||||
"""Test that to_dict includes orchestrationId when set."""
|
||||
request = DurableAgentStateRequest(
|
||||
correlation_id="corr-123",
|
||||
created_at=datetime.now(),
|
||||
messages=[
|
||||
DurableAgentStateMessage(
|
||||
role="user",
|
||||
contents=[DurableAgentStateTextContent(text="test")],
|
||||
)
|
||||
],
|
||||
orchestration_id="orch-789",
|
||||
)
|
||||
|
||||
data = request.to_dict()
|
||||
|
||||
assert "orchestrationId" in data
|
||||
assert data["orchestrationId"] == "orch-789"
|
||||
|
||||
def test_request_to_dict_excludes_orchestration_id_when_none(self) -> None:
|
||||
"""Test that to_dict excludes orchestrationId when not set."""
|
||||
request = DurableAgentStateRequest(
|
||||
correlation_id="corr-123",
|
||||
created_at=datetime.now(),
|
||||
messages=[
|
||||
DurableAgentStateMessage(
|
||||
role="user",
|
||||
contents=[DurableAgentStateTextContent(text="test")],
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
data = request.to_dict()
|
||||
|
||||
assert "orchestrationId" not in data
|
||||
|
||||
def test_request_from_dict_with_orchestration_id(self) -> None:
|
||||
"""Test from_dict correctly parses orchestrationId."""
|
||||
data = {
|
||||
"$type": "request",
|
||||
"correlationId": "corr-123",
|
||||
"createdAt": "2024-01-01T00:00:00Z",
|
||||
"messages": [{"role": "user", "contents": [{"$type": "text", "text": "test"}]}],
|
||||
"orchestrationId": "orch-from-dict",
|
||||
}
|
||||
|
||||
request = DurableAgentStateRequest.from_dict(data)
|
||||
|
||||
assert request.orchestration_id == "orch-from-dict"
|
||||
|
||||
def test_request_from_run_request_with_orchestration_id(self) -> None:
|
||||
"""Test from_run_request correctly transfers orchestration_id."""
|
||||
run_request = RunRequest(
|
||||
message="test message",
|
||||
correlation_id="corr-run",
|
||||
orchestration_id="orch-from-run-request",
|
||||
)
|
||||
|
||||
durable_request = DurableAgentStateRequest.from_run_request(run_request)
|
||||
|
||||
assert durable_request.orchestration_id == "orch-from-run-request"
|
||||
|
||||
def test_request_from_run_request_without_orchestration_id(self) -> None:
|
||||
"""Test from_run_request correctly handles missing orchestration_id."""
|
||||
run_request = RunRequest(
|
||||
message="test message",
|
||||
correlation_id="corr-run",
|
||||
)
|
||||
|
||||
durable_request = DurableAgentStateRequest.from_run_request(run_request)
|
||||
|
||||
assert durable_request.orchestration_id is None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "--tb=short"])
|
||||
|
||||
@@ -336,6 +336,71 @@ class TestRunRequest:
|
||||
assert restored.correlation_id == original.correlation_id
|
||||
assert restored.thread_id == original.thread_id
|
||||
|
||||
def test_init_with_orchestration_id(self) -> None:
|
||||
"""Test RunRequest initialization with orchestration_id."""
|
||||
request = RunRequest(
|
||||
message="Test message",
|
||||
thread_id="thread-orch-init",
|
||||
orchestration_id="orch-123",
|
||||
)
|
||||
|
||||
assert request.message == "Test message"
|
||||
assert request.orchestration_id == "orch-123"
|
||||
|
||||
def test_to_dict_with_orchestration_id(self) -> None:
|
||||
"""Test to_dict includes orchestrationId."""
|
||||
request = RunRequest(
|
||||
message="Test",
|
||||
thread_id="thread-orch-to-dict",
|
||||
orchestration_id="orch-456",
|
||||
)
|
||||
data = request.to_dict()
|
||||
|
||||
assert data["message"] == "Test"
|
||||
assert data["orchestrationId"] == "orch-456"
|
||||
|
||||
def test_to_dict_excludes_orchestration_id_when_none(self) -> None:
|
||||
"""Test to_dict excludes orchestrationId when not set."""
|
||||
request = RunRequest(
|
||||
message="Test",
|
||||
thread_id="thread-orch-none",
|
||||
)
|
||||
data = request.to_dict()
|
||||
|
||||
assert "orchestrationId" not in data
|
||||
|
||||
def test_from_dict_with_orchestration_id(self) -> None:
|
||||
"""Test from_dict with orchestrationId."""
|
||||
data = {
|
||||
"message": "Test",
|
||||
"orchestrationId": "orch-789",
|
||||
"thread_id": "thread-orch-from-dict",
|
||||
}
|
||||
request = RunRequest.from_dict(data)
|
||||
|
||||
assert request.message == "Test"
|
||||
assert request.orchestration_id == "orch-789"
|
||||
assert request.thread_id == "thread-orch-from-dict"
|
||||
|
||||
def test_round_trip_with_orchestration_id(self) -> None:
|
||||
"""Test round-trip to_dict and from_dict with orchestration_id."""
|
||||
original = RunRequest(
|
||||
message="Test message",
|
||||
thread_id="thread-123",
|
||||
role=Role.SYSTEM,
|
||||
correlation_id="corr-123",
|
||||
orchestration_id="orch-123",
|
||||
)
|
||||
|
||||
data = original.to_dict()
|
||||
restored = RunRequest.from_dict(data)
|
||||
|
||||
assert restored.message == original.message
|
||||
assert restored.role == original.role
|
||||
assert restored.correlation_id == original.correlation_id
|
||||
assert restored.orchestration_id == original.orchestration_id
|
||||
assert restored.thread_id == original.thread_id
|
||||
|
||||
|
||||
class TestModelIntegration:
|
||||
"""Test suite for integration between models."""
|
||||
|
||||
@@ -302,6 +302,28 @@ class TestDurableAIAgent:
|
||||
assert request["correlationId"] == "correlation-guid"
|
||||
assert "thread_id" in request
|
||||
assert request["thread_id"] == "thread-guid"
|
||||
# Verify orchestration ID is set from context.instance_id
|
||||
assert "orchestrationId" in request
|
||||
assert request["orchestrationId"] == "test-instance-001"
|
||||
|
||||
def test_run_sets_orchestration_id(self) -> None:
|
||||
"""Test that run() sets the orchestration_id from context.instance_id."""
|
||||
mock_context = Mock()
|
||||
mock_context.instance_id = "my-orchestration-123"
|
||||
mock_context.new_uuid = Mock(side_effect=["thread-guid", "correlation-guid"])
|
||||
|
||||
entity_task = _create_entity_task()
|
||||
mock_context.call_entity = Mock(return_value=entity_task)
|
||||
|
||||
agent = DurableAIAgent(mock_context, "TestAgent")
|
||||
thread = agent.get_new_thread()
|
||||
|
||||
agent.run(messages="Test", thread=thread)
|
||||
|
||||
call_args = mock_context.call_entity.call_args
|
||||
request = call_args[0][2]
|
||||
|
||||
assert request["orchestrationId"] == "my-orchestration-123"
|
||||
|
||||
def test_run_without_thread(self) -> None:
|
||||
"""Test that run() works without explicit thread (creates unique session key)."""
|
||||
|
||||
Generated
+1
-1
@@ -1806,7 +1806,7 @@ name = "exceptiongroup"
|
||||
version = "1.3.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "typing-extensions", marker = "(python_full_version < '3.11' and sys_platform == 'darwin') or (python_full_version < '3.11' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform == 'win32')" },
|
||||
{ name = "typing-extensions", marker = "(python_full_version < '3.13' and sys_platform == 'darwin') or (python_full_version < '3.13' and sys_platform == 'linux') or (python_full_version < '3.13' and sys_platform == 'win32')" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/50/79/66800aadf48771f6b62f7eb014e352e5d06856655206165d775e675a02c9/exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219", size = 30371, upload-time = "2025-11-21T23:01:54.787Z" }
|
||||
wheels = [
|
||||
|
||||
Reference in New Issue
Block a user