diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index bb7c4398fc..6109ea468f 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -16,11 +16,11 @@ import azure.functions as func from agent_framework import AgentProtocol, get_logger from ._callbacks import AgentResponseCallbackProtocol +from ._durable_agent_state import DurableAgentState from ._entities import create_agent_entity from ._errors import IncomingRequestError from ._models import AgentSessionId, RunRequest from ._orchestration import AgentOrchestrationContextType, DurableAIAgent -from ._state import AgentState logger = get_logger("agent_framework.azurefunctions") @@ -34,9 +34,6 @@ WAIT_FOR_RESPONSE_HEADER: str = "x-ms-wait-for-response" EntityHandler = Callable[[df.DurableEntityContext], None] HandlerT = TypeVar("HandlerT", bound=Callable[..., Any]) -DEFAULT_MAX_POLL_RETRIES: int = 30 -DEFAULT_POLL_INTERVAL_SECONDS: float = 1.0 - if TYPE_CHECKING: class DFAppBase: @@ -73,16 +70,17 @@ class AgentFunctionApp(DFAppBase): .. code-block:: python - from agent_framework.azure import AgentFunctionApp, AzureOpenAIChatClient + from agent_framework.azure import AgentFunctionApp + from agent_framework.azure import AzureOpenAIAssistantsClient # Create agents with unique names - weather_agent = AzureOpenAIChatClient(...).create_agent( + weather_agent = AzureOpenAIAssistantsClient(...).create_agent( name="WeatherAgent", instructions="You are a helpful weather agent.", tools=[get_weather], ) - math_agent = AzureOpenAIChatClient(...).create_agent( + math_agent = AzureOpenAIAssistantsClient(...).create_agent( name="MathAgent", instructions="You are a helpful math assistant.", tools=[calculate], @@ -130,23 +128,23 @@ class AgentFunctionApp(DFAppBase): http_auth_level: func.AuthLevel = func.AuthLevel.FUNCTION, enable_health_check: bool = True, enable_http_endpoints: bool = True, - max_poll_retries: int = DEFAULT_MAX_POLL_RETRIES, - poll_interval_seconds: float = DEFAULT_POLL_INTERVAL_SECONDS, + max_poll_retries: int = 30, + poll_interval_seconds: float = 1, default_callback: AgentResponseCallbackProtocol | None = None, ): """Initialize the AgentFunctionApp. - :param agents: List of agent instances to register. - :param http_auth_level: HTTP authentication level (default: ``func.AuthLevel.FUNCTION``). - :param enable_health_check: Enable the built-in health check endpoint (default: ``True``). - :param enable_http_endpoints: Enable HTTP endpoints for agents (default: ``True``). - :param max_poll_retries: Maximum polling attempts when waiting for a response. - Defaults to ``DEFAULT_MAX_POLL_RETRIES``. - :param poll_interval_seconds: Delay in seconds between polling attempts. - Defaults to ``DEFAULT_POLL_INTERVAL_SECONDS``. - :param default_callback: Optional callback invoked for agents without specific callbacks. + Args: + agents: List of agent instances to register + http_auth_level: HTTP authentication level (default: FUNCTION) + enable_health_check: Enable built-in health check endpoint (default: True) + enable_http_endpoints: Enable HTTP endpoints for agents (default: True) + max_poll_retries: Maximum number of polling attempts when waiting for a response + poll_interval_seconds: Delay (in seconds) between polling attempts + default_callback: Optional callback invoked for agents without specific callbacks - :note: If no agents are provided, they can be added later using :meth:`add_agent`. + Note: + If no agents are provided, they can be added later using add_agent(). """ logger.debug("[AgentFunctionApp] Initializing with Durable Entities...") @@ -163,14 +161,14 @@ class AgentFunctionApp(DFAppBase): try: retries = int(max_poll_retries) except (TypeError, ValueError): - retries = DEFAULT_MAX_POLL_RETRIES + retries = 10 self.max_poll_retries = max(1, retries) try: interval = float(poll_interval_seconds) except (TypeError, ValueError): - interval = DEFAULT_POLL_INTERVAL_SECONDS - self.poll_interval_seconds = interval if interval > 0 else DEFAULT_POLL_INTERVAL_SECONDS + interval = 0.5 + self.poll_interval_seconds = interval if interval > 0 else 0.5 if agents: # Register all provided agents @@ -339,10 +337,10 @@ class AgentFunctionApp(DFAppBase): ) session_id = self._create_session_id(agent_name, thread_id) - correlation_id = self._generate_unique_id() + correlationId = self._generate_unique_id() logger.debug(f"[HTTP Trigger] Using session ID: {session_id}") - logger.debug(f"[HTTP Trigger] Generated correlation ID: {correlation_id}") + logger.debug(f"[HTTP Trigger] Generated correlation ID: {correlationId}") logger.debug("[HTTP Trigger] Calling entity to run agent...") entity_instance_id = session_id.to_entity_id() @@ -350,7 +348,7 @@ class AgentFunctionApp(DFAppBase): req_body, message, thread_id, - correlation_id, + correlationId, ) logger.debug("Signalling entity %s with request: %s", entity_instance_id, run_request) await client.signal_entity(entity_instance_id, "run_agent", run_request) @@ -361,7 +359,7 @@ class AgentFunctionApp(DFAppBase): result = await self._get_response_from_entity( client=client, entity_instance_id=entity_instance_id, - correlation_id=correlation_id, + correlationId=correlationId, message=message, thread_id=thread_id, ) @@ -377,7 +375,7 @@ class AgentFunctionApp(DFAppBase): logger.debug("[HTTP Trigger] wait_for_response disabled; returning correlation ID") accepted_response = self._build_accepted_response( - message=message, thread_id=thread_id, correlation_id=correlation_id + message=message, thread_id=thread_id, correlationId=correlationId ) return self._create_http_response( @@ -491,7 +489,7 @@ class AgentFunctionApp(DFAppBase): self, client: df.DurableOrchestrationClient, entity_instance_id: df.EntityId, - ) -> AgentState | None: + ) -> DurableAgentState | None: state_response = await client.read_entity_state(entity_instance_id) if not state_response or not state_response.entity_exists: return None @@ -502,7 +500,7 @@ class AgentFunctionApp(DFAppBase): typed_state_payload = cast(dict[str, Any], state_payload) - agent_state = AgentState() + agent_state = DurableAgentState() agent_state.restore_state(typed_state_payload) return agent_state @@ -510,7 +508,7 @@ class AgentFunctionApp(DFAppBase): self, client: df.DurableOrchestrationClient, entity_instance_id: df.EntityId, - correlation_id: str, + correlationId: str, message: str, thread_id: str, ) -> dict[str, Any]: @@ -522,7 +520,7 @@ class AgentFunctionApp(DFAppBase): retry_count = 0 result: dict[str, Any] | None = None - logger.debug(f"[HTTP Trigger] Waiting for response with correlation ID: {correlation_id}") + logger.debug(f"[HTTP Trigger] Waiting for response with correlation ID: {correlationId}") while retry_count < max_retries: await asyncio.sleep(interval) @@ -530,7 +528,7 @@ class AgentFunctionApp(DFAppBase): result = await self._poll_entity_for_response( client=client, entity_instance_id=entity_instance_id, - correlation_id=correlation_id, + correlationId=correlationId, message=message, thread_id=thread_id, ) @@ -544,16 +542,16 @@ class AgentFunctionApp(DFAppBase): return result logger.warning( - f"[HTTP Trigger] Response with correlation ID {correlation_id} " + f"[HTTP Trigger] Response with correlation ID {correlationId} " f"not found in time (waited {max_retries * interval} seconds)" ) - return await self._build_timeout_result(message=message, thread_id=thread_id, correlation_id=correlation_id) + return await self._build_timeout_result(message=message, thread_id=thread_id, correlationId=correlationId) async def _poll_entity_for_response( self, client: df.DurableOrchestrationClient, entity_instance_id: df.EntityId, - correlation_id: str, + correlationId: str, message: str, thread_id: str, ) -> dict[str, Any] | None: @@ -564,34 +562,34 @@ class AgentFunctionApp(DFAppBase): if state is None: return None - agent_response = state.try_get_agent_response(correlation_id) + agent_response = state.try_get_agent_response(correlationId) if agent_response: result = self._build_success_result( response_data=agent_response, message=message, thread_id=thread_id, - correlation_id=correlation_id, + correlationId=correlationId, state=state, ) - logger.debug(f"[HTTP Trigger] Found response for correlation ID: {correlation_id}") + logger.debug(f"[HTTP Trigger] Found response for correlation ID: {correlationId}") except Exception as exc: logger.warning(f"[HTTP Trigger] Error reading entity state: {exc}") return result - async def _build_timeout_result(self, message: str, thread_id: str, correlation_id: str) -> dict[str, Any]: + async def _build_timeout_result(self, message: str, thread_id: str, correlationId: str) -> dict[str, Any]: """Create the timeout response.""" return { "response": "Agent is still processing or timed out...", "message": message, THREAD_ID_FIELD: thread_id, "status": "timeout", - "correlation_id": correlation_id, + "correlationId": correlationId, } def _build_success_result( - self, response_data: dict[str, Any], message: str, thread_id: str, correlation_id: str, state: AgentState + self, response_data: dict[str, Any], message: str, thread_id: str, correlationId: str, state: DurableAgentState ) -> dict[str, Any]: """Build the success result returned to the HTTP caller.""" return { @@ -600,11 +598,11 @@ class AgentFunctionApp(DFAppBase): THREAD_ID_FIELD: thread_id, "status": "success", "message_count": response_data.get("message_count", state.message_count), - "correlation_id": correlation_id, + "correlationId": correlationId, } def _build_request_data( - self, req_body: dict[str, Any], message: str, thread_id: str, correlation_id: str + self, req_body: dict[str, Any], message: str, thread_id: str, correlationId: str ) -> dict[str, Any]: """Create the durable entity request payload.""" enable_tool_calls_value = req_body.get("enable_tool_calls") @@ -616,17 +614,17 @@ class AgentFunctionApp(DFAppBase): response_format=req_body.get("response_format"), enable_tool_calls=enable_tool_calls, thread_id=thread_id, - correlation_id=correlation_id, + correlationId=correlationId, ).to_dict() - def _build_accepted_response(self, message: str, thread_id: str, correlation_id: str) -> dict[str, Any]: + def _build_accepted_response(self, message: str, thread_id: str, correlationId: str) -> dict[str, Any]: """Build the response returned when not waiting for completion.""" return { "response": "Agent request accepted", "message": message, THREAD_ID_FIELD: thread_id, "status": "accepted", - "correlation_id": correlation_id, + "correlationId": correlationId, } def _create_http_response( @@ -712,8 +710,7 @@ class AgentFunctionApp(DFAppBase): headers: dict[str, str] = {} raw_headers = req.headers if isinstance(raw_headers, Mapping): - header_mapping: Mapping[str, Any] = cast(Mapping[str, Any], raw_headers) - for key, value in header_mapping.items(): + for key, value in raw_headers.items(): if value is not None: headers[str(key).lower()] = str(value) return headers @@ -774,8 +771,13 @@ class AgentFunctionApp(DFAppBase): def _should_wait_for_response(self, req: func.HttpRequest, req_body: dict[str, Any]) -> bool: """Determine whether the caller requested to wait for the response.""" - headers: dict[str, str] = self._extract_normalized_headers(req) - header_value: str | None = headers.get(WAIT_FOR_RESPONSE_HEADER) + header_value = None + raw_headers = req.headers + if isinstance(raw_headers, Mapping): + for key, value in raw_headers.items(): + if str(key).lower() == WAIT_FOR_RESPONSE_HEADER: + header_value = value + break if header_value is not None: return self._coerce_to_bool(header_value) @@ -799,4 +801,4 @@ class AgentFunctionApp(DFAppBase): return bool(value) if isinstance(value, str): return value.strip().lower() in {"true", "1", "yes", "y", "on"} - return False + return False \ No newline at end of file diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_callbacks.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_callbacks.py index 31b31111ac..8c57645e7a 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_callbacks.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_callbacks.py @@ -19,7 +19,7 @@ class AgentCallbackContext: """Context supplied to callback invocations.""" agent_name: str - correlation_id: str + correlationId: str thread_id: str | None = None request_message: str | None = None diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py new file mode 100644 index 0000000000..2204d575d7 --- /dev/null +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py @@ -0,0 +1,821 @@ +# Copyright (c) Microsoft. All rights reserved. +from __future__ import annotations + +import json + +from typing import Any, List, Dict, Optional, cast +from datetime import datetime, timezone + +# Base content type + +class DurableAgentStateContent: + extensionData: Optional[Dict] + + def to_ai_content(self): + raise NotImplementedError + + @staticmethod + def from_ai_content(content): + # Map AI content type to appropriate DurableAgentStateContent subclass + from agent_framework import ( + DataContent, ErrorContent, FunctionCallContent, FunctionResultContent, + HostedFileContent, HostedVectorStoreContent, TextContent, + TextReasoningContent, UriContent, UsageContent + ) + + if isinstance(content, DataContent): + return DurableAgentStateDataContent.from_data_content(content) + elif isinstance(content, ErrorContent): + return DurableAgentStateErrorContent.from_error_content(content) + elif isinstance(content, FunctionCallContent): + return DurableAgentStateFunctionCallContent.from_function_call_content(content) + elif isinstance(content, FunctionResultContent): + return DurableAgentStateFunctionResultContent.from_function_result_content(content) + elif isinstance(content, HostedFileContent): + return DurableAgentStateHostedFileContent.from_hosted_file_content(content) + elif isinstance(content, HostedVectorStoreContent): + return DurableAgentStateHostedVectorStoreContent.from_hosted_vector_store_content(content) + elif isinstance(content, TextContent): + return DurableAgentStateTextContent.from_text_content(content) + elif isinstance(content, TextReasoningContent): + return DurableAgentStateTextReasoningContent.from_text_reasoning_content(content) + elif isinstance(content, UriContent): + return DurableAgentStateUriContent.from_uri_content(content) + elif isinstance(content, UsageContent): + return DurableAgentStateUsageContent.from_usage_content(content) + else: + return DurableAgentStateUnknownContent.from_unknown_content(content) + +# Core state classes + +class DurableAgentStateData: + conversationHistory: List['DurableAgentStateEntry'] + extensionData: Optional[Dict] + + def __init__(self, conversationHistory=None, extensionData=None): + self.conversationHistory = conversationHistory or [] + self.extensionData = extensionData + + +class DurableAgentState: + data: DurableAgentStateData + schema_version: str = "1.0.0" + + def __init__(self, schema_version: str = "1.0.0"): + self.data = DurableAgentStateData() + self.schema_version = schema_version + + def to_dict(self) -> Dict[str, Any]: + # Serialize conversationHistory + serialized_history = [] + for entry in self.data.conversationHistory: + # Properly serialize each entry to a dictionary + if hasattr(entry, 'to_dict'): + serialized_history.append(entry.to_dict()) + else: + # Fallback for already-serialized entries + serialized_history.append(entry) + + return { + "schemaVersion": self.schema_version, + "data": { + "conversationHistory": serialized_history, + "extensionData": self.data.extensionData + }, + "message_count": self.message_count, + "last_response": self.last_response, + } + + def to_json(self) -> str: + return json.dumps(self.to_dict()) + + @classmethod + def from_dict(cls, obj: Dict[str, Any]) -> "DurableAgentState": + schema_version = obj.get("schemaVersion") + if not schema_version: + raise ValueError("The durable agent state is missing the 'schemaVersion' property.") + + if not schema_version.startswith("1."): + raise ValueError(f"The durable agent state schema version '{schema_version}' is not supported.") + + data_dict = obj.get("data") + if data_dict is None: + raise ValueError("The durable agent state is missing the 'data' property.") + + instance = cls(schema_version=schema_version) + # Deserialize the data dict into DurableAgentStateData + if isinstance(data_dict, dict): + instance.data = DurableAgentStateData( + conversationHistory=data_dict.get("conversationHistory", []), + extensionData=data_dict.get("extensionData") + ) + return instance + + @classmethod + def from_json(cls, json_str: str) -> "DurableAgentState": + try: + obj = json.loads(json_str) + except json.JSONDecodeError as e: + raise ValueError("The durable agent state is not valid JSON.") from e + + return cls.from_dict(obj) + + def restore_state(self, state: dict[str, Any]) -> None: + """Restore state from a dictionary. + + Args: + state: Dictionary containing schemaVersion and data (full state structure) + """ + # Extract the data portion from the state + data_dict = state.get("data", {}) + + # Restore the conversation history - deserialize entries from dicts to objects + history_data = data_dict.get("conversationHistory", []) + deserialized_history = [] + for entry_dict in history_data: + if isinstance(entry_dict, dict): + # Deserialize based on whether it's a request or response + if "usage" in entry_dict: + deserialized_history.append(DurableAgentStateResponse.from_dict(entry_dict)) + elif "response_type" in entry_dict: + deserialized_history.append(DurableAgentStateRequest.from_dict(entry_dict)) + else: + deserialized_history.append(DurableAgentStateEntry.from_dict(entry_dict)) + else: + # Already an object + deserialized_history.append(entry_dict) + + self.data.conversationHistory = deserialized_history + self.data.extensionData = data_dict.get("extensionData") + + @property + def message_count(self) -> int: + """Get the count of conversation entries (requests + responses).""" + return len(self.data.conversationHistory) + + @property + def last_response(self) -> str | None: + """Get the text from the last assistant response in the conversation history.""" + # Iterate through messages in reverse to find the last assistant message + for entry in reversed(self.data.conversationHistory): + for message in reversed(entry.messages): + if message.role == "assistant": + return message.text + return None + + def add_assistant_message(self, content: str, agent_run_response, correlationId: str) -> None: + """Add an assistant message to the conversation history. + + Args: + content: The message content + agent_run_response: The agent's run response + correlationId: The correlation ID for this response + """ + # This method is called from the entity after storing the response + # The response has already been added to conversationHistory, so we don't need to do anything here + pass + + def try_get_agent_response(self, correlationId: str) -> Dict[str, Any] | None: + """Try to get an agent response by correlation ID. + + Args: + correlationId: The correlation ID to search for + + Returns: + Response data dict if found, None otherwise + """ + # Search through conversation history for a response with this correlationId + for entry in self.data.conversationHistory: + if hasattr(entry, 'correlationId') and entry.correlationId == correlationId: + # Found the entry, extract response data + if isinstance(entry, DurableAgentStateResponse): + # Get the text content from assistant messages only + content = "" + for message in entry.messages: + if hasattr(message, 'role') and message.role == "assistant" and hasattr(message, 'text'): + content += message.text + + return { + "content": content, + "message_count": self.message_count, + "correlationId": correlationId + } + return None + +# Entry classes + +class DurableAgentStateEntry: + json_type: str + correlationId: str + created_at: datetime + messages: List['DurableAgentStateMessage'] + extensionData: Optional[Dict] + + # Request-only + responseType: Optional[str] = None + responseSchema: Optional[dict] = None + + # Response-only + usage: Optional["DurableAgentStateUsage"] = None + + + def __init__(self, json_type, correlationId, created_at, messages, extensionData=None, responseType=None, responseSchema=None, usage=None): + self.json_type = json_type + self.correlationId = correlationId + self.created_at = created_at + self.messages = messages + self.extensionData = extensionData + self.responseType = responseType + self.responseSchema = responseSchema + self.usage = usage + + def to_dict(self) -> Dict[str, Any]: + data = { + "$type": self.json_type, + "correlationId": self.correlationId, + "createdAt": self.created_at.isoformat() if isinstance(self.created_at, datetime) else self.created_at, + "messages": [m.to_dict() if hasattr(m, 'to_dict') else m for m in self.messages], + "extensionData": self.extensionData + } + if self.json_type == "request": + data.update({ + "responseType": self.responseType, + "responseSchema": self.responseSchema, + }) + elif self.json_type == "response": + if self.usage: + data["usage"] = self.usage.to_dict() + return data + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'DurableAgentStateEntry': + from dateutil import parser as date_parser + created_at = data.get("created_at") + if isinstance(created_at, str): + created_at = date_parser.parse(created_at) + + messages = [] + for msg_dict in data.get("messages", []): + if isinstance(msg_dict, dict): + messages.append(DurableAgentStateMessage.from_dict(msg_dict)) + else: + messages.append(msg_dict) + + return cls( + correlationId=data.get("correlationId"), + created_at=created_at, + messages=messages, + extensionData=data.get("extensionData") + ) + + +class DurableAgentStateRequest(DurableAgentStateEntry): + response_type: Optional[str] = None + response_schema: Optional[Dict] = None + + def __init__(self, correlationId, created_at, messages, json_type, extensionData=None, response_type=None, response_schema=None): + self.correlationId = correlationId + self.created_at = created_at + self.messages = messages + self.json_type = json_type + self.extensionData = extensionData + self.response_type = response_type + self.response_schema = response_schema + + def to_dict(self) -> Dict[str, Any]: + base_dict = super().to_dict() + base_dict["response_type"] = self.response_type + base_dict["response_schema"] = self.response_schema + return base_dict + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'DurableAgentStateRequest': + from dateutil import parser as date_parser + created_at = data.get("created_at") + if isinstance(created_at, str): + created_at = date_parser.parse(created_at) + + messages = [] + for msg_dict in data.get("messages", []): + if isinstance(msg_dict, dict): + messages.append(DurableAgentStateMessage.from_dict(msg_dict)) + else: + messages.append(msg_dict) + + return cls( + json_type=data.get("$type", "request"), + correlationId=data.get("correlationId"), + created_at=created_at, + messages=messages, + extensionData=data.get("extensionData"), + response_type=data.get("response_type"), + response_schema=data.get("response_schema") + ) + + @staticmethod + def from_run_request(content): + from agent_framework import TextContent + return DurableAgentStateRequest(correlationId=content.correlationId, + messages=[DurableAgentStateMessage.from_chat_message(content)], + created_at=datetime.now(tz=timezone.utc), + json_type="request", + extensionData=content.extensionData if hasattr(content, 'extensionData') else None, + response_type="text" if isinstance(content.response_format, TextContent) else "json", + response_schema=content.response_format) + + +class DurableAgentStateResponse(DurableAgentStateEntry): + usage: Optional['DurableAgentStateUsage'] = None + + def __init__(self, json_type, correlationId, created_at, messages, extensionData=None, usage=None): + self.json_type = json_type + self.correlationId = correlationId + self.created_at = created_at + self.messages = messages + self.extensionData = extensionData + self.usage = usage + + def to_dict(self) -> Dict[str, Any]: + base_dict = super().to_dict() + base_dict["usage"] = self.usage.to_dict() if self.usage and hasattr(self.usage, 'to_dict') else self.usage + return base_dict + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'DurableAgentStateResponse': + from dateutil import parser as date_parser + created_at = data.get("created_at") + if isinstance(created_at, str): + created_at = date_parser.parse(created_at) + + messages = [] + for msg_dict in data.get("messages", []): + if isinstance(msg_dict, dict): + messages.append(DurableAgentStateMessage.from_dict(msg_dict)) + else: + messages.append(msg_dict) + + usage_dict = data.get("usage") + usage = None + if usage_dict and isinstance(usage_dict, dict): + usage = DurableAgentStateUsage.from_dict(usage_dict) + elif usage_dict: + usage = usage_dict + + return cls( + json_type=data.get("$type", "response"), + correlationId=data.get("correlationId"), + created_at=created_at, + messages=messages, + extensionData=data.get("extensionData"), + usage=usage + ) + + @staticmethod + def from_run_response(correlationId: str, response) -> DurableAgentStateResponse: + """ + Creates a DurableAgentStateResponse from an AgentRunResponse. + """ + # Determine the earliest created_at timestamp among messages (if available) + timestamps = [m.created_at for m in response.messages if hasattr(m, 'created_at') and m.created_at is not None] + created_at = min(timestamps) if timestamps else datetime.now(tz=timezone.utc) + + return DurableAgentStateResponse( + json_type="response", + correlationId=correlationId, + created_at=created_at, + messages=[DurableAgentStateMessage.from_chat_message(m) for m in response.messages], + usage=DurableAgentStateUsage.from_usage(response.usage) if hasattr(response, 'usage') and response.usage else None + ) + + def to_run_response(self): + """ + Converts this DurableAgentStateResponse back to an AgentRunResponse. + """ + from agent_framework import AgentRunResponse + + return AgentRunResponse( + created_at=self.created_at, + messages=[m.to_chat_message() for m in self.messages], + usage=self.usage.to_usage_details() if self.usage else None + ) + +# Message class + +class DurableAgentStateMessage: + role: str + contents: List[DurableAgentStateContent] + author_name: Optional[str] = None + created_at: Optional[datetime] = None + extensionData: Optional[Dict] = None + + def __init__(self, role, contents, author_name=None, created_at=None, extensionData=None): + self.role = role + self.contents = contents + self.author_name = author_name + self.created_at = created_at + self.extensionData = extensionData + + def to_dict(self) -> Dict[str, Any]: + return { + "role": self.role, + "contents": [ + {"$type": c.to_dict().get("type", "text"), **{k: v for k, v in c.to_dict().items() if k != "type"}} for c in self.contents + ], + "authorName": self.author_name, + "createdAt": self.created_at.isoformat() if isinstance(self.created_at, datetime) else self.created_at, + "extensionData": self.extensionData + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'DurableAgentStateMessage': + from dateutil import parser as date_parser + created_at = data.get("created_at") + if created_at and isinstance(created_at, str): + created_at = date_parser.parse(created_at) + + contents = [] + for content_dict in data.get("contents", []): + if isinstance(content_dict, dict): + content_type = content_dict.get("type") + if content_type == "text": + contents.append(DurableAgentStateTextContent(text=content_dict.get("text"))) + elif content_type == "data": + contents.append(DurableAgentStateDataContent(uri=content_dict.get("uri"), media_type=content_dict.get("media_type"))) + elif content_type == "error": + contents.append(DurableAgentStateErrorContent(message=content_dict.get("message"), error_code=content_dict.get("error_code"), details=content_dict.get("details"))) + elif content_type == "function_call": + contents.append(DurableAgentStateFunctionCallContent(call_id=content_dict.get("call_id"), name=content_dict.get("name"), arguments=content_dict.get("arguments"))) + elif content_type == "function_result": + contents.append(DurableAgentStateFunctionResultContent(call_id=content_dict.get("call_id"), result=content_dict.get("result"))) + elif content_type == "hosted_file": + contents.append(DurableAgentStateHostedFileContent(file_id=content_dict.get("file_id"))) + elif content_type == "hosted_vector_store": + contents.append(DurableAgentStateHostedVectorStoreContent(vector_store_id=content_dict.get("vector_store_id"))) + elif content_type == "text_reasoning": + contents.append(DurableAgentStateTextReasoningContent(text=content_dict.get("text"))) + elif content_type == "uri": + contents.append(DurableAgentStateUriContent(uri=content_dict.get("uri"), media_type=content_dict.get("media_type"))) + elif content_type == "usage": + usage_data = content_dict.get("usage") + if usage_data and isinstance(usage_data, dict): + contents.append(DurableAgentStateUsageContent(usage=DurableAgentStateUsage.from_dict(usage_data))) + elif content_type == "unknown": + contents.append(DurableAgentStateUnknownContent(content=content_dict.get("content"))) + else: + contents.append(content_dict) + + return cls( + role=data.get("role"), + contents=contents, + author_name=data.get("author_name"), + created_at=created_at, + extensionData=data.get("extensionData") + ) + + @property + def text(self) -> str: + """Extract text from the contents list.""" + text_parts = [] + for content in self.contents: + if isinstance(content, DurableAgentStateTextContent): + text_parts.append(content.text or "") + return "".join(text_parts) + + @staticmethod + def from_chat_message(content): + # Convert to a list of DurableAgentStateContent objects + contents_list = [] + + if hasattr(content, 'message') and isinstance(content.message, str): + # RunRequest with 'message' attribute + contents_list = [DurableAgentStateTextContent(text=content.message)] + elif hasattr(content, 'contents') and content.contents: + # ChatMessage with 'contents' attribute - convert each content object + for c in content.contents: + converted = DurableAgentStateContent.from_ai_content(c) + contents_list.append(converted) + + # Convert role enum to string if needed + role_value = content.role.value if hasattr(content.role, 'value') else str(content.role) + + return DurableAgentStateMessage( + role=role_value, + contents=contents_list, + author_name=content.author_name if hasattr(content, 'author_name') else None, + created_at=content.created_at if hasattr(content, 'created_at') else None, + extensionData=content.extensionData if hasattr(content, 'extensionData') else None + ) + + def to_chat_message(self): + from agent_framework import ChatMessage + # Convert DurableAgentStateContent objects back to agent_framework content objects + ai_contents = [c.to_ai_content() for c in self.contents] + return ChatMessage(role=self.role, contents=ai_contents, author_name=self.author_name, created_at=self.created_at, extensionData=self.extensionData) + +# Content subclasses + +class DurableAgentStateDataContent(DurableAgentStateContent): + uri: str = "" + media_type: Optional[str] = None + + def __init__(self, uri, media_type=None): + self.uri = uri + self.media_type = media_type + + def to_dict(self) -> Dict[str, Any]: + return { + "type": "data", + "uri": self.uri, + "mediaType": self.media_type + } + + @staticmethod + def from_data_content(content): + return DurableAgentStateDataContent(uri=content.uri, media_type=content.media_type) + + def to_ai_content(self): + from agent_framework import DataContent + return DataContent(uri=self.uri, media_type=self.media_type) + + +class DurableAgentStateErrorContent(DurableAgentStateContent): + message: Optional[str] = None + error_code: Optional[str] = None + details: Optional[str] = None + + def __init__(self, message=None, error_code=None, details=None): + self.message = message + self.error_code = error_code + self.details = details + + def to_dict(self) -> Dict[str, Any]: + return { + "type": "error", + "message": self.message, + "errorCode": self.error_code, + "details": self.details + } + + @staticmethod + def from_error_content(content): + return DurableAgentStateErrorContent(message=content.message, error_code=content.error_code, details=content.details) + + def to_ai_content(self): + from agent_framework import ErrorContent + return ErrorContent(message=self.message, error_code=self.error_code, details=self.details) + + +class DurableAgentStateFunctionCallContent(DurableAgentStateContent): + call_id: str + name: str + arguments: Dict[str, object] + + def __init__(self, call_id, name, arguments): + self.call_id = call_id + self.name = name + self.arguments = arguments + + def to_dict(self) -> Dict[str, Any]: + return { + "type": "function_call", + "callId": self.call_id, + "name": self.name, + "arguments": self.arguments + } + + @staticmethod + def from_function_call_content(content): + return DurableAgentStateFunctionCallContent( + call_id=content.call_id, + name=content.name, + arguments=content.arguments if content.arguments else {} + ) + + def to_ai_content(self): + from agent_framework import FunctionCallContent + return FunctionCallContent(call_id=self.call_id, name=self.name, arguments=self.arguments) + + +class DurableAgentStateFunctionResultContent(DurableAgentStateContent): + call_id: str + result: Optional[object] = None + + def __init__(self, call_id, result=None): + self.call_id = call_id + self.result = result + + def to_dict(self) -> Dict[str, Any]: + return { + "type": "function_result", + "callId": self.call_id, + "result": self.result + } + + @staticmethod + def from_function_result_content(content): + return DurableAgentStateFunctionResultContent(call_id=content.call_id, result=content.result) + + def to_ai_content(self): + from agent_framework import FunctionResultContent + return FunctionResultContent(call_id=self.call_id, result=self.result) + + +class DurableAgentStateHostedFileContent(DurableAgentStateContent): + file_id: str + + def __init__(self, file_id): + self.file_id = file_id + + def to_dict(self) -> Dict[str, Any]: + return { + "type": "hosted_file", + "fileId": self.file_id + } + + @staticmethod + def from_hosted_file_content(content): + return DurableAgentStateHostedFileContent(file_id=content.file_id) + + def to_ai_content(self): + from agent_framework import HostedFileContent + return HostedFileContent(file_id=self.file_id) + + +class DurableAgentStateHostedVectorStoreContent(DurableAgentStateContent): + vector_store_id: str + + def __init__(self, vector_store_id): + self.vector_store_id = vector_store_id + + def to_dict(self) -> Dict[str, Any]: + return { + "type": "hosted_vector_store", + "vectorStoreId": self.vector_store_id + } + + @staticmethod + def from_hosted_vector_store_content(content): + return DurableAgentStateHostedVectorStoreContent(vector_store_id=content.vector_store_id) + + def to_ai_content(self): + from agent_framework import HostedVectorStoreContent + return HostedVectorStoreContent(vector_store_id=self.vector_store_id) + + +class DurableAgentStateTextContent(DurableAgentStateContent): + text: Optional[str] = None + + def __init__(self, text): + self.text = text + + def to_dict(self) -> Dict[str, Any]: + return { + "type": "text", + "text": self.text + } + + @staticmethod + def from_text_content(content): + return DurableAgentStateTextContent(text=content.text) + + def to_ai_content(self): + from agent_framework import TextContent + return TextContent(text=self.text) + + +class DurableAgentStateTextReasoningContent(DurableAgentStateContent): + text: Optional[str] = None + + def __init__(self, text): + self.text = text + + def to_dict(self) -> Dict[str, Any]: + return { + "type": "text_reasoning", + "text": self.text + } + + @staticmethod + def from_text_reasoning_content(content): + return DurableAgentStateTextReasoningContent(text=content.text) + + def to_ai_content(self): + from agent_framework import TextReasoningContent + return TextReasoningContent(text=self.text) + + +class DurableAgentStateUriContent(DurableAgentStateContent): + uri: str + media_type: str + + def __init__(self, uri, media_type): + self.uri = uri + self.media_type = media_type + + def to_dict(self) -> Dict[str, Any]: + return { + "type": "uri", + "uri": self.uri, + "mediaType": self.media_type + } + + @staticmethod + def from_uri_content(content): + return DurableAgentStateUriContent(uri=content.uri, media_type=content.media_type) + + def to_ai_content(self): + from agent_framework import UriContent + return UriContent(uri=self.uri, media_type=self.media_type) + + +class DurableAgentStateUsage: + input_token_count: Optional[int] = None + output_token_count: Optional[int] = None + total_token_count: Optional[int] = None + extensionData: Optional[Dict] = None + + def __init__(self, input_token_count=None, output_token_count=None, total_token_count=None, extensionData=None): + self.input_token_count = input_token_count + self.output_token_count = output_token_count + self.total_token_count = total_token_count + self.extensionData = extensionData + + def to_dict(self) -> Dict[str, Any]: + return { + "inputTokenCount": self.input_token_count, + "outputTokenCount": self.output_token_count, + "totalTokenCount": self.total_token_count, + "extensionData": self.extensionData + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'DurableAgentStateUsage': + return cls( + input_token_count=data.get("input_token_count"), + output_token_count=data.get("output_token_count"), + total_token_count=data.get("total_token_count"), + extensionData=data.get("extensionData") + ) + + @staticmethod + def from_usage(usage): + if usage is None: + return None + return DurableAgentStateUsage( + input_token_count=usage.input_token_count, + output_token_count=usage.output_token_count, + total_token_count=usage.total_token_count + ) + + def to_usage_details(self): + # Convert back to AI SDK UsageDetails + from agent_framework import UsageDetails + return UsageDetails( + input_token_count=self.input_token_count, + output_token_count=self.output_token_count, + total_token_count=self.total_token_count + ) + + +class DurableAgentStateUsageContent(DurableAgentStateContent): + usage: DurableAgentStateUsage = DurableAgentStateUsage() + + def __init__(self, usage): + self.usage = usage + + def to_dict(self) -> Dict[str, Any]: + return { + "type": "usage", + "usage": self.usage.to_dict() if hasattr(self.usage, 'to_dict') else self.usage + } + + @staticmethod + def from_usage_content(content): + return DurableAgentStateUsageContent(usage=DurableAgentStateUsage.from_usage(content.details)) + + def to_ai_content(self): + from agent_framework import UsageContent + return UsageContent(details=self.usage.to_usage_details()) + + +class DurableAgentStateUnknownContent(DurableAgentStateContent): + content: dict + + def __init__(self, content): + self.content = content + + def to_dict(self) -> Dict[str, Any]: + return { + "type": "unknown", + "content": self.content + } + + @staticmethod + def from_unknown_content(content): + return DurableAgentStateUnknownContent(content=json.loads(content)) + + def to_ai_content(self): + from agent_framework import BaseContent + if not self.content: + raise Exception(f"The content is missing and cannot be converted to valid AI content.") + return BaseContent(content=json.loads(self.content)) \ No newline at end of file diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py index 8df8e3f335..4cdc292972 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py @@ -10,15 +10,21 @@ allows for long-running agent conversations. import asyncio import inspect import json -from collections.abc import AsyncIterable, Callable -from typing import Any, cast +from collections.abc import AsyncIterable +from datetime import datetime, timezone +from typing import Any, cast, Callable import azure.durable_functions as df from agent_framework import AgentProtocol, AgentRunResponse, AgentRunResponseUpdate, Role, get_logger from ._callbacks import AgentCallbackContext, AgentResponseCallbackProtocol +from ._durable_agent_state import ( + DurableAgentState, + DurableAgentStateData, + DurableAgentStateRequest, + DurableAgentStateResponse, +) from ._models import AgentResponse, RunRequest -from ._state import AgentState logger = get_logger("agent_framework.azurefunctions.entities") @@ -38,11 +44,11 @@ class AgentEntity: Attributes: agent: The AgentProtocol instance - state: The AgentState managing conversation history + state: The DurableAgentState managing conversation history """ agent: AgentProtocol - state: AgentState + state: DurableAgentState def __init__( self, @@ -56,8 +62,9 @@ class AgentEntity: callback: Optional callback invoked during streaming updates and final responses """ self.agent = agent - self.state = AgentState() + self.state = DurableAgentState() self.callback = callback + self._pending_requests: dict[str, DurableAgentStateRequest] = {} logger.debug(f"[AgentEntity] Initialized with agent type: {type(agent).__name__}") @@ -89,31 +96,49 @@ class AgentEntity: message = run_request.message thread_id = run_request.thread_id - correlation_id = run_request.correlation_id + correlationId = run_request.correlationId if not thread_id: raise ValueError("RunRequest must include a thread_id") - if not correlation_id: - raise ValueError("RunRequest must include a correlation_id") + if not correlationId: + raise ValueError("RunRequest must include a correlationId") role = run_request.role or Role.USER response_format = run_request.response_format enable_tool_calls = run_request.enable_tool_calls + # Store request in pending (will be combined with response later) + state_request = DurableAgentStateRequest.from_run_request(run_request) + self.state.data.conversationHistory.append(state_request) + self._pending_requests[correlationId] = state_request + logger.debug(f"[AgentEntity.run_agent] Received message: {message}") logger.debug(f"[AgentEntity.run_agent] Thread ID: {thread_id}") - logger.debug(f"[AgentEntity.run_agent] Correlation ID: {correlation_id}") + logger.debug(f"[AgentEntity.run_agent] Correlation ID: {correlationId}") logger.debug(f"[AgentEntity.run_agent] Role: {role.value}") logger.debug(f"[AgentEntity.run_agent] Enable tool calls: {enable_tool_calls}") logger.debug(f"[AgentEntity.run_agent] Response format: {'provided' if response_format else 'none'}") - - # Store message in history with role - self.state.add_user_message(message, role=role, correlation_id=correlation_id) - + logger.debug(f"[AgentEntity.run_agent] Saved state request: {state_request}") logger.debug("[AgentEntity.run_agent] Executing agent...") try: logger.debug("[AgentEntity.run_agent] Starting agent invocation") - run_kwargs: dict[str, Any] = {"messages": self.state.get_chat_messages()} + # Build messages from conversation history plus the current request + chat_messages = [ + m.to_chat_message() + for entry in self.state.data.conversationHistory + for m in entry.messages + ] + # Add the current request message + # for m in state_request.messages: + # chat_messages.append(m.to_chat_message()) + + # Strip additional_properties from all messages to avoid metadata being sent to Azure OpenAI + # Azure OpenAI doesn't support the 'metadata' field in messages + for msg in chat_messages: + if hasattr(msg, 'additional_properties'): + msg.additional_properties = {} + + run_kwargs: dict[str, Any] = {"messages": chat_messages} if not enable_tool_calls: run_kwargs["tools"] = None if response_format: @@ -121,7 +146,7 @@ class AgentEntity: agent_run_response: AgentRunResponse = await self._invoke_agent( run_kwargs=run_kwargs, - correlation_id=correlation_id, + correlationId=correlationId, thread_id=thread_id, request_message=message, ) @@ -131,6 +156,17 @@ class AgentEntity: type(agent_run_response).__name__, ) + # Convert response into DurableAgentStateResponse and combine with request + state_response = DurableAgentStateResponse.from_run_response(correlationId, agent_run_response) + + # Get the pending request and combine its messages with the response messages + # pending_request = self._pending_requests.pop(correlationId, None) + # if pending_request: + # # Combine request and response messages into a single entry + # state_response.messages = pending_request.messages + state_response.messages + + self.state.data.conversationHistory.append(state_response) + response_text = None structured_response = None @@ -161,13 +197,13 @@ class AgentEntity: message=str(message), thread_id=str(thread_id), status="success", - message_count=self.state.message_count, + message_count=len(self.state.data.conversationHistory), structured_response=structured_response, ) result = agent_response.to_dict() content = json.dumps(structured_response) if structured_response else (response_text or "") - self.state.add_assistant_message(content, agent_run_response, correlation_id) + self.state.add_assistant_message(content, agent_run_response, correlationId) logger.debug("[AgentEntity.run_agent] AgentRunResponse stored in conversation history") return result @@ -181,12 +217,39 @@ class AgentEntity: logger.error(f"Error type: {type(exc).__name__}") logger.error(f"Full traceback:\n{error_traceback}") + # Create error response and store it in conversation history so polling can find it + from agent_framework import ChatMessage, ErrorContent + + # Get the pending request + pending_request = self._pending_requests.pop(correlationId, None) + + # Create error message + error_message = DurableAgentStateMessage.from_chat_message( + ChatMessage(role="assistant", contents=[ErrorContent(message=str(exc), error_code=type(exc).__name__)]) + ) + + # Combine request and error response messages + messages = [] + if pending_request: + messages.extend(pending_request.messages) + messages.append(error_message) + + # Create and store error response in conversation history + error_state_response = DurableAgentStateResponse( + correlationId=correlationId, + createdAt=datetime.now(tz=timezone.utc), + messages=messages, + extensionData=None, + usage=None + ) + self.state.data.conversationHistory.append(error_state_response) + error_response = AgentResponse( response=f"Error: {exc!s}", message=str(message), thread_id=str(thread_id), status="error", - message_count=self.state.message_count, + message_count=len(self.state.data.conversationHistory), error=str(exc), error_type=type(exc).__name__, ) @@ -195,7 +258,7 @@ class AgentEntity: async def _invoke_agent( self, run_kwargs: dict[str, Any], - correlation_id: str, + correlationId: str, thread_id: str, request_message: str, ) -> AgentRunResponse: @@ -203,7 +266,7 @@ class AgentEntity: callback_context: AgentCallbackContext | None = None if self.callback is not None: callback_context = self._build_callback_context( - correlation_id=correlation_id, + correlationId=correlationId, thread_id=thread_id, request_message=request_message, ) @@ -317,7 +380,7 @@ class AgentEntity: def _build_callback_context( self, - correlation_id: str, + correlationId: str, thread_id: str, request_message: str, ) -> AgentCallbackContext: @@ -325,7 +388,7 @@ class AgentEntity: agent_name = getattr(self.agent, "name", None) or type(self.agent).__name__ return AgentCallbackContext( agent_name=agent_name, - correlation_id=correlation_id, + correlationId=correlationId, thread_id=thread_id, request_message=request_message, ) @@ -333,7 +396,7 @@ class AgentEntity: def reset(self, context: df.DurableEntityContext) -> None: """Reset the entity state (clear conversation history).""" logger.debug("[AgentEntity.reset] Resetting entity state") - self.state.reset() + self.state.data = DurableAgentStateData(conversationHistory=[]) logger.debug("[AgentEntity.reset] State reset complete") @@ -392,8 +455,9 @@ def create_agent_entity( logger.error("[entity_function] Unknown operation: %s", operation) context.set_result({"error": f"Unknown operation: {operation}"}) + logger.info("State dict: %s", str(entity.state.to_dict())) context.set_state(entity.state.to_dict()) - logger.debug(f"[entity_function] Operation {operation} completed successfully") + logger.info(f"[entity_function] Operation {operation} completed successfully") except Exception as exc: import traceback @@ -424,4 +488,4 @@ def create_agent_entity( logger.error("[entity_function] Unexpected error executing entity: %s", exc, exc_info=True) context.set_result({"error": str(exc), "status": "error"}) - return entity_function + return entity_function \ No newline at end of file diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py index 015ca40754..45d7084aa9 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_models.py @@ -282,7 +282,7 @@ class RunRequest: response_format: Optional Pydantic BaseModel type describing the structured response format enable_tool_calls: Whether to enable tool calls for this request thread_id: Optional thread ID for tracking - correlation_id: Optional correlation ID for tracking the response to this specific request + correlationId: Optional correlation ID for tracking the response to this specific request """ message: str @@ -290,7 +290,10 @@ class RunRequest: response_format: type[BaseModel] | None = None enable_tool_calls: bool = True thread_id: str | None = None - correlation_id: str | None = None + correlationId: str | None = None + author_name: str | None = None + created_at: str | None = None + extension_data: dict[str, Any] | None = None def __init__( self, @@ -299,14 +302,14 @@ class RunRequest: response_format: type[BaseModel] | None = None, enable_tool_calls: bool = True, thread_id: str | None = None, - correlation_id: str | None = None, + correlationId: str | None = None, ) -> None: self.message = message self.role = self.coerce_role(role) self.response_format = response_format self.enable_tool_calls = enable_tool_calls self.thread_id = thread_id - self.correlation_id = correlation_id + self.correlationId = correlationId @staticmethod def coerce_role(value: Role | str | None) -> Role: @@ -331,8 +334,14 @@ class RunRequest: result["response_format"] = _serialize_response_format(self.response_format) if self.thread_id: result["thread_id"] = self.thread_id - if self.correlation_id: - result["correlation_id"] = self.correlation_id + if self.correlationId: + result["correlationId"] = self.correlationId + if self.author_name: + result["author_name"] = self.author_name + if self.created_at: + result["created_at"] = self.created_at + if self.extension_data: + result["extension_data"] = self.extension_data return result @classmethod @@ -344,7 +353,7 @@ class RunRequest: response_format=_deserialize_response_format(data.get("response_format")), enable_tool_calls=data.get("enable_tool_calls", True), thread_id=data.get("thread_id"), - correlation_id=data.get("correlation_id"), + correlationId=data.get("correlationId"), ) @@ -392,4 +401,4 @@ class AgentResponse: if self.error_type: result["error_type"] = self.error_type - return result + return result \ No newline at end of file diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py index 2fd4522964..fbec321c45 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_orchestration.py @@ -129,13 +129,13 @@ class DurableAIAgent(AgentProtocol): # Generate a deterministic correlation ID for this call # This is required by the entity and must be unique per call - correlation_id = str(self.context.new_uuid()) + correlationId = str(self.context.new_uuid()) # Prepare the request using RunRequest model run_request = RunRequest( message=message_str, enable_tool_calls=enable_tool_calls, - correlation_id=correlation_id, + correlationId=correlationId, thread_id=session_id.key, response_format=response_format, ) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_state.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_state.py deleted file mode 100644 index c9d54b8333..0000000000 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_state.py +++ /dev/null @@ -1,179 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -"""Agent State Management. - -This module defines the AgentState class for managing conversation state and -serializing agent framework responses. -""" - -from collections.abc import MutableMapping -from datetime import datetime, timezone -from typing import Any, cast - -from agent_framework import AgentRunResponse, ChatMessage, Role, get_logger - -logger = get_logger("agent_framework.azurefunctions.state") - - -class AgentState: - """Manages agent conversation state using agent_framework types (ChatMessage, AgentRunResponse). - - This class handles: - - Conversation history tracking using ChatMessage objects - - Agent response storage using AgentRunResponse objects with correlation IDs - - State persistence and restoration - - Message counting - """ - - def __init__(self) -> None: - """Initialize empty agent state.""" - self.conversation_history: list[ChatMessage] = [] - self.last_response: str | None = None - self.message_count: int = 0 - - def _current_timestamp(self) -> str: - """Return an ISO 8601 UTC timestamp.""" - return datetime.now(timezone.utc).isoformat() - - def add_user_message( - self, - content: str, - role: Role = Role.USER, - correlation_id: str | None = None, - ) -> None: - """Add a user message to the conversation history as a ChatMessage object. - - Args: - content: The message content - role: The message role (user, system, etc.) - correlation_id: Optional correlation identifier associated with the user message - """ - self.message_count += 1 - timestamp = self._current_timestamp() - additional_props: MutableMapping[str, Any] = {"timestamp": timestamp} - if correlation_id is not None: - additional_props["correlation_id"] = correlation_id - chat_message = ChatMessage(role=role, text=content, additional_properties=additional_props) - self.conversation_history.append(chat_message) - logger.debug(f"Added {role} ChatMessage to history (message #{self.message_count})") - - def add_assistant_message( - self, content: str, agent_response: AgentRunResponse, correlation_id: str | None = None - ) -> None: - """Add an assistant message to the conversation history with full agent response. - - Args: - content: The text content of the response - agent_response: The AgentRunResponse object from the agent framework - correlation_id: Optional correlation ID for tracking this response - """ - self.last_response = content - timestamp = self._current_timestamp() - serialized_response = self.serialize_response(agent_response) - - # Create a ChatMessage for the assistant response - # The agent_response already contains messages, but we store it as a custom ChatMessage - # with the agent_response stored in additional_properties for full metadata preservation - additional_props: dict[str, Any] = { - "agent_response": serialized_response, - "correlation_id": correlation_id, - "timestamp": timestamp, - "message_count": self.message_count, - } - chat_message = ChatMessage(role="assistant", text=content, additional_properties=additional_props) - - self.conversation_history.append(chat_message) - - logger.debug( - f"Added assistant ChatMessage to history with AgentRunResponse metadata (correlation_id: {correlation_id})" - ) - - def get_chat_messages(self) -> list[ChatMessage]: - """Return a copy of the full conversation history.""" - return list(self.conversation_history) - - def try_get_agent_response(self, correlation_id: str) -> dict[str, Any] | None: - """Get an agent response by correlation ID. - - Args: - correlation_id: The correlation ID to look up - - Returns: - The agent response data if found, None otherwise - """ - for message in reversed(self.conversation_history): - metadata = getattr(message, "additional_properties", {}) or {} - if metadata.get("correlation_id") == correlation_id: - return self._build_agent_response_payload(message, metadata) - - return None - - def serialize_response(self, response: AgentRunResponse) -> dict[str, Any]: - """Serialize an ``AgentRunResponse`` to a dictionary. - - Args: - response: The agent framework response object - - Returns: - Dictionary containing all response fields - """ - try: - return response.to_dict() - except Exception as exc: # pragma: no cover - defensive logging path - logger.warning(f"Error serializing response: {exc}") - return {"response": str(response), "serialization_error": str(exc)} - - def to_dict(self) -> dict[str, Any]: - """Get the current state as a dictionary for persistence. - - Returns: - Dictionary containing conversation_history (as serialized ChatMessages), - last_response, and message_count - """ - return { - "conversation_history": [msg.to_dict() for msg in self.conversation_history], - "last_response": self.last_response, - "message_count": self.message_count, - } - - def restore_state(self, state: dict[str, Any]) -> None: - """Restore state from a dictionary, reconstructing ChatMessage objects. - - Args: - state: Dictionary containing conversation_history, last_response, and message_count - """ - # Restore conversation history as ChatMessage objects - history_data = state.get("conversation_history", []) - restored_history: list[ChatMessage] = [] - for raw_message in history_data: - if isinstance(raw_message, dict): - restored_history.append(ChatMessage.from_dict(cast(dict[str, Any], raw_message))) - else: - restored_history.append(cast(ChatMessage, raw_message)) - - self.conversation_history = restored_history - - self.last_response = state.get("last_response") - self.message_count = state.get("message_count", 0) - logger.debug("Restored state: %s ChatMessages in history", len(self.conversation_history)) - - def reset(self) -> None: - """Reset the state to empty.""" - self.conversation_history = [] - self.last_response = None - self.message_count = 0 - logger.debug("State reset to empty") - - def __repr__(self) -> str: - """String representation of the state.""" - return f"AgentState(messages={self.message_count}, history_length={len(self.conversation_history)})" - - def _build_agent_response_payload(self, message: ChatMessage, metadata: dict[str, Any]) -> dict[str, Any]: - """Construct the agent response payload returned to callers.""" - return { - "content": message.text, - "agent_response": metadata.get("agent_response"), - "message_count": metadata.get("message_count", self.message_count), - "timestamp": metadata.get("timestamp"), - "correlation_id": metadata.get("correlation_id"), - }