mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
semi working
This commit is contained in:
@@ -16,11 +16,11 @@ import azure.functions as func
|
||||
from agent_framework import AgentProtocol, get_logger
|
||||
|
||||
from ._callbacks import AgentResponseCallbackProtocol
|
||||
from ._durable_agent_state import DurableAgentState
|
||||
from ._entities import create_agent_entity
|
||||
from ._errors import IncomingRequestError
|
||||
from ._models import AgentSessionId, RunRequest
|
||||
from ._orchestration import AgentOrchestrationContextType, DurableAIAgent
|
||||
from ._state import AgentState
|
||||
|
||||
logger = get_logger("agent_framework.azurefunctions")
|
||||
|
||||
@@ -34,9 +34,6 @@ WAIT_FOR_RESPONSE_HEADER: str = "x-ms-wait-for-response"
|
||||
EntityHandler = Callable[[df.DurableEntityContext], None]
|
||||
HandlerT = TypeVar("HandlerT", bound=Callable[..., Any])
|
||||
|
||||
DEFAULT_MAX_POLL_RETRIES: int = 30
|
||||
DEFAULT_POLL_INTERVAL_SECONDS: float = 1.0
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class DFAppBase:
|
||||
@@ -73,16 +70,17 @@ class AgentFunctionApp(DFAppBase):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from agent_framework.azure import AgentFunctionApp, AzureOpenAIChatClient
|
||||
from agent_framework.azure import AgentFunctionApp
|
||||
from agent_framework.azure import AzureOpenAIAssistantsClient
|
||||
|
||||
# Create agents with unique names
|
||||
weather_agent = AzureOpenAIChatClient(...).create_agent(
|
||||
weather_agent = AzureOpenAIAssistantsClient(...).create_agent(
|
||||
name="WeatherAgent",
|
||||
instructions="You are a helpful weather agent.",
|
||||
tools=[get_weather],
|
||||
)
|
||||
|
||||
math_agent = AzureOpenAIChatClient(...).create_agent(
|
||||
math_agent = AzureOpenAIAssistantsClient(...).create_agent(
|
||||
name="MathAgent",
|
||||
instructions="You are a helpful math assistant.",
|
||||
tools=[calculate],
|
||||
@@ -130,23 +128,23 @@ class AgentFunctionApp(DFAppBase):
|
||||
http_auth_level: func.AuthLevel = func.AuthLevel.FUNCTION,
|
||||
enable_health_check: bool = True,
|
||||
enable_http_endpoints: bool = True,
|
||||
max_poll_retries: int = DEFAULT_MAX_POLL_RETRIES,
|
||||
poll_interval_seconds: float = DEFAULT_POLL_INTERVAL_SECONDS,
|
||||
max_poll_retries: int = 30,
|
||||
poll_interval_seconds: float = 1,
|
||||
default_callback: AgentResponseCallbackProtocol | None = None,
|
||||
):
|
||||
"""Initialize the AgentFunctionApp.
|
||||
|
||||
:param agents: List of agent instances to register.
|
||||
:param http_auth_level: HTTP authentication level (default: ``func.AuthLevel.FUNCTION``).
|
||||
:param enable_health_check: Enable the built-in health check endpoint (default: ``True``).
|
||||
:param enable_http_endpoints: Enable HTTP endpoints for agents (default: ``True``).
|
||||
:param max_poll_retries: Maximum polling attempts when waiting for a response.
|
||||
Defaults to ``DEFAULT_MAX_POLL_RETRIES``.
|
||||
:param poll_interval_seconds: Delay in seconds between polling attempts.
|
||||
Defaults to ``DEFAULT_POLL_INTERVAL_SECONDS``.
|
||||
:param default_callback: Optional callback invoked for agents without specific callbacks.
|
||||
Args:
|
||||
agents: List of agent instances to register
|
||||
http_auth_level: HTTP authentication level (default: FUNCTION)
|
||||
enable_health_check: Enable built-in health check endpoint (default: True)
|
||||
enable_http_endpoints: Enable HTTP endpoints for agents (default: True)
|
||||
max_poll_retries: Maximum number of polling attempts when waiting for a response
|
||||
poll_interval_seconds: Delay (in seconds) between polling attempts
|
||||
default_callback: Optional callback invoked for agents without specific callbacks
|
||||
|
||||
:note: If no agents are provided, they can be added later using :meth:`add_agent`.
|
||||
Note:
|
||||
If no agents are provided, they can be added later using add_agent().
|
||||
"""
|
||||
logger.debug("[AgentFunctionApp] Initializing with Durable Entities...")
|
||||
|
||||
@@ -163,14 +161,14 @@ class AgentFunctionApp(DFAppBase):
|
||||
try:
|
||||
retries = int(max_poll_retries)
|
||||
except (TypeError, ValueError):
|
||||
retries = DEFAULT_MAX_POLL_RETRIES
|
||||
retries = 10
|
||||
self.max_poll_retries = max(1, retries)
|
||||
|
||||
try:
|
||||
interval = float(poll_interval_seconds)
|
||||
except (TypeError, ValueError):
|
||||
interval = DEFAULT_POLL_INTERVAL_SECONDS
|
||||
self.poll_interval_seconds = interval if interval > 0 else DEFAULT_POLL_INTERVAL_SECONDS
|
||||
interval = 0.5
|
||||
self.poll_interval_seconds = interval if interval > 0 else 0.5
|
||||
|
||||
if agents:
|
||||
# Register all provided agents
|
||||
@@ -339,10 +337,10 @@ class AgentFunctionApp(DFAppBase):
|
||||
)
|
||||
|
||||
session_id = self._create_session_id(agent_name, thread_id)
|
||||
correlation_id = self._generate_unique_id()
|
||||
correlationId = self._generate_unique_id()
|
||||
|
||||
logger.debug(f"[HTTP Trigger] Using session ID: {session_id}")
|
||||
logger.debug(f"[HTTP Trigger] Generated correlation ID: {correlation_id}")
|
||||
logger.debug(f"[HTTP Trigger] Generated correlation ID: {correlationId}")
|
||||
logger.debug("[HTTP Trigger] Calling entity to run agent...")
|
||||
|
||||
entity_instance_id = session_id.to_entity_id()
|
||||
@@ -350,7 +348,7 @@ class AgentFunctionApp(DFAppBase):
|
||||
req_body,
|
||||
message,
|
||||
thread_id,
|
||||
correlation_id,
|
||||
correlationId,
|
||||
)
|
||||
logger.debug("Signalling entity %s with request: %s", entity_instance_id, run_request)
|
||||
await client.signal_entity(entity_instance_id, "run_agent", run_request)
|
||||
@@ -361,7 +359,7 @@ class AgentFunctionApp(DFAppBase):
|
||||
result = await self._get_response_from_entity(
|
||||
client=client,
|
||||
entity_instance_id=entity_instance_id,
|
||||
correlation_id=correlation_id,
|
||||
correlationId=correlationId,
|
||||
message=message,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
@@ -377,7 +375,7 @@ class AgentFunctionApp(DFAppBase):
|
||||
logger.debug("[HTTP Trigger] wait_for_response disabled; returning correlation ID")
|
||||
|
||||
accepted_response = self._build_accepted_response(
|
||||
message=message, thread_id=thread_id, correlation_id=correlation_id
|
||||
message=message, thread_id=thread_id, correlationId=correlationId
|
||||
)
|
||||
|
||||
return self._create_http_response(
|
||||
@@ -491,7 +489,7 @@ class AgentFunctionApp(DFAppBase):
|
||||
self,
|
||||
client: df.DurableOrchestrationClient,
|
||||
entity_instance_id: df.EntityId,
|
||||
) -> AgentState | None:
|
||||
) -> DurableAgentState | None:
|
||||
state_response = await client.read_entity_state(entity_instance_id)
|
||||
if not state_response or not state_response.entity_exists:
|
||||
return None
|
||||
@@ -502,7 +500,7 @@ class AgentFunctionApp(DFAppBase):
|
||||
|
||||
typed_state_payload = cast(dict[str, Any], state_payload)
|
||||
|
||||
agent_state = AgentState()
|
||||
agent_state = DurableAgentState()
|
||||
agent_state.restore_state(typed_state_payload)
|
||||
return agent_state
|
||||
|
||||
@@ -510,7 +508,7 @@ class AgentFunctionApp(DFAppBase):
|
||||
self,
|
||||
client: df.DurableOrchestrationClient,
|
||||
entity_instance_id: df.EntityId,
|
||||
correlation_id: str,
|
||||
correlationId: str,
|
||||
message: str,
|
||||
thread_id: str,
|
||||
) -> dict[str, Any]:
|
||||
@@ -522,7 +520,7 @@ class AgentFunctionApp(DFAppBase):
|
||||
retry_count = 0
|
||||
result: dict[str, Any] | None = None
|
||||
|
||||
logger.debug(f"[HTTP Trigger] Waiting for response with correlation ID: {correlation_id}")
|
||||
logger.debug(f"[HTTP Trigger] Waiting for response with correlation ID: {correlationId}")
|
||||
|
||||
while retry_count < max_retries:
|
||||
await asyncio.sleep(interval)
|
||||
@@ -530,7 +528,7 @@ class AgentFunctionApp(DFAppBase):
|
||||
result = await self._poll_entity_for_response(
|
||||
client=client,
|
||||
entity_instance_id=entity_instance_id,
|
||||
correlation_id=correlation_id,
|
||||
correlationId=correlationId,
|
||||
message=message,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
@@ -544,16 +542,16 @@ class AgentFunctionApp(DFAppBase):
|
||||
return result
|
||||
|
||||
logger.warning(
|
||||
f"[HTTP Trigger] Response with correlation ID {correlation_id} "
|
||||
f"[HTTP Trigger] Response with correlation ID {correlationId} "
|
||||
f"not found in time (waited {max_retries * interval} seconds)"
|
||||
)
|
||||
return await self._build_timeout_result(message=message, thread_id=thread_id, correlation_id=correlation_id)
|
||||
return await self._build_timeout_result(message=message, thread_id=thread_id, correlationId=correlationId)
|
||||
|
||||
async def _poll_entity_for_response(
|
||||
self,
|
||||
client: df.DurableOrchestrationClient,
|
||||
entity_instance_id: df.EntityId,
|
||||
correlation_id: str,
|
||||
correlationId: str,
|
||||
message: str,
|
||||
thread_id: str,
|
||||
) -> dict[str, Any] | None:
|
||||
@@ -564,34 +562,34 @@ class AgentFunctionApp(DFAppBase):
|
||||
if state is None:
|
||||
return None
|
||||
|
||||
agent_response = state.try_get_agent_response(correlation_id)
|
||||
agent_response = state.try_get_agent_response(correlationId)
|
||||
if agent_response:
|
||||
result = self._build_success_result(
|
||||
response_data=agent_response,
|
||||
message=message,
|
||||
thread_id=thread_id,
|
||||
correlation_id=correlation_id,
|
||||
correlationId=correlationId,
|
||||
state=state,
|
||||
)
|
||||
logger.debug(f"[HTTP Trigger] Found response for correlation ID: {correlation_id}")
|
||||
logger.debug(f"[HTTP Trigger] Found response for correlation ID: {correlationId}")
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning(f"[HTTP Trigger] Error reading entity state: {exc}")
|
||||
|
||||
return result
|
||||
|
||||
async def _build_timeout_result(self, message: str, thread_id: str, correlation_id: str) -> dict[str, Any]:
|
||||
async def _build_timeout_result(self, message: str, thread_id: str, correlationId: str) -> dict[str, Any]:
|
||||
"""Create the timeout response."""
|
||||
return {
|
||||
"response": "Agent is still processing or timed out...",
|
||||
"message": message,
|
||||
THREAD_ID_FIELD: thread_id,
|
||||
"status": "timeout",
|
||||
"correlation_id": correlation_id,
|
||||
"correlationId": correlationId,
|
||||
}
|
||||
|
||||
def _build_success_result(
|
||||
self, response_data: dict[str, Any], message: str, thread_id: str, correlation_id: str, state: AgentState
|
||||
self, response_data: dict[str, Any], message: str, thread_id: str, correlationId: str, state: DurableAgentState
|
||||
) -> dict[str, Any]:
|
||||
"""Build the success result returned to the HTTP caller."""
|
||||
return {
|
||||
@@ -600,11 +598,11 @@ class AgentFunctionApp(DFAppBase):
|
||||
THREAD_ID_FIELD: thread_id,
|
||||
"status": "success",
|
||||
"message_count": response_data.get("message_count", state.message_count),
|
||||
"correlation_id": correlation_id,
|
||||
"correlationId": correlationId,
|
||||
}
|
||||
|
||||
def _build_request_data(
|
||||
self, req_body: dict[str, Any], message: str, thread_id: str, correlation_id: str
|
||||
self, req_body: dict[str, Any], message: str, thread_id: str, correlationId: str
|
||||
) -> dict[str, Any]:
|
||||
"""Create the durable entity request payload."""
|
||||
enable_tool_calls_value = req_body.get("enable_tool_calls")
|
||||
@@ -616,17 +614,17 @@ class AgentFunctionApp(DFAppBase):
|
||||
response_format=req_body.get("response_format"),
|
||||
enable_tool_calls=enable_tool_calls,
|
||||
thread_id=thread_id,
|
||||
correlation_id=correlation_id,
|
||||
correlationId=correlationId,
|
||||
).to_dict()
|
||||
|
||||
def _build_accepted_response(self, message: str, thread_id: str, correlation_id: str) -> dict[str, Any]:
|
||||
def _build_accepted_response(self, message: str, thread_id: str, correlationId: str) -> dict[str, Any]:
|
||||
"""Build the response returned when not waiting for completion."""
|
||||
return {
|
||||
"response": "Agent request accepted",
|
||||
"message": message,
|
||||
THREAD_ID_FIELD: thread_id,
|
||||
"status": "accepted",
|
||||
"correlation_id": correlation_id,
|
||||
"correlationId": correlationId,
|
||||
}
|
||||
|
||||
def _create_http_response(
|
||||
@@ -712,8 +710,7 @@ class AgentFunctionApp(DFAppBase):
|
||||
headers: dict[str, str] = {}
|
||||
raw_headers = req.headers
|
||||
if isinstance(raw_headers, Mapping):
|
||||
header_mapping: Mapping[str, Any] = cast(Mapping[str, Any], raw_headers)
|
||||
for key, value in header_mapping.items():
|
||||
for key, value in raw_headers.items():
|
||||
if value is not None:
|
||||
headers[str(key).lower()] = str(value)
|
||||
return headers
|
||||
@@ -774,8 +771,13 @@ class AgentFunctionApp(DFAppBase):
|
||||
|
||||
def _should_wait_for_response(self, req: func.HttpRequest, req_body: dict[str, Any]) -> bool:
|
||||
"""Determine whether the caller requested to wait for the response."""
|
||||
headers: dict[str, str] = self._extract_normalized_headers(req)
|
||||
header_value: str | None = headers.get(WAIT_FOR_RESPONSE_HEADER)
|
||||
header_value = None
|
||||
raw_headers = req.headers
|
||||
if isinstance(raw_headers, Mapping):
|
||||
for key, value in raw_headers.items():
|
||||
if str(key).lower() == WAIT_FOR_RESPONSE_HEADER:
|
||||
header_value = value
|
||||
break
|
||||
|
||||
if header_value is not None:
|
||||
return self._coerce_to_bool(header_value)
|
||||
@@ -799,4 +801,4 @@ class AgentFunctionApp(DFAppBase):
|
||||
return bool(value)
|
||||
if isinstance(value, str):
|
||||
return value.strip().lower() in {"true", "1", "yes", "y", "on"}
|
||||
return False
|
||||
return False
|
||||
@@ -19,7 +19,7 @@ class AgentCallbackContext:
|
||||
"""Context supplied to callback invocations."""
|
||||
|
||||
agent_name: str
|
||||
correlation_id: str
|
||||
correlationId: str
|
||||
thread_id: str | None = None
|
||||
request_message: str | None = None
|
||||
|
||||
|
||||
+821
@@ -0,0 +1,821 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from typing import Any, List, Dict, Optional, cast
|
||||
from datetime import datetime, timezone
|
||||
|
||||
# Base content type
|
||||
|
||||
class DurableAgentStateContent:
|
||||
extensionData: Optional[Dict]
|
||||
|
||||
def to_ai_content(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def from_ai_content(content):
|
||||
# Map AI content type to appropriate DurableAgentStateContent subclass
|
||||
from agent_framework import (
|
||||
DataContent, ErrorContent, FunctionCallContent, FunctionResultContent,
|
||||
HostedFileContent, HostedVectorStoreContent, TextContent,
|
||||
TextReasoningContent, UriContent, UsageContent
|
||||
)
|
||||
|
||||
if isinstance(content, DataContent):
|
||||
return DurableAgentStateDataContent.from_data_content(content)
|
||||
elif isinstance(content, ErrorContent):
|
||||
return DurableAgentStateErrorContent.from_error_content(content)
|
||||
elif isinstance(content, FunctionCallContent):
|
||||
return DurableAgentStateFunctionCallContent.from_function_call_content(content)
|
||||
elif isinstance(content, FunctionResultContent):
|
||||
return DurableAgentStateFunctionResultContent.from_function_result_content(content)
|
||||
elif isinstance(content, HostedFileContent):
|
||||
return DurableAgentStateHostedFileContent.from_hosted_file_content(content)
|
||||
elif isinstance(content, HostedVectorStoreContent):
|
||||
return DurableAgentStateHostedVectorStoreContent.from_hosted_vector_store_content(content)
|
||||
elif isinstance(content, TextContent):
|
||||
return DurableAgentStateTextContent.from_text_content(content)
|
||||
elif isinstance(content, TextReasoningContent):
|
||||
return DurableAgentStateTextReasoningContent.from_text_reasoning_content(content)
|
||||
elif isinstance(content, UriContent):
|
||||
return DurableAgentStateUriContent.from_uri_content(content)
|
||||
elif isinstance(content, UsageContent):
|
||||
return DurableAgentStateUsageContent.from_usage_content(content)
|
||||
else:
|
||||
return DurableAgentStateUnknownContent.from_unknown_content(content)
|
||||
|
||||
# Core state classes
|
||||
|
||||
class DurableAgentStateData:
|
||||
conversationHistory: List['DurableAgentStateEntry']
|
||||
extensionData: Optional[Dict]
|
||||
|
||||
def __init__(self, conversationHistory=None, extensionData=None):
|
||||
self.conversationHistory = conversationHistory or []
|
||||
self.extensionData = extensionData
|
||||
|
||||
|
||||
class DurableAgentState:
|
||||
data: DurableAgentStateData
|
||||
schema_version: str = "1.0.0"
|
||||
|
||||
def __init__(self, schema_version: str = "1.0.0"):
|
||||
self.data = DurableAgentStateData()
|
||||
self.schema_version = schema_version
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
# Serialize conversationHistory
|
||||
serialized_history = []
|
||||
for entry in self.data.conversationHistory:
|
||||
# Properly serialize each entry to a dictionary
|
||||
if hasattr(entry, 'to_dict'):
|
||||
serialized_history.append(entry.to_dict())
|
||||
else:
|
||||
# Fallback for already-serialized entries
|
||||
serialized_history.append(entry)
|
||||
|
||||
return {
|
||||
"schemaVersion": self.schema_version,
|
||||
"data": {
|
||||
"conversationHistory": serialized_history,
|
||||
"extensionData": self.data.extensionData
|
||||
},
|
||||
"message_count": self.message_count,
|
||||
"last_response": self.last_response,
|
||||
}
|
||||
|
||||
def to_json(self) -> str:
|
||||
return json.dumps(self.to_dict())
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, obj: Dict[str, Any]) -> "DurableAgentState":
|
||||
schema_version = obj.get("schemaVersion")
|
||||
if not schema_version:
|
||||
raise ValueError("The durable agent state is missing the 'schemaVersion' property.")
|
||||
|
||||
if not schema_version.startswith("1."):
|
||||
raise ValueError(f"The durable agent state schema version '{schema_version}' is not supported.")
|
||||
|
||||
data_dict = obj.get("data")
|
||||
if data_dict is None:
|
||||
raise ValueError("The durable agent state is missing the 'data' property.")
|
||||
|
||||
instance = cls(schema_version=schema_version)
|
||||
# Deserialize the data dict into DurableAgentStateData
|
||||
if isinstance(data_dict, dict):
|
||||
instance.data = DurableAgentStateData(
|
||||
conversationHistory=data_dict.get("conversationHistory", []),
|
||||
extensionData=data_dict.get("extensionData")
|
||||
)
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, json_str: str) -> "DurableAgentState":
|
||||
try:
|
||||
obj = json.loads(json_str)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError("The durable agent state is not valid JSON.") from e
|
||||
|
||||
return cls.from_dict(obj)
|
||||
|
||||
def restore_state(self, state: dict[str, Any]) -> None:
|
||||
"""Restore state from a dictionary.
|
||||
|
||||
Args:
|
||||
state: Dictionary containing schemaVersion and data (full state structure)
|
||||
"""
|
||||
# Extract the data portion from the state
|
||||
data_dict = state.get("data", {})
|
||||
|
||||
# Restore the conversation history - deserialize entries from dicts to objects
|
||||
history_data = data_dict.get("conversationHistory", [])
|
||||
deserialized_history = []
|
||||
for entry_dict in history_data:
|
||||
if isinstance(entry_dict, dict):
|
||||
# Deserialize based on whether it's a request or response
|
||||
if "usage" in entry_dict:
|
||||
deserialized_history.append(DurableAgentStateResponse.from_dict(entry_dict))
|
||||
elif "response_type" in entry_dict:
|
||||
deserialized_history.append(DurableAgentStateRequest.from_dict(entry_dict))
|
||||
else:
|
||||
deserialized_history.append(DurableAgentStateEntry.from_dict(entry_dict))
|
||||
else:
|
||||
# Already an object
|
||||
deserialized_history.append(entry_dict)
|
||||
|
||||
self.data.conversationHistory = deserialized_history
|
||||
self.data.extensionData = data_dict.get("extensionData")
|
||||
|
||||
@property
|
||||
def message_count(self) -> int:
|
||||
"""Get the count of conversation entries (requests + responses)."""
|
||||
return len(self.data.conversationHistory)
|
||||
|
||||
@property
|
||||
def last_response(self) -> str | None:
|
||||
"""Get the text from the last assistant response in the conversation history."""
|
||||
# Iterate through messages in reverse to find the last assistant message
|
||||
for entry in reversed(self.data.conversationHistory):
|
||||
for message in reversed(entry.messages):
|
||||
if message.role == "assistant":
|
||||
return message.text
|
||||
return None
|
||||
|
||||
def add_assistant_message(self, content: str, agent_run_response, correlationId: str) -> None:
|
||||
"""Add an assistant message to the conversation history.
|
||||
|
||||
Args:
|
||||
content: The message content
|
||||
agent_run_response: The agent's run response
|
||||
correlationId: The correlation ID for this response
|
||||
"""
|
||||
# This method is called from the entity after storing the response
|
||||
# The response has already been added to conversationHistory, so we don't need to do anything here
|
||||
pass
|
||||
|
||||
def try_get_agent_response(self, correlationId: str) -> Dict[str, Any] | None:
|
||||
"""Try to get an agent response by correlation ID.
|
||||
|
||||
Args:
|
||||
correlationId: The correlation ID to search for
|
||||
|
||||
Returns:
|
||||
Response data dict if found, None otherwise
|
||||
"""
|
||||
# Search through conversation history for a response with this correlationId
|
||||
for entry in self.data.conversationHistory:
|
||||
if hasattr(entry, 'correlationId') and entry.correlationId == correlationId:
|
||||
# Found the entry, extract response data
|
||||
if isinstance(entry, DurableAgentStateResponse):
|
||||
# Get the text content from assistant messages only
|
||||
content = ""
|
||||
for message in entry.messages:
|
||||
if hasattr(message, 'role') and message.role == "assistant" and hasattr(message, 'text'):
|
||||
content += message.text
|
||||
|
||||
return {
|
||||
"content": content,
|
||||
"message_count": self.message_count,
|
||||
"correlationId": correlationId
|
||||
}
|
||||
return None
|
||||
|
||||
# Entry classes
|
||||
|
||||
class DurableAgentStateEntry:
|
||||
json_type: str
|
||||
correlationId: str
|
||||
created_at: datetime
|
||||
messages: List['DurableAgentStateMessage']
|
||||
extensionData: Optional[Dict]
|
||||
|
||||
# Request-only
|
||||
responseType: Optional[str] = None
|
||||
responseSchema: Optional[dict] = None
|
||||
|
||||
# Response-only
|
||||
usage: Optional["DurableAgentStateUsage"] = None
|
||||
|
||||
|
||||
def __init__(self, json_type, correlationId, created_at, messages, extensionData=None, responseType=None, responseSchema=None, usage=None):
|
||||
self.json_type = json_type
|
||||
self.correlationId = correlationId
|
||||
self.created_at = created_at
|
||||
self.messages = messages
|
||||
self.extensionData = extensionData
|
||||
self.responseType = responseType
|
||||
self.responseSchema = responseSchema
|
||||
self.usage = usage
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
data = {
|
||||
"$type": self.json_type,
|
||||
"correlationId": self.correlationId,
|
||||
"createdAt": self.created_at.isoformat() if isinstance(self.created_at, datetime) else self.created_at,
|
||||
"messages": [m.to_dict() if hasattr(m, 'to_dict') else m for m in self.messages],
|
||||
"extensionData": self.extensionData
|
||||
}
|
||||
if self.json_type == "request":
|
||||
data.update({
|
||||
"responseType": self.responseType,
|
||||
"responseSchema": self.responseSchema,
|
||||
})
|
||||
elif self.json_type == "response":
|
||||
if self.usage:
|
||||
data["usage"] = self.usage.to_dict()
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'DurableAgentStateEntry':
|
||||
from dateutil import parser as date_parser
|
||||
created_at = data.get("created_at")
|
||||
if isinstance(created_at, str):
|
||||
created_at = date_parser.parse(created_at)
|
||||
|
||||
messages = []
|
||||
for msg_dict in data.get("messages", []):
|
||||
if isinstance(msg_dict, dict):
|
||||
messages.append(DurableAgentStateMessage.from_dict(msg_dict))
|
||||
else:
|
||||
messages.append(msg_dict)
|
||||
|
||||
return cls(
|
||||
correlationId=data.get("correlationId"),
|
||||
created_at=created_at,
|
||||
messages=messages,
|
||||
extensionData=data.get("extensionData")
|
||||
)
|
||||
|
||||
|
||||
class DurableAgentStateRequest(DurableAgentStateEntry):
|
||||
response_type: Optional[str] = None
|
||||
response_schema: Optional[Dict] = None
|
||||
|
||||
def __init__(self, correlationId, created_at, messages, json_type, extensionData=None, response_type=None, response_schema=None):
|
||||
self.correlationId = correlationId
|
||||
self.created_at = created_at
|
||||
self.messages = messages
|
||||
self.json_type = json_type
|
||||
self.extensionData = extensionData
|
||||
self.response_type = response_type
|
||||
self.response_schema = response_schema
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
base_dict = super().to_dict()
|
||||
base_dict["response_type"] = self.response_type
|
||||
base_dict["response_schema"] = self.response_schema
|
||||
return base_dict
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'DurableAgentStateRequest':
|
||||
from dateutil import parser as date_parser
|
||||
created_at = data.get("created_at")
|
||||
if isinstance(created_at, str):
|
||||
created_at = date_parser.parse(created_at)
|
||||
|
||||
messages = []
|
||||
for msg_dict in data.get("messages", []):
|
||||
if isinstance(msg_dict, dict):
|
||||
messages.append(DurableAgentStateMessage.from_dict(msg_dict))
|
||||
else:
|
||||
messages.append(msg_dict)
|
||||
|
||||
return cls(
|
||||
json_type=data.get("$type", "request"),
|
||||
correlationId=data.get("correlationId"),
|
||||
created_at=created_at,
|
||||
messages=messages,
|
||||
extensionData=data.get("extensionData"),
|
||||
response_type=data.get("response_type"),
|
||||
response_schema=data.get("response_schema")
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_run_request(content):
|
||||
from agent_framework import TextContent
|
||||
return DurableAgentStateRequest(correlationId=content.correlationId,
|
||||
messages=[DurableAgentStateMessage.from_chat_message(content)],
|
||||
created_at=datetime.now(tz=timezone.utc),
|
||||
json_type="request",
|
||||
extensionData=content.extensionData if hasattr(content, 'extensionData') else None,
|
||||
response_type="text" if isinstance(content.response_format, TextContent) else "json",
|
||||
response_schema=content.response_format)
|
||||
|
||||
|
||||
class DurableAgentStateResponse(DurableAgentStateEntry):
|
||||
usage: Optional['DurableAgentStateUsage'] = None
|
||||
|
||||
def __init__(self, json_type, correlationId, created_at, messages, extensionData=None, usage=None):
|
||||
self.json_type = json_type
|
||||
self.correlationId = correlationId
|
||||
self.created_at = created_at
|
||||
self.messages = messages
|
||||
self.extensionData = extensionData
|
||||
self.usage = usage
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
base_dict = super().to_dict()
|
||||
base_dict["usage"] = self.usage.to_dict() if self.usage and hasattr(self.usage, 'to_dict') else self.usage
|
||||
return base_dict
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'DurableAgentStateResponse':
|
||||
from dateutil import parser as date_parser
|
||||
created_at = data.get("created_at")
|
||||
if isinstance(created_at, str):
|
||||
created_at = date_parser.parse(created_at)
|
||||
|
||||
messages = []
|
||||
for msg_dict in data.get("messages", []):
|
||||
if isinstance(msg_dict, dict):
|
||||
messages.append(DurableAgentStateMessage.from_dict(msg_dict))
|
||||
else:
|
||||
messages.append(msg_dict)
|
||||
|
||||
usage_dict = data.get("usage")
|
||||
usage = None
|
||||
if usage_dict and isinstance(usage_dict, dict):
|
||||
usage = DurableAgentStateUsage.from_dict(usage_dict)
|
||||
elif usage_dict:
|
||||
usage = usage_dict
|
||||
|
||||
return cls(
|
||||
json_type=data.get("$type", "response"),
|
||||
correlationId=data.get("correlationId"),
|
||||
created_at=created_at,
|
||||
messages=messages,
|
||||
extensionData=data.get("extensionData"),
|
||||
usage=usage
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_run_response(correlationId: str, response) -> DurableAgentStateResponse:
|
||||
"""
|
||||
Creates a DurableAgentStateResponse from an AgentRunResponse.
|
||||
"""
|
||||
# Determine the earliest created_at timestamp among messages (if available)
|
||||
timestamps = [m.created_at for m in response.messages if hasattr(m, 'created_at') and m.created_at is not None]
|
||||
created_at = min(timestamps) if timestamps else datetime.now(tz=timezone.utc)
|
||||
|
||||
return DurableAgentStateResponse(
|
||||
json_type="response",
|
||||
correlationId=correlationId,
|
||||
created_at=created_at,
|
||||
messages=[DurableAgentStateMessage.from_chat_message(m) for m in response.messages],
|
||||
usage=DurableAgentStateUsage.from_usage(response.usage) if hasattr(response, 'usage') and response.usage else None
|
||||
)
|
||||
|
||||
def to_run_response(self):
|
||||
"""
|
||||
Converts this DurableAgentStateResponse back to an AgentRunResponse.
|
||||
"""
|
||||
from agent_framework import AgentRunResponse
|
||||
|
||||
return AgentRunResponse(
|
||||
created_at=self.created_at,
|
||||
messages=[m.to_chat_message() for m in self.messages],
|
||||
usage=self.usage.to_usage_details() if self.usage else None
|
||||
)
|
||||
|
||||
# Message class
|
||||
|
||||
class DurableAgentStateMessage:
|
||||
role: str
|
||||
contents: List[DurableAgentStateContent]
|
||||
author_name: Optional[str] = None
|
||||
created_at: Optional[datetime] = None
|
||||
extensionData: Optional[Dict] = None
|
||||
|
||||
def __init__(self, role, contents, author_name=None, created_at=None, extensionData=None):
|
||||
self.role = role
|
||||
self.contents = contents
|
||||
self.author_name = author_name
|
||||
self.created_at = created_at
|
||||
self.extensionData = extensionData
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"role": self.role,
|
||||
"contents": [
|
||||
{"$type": c.to_dict().get("type", "text"), **{k: v for k, v in c.to_dict().items() if k != "type"}} for c in self.contents
|
||||
],
|
||||
"authorName": self.author_name,
|
||||
"createdAt": self.created_at.isoformat() if isinstance(self.created_at, datetime) else self.created_at,
|
||||
"extensionData": self.extensionData
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'DurableAgentStateMessage':
|
||||
from dateutil import parser as date_parser
|
||||
created_at = data.get("created_at")
|
||||
if created_at and isinstance(created_at, str):
|
||||
created_at = date_parser.parse(created_at)
|
||||
|
||||
contents = []
|
||||
for content_dict in data.get("contents", []):
|
||||
if isinstance(content_dict, dict):
|
||||
content_type = content_dict.get("type")
|
||||
if content_type == "text":
|
||||
contents.append(DurableAgentStateTextContent(text=content_dict.get("text")))
|
||||
elif content_type == "data":
|
||||
contents.append(DurableAgentStateDataContent(uri=content_dict.get("uri"), media_type=content_dict.get("media_type")))
|
||||
elif content_type == "error":
|
||||
contents.append(DurableAgentStateErrorContent(message=content_dict.get("message"), error_code=content_dict.get("error_code"), details=content_dict.get("details")))
|
||||
elif content_type == "function_call":
|
||||
contents.append(DurableAgentStateFunctionCallContent(call_id=content_dict.get("call_id"), name=content_dict.get("name"), arguments=content_dict.get("arguments")))
|
||||
elif content_type == "function_result":
|
||||
contents.append(DurableAgentStateFunctionResultContent(call_id=content_dict.get("call_id"), result=content_dict.get("result")))
|
||||
elif content_type == "hosted_file":
|
||||
contents.append(DurableAgentStateHostedFileContent(file_id=content_dict.get("file_id")))
|
||||
elif content_type == "hosted_vector_store":
|
||||
contents.append(DurableAgentStateHostedVectorStoreContent(vector_store_id=content_dict.get("vector_store_id")))
|
||||
elif content_type == "text_reasoning":
|
||||
contents.append(DurableAgentStateTextReasoningContent(text=content_dict.get("text")))
|
||||
elif content_type == "uri":
|
||||
contents.append(DurableAgentStateUriContent(uri=content_dict.get("uri"), media_type=content_dict.get("media_type")))
|
||||
elif content_type == "usage":
|
||||
usage_data = content_dict.get("usage")
|
||||
if usage_data and isinstance(usage_data, dict):
|
||||
contents.append(DurableAgentStateUsageContent(usage=DurableAgentStateUsage.from_dict(usage_data)))
|
||||
elif content_type == "unknown":
|
||||
contents.append(DurableAgentStateUnknownContent(content=content_dict.get("content")))
|
||||
else:
|
||||
contents.append(content_dict)
|
||||
|
||||
return cls(
|
||||
role=data.get("role"),
|
||||
contents=contents,
|
||||
author_name=data.get("author_name"),
|
||||
created_at=created_at,
|
||||
extensionData=data.get("extensionData")
|
||||
)
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
"""Extract text from the contents list."""
|
||||
text_parts = []
|
||||
for content in self.contents:
|
||||
if isinstance(content, DurableAgentStateTextContent):
|
||||
text_parts.append(content.text or "")
|
||||
return "".join(text_parts)
|
||||
|
||||
@staticmethod
|
||||
def from_chat_message(content):
|
||||
# Convert to a list of DurableAgentStateContent objects
|
||||
contents_list = []
|
||||
|
||||
if hasattr(content, 'message') and isinstance(content.message, str):
|
||||
# RunRequest with 'message' attribute
|
||||
contents_list = [DurableAgentStateTextContent(text=content.message)]
|
||||
elif hasattr(content, 'contents') and content.contents:
|
||||
# ChatMessage with 'contents' attribute - convert each content object
|
||||
for c in content.contents:
|
||||
converted = DurableAgentStateContent.from_ai_content(c)
|
||||
contents_list.append(converted)
|
||||
|
||||
# Convert role enum to string if needed
|
||||
role_value = content.role.value if hasattr(content.role, 'value') else str(content.role)
|
||||
|
||||
return DurableAgentStateMessage(
|
||||
role=role_value,
|
||||
contents=contents_list,
|
||||
author_name=content.author_name if hasattr(content, 'author_name') else None,
|
||||
created_at=content.created_at if hasattr(content, 'created_at') else None,
|
||||
extensionData=content.extensionData if hasattr(content, 'extensionData') else None
|
||||
)
|
||||
|
||||
def to_chat_message(self):
|
||||
from agent_framework import ChatMessage
|
||||
# Convert DurableAgentStateContent objects back to agent_framework content objects
|
||||
ai_contents = [c.to_ai_content() for c in self.contents]
|
||||
return ChatMessage(role=self.role, contents=ai_contents, author_name=self.author_name, created_at=self.created_at, extensionData=self.extensionData)
|
||||
|
||||
# Content subclasses
|
||||
|
||||
class DurableAgentStateDataContent(DurableAgentStateContent):
|
||||
uri: str = ""
|
||||
media_type: Optional[str] = None
|
||||
|
||||
def __init__(self, uri, media_type=None):
|
||||
self.uri = uri
|
||||
self.media_type = media_type
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "data",
|
||||
"uri": self.uri,
|
||||
"mediaType": self.media_type
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_data_content(content):
|
||||
return DurableAgentStateDataContent(uri=content.uri, media_type=content.media_type)
|
||||
|
||||
def to_ai_content(self):
|
||||
from agent_framework import DataContent
|
||||
return DataContent(uri=self.uri, media_type=self.media_type)
|
||||
|
||||
|
||||
class DurableAgentStateErrorContent(DurableAgentStateContent):
|
||||
message: Optional[str] = None
|
||||
error_code: Optional[str] = None
|
||||
details: Optional[str] = None
|
||||
|
||||
def __init__(self, message=None, error_code=None, details=None):
|
||||
self.message = message
|
||||
self.error_code = error_code
|
||||
self.details = details
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "error",
|
||||
"message": self.message,
|
||||
"errorCode": self.error_code,
|
||||
"details": self.details
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_error_content(content):
|
||||
return DurableAgentStateErrorContent(message=content.message, error_code=content.error_code, details=content.details)
|
||||
|
||||
def to_ai_content(self):
|
||||
from agent_framework import ErrorContent
|
||||
return ErrorContent(message=self.message, error_code=self.error_code, details=self.details)
|
||||
|
||||
|
||||
class DurableAgentStateFunctionCallContent(DurableAgentStateContent):
|
||||
call_id: str
|
||||
name: str
|
||||
arguments: Dict[str, object]
|
||||
|
||||
def __init__(self, call_id, name, arguments):
|
||||
self.call_id = call_id
|
||||
self.name = name
|
||||
self.arguments = arguments
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "function_call",
|
||||
"callId": self.call_id,
|
||||
"name": self.name,
|
||||
"arguments": self.arguments
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_function_call_content(content):
|
||||
return DurableAgentStateFunctionCallContent(
|
||||
call_id=content.call_id,
|
||||
name=content.name,
|
||||
arguments=content.arguments if content.arguments else {}
|
||||
)
|
||||
|
||||
def to_ai_content(self):
|
||||
from agent_framework import FunctionCallContent
|
||||
return FunctionCallContent(call_id=self.call_id, name=self.name, arguments=self.arguments)
|
||||
|
||||
|
||||
class DurableAgentStateFunctionResultContent(DurableAgentStateContent):
|
||||
call_id: str
|
||||
result: Optional[object] = None
|
||||
|
||||
def __init__(self, call_id, result=None):
|
||||
self.call_id = call_id
|
||||
self.result = result
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "function_result",
|
||||
"callId": self.call_id,
|
||||
"result": self.result
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_function_result_content(content):
|
||||
return DurableAgentStateFunctionResultContent(call_id=content.call_id, result=content.result)
|
||||
|
||||
def to_ai_content(self):
|
||||
from agent_framework import FunctionResultContent
|
||||
return FunctionResultContent(call_id=self.call_id, result=self.result)
|
||||
|
||||
|
||||
class DurableAgentStateHostedFileContent(DurableAgentStateContent):
|
||||
file_id: str
|
||||
|
||||
def __init__(self, file_id):
|
||||
self.file_id = file_id
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "hosted_file",
|
||||
"fileId": self.file_id
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_hosted_file_content(content):
|
||||
return DurableAgentStateHostedFileContent(file_id=content.file_id)
|
||||
|
||||
def to_ai_content(self):
|
||||
from agent_framework import HostedFileContent
|
||||
return HostedFileContent(file_id=self.file_id)
|
||||
|
||||
|
||||
class DurableAgentStateHostedVectorStoreContent(DurableAgentStateContent):
|
||||
vector_store_id: str
|
||||
|
||||
def __init__(self, vector_store_id):
|
||||
self.vector_store_id = vector_store_id
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "hosted_vector_store",
|
||||
"vectorStoreId": self.vector_store_id
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_hosted_vector_store_content(content):
|
||||
return DurableAgentStateHostedVectorStoreContent(vector_store_id=content.vector_store_id)
|
||||
|
||||
def to_ai_content(self):
|
||||
from agent_framework import HostedVectorStoreContent
|
||||
return HostedVectorStoreContent(vector_store_id=self.vector_store_id)
|
||||
|
||||
|
||||
class DurableAgentStateTextContent(DurableAgentStateContent):
|
||||
text: Optional[str] = None
|
||||
|
||||
def __init__(self, text):
|
||||
self.text = text
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "text",
|
||||
"text": self.text
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_text_content(content):
|
||||
return DurableAgentStateTextContent(text=content.text)
|
||||
|
||||
def to_ai_content(self):
|
||||
from agent_framework import TextContent
|
||||
return TextContent(text=self.text)
|
||||
|
||||
|
||||
class DurableAgentStateTextReasoningContent(DurableAgentStateContent):
|
||||
text: Optional[str] = None
|
||||
|
||||
def __init__(self, text):
|
||||
self.text = text
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "text_reasoning",
|
||||
"text": self.text
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_text_reasoning_content(content):
|
||||
return DurableAgentStateTextReasoningContent(text=content.text)
|
||||
|
||||
def to_ai_content(self):
|
||||
from agent_framework import TextReasoningContent
|
||||
return TextReasoningContent(text=self.text)
|
||||
|
||||
|
||||
class DurableAgentStateUriContent(DurableAgentStateContent):
|
||||
uri: str
|
||||
media_type: str
|
||||
|
||||
def __init__(self, uri, media_type):
|
||||
self.uri = uri
|
||||
self.media_type = media_type
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "uri",
|
||||
"uri": self.uri,
|
||||
"mediaType": self.media_type
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_uri_content(content):
|
||||
return DurableAgentStateUriContent(uri=content.uri, media_type=content.media_type)
|
||||
|
||||
def to_ai_content(self):
|
||||
from agent_framework import UriContent
|
||||
return UriContent(uri=self.uri, media_type=self.media_type)
|
||||
|
||||
|
||||
class DurableAgentStateUsage:
|
||||
input_token_count: Optional[int] = None
|
||||
output_token_count: Optional[int] = None
|
||||
total_token_count: Optional[int] = None
|
||||
extensionData: Optional[Dict] = None
|
||||
|
||||
def __init__(self, input_token_count=None, output_token_count=None, total_token_count=None, extensionData=None):
|
||||
self.input_token_count = input_token_count
|
||||
self.output_token_count = output_token_count
|
||||
self.total_token_count = total_token_count
|
||||
self.extensionData = extensionData
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"inputTokenCount": self.input_token_count,
|
||||
"outputTokenCount": self.output_token_count,
|
||||
"totalTokenCount": self.total_token_count,
|
||||
"extensionData": self.extensionData
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'DurableAgentStateUsage':
|
||||
return cls(
|
||||
input_token_count=data.get("input_token_count"),
|
||||
output_token_count=data.get("output_token_count"),
|
||||
total_token_count=data.get("total_token_count"),
|
||||
extensionData=data.get("extensionData")
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_usage(usage):
|
||||
if usage is None:
|
||||
return None
|
||||
return DurableAgentStateUsage(
|
||||
input_token_count=usage.input_token_count,
|
||||
output_token_count=usage.output_token_count,
|
||||
total_token_count=usage.total_token_count
|
||||
)
|
||||
|
||||
def to_usage_details(self):
|
||||
# Convert back to AI SDK UsageDetails
|
||||
from agent_framework import UsageDetails
|
||||
return UsageDetails(
|
||||
input_token_count=self.input_token_count,
|
||||
output_token_count=self.output_token_count,
|
||||
total_token_count=self.total_token_count
|
||||
)
|
||||
|
||||
|
||||
class DurableAgentStateUsageContent(DurableAgentStateContent):
|
||||
usage: DurableAgentStateUsage = DurableAgentStateUsage()
|
||||
|
||||
def __init__(self, usage):
|
||||
self.usage = usage
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "usage",
|
||||
"usage": self.usage.to_dict() if hasattr(self.usage, 'to_dict') else self.usage
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_usage_content(content):
|
||||
return DurableAgentStateUsageContent(usage=DurableAgentStateUsage.from_usage(content.details))
|
||||
|
||||
def to_ai_content(self):
|
||||
from agent_framework import UsageContent
|
||||
return UsageContent(details=self.usage.to_usage_details())
|
||||
|
||||
|
||||
class DurableAgentStateUnknownContent(DurableAgentStateContent):
|
||||
content: dict
|
||||
|
||||
def __init__(self, content):
|
||||
self.content = content
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "unknown",
|
||||
"content": self.content
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_unknown_content(content):
|
||||
return DurableAgentStateUnknownContent(content=json.loads(content))
|
||||
|
||||
def to_ai_content(self):
|
||||
from agent_framework import BaseContent
|
||||
if not self.content:
|
||||
raise Exception(f"The content is missing and cannot be converted to valid AI content.")
|
||||
return BaseContent(content=json.loads(self.content))
|
||||
@@ -10,15 +10,21 @@ allows for long-running agent conversations.
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
from collections.abc import AsyncIterable, Callable
|
||||
from typing import Any, cast
|
||||
from collections.abc import AsyncIterable
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, cast, Callable
|
||||
|
||||
import azure.durable_functions as df
|
||||
from agent_framework import AgentProtocol, AgentRunResponse, AgentRunResponseUpdate, Role, get_logger
|
||||
|
||||
from ._callbacks import AgentCallbackContext, AgentResponseCallbackProtocol
|
||||
from ._durable_agent_state import (
|
||||
DurableAgentState,
|
||||
DurableAgentStateData,
|
||||
DurableAgentStateRequest,
|
||||
DurableAgentStateResponse,
|
||||
)
|
||||
from ._models import AgentResponse, RunRequest
|
||||
from ._state import AgentState
|
||||
|
||||
logger = get_logger("agent_framework.azurefunctions.entities")
|
||||
|
||||
@@ -38,11 +44,11 @@ class AgentEntity:
|
||||
|
||||
Attributes:
|
||||
agent: The AgentProtocol instance
|
||||
state: The AgentState managing conversation history
|
||||
state: The DurableAgentState managing conversation history
|
||||
"""
|
||||
|
||||
agent: AgentProtocol
|
||||
state: AgentState
|
||||
state: DurableAgentState
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -56,8 +62,9 @@ class AgentEntity:
|
||||
callback: Optional callback invoked during streaming updates and final responses
|
||||
"""
|
||||
self.agent = agent
|
||||
self.state = AgentState()
|
||||
self.state = DurableAgentState()
|
||||
self.callback = callback
|
||||
self._pending_requests: dict[str, DurableAgentStateRequest] = {}
|
||||
|
||||
logger.debug(f"[AgentEntity] Initialized with agent type: {type(agent).__name__}")
|
||||
|
||||
@@ -89,31 +96,49 @@ class AgentEntity:
|
||||
|
||||
message = run_request.message
|
||||
thread_id = run_request.thread_id
|
||||
correlation_id = run_request.correlation_id
|
||||
correlationId = run_request.correlationId
|
||||
if not thread_id:
|
||||
raise ValueError("RunRequest must include a thread_id")
|
||||
if not correlation_id:
|
||||
raise ValueError("RunRequest must include a correlation_id")
|
||||
if not correlationId:
|
||||
raise ValueError("RunRequest must include a correlationId")
|
||||
role = run_request.role or Role.USER
|
||||
response_format = run_request.response_format
|
||||
enable_tool_calls = run_request.enable_tool_calls
|
||||
|
||||
# Store request in pending (will be combined with response later)
|
||||
state_request = DurableAgentStateRequest.from_run_request(run_request)
|
||||
self.state.data.conversationHistory.append(state_request)
|
||||
self._pending_requests[correlationId] = state_request
|
||||
|
||||
logger.debug(f"[AgentEntity.run_agent] Received message: {message}")
|
||||
logger.debug(f"[AgentEntity.run_agent] Thread ID: {thread_id}")
|
||||
logger.debug(f"[AgentEntity.run_agent] Correlation ID: {correlation_id}")
|
||||
logger.debug(f"[AgentEntity.run_agent] Correlation ID: {correlationId}")
|
||||
logger.debug(f"[AgentEntity.run_agent] Role: {role.value}")
|
||||
logger.debug(f"[AgentEntity.run_agent] Enable tool calls: {enable_tool_calls}")
|
||||
logger.debug(f"[AgentEntity.run_agent] Response format: {'provided' if response_format else 'none'}")
|
||||
|
||||
# Store message in history with role
|
||||
self.state.add_user_message(message, role=role, correlation_id=correlation_id)
|
||||
|
||||
logger.debug(f"[AgentEntity.run_agent] Saved state request: {state_request}")
|
||||
logger.debug("[AgentEntity.run_agent] Executing agent...")
|
||||
|
||||
try:
|
||||
logger.debug("[AgentEntity.run_agent] Starting agent invocation")
|
||||
|
||||
run_kwargs: dict[str, Any] = {"messages": self.state.get_chat_messages()}
|
||||
# Build messages from conversation history plus the current request
|
||||
chat_messages = [
|
||||
m.to_chat_message()
|
||||
for entry in self.state.data.conversationHistory
|
||||
for m in entry.messages
|
||||
]
|
||||
# Add the current request message
|
||||
# for m in state_request.messages:
|
||||
# chat_messages.append(m.to_chat_message())
|
||||
|
||||
# Strip additional_properties from all messages to avoid metadata being sent to Azure OpenAI
|
||||
# Azure OpenAI doesn't support the 'metadata' field in messages
|
||||
for msg in chat_messages:
|
||||
if hasattr(msg, 'additional_properties'):
|
||||
msg.additional_properties = {}
|
||||
|
||||
run_kwargs: dict[str, Any] = {"messages": chat_messages}
|
||||
if not enable_tool_calls:
|
||||
run_kwargs["tools"] = None
|
||||
if response_format:
|
||||
@@ -121,7 +146,7 @@ class AgentEntity:
|
||||
|
||||
agent_run_response: AgentRunResponse = await self._invoke_agent(
|
||||
run_kwargs=run_kwargs,
|
||||
correlation_id=correlation_id,
|
||||
correlationId=correlationId,
|
||||
thread_id=thread_id,
|
||||
request_message=message,
|
||||
)
|
||||
@@ -131,6 +156,17 @@ class AgentEntity:
|
||||
type(agent_run_response).__name__,
|
||||
)
|
||||
|
||||
# Convert response into DurableAgentStateResponse and combine with request
|
||||
state_response = DurableAgentStateResponse.from_run_response(correlationId, agent_run_response)
|
||||
|
||||
# Get the pending request and combine its messages with the response messages
|
||||
# pending_request = self._pending_requests.pop(correlationId, None)
|
||||
# if pending_request:
|
||||
# # Combine request and response messages into a single entry
|
||||
# state_response.messages = pending_request.messages + state_response.messages
|
||||
|
||||
self.state.data.conversationHistory.append(state_response)
|
||||
|
||||
response_text = None
|
||||
structured_response = None
|
||||
|
||||
@@ -161,13 +197,13 @@ class AgentEntity:
|
||||
message=str(message),
|
||||
thread_id=str(thread_id),
|
||||
status="success",
|
||||
message_count=self.state.message_count,
|
||||
message_count=len(self.state.data.conversationHistory),
|
||||
structured_response=structured_response,
|
||||
)
|
||||
result = agent_response.to_dict()
|
||||
|
||||
content = json.dumps(structured_response) if structured_response else (response_text or "")
|
||||
self.state.add_assistant_message(content, agent_run_response, correlation_id)
|
||||
self.state.add_assistant_message(content, agent_run_response, correlationId)
|
||||
logger.debug("[AgentEntity.run_agent] AgentRunResponse stored in conversation history")
|
||||
|
||||
return result
|
||||
@@ -181,12 +217,39 @@ class AgentEntity:
|
||||
logger.error(f"Error type: {type(exc).__name__}")
|
||||
logger.error(f"Full traceback:\n{error_traceback}")
|
||||
|
||||
# Create error response and store it in conversation history so polling can find it
|
||||
from agent_framework import ChatMessage, ErrorContent
|
||||
|
||||
# Get the pending request
|
||||
pending_request = self._pending_requests.pop(correlationId, None)
|
||||
|
||||
# Create error message
|
||||
error_message = DurableAgentStateMessage.from_chat_message(
|
||||
ChatMessage(role="assistant", contents=[ErrorContent(message=str(exc), error_code=type(exc).__name__)])
|
||||
)
|
||||
|
||||
# Combine request and error response messages
|
||||
messages = []
|
||||
if pending_request:
|
||||
messages.extend(pending_request.messages)
|
||||
messages.append(error_message)
|
||||
|
||||
# Create and store error response in conversation history
|
||||
error_state_response = DurableAgentStateResponse(
|
||||
correlationId=correlationId,
|
||||
createdAt=datetime.now(tz=timezone.utc),
|
||||
messages=messages,
|
||||
extensionData=None,
|
||||
usage=None
|
||||
)
|
||||
self.state.data.conversationHistory.append(error_state_response)
|
||||
|
||||
error_response = AgentResponse(
|
||||
response=f"Error: {exc!s}",
|
||||
message=str(message),
|
||||
thread_id=str(thread_id),
|
||||
status="error",
|
||||
message_count=self.state.message_count,
|
||||
message_count=len(self.state.data.conversationHistory),
|
||||
error=str(exc),
|
||||
error_type=type(exc).__name__,
|
||||
)
|
||||
@@ -195,7 +258,7 @@ class AgentEntity:
|
||||
async def _invoke_agent(
|
||||
self,
|
||||
run_kwargs: dict[str, Any],
|
||||
correlation_id: str,
|
||||
correlationId: str,
|
||||
thread_id: str,
|
||||
request_message: str,
|
||||
) -> AgentRunResponse:
|
||||
@@ -203,7 +266,7 @@ class AgentEntity:
|
||||
callback_context: AgentCallbackContext | None = None
|
||||
if self.callback is not None:
|
||||
callback_context = self._build_callback_context(
|
||||
correlation_id=correlation_id,
|
||||
correlationId=correlationId,
|
||||
thread_id=thread_id,
|
||||
request_message=request_message,
|
||||
)
|
||||
@@ -317,7 +380,7 @@ class AgentEntity:
|
||||
|
||||
def _build_callback_context(
|
||||
self,
|
||||
correlation_id: str,
|
||||
correlationId: str,
|
||||
thread_id: str,
|
||||
request_message: str,
|
||||
) -> AgentCallbackContext:
|
||||
@@ -325,7 +388,7 @@ class AgentEntity:
|
||||
agent_name = getattr(self.agent, "name", None) or type(self.agent).__name__
|
||||
return AgentCallbackContext(
|
||||
agent_name=agent_name,
|
||||
correlation_id=correlation_id,
|
||||
correlationId=correlationId,
|
||||
thread_id=thread_id,
|
||||
request_message=request_message,
|
||||
)
|
||||
@@ -333,7 +396,7 @@ class AgentEntity:
|
||||
def reset(self, context: df.DurableEntityContext) -> None:
|
||||
"""Reset the entity state (clear conversation history)."""
|
||||
logger.debug("[AgentEntity.reset] Resetting entity state")
|
||||
self.state.reset()
|
||||
self.state.data = DurableAgentStateData(conversationHistory=[])
|
||||
logger.debug("[AgentEntity.reset] State reset complete")
|
||||
|
||||
|
||||
@@ -392,8 +455,9 @@ def create_agent_entity(
|
||||
logger.error("[entity_function] Unknown operation: %s", operation)
|
||||
context.set_result({"error": f"Unknown operation: {operation}"})
|
||||
|
||||
logger.info("State dict: %s", str(entity.state.to_dict()))
|
||||
context.set_state(entity.state.to_dict())
|
||||
logger.debug(f"[entity_function] Operation {operation} completed successfully")
|
||||
logger.info(f"[entity_function] Operation {operation} completed successfully")
|
||||
|
||||
except Exception as exc:
|
||||
import traceback
|
||||
@@ -424,4 +488,4 @@ def create_agent_entity(
|
||||
logger.error("[entity_function] Unexpected error executing entity: %s", exc, exc_info=True)
|
||||
context.set_result({"error": str(exc), "status": "error"})
|
||||
|
||||
return entity_function
|
||||
return entity_function
|
||||
@@ -282,7 +282,7 @@ class RunRequest:
|
||||
response_format: Optional Pydantic BaseModel type describing the structured response format
|
||||
enable_tool_calls: Whether to enable tool calls for this request
|
||||
thread_id: Optional thread ID for tracking
|
||||
correlation_id: Optional correlation ID for tracking the response to this specific request
|
||||
correlationId: Optional correlation ID for tracking the response to this specific request
|
||||
"""
|
||||
|
||||
message: str
|
||||
@@ -290,7 +290,10 @@ class RunRequest:
|
||||
response_format: type[BaseModel] | None = None
|
||||
enable_tool_calls: bool = True
|
||||
thread_id: str | None = None
|
||||
correlation_id: str | None = None
|
||||
correlationId: str | None = None
|
||||
author_name: str | None = None
|
||||
created_at: str | None = None
|
||||
extension_data: dict[str, Any] | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -299,14 +302,14 @@ class RunRequest:
|
||||
response_format: type[BaseModel] | None = None,
|
||||
enable_tool_calls: bool = True,
|
||||
thread_id: str | None = None,
|
||||
correlation_id: str | None = None,
|
||||
correlationId: str | None = None,
|
||||
) -> None:
|
||||
self.message = message
|
||||
self.role = self.coerce_role(role)
|
||||
self.response_format = response_format
|
||||
self.enable_tool_calls = enable_tool_calls
|
||||
self.thread_id = thread_id
|
||||
self.correlation_id = correlation_id
|
||||
self.correlationId = correlationId
|
||||
|
||||
@staticmethod
|
||||
def coerce_role(value: Role | str | None) -> Role:
|
||||
@@ -331,8 +334,14 @@ class RunRequest:
|
||||
result["response_format"] = _serialize_response_format(self.response_format)
|
||||
if self.thread_id:
|
||||
result["thread_id"] = self.thread_id
|
||||
if self.correlation_id:
|
||||
result["correlation_id"] = self.correlation_id
|
||||
if self.correlationId:
|
||||
result["correlationId"] = self.correlationId
|
||||
if self.author_name:
|
||||
result["author_name"] = self.author_name
|
||||
if self.created_at:
|
||||
result["created_at"] = self.created_at
|
||||
if self.extension_data:
|
||||
result["extension_data"] = self.extension_data
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
@@ -344,7 +353,7 @@ class RunRequest:
|
||||
response_format=_deserialize_response_format(data.get("response_format")),
|
||||
enable_tool_calls=data.get("enable_tool_calls", True),
|
||||
thread_id=data.get("thread_id"),
|
||||
correlation_id=data.get("correlation_id"),
|
||||
correlationId=data.get("correlationId"),
|
||||
)
|
||||
|
||||
|
||||
@@ -392,4 +401,4 @@ class AgentResponse:
|
||||
if self.error_type:
|
||||
result["error_type"] = self.error_type
|
||||
|
||||
return result
|
||||
return result
|
||||
@@ -129,13 +129,13 @@ class DurableAIAgent(AgentProtocol):
|
||||
|
||||
# Generate a deterministic correlation ID for this call
|
||||
# This is required by the entity and must be unique per call
|
||||
correlation_id = str(self.context.new_uuid())
|
||||
correlationId = str(self.context.new_uuid())
|
||||
|
||||
# Prepare the request using RunRequest model
|
||||
run_request = RunRequest(
|
||||
message=message_str,
|
||||
enable_tool_calls=enable_tool_calls,
|
||||
correlation_id=correlation_id,
|
||||
correlationId=correlationId,
|
||||
thread_id=session_id.key,
|
||||
response_format=response_format,
|
||||
)
|
||||
|
||||
@@ -1,179 +0,0 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Agent State Management.
|
||||
|
||||
This module defines the AgentState class for managing conversation state and
|
||||
serializing agent framework responses.
|
||||
"""
|
||||
|
||||
from collections.abc import MutableMapping
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, cast
|
||||
|
||||
from agent_framework import AgentRunResponse, ChatMessage, Role, get_logger
|
||||
|
||||
logger = get_logger("agent_framework.azurefunctions.state")
|
||||
|
||||
|
||||
class AgentState:
|
||||
"""Manages agent conversation state using agent_framework types (ChatMessage, AgentRunResponse).
|
||||
|
||||
This class handles:
|
||||
- Conversation history tracking using ChatMessage objects
|
||||
- Agent response storage using AgentRunResponse objects with correlation IDs
|
||||
- State persistence and restoration
|
||||
- Message counting
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize empty agent state."""
|
||||
self.conversation_history: list[ChatMessage] = []
|
||||
self.last_response: str | None = None
|
||||
self.message_count: int = 0
|
||||
|
||||
def _current_timestamp(self) -> str:
|
||||
"""Return an ISO 8601 UTC timestamp."""
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
def add_user_message(
|
||||
self,
|
||||
content: str,
|
||||
role: Role = Role.USER,
|
||||
correlation_id: str | None = None,
|
||||
) -> None:
|
||||
"""Add a user message to the conversation history as a ChatMessage object.
|
||||
|
||||
Args:
|
||||
content: The message content
|
||||
role: The message role (user, system, etc.)
|
||||
correlation_id: Optional correlation identifier associated with the user message
|
||||
"""
|
||||
self.message_count += 1
|
||||
timestamp = self._current_timestamp()
|
||||
additional_props: MutableMapping[str, Any] = {"timestamp": timestamp}
|
||||
if correlation_id is not None:
|
||||
additional_props["correlation_id"] = correlation_id
|
||||
chat_message = ChatMessage(role=role, text=content, additional_properties=additional_props)
|
||||
self.conversation_history.append(chat_message)
|
||||
logger.debug(f"Added {role} ChatMessage to history (message #{self.message_count})")
|
||||
|
||||
def add_assistant_message(
|
||||
self, content: str, agent_response: AgentRunResponse, correlation_id: str | None = None
|
||||
) -> None:
|
||||
"""Add an assistant message to the conversation history with full agent response.
|
||||
|
||||
Args:
|
||||
content: The text content of the response
|
||||
agent_response: The AgentRunResponse object from the agent framework
|
||||
correlation_id: Optional correlation ID for tracking this response
|
||||
"""
|
||||
self.last_response = content
|
||||
timestamp = self._current_timestamp()
|
||||
serialized_response = self.serialize_response(agent_response)
|
||||
|
||||
# Create a ChatMessage for the assistant response
|
||||
# The agent_response already contains messages, but we store it as a custom ChatMessage
|
||||
# with the agent_response stored in additional_properties for full metadata preservation
|
||||
additional_props: dict[str, Any] = {
|
||||
"agent_response": serialized_response,
|
||||
"correlation_id": correlation_id,
|
||||
"timestamp": timestamp,
|
||||
"message_count": self.message_count,
|
||||
}
|
||||
chat_message = ChatMessage(role="assistant", text=content, additional_properties=additional_props)
|
||||
|
||||
self.conversation_history.append(chat_message)
|
||||
|
||||
logger.debug(
|
||||
f"Added assistant ChatMessage to history with AgentRunResponse metadata (correlation_id: {correlation_id})"
|
||||
)
|
||||
|
||||
def get_chat_messages(self) -> list[ChatMessage]:
|
||||
"""Return a copy of the full conversation history."""
|
||||
return list(self.conversation_history)
|
||||
|
||||
def try_get_agent_response(self, correlation_id: str) -> dict[str, Any] | None:
|
||||
"""Get an agent response by correlation ID.
|
||||
|
||||
Args:
|
||||
correlation_id: The correlation ID to look up
|
||||
|
||||
Returns:
|
||||
The agent response data if found, None otherwise
|
||||
"""
|
||||
for message in reversed(self.conversation_history):
|
||||
metadata = getattr(message, "additional_properties", {}) or {}
|
||||
if metadata.get("correlation_id") == correlation_id:
|
||||
return self._build_agent_response_payload(message, metadata)
|
||||
|
||||
return None
|
||||
|
||||
def serialize_response(self, response: AgentRunResponse) -> dict[str, Any]:
|
||||
"""Serialize an ``AgentRunResponse`` to a dictionary.
|
||||
|
||||
Args:
|
||||
response: The agent framework response object
|
||||
|
||||
Returns:
|
||||
Dictionary containing all response fields
|
||||
"""
|
||||
try:
|
||||
return response.to_dict()
|
||||
except Exception as exc: # pragma: no cover - defensive logging path
|
||||
logger.warning(f"Error serializing response: {exc}")
|
||||
return {"response": str(response), "serialization_error": str(exc)}
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Get the current state as a dictionary for persistence.
|
||||
|
||||
Returns:
|
||||
Dictionary containing conversation_history (as serialized ChatMessages),
|
||||
last_response, and message_count
|
||||
"""
|
||||
return {
|
||||
"conversation_history": [msg.to_dict() for msg in self.conversation_history],
|
||||
"last_response": self.last_response,
|
||||
"message_count": self.message_count,
|
||||
}
|
||||
|
||||
def restore_state(self, state: dict[str, Any]) -> None:
|
||||
"""Restore state from a dictionary, reconstructing ChatMessage objects.
|
||||
|
||||
Args:
|
||||
state: Dictionary containing conversation_history, last_response, and message_count
|
||||
"""
|
||||
# Restore conversation history as ChatMessage objects
|
||||
history_data = state.get("conversation_history", [])
|
||||
restored_history: list[ChatMessage] = []
|
||||
for raw_message in history_data:
|
||||
if isinstance(raw_message, dict):
|
||||
restored_history.append(ChatMessage.from_dict(cast(dict[str, Any], raw_message)))
|
||||
else:
|
||||
restored_history.append(cast(ChatMessage, raw_message))
|
||||
|
||||
self.conversation_history = restored_history
|
||||
|
||||
self.last_response = state.get("last_response")
|
||||
self.message_count = state.get("message_count", 0)
|
||||
logger.debug("Restored state: %s ChatMessages in history", len(self.conversation_history))
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the state to empty."""
|
||||
self.conversation_history = []
|
||||
self.last_response = None
|
||||
self.message_count = 0
|
||||
logger.debug("State reset to empty")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation of the state."""
|
||||
return f"AgentState(messages={self.message_count}, history_length={len(self.conversation_history)})"
|
||||
|
||||
def _build_agent_response_payload(self, message: ChatMessage, metadata: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Construct the agent response payload returned to callers."""
|
||||
return {
|
||||
"content": message.text,
|
||||
"agent_response": metadata.get("agent_response"),
|
||||
"message_count": metadata.get("message_count", self.message_count),
|
||||
"timestamp": metadata.get("timestamp"),
|
||||
"correlation_id": metadata.get("correlation_id"),
|
||||
}
|
||||
Reference in New Issue
Block a user