semi working

This commit is contained in:
Victoria Hall
2025-11-14 16:27:37 -06:00
Unverified
parent a75590eb9b
commit cb1ee732ed
7 changed files with 984 additions and 267 deletions
@@ -16,11 +16,11 @@ import azure.functions as func
from agent_framework import AgentProtocol, get_logger
from ._callbacks import AgentResponseCallbackProtocol
from ._durable_agent_state import DurableAgentState
from ._entities import create_agent_entity
from ._errors import IncomingRequestError
from ._models import AgentSessionId, RunRequest
from ._orchestration import AgentOrchestrationContextType, DurableAIAgent
from ._state import AgentState
logger = get_logger("agent_framework.azurefunctions")
@@ -34,9 +34,6 @@ WAIT_FOR_RESPONSE_HEADER: str = "x-ms-wait-for-response"
EntityHandler = Callable[[df.DurableEntityContext], None]
HandlerT = TypeVar("HandlerT", bound=Callable[..., Any])
DEFAULT_MAX_POLL_RETRIES: int = 30
DEFAULT_POLL_INTERVAL_SECONDS: float = 1.0
if TYPE_CHECKING:
class DFAppBase:
@@ -73,16 +70,17 @@ class AgentFunctionApp(DFAppBase):
.. code-block:: python
from agent_framework.azure import AgentFunctionApp, AzureOpenAIChatClient
from agent_framework.azure import AgentFunctionApp
from agent_framework.azure import AzureOpenAIAssistantsClient
# Create agents with unique names
weather_agent = AzureOpenAIChatClient(...).create_agent(
weather_agent = AzureOpenAIAssistantsClient(...).create_agent(
name="WeatherAgent",
instructions="You are a helpful weather agent.",
tools=[get_weather],
)
math_agent = AzureOpenAIChatClient(...).create_agent(
math_agent = AzureOpenAIAssistantsClient(...).create_agent(
name="MathAgent",
instructions="You are a helpful math assistant.",
tools=[calculate],
@@ -130,23 +128,23 @@ class AgentFunctionApp(DFAppBase):
http_auth_level: func.AuthLevel = func.AuthLevel.FUNCTION,
enable_health_check: bool = True,
enable_http_endpoints: bool = True,
max_poll_retries: int = DEFAULT_MAX_POLL_RETRIES,
poll_interval_seconds: float = DEFAULT_POLL_INTERVAL_SECONDS,
max_poll_retries: int = 30,
poll_interval_seconds: float = 1,
default_callback: AgentResponseCallbackProtocol | None = None,
):
"""Initialize the AgentFunctionApp.
:param agents: List of agent instances to register.
:param http_auth_level: HTTP authentication level (default: ``func.AuthLevel.FUNCTION``).
:param enable_health_check: Enable the built-in health check endpoint (default: ``True``).
:param enable_http_endpoints: Enable HTTP endpoints for agents (default: ``True``).
:param max_poll_retries: Maximum polling attempts when waiting for a response.
Defaults to ``DEFAULT_MAX_POLL_RETRIES``.
:param poll_interval_seconds: Delay in seconds between polling attempts.
Defaults to ``DEFAULT_POLL_INTERVAL_SECONDS``.
:param default_callback: Optional callback invoked for agents without specific callbacks.
Args:
agents: List of agent instances to register
http_auth_level: HTTP authentication level (default: FUNCTION)
enable_health_check: Enable built-in health check endpoint (default: True)
enable_http_endpoints: Enable HTTP endpoints for agents (default: True)
max_poll_retries: Maximum number of polling attempts when waiting for a response
poll_interval_seconds: Delay (in seconds) between polling attempts
default_callback: Optional callback invoked for agents without specific callbacks
:note: If no agents are provided, they can be added later using :meth:`add_agent`.
Note:
If no agents are provided, they can be added later using add_agent().
"""
logger.debug("[AgentFunctionApp] Initializing with Durable Entities...")
@@ -163,14 +161,14 @@ class AgentFunctionApp(DFAppBase):
try:
retries = int(max_poll_retries)
except (TypeError, ValueError):
retries = DEFAULT_MAX_POLL_RETRIES
retries = 10
self.max_poll_retries = max(1, retries)
try:
interval = float(poll_interval_seconds)
except (TypeError, ValueError):
interval = DEFAULT_POLL_INTERVAL_SECONDS
self.poll_interval_seconds = interval if interval > 0 else DEFAULT_POLL_INTERVAL_SECONDS
interval = 0.5
self.poll_interval_seconds = interval if interval > 0 else 0.5
if agents:
# Register all provided agents
@@ -339,10 +337,10 @@ class AgentFunctionApp(DFAppBase):
)
session_id = self._create_session_id(agent_name, thread_id)
correlation_id = self._generate_unique_id()
correlationId = self._generate_unique_id()
logger.debug(f"[HTTP Trigger] Using session ID: {session_id}")
logger.debug(f"[HTTP Trigger] Generated correlation ID: {correlation_id}")
logger.debug(f"[HTTP Trigger] Generated correlation ID: {correlationId}")
logger.debug("[HTTP Trigger] Calling entity to run agent...")
entity_instance_id = session_id.to_entity_id()
@@ -350,7 +348,7 @@ class AgentFunctionApp(DFAppBase):
req_body,
message,
thread_id,
correlation_id,
correlationId,
)
logger.debug("Signalling entity %s with request: %s", entity_instance_id, run_request)
await client.signal_entity(entity_instance_id, "run_agent", run_request)
@@ -361,7 +359,7 @@ class AgentFunctionApp(DFAppBase):
result = await self._get_response_from_entity(
client=client,
entity_instance_id=entity_instance_id,
correlation_id=correlation_id,
correlationId=correlationId,
message=message,
thread_id=thread_id,
)
@@ -377,7 +375,7 @@ class AgentFunctionApp(DFAppBase):
logger.debug("[HTTP Trigger] wait_for_response disabled; returning correlation ID")
accepted_response = self._build_accepted_response(
message=message, thread_id=thread_id, correlation_id=correlation_id
message=message, thread_id=thread_id, correlationId=correlationId
)
return self._create_http_response(
@@ -491,7 +489,7 @@ class AgentFunctionApp(DFAppBase):
self,
client: df.DurableOrchestrationClient,
entity_instance_id: df.EntityId,
) -> AgentState | None:
) -> DurableAgentState | None:
state_response = await client.read_entity_state(entity_instance_id)
if not state_response or not state_response.entity_exists:
return None
@@ -502,7 +500,7 @@ class AgentFunctionApp(DFAppBase):
typed_state_payload = cast(dict[str, Any], state_payload)
agent_state = AgentState()
agent_state = DurableAgentState()
agent_state.restore_state(typed_state_payload)
return agent_state
@@ -510,7 +508,7 @@ class AgentFunctionApp(DFAppBase):
self,
client: df.DurableOrchestrationClient,
entity_instance_id: df.EntityId,
correlation_id: str,
correlationId: str,
message: str,
thread_id: str,
) -> dict[str, Any]:
@@ -522,7 +520,7 @@ class AgentFunctionApp(DFAppBase):
retry_count = 0
result: dict[str, Any] | None = None
logger.debug(f"[HTTP Trigger] Waiting for response with correlation ID: {correlation_id}")
logger.debug(f"[HTTP Trigger] Waiting for response with correlation ID: {correlationId}")
while retry_count < max_retries:
await asyncio.sleep(interval)
@@ -530,7 +528,7 @@ class AgentFunctionApp(DFAppBase):
result = await self._poll_entity_for_response(
client=client,
entity_instance_id=entity_instance_id,
correlation_id=correlation_id,
correlationId=correlationId,
message=message,
thread_id=thread_id,
)
@@ -544,16 +542,16 @@ class AgentFunctionApp(DFAppBase):
return result
logger.warning(
f"[HTTP Trigger] Response with correlation ID {correlation_id} "
f"[HTTP Trigger] Response with correlation ID {correlationId} "
f"not found in time (waited {max_retries * interval} seconds)"
)
return await self._build_timeout_result(message=message, thread_id=thread_id, correlation_id=correlation_id)
return await self._build_timeout_result(message=message, thread_id=thread_id, correlationId=correlationId)
async def _poll_entity_for_response(
self,
client: df.DurableOrchestrationClient,
entity_instance_id: df.EntityId,
correlation_id: str,
correlationId: str,
message: str,
thread_id: str,
) -> dict[str, Any] | None:
@@ -564,34 +562,34 @@ class AgentFunctionApp(DFAppBase):
if state is None:
return None
agent_response = state.try_get_agent_response(correlation_id)
agent_response = state.try_get_agent_response(correlationId)
if agent_response:
result = self._build_success_result(
response_data=agent_response,
message=message,
thread_id=thread_id,
correlation_id=correlation_id,
correlationId=correlationId,
state=state,
)
logger.debug(f"[HTTP Trigger] Found response for correlation ID: {correlation_id}")
logger.debug(f"[HTTP Trigger] Found response for correlation ID: {correlationId}")
except Exception as exc:
logger.warning(f"[HTTP Trigger] Error reading entity state: {exc}")
return result
async def _build_timeout_result(self, message: str, thread_id: str, correlation_id: str) -> dict[str, Any]:
async def _build_timeout_result(self, message: str, thread_id: str, correlationId: str) -> dict[str, Any]:
"""Create the timeout response."""
return {
"response": "Agent is still processing or timed out...",
"message": message,
THREAD_ID_FIELD: thread_id,
"status": "timeout",
"correlation_id": correlation_id,
"correlationId": correlationId,
}
def _build_success_result(
self, response_data: dict[str, Any], message: str, thread_id: str, correlation_id: str, state: AgentState
self, response_data: dict[str, Any], message: str, thread_id: str, correlationId: str, state: DurableAgentState
) -> dict[str, Any]:
"""Build the success result returned to the HTTP caller."""
return {
@@ -600,11 +598,11 @@ class AgentFunctionApp(DFAppBase):
THREAD_ID_FIELD: thread_id,
"status": "success",
"message_count": response_data.get("message_count", state.message_count),
"correlation_id": correlation_id,
"correlationId": correlationId,
}
def _build_request_data(
self, req_body: dict[str, Any], message: str, thread_id: str, correlation_id: str
self, req_body: dict[str, Any], message: str, thread_id: str, correlationId: str
) -> dict[str, Any]:
"""Create the durable entity request payload."""
enable_tool_calls_value = req_body.get("enable_tool_calls")
@@ -616,17 +614,17 @@ class AgentFunctionApp(DFAppBase):
response_format=req_body.get("response_format"),
enable_tool_calls=enable_tool_calls,
thread_id=thread_id,
correlation_id=correlation_id,
correlationId=correlationId,
).to_dict()
def _build_accepted_response(self, message: str, thread_id: str, correlation_id: str) -> dict[str, Any]:
def _build_accepted_response(self, message: str, thread_id: str, correlationId: str) -> dict[str, Any]:
"""Build the response returned when not waiting for completion."""
return {
"response": "Agent request accepted",
"message": message,
THREAD_ID_FIELD: thread_id,
"status": "accepted",
"correlation_id": correlation_id,
"correlationId": correlationId,
}
def _create_http_response(
@@ -712,8 +710,7 @@ class AgentFunctionApp(DFAppBase):
headers: dict[str, str] = {}
raw_headers = req.headers
if isinstance(raw_headers, Mapping):
header_mapping: Mapping[str, Any] = cast(Mapping[str, Any], raw_headers)
for key, value in header_mapping.items():
for key, value in raw_headers.items():
if value is not None:
headers[str(key).lower()] = str(value)
return headers
@@ -774,8 +771,13 @@ class AgentFunctionApp(DFAppBase):
def _should_wait_for_response(self, req: func.HttpRequest, req_body: dict[str, Any]) -> bool:
"""Determine whether the caller requested to wait for the response."""
headers: dict[str, str] = self._extract_normalized_headers(req)
header_value: str | None = headers.get(WAIT_FOR_RESPONSE_HEADER)
header_value = None
raw_headers = req.headers
if isinstance(raw_headers, Mapping):
for key, value in raw_headers.items():
if str(key).lower() == WAIT_FOR_RESPONSE_HEADER:
header_value = value
break
if header_value is not None:
return self._coerce_to_bool(header_value)
@@ -799,4 +801,4 @@ class AgentFunctionApp(DFAppBase):
return bool(value)
if isinstance(value, str):
return value.strip().lower() in {"true", "1", "yes", "y", "on"}
return False
return False
@@ -19,7 +19,7 @@ class AgentCallbackContext:
"""Context supplied to callback invocations."""
agent_name: str
correlation_id: str
correlationId: str
thread_id: str | None = None
request_message: str | None = None
@@ -0,0 +1,821 @@
# Copyright (c) Microsoft. All rights reserved.
from __future__ import annotations
import json
from typing import Any, List, Dict, Optional, cast
from datetime import datetime, timezone
# Base content type
class DurableAgentStateContent:
extensionData: Optional[Dict]
def to_ai_content(self):
raise NotImplementedError
@staticmethod
def from_ai_content(content):
# Map AI content type to appropriate DurableAgentStateContent subclass
from agent_framework import (
DataContent, ErrorContent, FunctionCallContent, FunctionResultContent,
HostedFileContent, HostedVectorStoreContent, TextContent,
TextReasoningContent, UriContent, UsageContent
)
if isinstance(content, DataContent):
return DurableAgentStateDataContent.from_data_content(content)
elif isinstance(content, ErrorContent):
return DurableAgentStateErrorContent.from_error_content(content)
elif isinstance(content, FunctionCallContent):
return DurableAgentStateFunctionCallContent.from_function_call_content(content)
elif isinstance(content, FunctionResultContent):
return DurableAgentStateFunctionResultContent.from_function_result_content(content)
elif isinstance(content, HostedFileContent):
return DurableAgentStateHostedFileContent.from_hosted_file_content(content)
elif isinstance(content, HostedVectorStoreContent):
return DurableAgentStateHostedVectorStoreContent.from_hosted_vector_store_content(content)
elif isinstance(content, TextContent):
return DurableAgentStateTextContent.from_text_content(content)
elif isinstance(content, TextReasoningContent):
return DurableAgentStateTextReasoningContent.from_text_reasoning_content(content)
elif isinstance(content, UriContent):
return DurableAgentStateUriContent.from_uri_content(content)
elif isinstance(content, UsageContent):
return DurableAgentStateUsageContent.from_usage_content(content)
else:
return DurableAgentStateUnknownContent.from_unknown_content(content)
# Core state classes
class DurableAgentStateData:
conversationHistory: List['DurableAgentStateEntry']
extensionData: Optional[Dict]
def __init__(self, conversationHistory=None, extensionData=None):
self.conversationHistory = conversationHistory or []
self.extensionData = extensionData
class DurableAgentState:
data: DurableAgentStateData
schema_version: str = "1.0.0"
def __init__(self, schema_version: str = "1.0.0"):
self.data = DurableAgentStateData()
self.schema_version = schema_version
def to_dict(self) -> Dict[str, Any]:
# Serialize conversationHistory
serialized_history = []
for entry in self.data.conversationHistory:
# Properly serialize each entry to a dictionary
if hasattr(entry, 'to_dict'):
serialized_history.append(entry.to_dict())
else:
# Fallback for already-serialized entries
serialized_history.append(entry)
return {
"schemaVersion": self.schema_version,
"data": {
"conversationHistory": serialized_history,
"extensionData": self.data.extensionData
},
"message_count": self.message_count,
"last_response": self.last_response,
}
def to_json(self) -> str:
return json.dumps(self.to_dict())
@classmethod
def from_dict(cls, obj: Dict[str, Any]) -> "DurableAgentState":
schema_version = obj.get("schemaVersion")
if not schema_version:
raise ValueError("The durable agent state is missing the 'schemaVersion' property.")
if not schema_version.startswith("1."):
raise ValueError(f"The durable agent state schema version '{schema_version}' is not supported.")
data_dict = obj.get("data")
if data_dict is None:
raise ValueError("The durable agent state is missing the 'data' property.")
instance = cls(schema_version=schema_version)
# Deserialize the data dict into DurableAgentStateData
if isinstance(data_dict, dict):
instance.data = DurableAgentStateData(
conversationHistory=data_dict.get("conversationHistory", []),
extensionData=data_dict.get("extensionData")
)
return instance
@classmethod
def from_json(cls, json_str: str) -> "DurableAgentState":
try:
obj = json.loads(json_str)
except json.JSONDecodeError as e:
raise ValueError("The durable agent state is not valid JSON.") from e
return cls.from_dict(obj)
def restore_state(self, state: dict[str, Any]) -> None:
"""Restore state from a dictionary.
Args:
state: Dictionary containing schemaVersion and data (full state structure)
"""
# Extract the data portion from the state
data_dict = state.get("data", {})
# Restore the conversation history - deserialize entries from dicts to objects
history_data = data_dict.get("conversationHistory", [])
deserialized_history = []
for entry_dict in history_data:
if isinstance(entry_dict, dict):
# Deserialize based on whether it's a request or response
if "usage" in entry_dict:
deserialized_history.append(DurableAgentStateResponse.from_dict(entry_dict))
elif "response_type" in entry_dict:
deserialized_history.append(DurableAgentStateRequest.from_dict(entry_dict))
else:
deserialized_history.append(DurableAgentStateEntry.from_dict(entry_dict))
else:
# Already an object
deserialized_history.append(entry_dict)
self.data.conversationHistory = deserialized_history
self.data.extensionData = data_dict.get("extensionData")
@property
def message_count(self) -> int:
"""Get the count of conversation entries (requests + responses)."""
return len(self.data.conversationHistory)
@property
def last_response(self) -> str | None:
"""Get the text from the last assistant response in the conversation history."""
# Iterate through messages in reverse to find the last assistant message
for entry in reversed(self.data.conversationHistory):
for message in reversed(entry.messages):
if message.role == "assistant":
return message.text
return None
def add_assistant_message(self, content: str, agent_run_response, correlationId: str) -> None:
"""Add an assistant message to the conversation history.
Args:
content: The message content
agent_run_response: The agent's run response
correlationId: The correlation ID for this response
"""
# This method is called from the entity after storing the response
# The response has already been added to conversationHistory, so we don't need to do anything here
pass
def try_get_agent_response(self, correlationId: str) -> Dict[str, Any] | None:
"""Try to get an agent response by correlation ID.
Args:
correlationId: The correlation ID to search for
Returns:
Response data dict if found, None otherwise
"""
# Search through conversation history for a response with this correlationId
for entry in self.data.conversationHistory:
if hasattr(entry, 'correlationId') and entry.correlationId == correlationId:
# Found the entry, extract response data
if isinstance(entry, DurableAgentStateResponse):
# Get the text content from assistant messages only
content = ""
for message in entry.messages:
if hasattr(message, 'role') and message.role == "assistant" and hasattr(message, 'text'):
content += message.text
return {
"content": content,
"message_count": self.message_count,
"correlationId": correlationId
}
return None
# Entry classes
class DurableAgentStateEntry:
json_type: str
correlationId: str
created_at: datetime
messages: List['DurableAgentStateMessage']
extensionData: Optional[Dict]
# Request-only
responseType: Optional[str] = None
responseSchema: Optional[dict] = None
# Response-only
usage: Optional["DurableAgentStateUsage"] = None
def __init__(self, json_type, correlationId, created_at, messages, extensionData=None, responseType=None, responseSchema=None, usage=None):
self.json_type = json_type
self.correlationId = correlationId
self.created_at = created_at
self.messages = messages
self.extensionData = extensionData
self.responseType = responseType
self.responseSchema = responseSchema
self.usage = usage
def to_dict(self) -> Dict[str, Any]:
data = {
"$type": self.json_type,
"correlationId": self.correlationId,
"createdAt": self.created_at.isoformat() if isinstance(self.created_at, datetime) else self.created_at,
"messages": [m.to_dict() if hasattr(m, 'to_dict') else m for m in self.messages],
"extensionData": self.extensionData
}
if self.json_type == "request":
data.update({
"responseType": self.responseType,
"responseSchema": self.responseSchema,
})
elif self.json_type == "response":
if self.usage:
data["usage"] = self.usage.to_dict()
return data
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'DurableAgentStateEntry':
from dateutil import parser as date_parser
created_at = data.get("created_at")
if isinstance(created_at, str):
created_at = date_parser.parse(created_at)
messages = []
for msg_dict in data.get("messages", []):
if isinstance(msg_dict, dict):
messages.append(DurableAgentStateMessage.from_dict(msg_dict))
else:
messages.append(msg_dict)
return cls(
correlationId=data.get("correlationId"),
created_at=created_at,
messages=messages,
extensionData=data.get("extensionData")
)
class DurableAgentStateRequest(DurableAgentStateEntry):
response_type: Optional[str] = None
response_schema: Optional[Dict] = None
def __init__(self, correlationId, created_at, messages, json_type, extensionData=None, response_type=None, response_schema=None):
self.correlationId = correlationId
self.created_at = created_at
self.messages = messages
self.json_type = json_type
self.extensionData = extensionData
self.response_type = response_type
self.response_schema = response_schema
def to_dict(self) -> Dict[str, Any]:
base_dict = super().to_dict()
base_dict["response_type"] = self.response_type
base_dict["response_schema"] = self.response_schema
return base_dict
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'DurableAgentStateRequest':
from dateutil import parser as date_parser
created_at = data.get("created_at")
if isinstance(created_at, str):
created_at = date_parser.parse(created_at)
messages = []
for msg_dict in data.get("messages", []):
if isinstance(msg_dict, dict):
messages.append(DurableAgentStateMessage.from_dict(msg_dict))
else:
messages.append(msg_dict)
return cls(
json_type=data.get("$type", "request"),
correlationId=data.get("correlationId"),
created_at=created_at,
messages=messages,
extensionData=data.get("extensionData"),
response_type=data.get("response_type"),
response_schema=data.get("response_schema")
)
@staticmethod
def from_run_request(content):
from agent_framework import TextContent
return DurableAgentStateRequest(correlationId=content.correlationId,
messages=[DurableAgentStateMessage.from_chat_message(content)],
created_at=datetime.now(tz=timezone.utc),
json_type="request",
extensionData=content.extensionData if hasattr(content, 'extensionData') else None,
response_type="text" if isinstance(content.response_format, TextContent) else "json",
response_schema=content.response_format)
class DurableAgentStateResponse(DurableAgentStateEntry):
usage: Optional['DurableAgentStateUsage'] = None
def __init__(self, json_type, correlationId, created_at, messages, extensionData=None, usage=None):
self.json_type = json_type
self.correlationId = correlationId
self.created_at = created_at
self.messages = messages
self.extensionData = extensionData
self.usage = usage
def to_dict(self) -> Dict[str, Any]:
base_dict = super().to_dict()
base_dict["usage"] = self.usage.to_dict() if self.usage and hasattr(self.usage, 'to_dict') else self.usage
return base_dict
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'DurableAgentStateResponse':
from dateutil import parser as date_parser
created_at = data.get("created_at")
if isinstance(created_at, str):
created_at = date_parser.parse(created_at)
messages = []
for msg_dict in data.get("messages", []):
if isinstance(msg_dict, dict):
messages.append(DurableAgentStateMessage.from_dict(msg_dict))
else:
messages.append(msg_dict)
usage_dict = data.get("usage")
usage = None
if usage_dict and isinstance(usage_dict, dict):
usage = DurableAgentStateUsage.from_dict(usage_dict)
elif usage_dict:
usage = usage_dict
return cls(
json_type=data.get("$type", "response"),
correlationId=data.get("correlationId"),
created_at=created_at,
messages=messages,
extensionData=data.get("extensionData"),
usage=usage
)
@staticmethod
def from_run_response(correlationId: str, response) -> DurableAgentStateResponse:
"""
Creates a DurableAgentStateResponse from an AgentRunResponse.
"""
# Determine the earliest created_at timestamp among messages (if available)
timestamps = [m.created_at for m in response.messages if hasattr(m, 'created_at') and m.created_at is not None]
created_at = min(timestamps) if timestamps else datetime.now(tz=timezone.utc)
return DurableAgentStateResponse(
json_type="response",
correlationId=correlationId,
created_at=created_at,
messages=[DurableAgentStateMessage.from_chat_message(m) for m in response.messages],
usage=DurableAgentStateUsage.from_usage(response.usage) if hasattr(response, 'usage') and response.usage else None
)
def to_run_response(self):
"""
Converts this DurableAgentStateResponse back to an AgentRunResponse.
"""
from agent_framework import AgentRunResponse
return AgentRunResponse(
created_at=self.created_at,
messages=[m.to_chat_message() for m in self.messages],
usage=self.usage.to_usage_details() if self.usage else None
)
# Message class
class DurableAgentStateMessage:
role: str
contents: List[DurableAgentStateContent]
author_name: Optional[str] = None
created_at: Optional[datetime] = None
extensionData: Optional[Dict] = None
def __init__(self, role, contents, author_name=None, created_at=None, extensionData=None):
self.role = role
self.contents = contents
self.author_name = author_name
self.created_at = created_at
self.extensionData = extensionData
def to_dict(self) -> Dict[str, Any]:
return {
"role": self.role,
"contents": [
{"$type": c.to_dict().get("type", "text"), **{k: v for k, v in c.to_dict().items() if k != "type"}} for c in self.contents
],
"authorName": self.author_name,
"createdAt": self.created_at.isoformat() if isinstance(self.created_at, datetime) else self.created_at,
"extensionData": self.extensionData
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'DurableAgentStateMessage':
from dateutil import parser as date_parser
created_at = data.get("created_at")
if created_at and isinstance(created_at, str):
created_at = date_parser.parse(created_at)
contents = []
for content_dict in data.get("contents", []):
if isinstance(content_dict, dict):
content_type = content_dict.get("type")
if content_type == "text":
contents.append(DurableAgentStateTextContent(text=content_dict.get("text")))
elif content_type == "data":
contents.append(DurableAgentStateDataContent(uri=content_dict.get("uri"), media_type=content_dict.get("media_type")))
elif content_type == "error":
contents.append(DurableAgentStateErrorContent(message=content_dict.get("message"), error_code=content_dict.get("error_code"), details=content_dict.get("details")))
elif content_type == "function_call":
contents.append(DurableAgentStateFunctionCallContent(call_id=content_dict.get("call_id"), name=content_dict.get("name"), arguments=content_dict.get("arguments")))
elif content_type == "function_result":
contents.append(DurableAgentStateFunctionResultContent(call_id=content_dict.get("call_id"), result=content_dict.get("result")))
elif content_type == "hosted_file":
contents.append(DurableAgentStateHostedFileContent(file_id=content_dict.get("file_id")))
elif content_type == "hosted_vector_store":
contents.append(DurableAgentStateHostedVectorStoreContent(vector_store_id=content_dict.get("vector_store_id")))
elif content_type == "text_reasoning":
contents.append(DurableAgentStateTextReasoningContent(text=content_dict.get("text")))
elif content_type == "uri":
contents.append(DurableAgentStateUriContent(uri=content_dict.get("uri"), media_type=content_dict.get("media_type")))
elif content_type == "usage":
usage_data = content_dict.get("usage")
if usage_data and isinstance(usage_data, dict):
contents.append(DurableAgentStateUsageContent(usage=DurableAgentStateUsage.from_dict(usage_data)))
elif content_type == "unknown":
contents.append(DurableAgentStateUnknownContent(content=content_dict.get("content")))
else:
contents.append(content_dict)
return cls(
role=data.get("role"),
contents=contents,
author_name=data.get("author_name"),
created_at=created_at,
extensionData=data.get("extensionData")
)
@property
def text(self) -> str:
"""Extract text from the contents list."""
text_parts = []
for content in self.contents:
if isinstance(content, DurableAgentStateTextContent):
text_parts.append(content.text or "")
return "".join(text_parts)
@staticmethod
def from_chat_message(content):
# Convert to a list of DurableAgentStateContent objects
contents_list = []
if hasattr(content, 'message') and isinstance(content.message, str):
# RunRequest with 'message' attribute
contents_list = [DurableAgentStateTextContent(text=content.message)]
elif hasattr(content, 'contents') and content.contents:
# ChatMessage with 'contents' attribute - convert each content object
for c in content.contents:
converted = DurableAgentStateContent.from_ai_content(c)
contents_list.append(converted)
# Convert role enum to string if needed
role_value = content.role.value if hasattr(content.role, 'value') else str(content.role)
return DurableAgentStateMessage(
role=role_value,
contents=contents_list,
author_name=content.author_name if hasattr(content, 'author_name') else None,
created_at=content.created_at if hasattr(content, 'created_at') else None,
extensionData=content.extensionData if hasattr(content, 'extensionData') else None
)
def to_chat_message(self):
from agent_framework import ChatMessage
# Convert DurableAgentStateContent objects back to agent_framework content objects
ai_contents = [c.to_ai_content() for c in self.contents]
return ChatMessage(role=self.role, contents=ai_contents, author_name=self.author_name, created_at=self.created_at, extensionData=self.extensionData)
# Content subclasses
class DurableAgentStateDataContent(DurableAgentStateContent):
uri: str = ""
media_type: Optional[str] = None
def __init__(self, uri, media_type=None):
self.uri = uri
self.media_type = media_type
def to_dict(self) -> Dict[str, Any]:
return {
"type": "data",
"uri": self.uri,
"mediaType": self.media_type
}
@staticmethod
def from_data_content(content):
return DurableAgentStateDataContent(uri=content.uri, media_type=content.media_type)
def to_ai_content(self):
from agent_framework import DataContent
return DataContent(uri=self.uri, media_type=self.media_type)
class DurableAgentStateErrorContent(DurableAgentStateContent):
message: Optional[str] = None
error_code: Optional[str] = None
details: Optional[str] = None
def __init__(self, message=None, error_code=None, details=None):
self.message = message
self.error_code = error_code
self.details = details
def to_dict(self) -> Dict[str, Any]:
return {
"type": "error",
"message": self.message,
"errorCode": self.error_code,
"details": self.details
}
@staticmethod
def from_error_content(content):
return DurableAgentStateErrorContent(message=content.message, error_code=content.error_code, details=content.details)
def to_ai_content(self):
from agent_framework import ErrorContent
return ErrorContent(message=self.message, error_code=self.error_code, details=self.details)
class DurableAgentStateFunctionCallContent(DurableAgentStateContent):
call_id: str
name: str
arguments: Dict[str, object]
def __init__(self, call_id, name, arguments):
self.call_id = call_id
self.name = name
self.arguments = arguments
def to_dict(self) -> Dict[str, Any]:
return {
"type": "function_call",
"callId": self.call_id,
"name": self.name,
"arguments": self.arguments
}
@staticmethod
def from_function_call_content(content):
return DurableAgentStateFunctionCallContent(
call_id=content.call_id,
name=content.name,
arguments=content.arguments if content.arguments else {}
)
def to_ai_content(self):
from agent_framework import FunctionCallContent
return FunctionCallContent(call_id=self.call_id, name=self.name, arguments=self.arguments)
class DurableAgentStateFunctionResultContent(DurableAgentStateContent):
call_id: str
result: Optional[object] = None
def __init__(self, call_id, result=None):
self.call_id = call_id
self.result = result
def to_dict(self) -> Dict[str, Any]:
return {
"type": "function_result",
"callId": self.call_id,
"result": self.result
}
@staticmethod
def from_function_result_content(content):
return DurableAgentStateFunctionResultContent(call_id=content.call_id, result=content.result)
def to_ai_content(self):
from agent_framework import FunctionResultContent
return FunctionResultContent(call_id=self.call_id, result=self.result)
class DurableAgentStateHostedFileContent(DurableAgentStateContent):
file_id: str
def __init__(self, file_id):
self.file_id = file_id
def to_dict(self) -> Dict[str, Any]:
return {
"type": "hosted_file",
"fileId": self.file_id
}
@staticmethod
def from_hosted_file_content(content):
return DurableAgentStateHostedFileContent(file_id=content.file_id)
def to_ai_content(self):
from agent_framework import HostedFileContent
return HostedFileContent(file_id=self.file_id)
class DurableAgentStateHostedVectorStoreContent(DurableAgentStateContent):
vector_store_id: str
def __init__(self, vector_store_id):
self.vector_store_id = vector_store_id
def to_dict(self) -> Dict[str, Any]:
return {
"type": "hosted_vector_store",
"vectorStoreId": self.vector_store_id
}
@staticmethod
def from_hosted_vector_store_content(content):
return DurableAgentStateHostedVectorStoreContent(vector_store_id=content.vector_store_id)
def to_ai_content(self):
from agent_framework import HostedVectorStoreContent
return HostedVectorStoreContent(vector_store_id=self.vector_store_id)
class DurableAgentStateTextContent(DurableAgentStateContent):
text: Optional[str] = None
def __init__(self, text):
self.text = text
def to_dict(self) -> Dict[str, Any]:
return {
"type": "text",
"text": self.text
}
@staticmethod
def from_text_content(content):
return DurableAgentStateTextContent(text=content.text)
def to_ai_content(self):
from agent_framework import TextContent
return TextContent(text=self.text)
class DurableAgentStateTextReasoningContent(DurableAgentStateContent):
text: Optional[str] = None
def __init__(self, text):
self.text = text
def to_dict(self) -> Dict[str, Any]:
return {
"type": "text_reasoning",
"text": self.text
}
@staticmethod
def from_text_reasoning_content(content):
return DurableAgentStateTextReasoningContent(text=content.text)
def to_ai_content(self):
from agent_framework import TextReasoningContent
return TextReasoningContent(text=self.text)
class DurableAgentStateUriContent(DurableAgentStateContent):
uri: str
media_type: str
def __init__(self, uri, media_type):
self.uri = uri
self.media_type = media_type
def to_dict(self) -> Dict[str, Any]:
return {
"type": "uri",
"uri": self.uri,
"mediaType": self.media_type
}
@staticmethod
def from_uri_content(content):
return DurableAgentStateUriContent(uri=content.uri, media_type=content.media_type)
def to_ai_content(self):
from agent_framework import UriContent
return UriContent(uri=self.uri, media_type=self.media_type)
class DurableAgentStateUsage:
input_token_count: Optional[int] = None
output_token_count: Optional[int] = None
total_token_count: Optional[int] = None
extensionData: Optional[Dict] = None
def __init__(self, input_token_count=None, output_token_count=None, total_token_count=None, extensionData=None):
self.input_token_count = input_token_count
self.output_token_count = output_token_count
self.total_token_count = total_token_count
self.extensionData = extensionData
def to_dict(self) -> Dict[str, Any]:
return {
"inputTokenCount": self.input_token_count,
"outputTokenCount": self.output_token_count,
"totalTokenCount": self.total_token_count,
"extensionData": self.extensionData
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'DurableAgentStateUsage':
return cls(
input_token_count=data.get("input_token_count"),
output_token_count=data.get("output_token_count"),
total_token_count=data.get("total_token_count"),
extensionData=data.get("extensionData")
)
@staticmethod
def from_usage(usage):
if usage is None:
return None
return DurableAgentStateUsage(
input_token_count=usage.input_token_count,
output_token_count=usage.output_token_count,
total_token_count=usage.total_token_count
)
def to_usage_details(self):
# Convert back to AI SDK UsageDetails
from agent_framework import UsageDetails
return UsageDetails(
input_token_count=self.input_token_count,
output_token_count=self.output_token_count,
total_token_count=self.total_token_count
)
class DurableAgentStateUsageContent(DurableAgentStateContent):
usage: DurableAgentStateUsage = DurableAgentStateUsage()
def __init__(self, usage):
self.usage = usage
def to_dict(self) -> Dict[str, Any]:
return {
"type": "usage",
"usage": self.usage.to_dict() if hasattr(self.usage, 'to_dict') else self.usage
}
@staticmethod
def from_usage_content(content):
return DurableAgentStateUsageContent(usage=DurableAgentStateUsage.from_usage(content.details))
def to_ai_content(self):
from agent_framework import UsageContent
return UsageContent(details=self.usage.to_usage_details())
class DurableAgentStateUnknownContent(DurableAgentStateContent):
content: dict
def __init__(self, content):
self.content = content
def to_dict(self) -> Dict[str, Any]:
return {
"type": "unknown",
"content": self.content
}
@staticmethod
def from_unknown_content(content):
return DurableAgentStateUnknownContent(content=json.loads(content))
def to_ai_content(self):
from agent_framework import BaseContent
if not self.content:
raise Exception(f"The content is missing and cannot be converted to valid AI content.")
return BaseContent(content=json.loads(self.content))
@@ -10,15 +10,21 @@ allows for long-running agent conversations.
import asyncio
import inspect
import json
from collections.abc import AsyncIterable, Callable
from typing import Any, cast
from collections.abc import AsyncIterable
from datetime import datetime, timezone
from typing import Any, cast, Callable
import azure.durable_functions as df
from agent_framework import AgentProtocol, AgentRunResponse, AgentRunResponseUpdate, Role, get_logger
from ._callbacks import AgentCallbackContext, AgentResponseCallbackProtocol
from ._durable_agent_state import (
DurableAgentState,
DurableAgentStateData,
DurableAgentStateRequest,
DurableAgentStateResponse,
)
from ._models import AgentResponse, RunRequest
from ._state import AgentState
logger = get_logger("agent_framework.azurefunctions.entities")
@@ -38,11 +44,11 @@ class AgentEntity:
Attributes:
agent: The AgentProtocol instance
state: The AgentState managing conversation history
state: The DurableAgentState managing conversation history
"""
agent: AgentProtocol
state: AgentState
state: DurableAgentState
def __init__(
self,
@@ -56,8 +62,9 @@ class AgentEntity:
callback: Optional callback invoked during streaming updates and final responses
"""
self.agent = agent
self.state = AgentState()
self.state = DurableAgentState()
self.callback = callback
self._pending_requests: dict[str, DurableAgentStateRequest] = {}
logger.debug(f"[AgentEntity] Initialized with agent type: {type(agent).__name__}")
@@ -89,31 +96,49 @@ class AgentEntity:
message = run_request.message
thread_id = run_request.thread_id
correlation_id = run_request.correlation_id
correlationId = run_request.correlationId
if not thread_id:
raise ValueError("RunRequest must include a thread_id")
if not correlation_id:
raise ValueError("RunRequest must include a correlation_id")
if not correlationId:
raise ValueError("RunRequest must include a correlationId")
role = run_request.role or Role.USER
response_format = run_request.response_format
enable_tool_calls = run_request.enable_tool_calls
# Store request in pending (will be combined with response later)
state_request = DurableAgentStateRequest.from_run_request(run_request)
self.state.data.conversationHistory.append(state_request)
self._pending_requests[correlationId] = state_request
logger.debug(f"[AgentEntity.run_agent] Received message: {message}")
logger.debug(f"[AgentEntity.run_agent] Thread ID: {thread_id}")
logger.debug(f"[AgentEntity.run_agent] Correlation ID: {correlation_id}")
logger.debug(f"[AgentEntity.run_agent] Correlation ID: {correlationId}")
logger.debug(f"[AgentEntity.run_agent] Role: {role.value}")
logger.debug(f"[AgentEntity.run_agent] Enable tool calls: {enable_tool_calls}")
logger.debug(f"[AgentEntity.run_agent] Response format: {'provided' if response_format else 'none'}")
# Store message in history with role
self.state.add_user_message(message, role=role, correlation_id=correlation_id)
logger.debug(f"[AgentEntity.run_agent] Saved state request: {state_request}")
logger.debug("[AgentEntity.run_agent] Executing agent...")
try:
logger.debug("[AgentEntity.run_agent] Starting agent invocation")
run_kwargs: dict[str, Any] = {"messages": self.state.get_chat_messages()}
# Build messages from conversation history plus the current request
chat_messages = [
m.to_chat_message()
for entry in self.state.data.conversationHistory
for m in entry.messages
]
# Add the current request message
# for m in state_request.messages:
# chat_messages.append(m.to_chat_message())
# Strip additional_properties from all messages to avoid metadata being sent to Azure OpenAI
# Azure OpenAI doesn't support the 'metadata' field in messages
for msg in chat_messages:
if hasattr(msg, 'additional_properties'):
msg.additional_properties = {}
run_kwargs: dict[str, Any] = {"messages": chat_messages}
if not enable_tool_calls:
run_kwargs["tools"] = None
if response_format:
@@ -121,7 +146,7 @@ class AgentEntity:
agent_run_response: AgentRunResponse = await self._invoke_agent(
run_kwargs=run_kwargs,
correlation_id=correlation_id,
correlationId=correlationId,
thread_id=thread_id,
request_message=message,
)
@@ -131,6 +156,17 @@ class AgentEntity:
type(agent_run_response).__name__,
)
# Convert response into DurableAgentStateResponse and combine with request
state_response = DurableAgentStateResponse.from_run_response(correlationId, agent_run_response)
# Get the pending request and combine its messages with the response messages
# pending_request = self._pending_requests.pop(correlationId, None)
# if pending_request:
# # Combine request and response messages into a single entry
# state_response.messages = pending_request.messages + state_response.messages
self.state.data.conversationHistory.append(state_response)
response_text = None
structured_response = None
@@ -161,13 +197,13 @@ class AgentEntity:
message=str(message),
thread_id=str(thread_id),
status="success",
message_count=self.state.message_count,
message_count=len(self.state.data.conversationHistory),
structured_response=structured_response,
)
result = agent_response.to_dict()
content = json.dumps(structured_response) if structured_response else (response_text or "")
self.state.add_assistant_message(content, agent_run_response, correlation_id)
self.state.add_assistant_message(content, agent_run_response, correlationId)
logger.debug("[AgentEntity.run_agent] AgentRunResponse stored in conversation history")
return result
@@ -181,12 +217,39 @@ class AgentEntity:
logger.error(f"Error type: {type(exc).__name__}")
logger.error(f"Full traceback:\n{error_traceback}")
# Create error response and store it in conversation history so polling can find it
from agent_framework import ChatMessage, ErrorContent
# Get the pending request
pending_request = self._pending_requests.pop(correlationId, None)
# Create error message
error_message = DurableAgentStateMessage.from_chat_message(
ChatMessage(role="assistant", contents=[ErrorContent(message=str(exc), error_code=type(exc).__name__)])
)
# Combine request and error response messages
messages = []
if pending_request:
messages.extend(pending_request.messages)
messages.append(error_message)
# Create and store error response in conversation history
error_state_response = DurableAgentStateResponse(
correlationId=correlationId,
createdAt=datetime.now(tz=timezone.utc),
messages=messages,
extensionData=None,
usage=None
)
self.state.data.conversationHistory.append(error_state_response)
error_response = AgentResponse(
response=f"Error: {exc!s}",
message=str(message),
thread_id=str(thread_id),
status="error",
message_count=self.state.message_count,
message_count=len(self.state.data.conversationHistory),
error=str(exc),
error_type=type(exc).__name__,
)
@@ -195,7 +258,7 @@ class AgentEntity:
async def _invoke_agent(
self,
run_kwargs: dict[str, Any],
correlation_id: str,
correlationId: str,
thread_id: str,
request_message: str,
) -> AgentRunResponse:
@@ -203,7 +266,7 @@ class AgentEntity:
callback_context: AgentCallbackContext | None = None
if self.callback is not None:
callback_context = self._build_callback_context(
correlation_id=correlation_id,
correlationId=correlationId,
thread_id=thread_id,
request_message=request_message,
)
@@ -317,7 +380,7 @@ class AgentEntity:
def _build_callback_context(
self,
correlation_id: str,
correlationId: str,
thread_id: str,
request_message: str,
) -> AgentCallbackContext:
@@ -325,7 +388,7 @@ class AgentEntity:
agent_name = getattr(self.agent, "name", None) or type(self.agent).__name__
return AgentCallbackContext(
agent_name=agent_name,
correlation_id=correlation_id,
correlationId=correlationId,
thread_id=thread_id,
request_message=request_message,
)
@@ -333,7 +396,7 @@ class AgentEntity:
def reset(self, context: df.DurableEntityContext) -> None:
"""Reset the entity state (clear conversation history)."""
logger.debug("[AgentEntity.reset] Resetting entity state")
self.state.reset()
self.state.data = DurableAgentStateData(conversationHistory=[])
logger.debug("[AgentEntity.reset] State reset complete")
@@ -392,8 +455,9 @@ def create_agent_entity(
logger.error("[entity_function] Unknown operation: %s", operation)
context.set_result({"error": f"Unknown operation: {operation}"})
logger.info("State dict: %s", str(entity.state.to_dict()))
context.set_state(entity.state.to_dict())
logger.debug(f"[entity_function] Operation {operation} completed successfully")
logger.info(f"[entity_function] Operation {operation} completed successfully")
except Exception as exc:
import traceback
@@ -424,4 +488,4 @@ def create_agent_entity(
logger.error("[entity_function] Unexpected error executing entity: %s", exc, exc_info=True)
context.set_result({"error": str(exc), "status": "error"})
return entity_function
return entity_function
@@ -282,7 +282,7 @@ class RunRequest:
response_format: Optional Pydantic BaseModel type describing the structured response format
enable_tool_calls: Whether to enable tool calls for this request
thread_id: Optional thread ID for tracking
correlation_id: Optional correlation ID for tracking the response to this specific request
correlationId: Optional correlation ID for tracking the response to this specific request
"""
message: str
@@ -290,7 +290,10 @@ class RunRequest:
response_format: type[BaseModel] | None = None
enable_tool_calls: bool = True
thread_id: str | None = None
correlation_id: str | None = None
correlationId: str | None = None
author_name: str | None = None
created_at: str | None = None
extension_data: dict[str, Any] | None = None
def __init__(
self,
@@ -299,14 +302,14 @@ class RunRequest:
response_format: type[BaseModel] | None = None,
enable_tool_calls: bool = True,
thread_id: str | None = None,
correlation_id: str | None = None,
correlationId: str | None = None,
) -> None:
self.message = message
self.role = self.coerce_role(role)
self.response_format = response_format
self.enable_tool_calls = enable_tool_calls
self.thread_id = thread_id
self.correlation_id = correlation_id
self.correlationId = correlationId
@staticmethod
def coerce_role(value: Role | str | None) -> Role:
@@ -331,8 +334,14 @@ class RunRequest:
result["response_format"] = _serialize_response_format(self.response_format)
if self.thread_id:
result["thread_id"] = self.thread_id
if self.correlation_id:
result["correlation_id"] = self.correlation_id
if self.correlationId:
result["correlationId"] = self.correlationId
if self.author_name:
result["author_name"] = self.author_name
if self.created_at:
result["created_at"] = self.created_at
if self.extension_data:
result["extension_data"] = self.extension_data
return result
@classmethod
@@ -344,7 +353,7 @@ class RunRequest:
response_format=_deserialize_response_format(data.get("response_format")),
enable_tool_calls=data.get("enable_tool_calls", True),
thread_id=data.get("thread_id"),
correlation_id=data.get("correlation_id"),
correlationId=data.get("correlationId"),
)
@@ -392,4 +401,4 @@ class AgentResponse:
if self.error_type:
result["error_type"] = self.error_type
return result
return result
@@ -129,13 +129,13 @@ class DurableAIAgent(AgentProtocol):
# Generate a deterministic correlation ID for this call
# This is required by the entity and must be unique per call
correlation_id = str(self.context.new_uuid())
correlationId = str(self.context.new_uuid())
# Prepare the request using RunRequest model
run_request = RunRequest(
message=message_str,
enable_tool_calls=enable_tool_calls,
correlation_id=correlation_id,
correlationId=correlationId,
thread_id=session_id.key,
response_format=response_format,
)
@@ -1,179 +0,0 @@
# Copyright (c) Microsoft. All rights reserved.
"""Agent State Management.
This module defines the AgentState class for managing conversation state and
serializing agent framework responses.
"""
from collections.abc import MutableMapping
from datetime import datetime, timezone
from typing import Any, cast
from agent_framework import AgentRunResponse, ChatMessage, Role, get_logger
logger = get_logger("agent_framework.azurefunctions.state")
class AgentState:
"""Manages agent conversation state using agent_framework types (ChatMessage, AgentRunResponse).
This class handles:
- Conversation history tracking using ChatMessage objects
- Agent response storage using AgentRunResponse objects with correlation IDs
- State persistence and restoration
- Message counting
"""
def __init__(self) -> None:
"""Initialize empty agent state."""
self.conversation_history: list[ChatMessage] = []
self.last_response: str | None = None
self.message_count: int = 0
def _current_timestamp(self) -> str:
"""Return an ISO 8601 UTC timestamp."""
return datetime.now(timezone.utc).isoformat()
def add_user_message(
self,
content: str,
role: Role = Role.USER,
correlation_id: str | None = None,
) -> None:
"""Add a user message to the conversation history as a ChatMessage object.
Args:
content: The message content
role: The message role (user, system, etc.)
correlation_id: Optional correlation identifier associated with the user message
"""
self.message_count += 1
timestamp = self._current_timestamp()
additional_props: MutableMapping[str, Any] = {"timestamp": timestamp}
if correlation_id is not None:
additional_props["correlation_id"] = correlation_id
chat_message = ChatMessage(role=role, text=content, additional_properties=additional_props)
self.conversation_history.append(chat_message)
logger.debug(f"Added {role} ChatMessage to history (message #{self.message_count})")
def add_assistant_message(
self, content: str, agent_response: AgentRunResponse, correlation_id: str | None = None
) -> None:
"""Add an assistant message to the conversation history with full agent response.
Args:
content: The text content of the response
agent_response: The AgentRunResponse object from the agent framework
correlation_id: Optional correlation ID for tracking this response
"""
self.last_response = content
timestamp = self._current_timestamp()
serialized_response = self.serialize_response(agent_response)
# Create a ChatMessage for the assistant response
# The agent_response already contains messages, but we store it as a custom ChatMessage
# with the agent_response stored in additional_properties for full metadata preservation
additional_props: dict[str, Any] = {
"agent_response": serialized_response,
"correlation_id": correlation_id,
"timestamp": timestamp,
"message_count": self.message_count,
}
chat_message = ChatMessage(role="assistant", text=content, additional_properties=additional_props)
self.conversation_history.append(chat_message)
logger.debug(
f"Added assistant ChatMessage to history with AgentRunResponse metadata (correlation_id: {correlation_id})"
)
def get_chat_messages(self) -> list[ChatMessage]:
"""Return a copy of the full conversation history."""
return list(self.conversation_history)
def try_get_agent_response(self, correlation_id: str) -> dict[str, Any] | None:
"""Get an agent response by correlation ID.
Args:
correlation_id: The correlation ID to look up
Returns:
The agent response data if found, None otherwise
"""
for message in reversed(self.conversation_history):
metadata = getattr(message, "additional_properties", {}) or {}
if metadata.get("correlation_id") == correlation_id:
return self._build_agent_response_payload(message, metadata)
return None
def serialize_response(self, response: AgentRunResponse) -> dict[str, Any]:
"""Serialize an ``AgentRunResponse`` to a dictionary.
Args:
response: The agent framework response object
Returns:
Dictionary containing all response fields
"""
try:
return response.to_dict()
except Exception as exc: # pragma: no cover - defensive logging path
logger.warning(f"Error serializing response: {exc}")
return {"response": str(response), "serialization_error": str(exc)}
def to_dict(self) -> dict[str, Any]:
"""Get the current state as a dictionary for persistence.
Returns:
Dictionary containing conversation_history (as serialized ChatMessages),
last_response, and message_count
"""
return {
"conversation_history": [msg.to_dict() for msg in self.conversation_history],
"last_response": self.last_response,
"message_count": self.message_count,
}
def restore_state(self, state: dict[str, Any]) -> None:
"""Restore state from a dictionary, reconstructing ChatMessage objects.
Args:
state: Dictionary containing conversation_history, last_response, and message_count
"""
# Restore conversation history as ChatMessage objects
history_data = state.get("conversation_history", [])
restored_history: list[ChatMessage] = []
for raw_message in history_data:
if isinstance(raw_message, dict):
restored_history.append(ChatMessage.from_dict(cast(dict[str, Any], raw_message)))
else:
restored_history.append(cast(ChatMessage, raw_message))
self.conversation_history = restored_history
self.last_response = state.get("last_response")
self.message_count = state.get("message_count", 0)
logger.debug("Restored state: %s ChatMessages in history", len(self.conversation_history))
def reset(self) -> None:
"""Reset the state to empty."""
self.conversation_history = []
self.last_response = None
self.message_count = 0
logger.debug("State reset to empty")
def __repr__(self) -> str:
"""String representation of the state."""
return f"AgentState(messages={self.message_count}, history_length={len(self.conversation_history)})"
def _build_agent_response_payload(self, message: ChatMessage, metadata: dict[str, Any]) -> dict[str, Any]:
"""Construct the agent response payload returned to callers."""
return {
"content": message.text,
"agent_response": metadata.get("agent_response"),
"message_count": metadata.get("message_count", self.message_count),
"timestamp": metadata.get("timestamp"),
"correlation_id": metadata.get("correlation_id"),
}