mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Add Durable Agent Wrapper code (#1913)
* add initial changes * Move code and add single sample * Update logger * Remove unused code * address PR comments * cleanup code and address comments --------- Co-authored-by: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
5686a009fb
commit
1762cda5f7
@@ -1 +1,18 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
"""Azure Durable Agent Function App.
|
||||
|
||||
This package provides integration between Microsoft Agent Framework and Azure Durable Functions,
|
||||
enabling durable, stateful AI agents deployed as Azure Function Apps.
|
||||
"""
|
||||
|
||||
from ._app import AgentFunctionApp
|
||||
from ._callbacks import AgentCallbackContext, AgentResponseCallbackProtocol
|
||||
from ._orchestration import DurableAIAgent, get_agent
|
||||
|
||||
__all__ = [
|
||||
"AgentCallbackContext",
|
||||
"AgentFunctionApp",
|
||||
"AgentResponseCallbackProtocol",
|
||||
"DurableAIAgent",
|
||||
"get_agent",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,657 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""AgentFunctionApp - Main application class.
|
||||
|
||||
This module provides the AgentFunctionApp class that integrates Microsoft Agent Framework
|
||||
with Azure Durable Entities, enabling stateful and durable AI agent execution.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
import azure.durable_functions as df
|
||||
import azure.functions as func
|
||||
from agent_framework import AgentProtocol, get_logger
|
||||
|
||||
from ._callbacks import AgentResponseCallbackProtocol
|
||||
from ._entities import create_agent_entity
|
||||
from ._errors import IncomingRequestError
|
||||
from ._models import AgentSessionId, ChatRole, RunRequest
|
||||
from ._state import AgentState
|
||||
|
||||
logger = get_logger("agent_framework.azurefunctions")
|
||||
|
||||
SESSION_ID_FIELD: str = "sessionId"
|
||||
SESSION_KEY_FIELD: str = "sessionKey"
|
||||
SESSION_IDENTIFIER_KEYS: tuple[str, str] = (
|
||||
SESSION_ID_FIELD,
|
||||
SESSION_KEY_FIELD,
|
||||
)
|
||||
|
||||
|
||||
class AgentFunctionApp(df.DFApp):
|
||||
"""Main application class for creating durable agent function apps using Durable Entities.
|
||||
|
||||
This class uses Durable Entities pattern for agent execution, providing:
|
||||
- Stateful agent conversations
|
||||
- Conversation history management
|
||||
- Signal-based operation invocation
|
||||
- Better state management than orchestrations
|
||||
|
||||
Usage:
|
||||
```python
|
||||
from agent_framework.azurefunctions import AgentFunctionApp
|
||||
from agent_framework.azure import AzureOpenAIAssistantsClient
|
||||
|
||||
# Create agents with unique names
|
||||
weather_agent = AzureOpenAIAssistantsClient(...).create_agent(
|
||||
name="WeatherAgent",
|
||||
instructions="You are a helpful weather agent.",
|
||||
tools=[get_weather],
|
||||
)
|
||||
|
||||
math_agent = AzureOpenAIAssistantsClient(...).create_agent(
|
||||
name="MathAgent",
|
||||
instructions="You are a helpful math assistant.",
|
||||
tools=[calculate],
|
||||
)
|
||||
|
||||
# Option 1: Pass list of agents during initialization
|
||||
app = AgentFunctionApp(agents=[weather_agent, math_agent])
|
||||
|
||||
# Option 2: Add agents after initialization
|
||||
app = AgentFunctionApp()
|
||||
app.add_agent(weather_agent)
|
||||
app.add_agent(math_agent)
|
||||
```
|
||||
|
||||
This creates:
|
||||
- HTTP trigger endpoint for each agent's requests (if enabled)
|
||||
- Durable entity for each agent's state management and execution
|
||||
- Full access to all Azure Functions capabilities
|
||||
|
||||
Attributes:
|
||||
agents: Dictionary of agent name to AgentProtocol instance
|
||||
enable_health_check: Whether health check endpoint is enabled
|
||||
enable_http_endpoints: Whether HTTP endpoints are created for agents
|
||||
max_poll_retries: Maximum polling attempts when waiting for responses
|
||||
poll_interval_seconds: Delay (seconds) between polling attempts
|
||||
"""
|
||||
|
||||
agents: dict[str, AgentProtocol]
|
||||
enable_health_check: bool
|
||||
enable_http_endpoints: bool
|
||||
agent_http_endpoint_flags: dict[str, bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agents: list[AgentProtocol] | None = None,
|
||||
http_auth_level: func.AuthLevel = func.AuthLevel.ANONYMOUS,
|
||||
enable_health_check: bool = True,
|
||||
enable_http_endpoints: bool = True,
|
||||
max_poll_retries: int = 10,
|
||||
poll_interval_seconds: float = 0.5,
|
||||
default_callback: AgentResponseCallbackProtocol | None = None,
|
||||
):
|
||||
"""Initialize the AgentFunctionApp.
|
||||
|
||||
Args:
|
||||
agents: List of agent instances to register
|
||||
http_auth_level: HTTP authentication level (default: ANONYMOUS)
|
||||
enable_health_check: Enable built-in health check endpoint (default: True)
|
||||
enable_http_endpoints: Enable HTTP endpoints for agents (default: True)
|
||||
max_poll_retries: Maximum number of polling attempts when waiting for a response
|
||||
poll_interval_seconds: Delay (in seconds) between polling attempts
|
||||
default_callback: Optional callback invoked for agents without specific callbacks
|
||||
|
||||
Note:
|
||||
If no agents are provided, they can be added later using add_agent().
|
||||
"""
|
||||
logger.debug("[AgentFunctionApp] Initializing with Durable Entities...")
|
||||
|
||||
# Initialize parent DFApp
|
||||
super().__init__(http_auth_level=http_auth_level)
|
||||
|
||||
# Initialize agents dictionary
|
||||
self.agents = {}
|
||||
self.agent_http_endpoint_flags = {}
|
||||
self.enable_health_check = enable_health_check
|
||||
self.enable_http_endpoints = enable_http_endpoints
|
||||
self.default_callback = default_callback
|
||||
|
||||
try:
|
||||
retries = int(max_poll_retries)
|
||||
except (TypeError, ValueError):
|
||||
retries = 10
|
||||
self.max_poll_retries = max(1, retries)
|
||||
|
||||
try:
|
||||
interval = float(poll_interval_seconds)
|
||||
except (TypeError, ValueError):
|
||||
interval = 0.5
|
||||
self.poll_interval_seconds = interval if interval > 0 else 0.5
|
||||
|
||||
if agents:
|
||||
# Register all provided agents
|
||||
logger.debug(f"[AgentFunctionApp] Registering {len(agents)} agent(s)")
|
||||
for agent_instance in agents:
|
||||
self.add_agent(agent_instance)
|
||||
|
||||
# Setup health check if enabled
|
||||
if self.enable_health_check:
|
||||
self._setup_health_route()
|
||||
|
||||
logger.debug("[AgentFunctionApp] Initialization complete")
|
||||
|
||||
def add_agent(
|
||||
self,
|
||||
agent: AgentProtocol,
|
||||
callback: AgentResponseCallbackProtocol | None = None,
|
||||
enable_http_endpoint: bool | None = None,
|
||||
) -> None:
|
||||
"""Add an agent to the function app after initialization.
|
||||
|
||||
Args:
|
||||
agent: The Microsoft Agent Framework agent instance (must implement AgentProtocol)
|
||||
The agent must have a 'name' attribute.
|
||||
callback: Optional callback invoked during agent execution
|
||||
enable_http_endpoint: Optional flag that overrides the app-level
|
||||
HTTP endpoint setting for this agent
|
||||
|
||||
Raises:
|
||||
ValueError: If the agent doesn't have a 'name' attribute or if an agent
|
||||
with the same name is already registered
|
||||
"""
|
||||
# Get agent name from the agent's name attribute
|
||||
name = getattr(agent, "name", None)
|
||||
if name is None:
|
||||
raise ValueError("Agent does not have a 'name' attribute. All agents must have a 'name' attribute.")
|
||||
|
||||
if name in self.agents:
|
||||
raise ValueError(f"Agent with name '{name}' is already registered. Each agent must have a unique name.")
|
||||
|
||||
effective_enable_http_endpoint = (
|
||||
self.enable_http_endpoints if enable_http_endpoint is None else self._coerce_to_bool(enable_http_endpoint)
|
||||
)
|
||||
|
||||
logger.debug(f"[AgentFunctionApp] Adding agent: {name}")
|
||||
logger.debug(f"[AgentFunctionApp] Route: /api/agents/{name}")
|
||||
logger.debug(
|
||||
"[AgentFunctionApp] HTTP endpoint %s for agent '%s'",
|
||||
"enabled" if effective_enable_http_endpoint else "disabled",
|
||||
name,
|
||||
)
|
||||
|
||||
self.agents[name] = agent
|
||||
self.agent_http_endpoint_flags[name] = effective_enable_http_endpoint
|
||||
|
||||
effective_callback = callback or self.default_callback
|
||||
|
||||
self._setup_agent_functions(
|
||||
agent,
|
||||
name,
|
||||
effective_callback,
|
||||
effective_enable_http_endpoint,
|
||||
)
|
||||
|
||||
logger.debug(f"[AgentFunctionApp] Agent '{name}' added successfully")
|
||||
|
||||
def _setup_agent_functions(
|
||||
self,
|
||||
agent: AgentProtocol,
|
||||
agent_name: str,
|
||||
callback: AgentResponseCallbackProtocol | None,
|
||||
enable_http_endpoint: bool,
|
||||
) -> None:
|
||||
"""Set up the HTTP trigger and entity for a specific agent.
|
||||
|
||||
Args:
|
||||
agent: The agent instance
|
||||
agent_name: The name to use for routing and entity registration
|
||||
callback: Optional callback to receive response updates
|
||||
enable_http_endpoint: Whether the HTTP run route is enabled for
|
||||
this agent
|
||||
"""
|
||||
logger.debug(f"[AgentFunctionApp] Setting up functions for agent '{agent_name}'...")
|
||||
|
||||
if enable_http_endpoint:
|
||||
self._setup_http_run_route(agent_name)
|
||||
else:
|
||||
logger.debug(
|
||||
"[AgentFunctionApp] HTTP run route disabled for agent '%s'",
|
||||
agent_name,
|
||||
)
|
||||
self._setup_agent_entity(agent, agent_name, callback)
|
||||
|
||||
def _setup_http_run_route(self, agent_name: str) -> None:
|
||||
"""Register the POST route that triggers agent execution.
|
||||
|
||||
Args:
|
||||
agent_name: The agent name (used for both routing and entity identification)
|
||||
"""
|
||||
run_function_name = self._build_function_name(agent_name, "run")
|
||||
|
||||
@self.function_name(run_function_name)
|
||||
@self.route(route=f"agents/{agent_name}/run", methods=["POST"])
|
||||
@self.durable_client_input(client_name="client")
|
||||
async def http_start(req: func.HttpRequest, client: df.DurableOrchestrationClient) -> func.HttpResponse:
|
||||
"""HTTP trigger that calls a durable entity to execute the agent and returns the result.
|
||||
|
||||
Expected request body (RunRequest format):
|
||||
{
|
||||
"message": "user message to agent",
|
||||
"sessionId": "optional session id (or sessionKey)",
|
||||
"role": "user|system" (optional, default: "user"),
|
||||
"response_format": {...} (optional JSON schema for structured responses),
|
||||
"enable_tool_calls": true|false (optional, default: true)
|
||||
}
|
||||
"""
|
||||
logger.debug(f"[HTTP Trigger] Received request on route: /api/agents/{agent_name}/run")
|
||||
|
||||
try:
|
||||
req_body, message = self._parse_incoming_request(req)
|
||||
session_key = self._resolve_session_key(req=req, req_body=req_body)
|
||||
wait_for_completion = self._should_wait_for_completion(req=req, req_body=req_body)
|
||||
|
||||
logger.debug(f"[HTTP Trigger] Message: {message}")
|
||||
logger.debug(f"[HTTP Trigger] Session Key: {session_key}")
|
||||
logger.debug(f"[HTTP Trigger] wait_for_completion: {wait_for_completion}")
|
||||
|
||||
if not message:
|
||||
logger.warning("[HTTP Trigger] Request rejected: Missing message")
|
||||
return func.HttpResponse(
|
||||
json.dumps({"error": "Message is required"}), status_code=400, mimetype="application/json"
|
||||
)
|
||||
|
||||
session_id = self._create_session_id(agent_name, session_key)
|
||||
correlation_id = self._generate_unique_id()
|
||||
|
||||
logger.debug(f"[HTTP Trigger] Using session ID: {session_id}")
|
||||
logger.debug(f"[HTTP Trigger] Generated correlation ID: {correlation_id}")
|
||||
logger.debug("[HTTP Trigger] Calling entity to run agent...")
|
||||
|
||||
entity_instance_id = session_id.to_entity_id()
|
||||
run_request = self._build_request_data(
|
||||
req_body,
|
||||
message,
|
||||
session_key,
|
||||
correlation_id,
|
||||
)
|
||||
logger.debug("Signalling entity %s with request: %s", entity_instance_id, run_request)
|
||||
await client.signal_entity(entity_instance_id, "run_agent", run_request)
|
||||
|
||||
logger.debug(f"[HTTP Trigger] Signal sent to entity {session_id}")
|
||||
|
||||
if wait_for_completion:
|
||||
result = await self._get_response_from_entity(
|
||||
client=client,
|
||||
entity_instance_id=entity_instance_id,
|
||||
correlation_id=correlation_id,
|
||||
message=message,
|
||||
session_key=session_key,
|
||||
)
|
||||
|
||||
logger.debug(f"[HTTP Trigger] Result status: {result.get('status', 'unknown')}")
|
||||
return func.HttpResponse(
|
||||
json.dumps(result),
|
||||
status_code=200 if result.get("status") == "success" else 500,
|
||||
mimetype="application/json",
|
||||
)
|
||||
|
||||
logger.debug("[HTTP Trigger] wait_for_completion disabled; returning correlation ID")
|
||||
|
||||
accepted_response = self._build_accepted_response(
|
||||
message=message, session_key=session_key, correlation_id=correlation_id
|
||||
)
|
||||
|
||||
return func.HttpResponse(json.dumps(accepted_response), status_code=202, mimetype="application/json")
|
||||
|
||||
except IncomingRequestError as exc:
|
||||
logger.warning(f"[HTTP Trigger] Request rejected: {exc!s}")
|
||||
return func.HttpResponse(
|
||||
json.dumps({"error": str(exc)}), status_code=exc.status_code, mimetype="application/json"
|
||||
)
|
||||
except ValueError as exc:
|
||||
logger.error(f"[HTTP Trigger] Invalid JSON: {exc!s}")
|
||||
return func.HttpResponse(
|
||||
json.dumps({"error": "Invalid JSON"}), status_code=400, mimetype="application/json"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"[HTTP Trigger] Error: {exc!s}", exc_info=True)
|
||||
return func.HttpResponse(json.dumps({"error": str(exc)}), status_code=500, mimetype="application/json")
|
||||
|
||||
_ = http_start
|
||||
|
||||
def _setup_agent_entity(
|
||||
self,
|
||||
agent: AgentProtocol,
|
||||
agent_name: str,
|
||||
callback: AgentResponseCallbackProtocol | None,
|
||||
) -> None:
|
||||
"""Register the durable entity responsible for agent state.
|
||||
|
||||
Args:
|
||||
agent: The agent instance
|
||||
agent_name: The agent name (used for both entity identification and function naming)
|
||||
callback: Optional callback for response updates
|
||||
"""
|
||||
# Use the prefixed entity name for both registration and function naming
|
||||
entity_name_with_prefix = AgentSessionId.to_entity_name(agent_name)
|
||||
|
||||
def entity_function(context: df.DurableEntityContext) -> None:
|
||||
"""Durable entity that manages agent execution and conversation state.
|
||||
|
||||
Operations:
|
||||
- run_agent: Execute the agent with a message
|
||||
- reset: Clear conversation history
|
||||
"""
|
||||
entity_handler = create_agent_entity(agent, callback)
|
||||
entity_handler(context)
|
||||
|
||||
# Set function name for Azure Functions (used in function.json generation)
|
||||
# Use the prefixed entity name as the function name too.
|
||||
entity_function.__name__ = entity_name_with_prefix
|
||||
self.entity_trigger(context_name="context", entity_name=entity_name_with_prefix)(entity_function)
|
||||
|
||||
def _setup_health_route(self) -> None:
|
||||
"""Register the optional health check route."""
|
||||
|
||||
@self.route(route="health", methods=["GET"])
|
||||
def health_check(req: func.HttpRequest) -> func.HttpResponse:
|
||||
"""Built-in health check endpoint."""
|
||||
agent_info = [
|
||||
{
|
||||
"name": name,
|
||||
"type": type(agent).__name__,
|
||||
"httpEndpointEnabled": self.agent_http_endpoint_flags.get(
|
||||
name,
|
||||
self.enable_http_endpoints,
|
||||
),
|
||||
}
|
||||
for name, agent in self.agents.items()
|
||||
]
|
||||
return func.HttpResponse(
|
||||
json.dumps({"status": "healthy", "agents": agent_info, "agent_count": len(self.agents)}),
|
||||
status_code=200,
|
||||
mimetype="application/json",
|
||||
)
|
||||
|
||||
_ = health_check
|
||||
|
||||
@staticmethod
|
||||
def _build_function_name(agent_name: str, suffix: str) -> str:
|
||||
"""Generate a unique, Azure Functions-compliant name for an agent function."""
|
||||
sanitized = re.sub(r"[^0-9a-zA-Z_]", "_", agent_name or "agent").strip("_")
|
||||
|
||||
if not sanitized:
|
||||
sanitized = "agent"
|
||||
|
||||
if sanitized[0].isdigit():
|
||||
sanitized = f"agent_{sanitized}"
|
||||
|
||||
return f"{sanitized}_{suffix}"
|
||||
|
||||
async def _read_cached_state(
|
||||
self,
|
||||
client: df.DurableOrchestrationClient,
|
||||
entity_instance_id: df.EntityId,
|
||||
) -> AgentState | None:
|
||||
state_response = await client.read_entity_state(entity_instance_id)
|
||||
if not state_response or not state_response.entity_exists:
|
||||
return None
|
||||
|
||||
state_payload = state_response.entity_state
|
||||
if not isinstance(state_payload, dict):
|
||||
return None
|
||||
|
||||
typed_state_payload = cast(dict[str, Any], state_payload)
|
||||
|
||||
agent_state = AgentState()
|
||||
agent_state.restore_state(typed_state_payload)
|
||||
return agent_state
|
||||
|
||||
async def _get_response_from_entity(
|
||||
self,
|
||||
client: df.DurableOrchestrationClient,
|
||||
entity_instance_id: df.EntityId,
|
||||
correlation_id: str,
|
||||
message: str,
|
||||
session_key: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Poll the entity state until a response is available or timeout occurs."""
|
||||
import asyncio
|
||||
|
||||
max_retries = self.max_poll_retries
|
||||
interval = self.poll_interval_seconds
|
||||
retry_count = 0
|
||||
result: dict[str, Any] | None = None
|
||||
|
||||
logger.debug(f"[HTTP Trigger] Waiting for response with correlation ID: {correlation_id}")
|
||||
|
||||
while retry_count < max_retries:
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
result = await self._poll_entity_for_response(
|
||||
client=client,
|
||||
entity_instance_id=entity_instance_id,
|
||||
correlation_id=correlation_id,
|
||||
message=message,
|
||||
session_key=session_key,
|
||||
)
|
||||
if result is not None:
|
||||
break
|
||||
|
||||
logger.debug(f"[HTTP Trigger] Response not available yet (retry {retry_count})")
|
||||
retry_count += 1
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
logger.warning(
|
||||
f"[HTTP Trigger] Response with correlation ID {correlation_id} "
|
||||
f"not found in time (waited {max_retries * interval} seconds)"
|
||||
)
|
||||
return await self._build_timeout_result(message=message, session_key=session_key, correlation_id=correlation_id)
|
||||
|
||||
async def _poll_entity_for_response(
|
||||
self,
|
||||
client: df.DurableOrchestrationClient,
|
||||
entity_instance_id: df.EntityId,
|
||||
correlation_id: str,
|
||||
message: str,
|
||||
session_key: str,
|
||||
) -> dict[str, Any] | None:
|
||||
result: dict[str, Any] | None = None
|
||||
try:
|
||||
state = await self._read_cached_state(client, entity_instance_id)
|
||||
|
||||
if state is None:
|
||||
return None
|
||||
|
||||
agent_response = state.try_get_agent_response(correlation_id)
|
||||
if agent_response:
|
||||
result = self._build_success_result(
|
||||
response_data=agent_response,
|
||||
message=message,
|
||||
session_key=session_key,
|
||||
correlation_id=correlation_id,
|
||||
state=state,
|
||||
)
|
||||
logger.debug(f"[HTTP Trigger] Found response for correlation ID: {correlation_id}")
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning(f"[HTTP Trigger] Error reading entity state: {exc}")
|
||||
|
||||
return result
|
||||
|
||||
async def _build_timeout_result(self, message: str, session_key: str, correlation_id: str) -> dict[str, Any]:
|
||||
"""Create the timeout response."""
|
||||
return {
|
||||
"response": "Agent is still processing or timed out...",
|
||||
"message": message,
|
||||
SESSION_ID_FIELD: session_key,
|
||||
"status": "timeout",
|
||||
"correlationId": correlation_id,
|
||||
}
|
||||
|
||||
def _build_success_result(
|
||||
self, response_data: dict[str, Any], message: str, session_key: str, correlation_id: str, state: AgentState
|
||||
) -> dict[str, Any]:
|
||||
"""Build the success result returned to the HTTP caller."""
|
||||
return {
|
||||
"response": response_data.get("content"),
|
||||
"message": message,
|
||||
SESSION_ID_FIELD: session_key,
|
||||
"status": "success",
|
||||
"message_count": response_data.get("message_count", state.message_count),
|
||||
"correlationId": correlation_id,
|
||||
}
|
||||
|
||||
def _build_request_data(
|
||||
self, req_body: dict[str, Any], message: str, conversation_id: str, correlation_id: str
|
||||
) -> dict[str, Any]:
|
||||
"""Create the durable entity request payload."""
|
||||
enable_tool_calls_value = req_body.get("enable_tool_calls")
|
||||
enable_tool_calls = True if enable_tool_calls_value is None else self._coerce_to_bool(enable_tool_calls_value)
|
||||
|
||||
role = self._coerce_chat_role(req_body.get("role"))
|
||||
|
||||
return RunRequest(
|
||||
message=message,
|
||||
role=role,
|
||||
response_format=req_body.get("response_format"),
|
||||
enable_tool_calls=enable_tool_calls,
|
||||
conversation_id=conversation_id,
|
||||
correlation_id=correlation_id,
|
||||
).to_dict()
|
||||
|
||||
def _build_accepted_response(self, message: str, session_key: str, correlation_id: str) -> dict[str, Any]:
|
||||
"""Build the response returned when not waiting for completion."""
|
||||
return {
|
||||
"response": "Agent request accepted",
|
||||
"message": message,
|
||||
SESSION_ID_FIELD: session_key,
|
||||
"status": "accepted",
|
||||
"correlationId": correlation_id,
|
||||
}
|
||||
|
||||
def _generate_unique_id(self) -> str:
|
||||
"""Generate a new unique identifier."""
|
||||
import uuid
|
||||
|
||||
return uuid.uuid4().hex
|
||||
|
||||
def _create_session_id(self, func_name: str, session_key: str | None) -> AgentSessionId:
|
||||
"""Create a session identifier using the provided key or a random value."""
|
||||
if session_key:
|
||||
return AgentSessionId(name=func_name, key=session_key)
|
||||
return AgentSessionId.with_random_key(name=func_name)
|
||||
|
||||
def _resolve_session_key(self, req: func.HttpRequest, req_body: dict[str, Any]) -> str:
|
||||
"""Retrieve the session key from request body or query parameters."""
|
||||
params = req.params or {}
|
||||
|
||||
for key in SESSION_IDENTIFIER_KEYS:
|
||||
if key in req_body:
|
||||
value = req_body.get(key)
|
||||
if value is not None:
|
||||
return str(value)
|
||||
|
||||
for key in SESSION_IDENTIFIER_KEYS:
|
||||
if key in params:
|
||||
value = params.get(key)
|
||||
if value is not None:
|
||||
return str(value)
|
||||
|
||||
logger.debug("[HTTP Trigger] No session identifier provided; using random session key")
|
||||
return self._generate_unique_id()
|
||||
|
||||
def _parse_incoming_request(self, req: func.HttpRequest) -> tuple[dict[str, Any], Any]:
|
||||
"""Parse the incoming run request supporting JSON and plain text bodies."""
|
||||
headers: dict[str, str] = {}
|
||||
raw_headers = req.headers
|
||||
if isinstance(raw_headers, Mapping):
|
||||
headers_mapping = cast(Mapping[Any, Any], raw_headers)
|
||||
for key, value in headers_mapping.items():
|
||||
if value is not None:
|
||||
headers[str(key)] = str(value)
|
||||
|
||||
content_type_header = headers.get("content-type")
|
||||
|
||||
normalized_content_type = ""
|
||||
if content_type_header:
|
||||
normalized_content_type = content_type_header.split(";")[0].strip().lower()
|
||||
|
||||
if normalized_content_type in {"application/json"} or normalized_content_type.endswith("+json"):
|
||||
parser = self._parse_json_body
|
||||
else:
|
||||
parser = self._parse_text_body
|
||||
|
||||
return parser(req)
|
||||
|
||||
@staticmethod
|
||||
def _parse_json_body(req: func.HttpRequest) -> tuple[dict[str, Any], Any]:
|
||||
req_body = req.get_json()
|
||||
if not isinstance(req_body, dict):
|
||||
raise IncomingRequestError("Invalid JSON payload. Expected an object.")
|
||||
|
||||
typed_req_body = cast(dict[str, Any], req_body)
|
||||
message_value = typed_req_body.get("message", "")
|
||||
message = message_value if isinstance(message_value, str) else str(message_value)
|
||||
return typed_req_body, message
|
||||
|
||||
@staticmethod
|
||||
def _parse_text_body(req: func.HttpRequest) -> tuple[dict[str, Any], Any]:
|
||||
body_bytes = req.get_body()
|
||||
text_body = body_bytes.decode("utf-8", errors="replace") if body_bytes else ""
|
||||
message = text_body.strip()
|
||||
|
||||
if not message:
|
||||
raise IncomingRequestError("Message is required")
|
||||
|
||||
return {}, message
|
||||
|
||||
def _should_wait_for_completion(self, req: func.HttpRequest, req_body: dict[str, Any]) -> bool:
|
||||
"""Determine whether the caller requested to wait for completion."""
|
||||
header_value = None
|
||||
raw_headers = req.headers
|
||||
if isinstance(raw_headers, Mapping):
|
||||
headers_mapping = cast(Mapping[Any, Any], raw_headers)
|
||||
for key, value in headers_mapping.items():
|
||||
if str(key).lower() == "x-wait-for-completion":
|
||||
header_value = value
|
||||
break
|
||||
|
||||
if header_value is not None:
|
||||
return self._coerce_to_bool(header_value)
|
||||
|
||||
for key in ("wait_for_completion", "waitForCompletion", "WaitForCompletion"):
|
||||
if key in req_body:
|
||||
return self._coerce_to_bool(req_body.get(key))
|
||||
|
||||
return False
|
||||
|
||||
def _coerce_chat_role(self, value: Any) -> ChatRole:
|
||||
"""Convert user-provided role to ChatRole, defaulting to user on error."""
|
||||
if isinstance(value, ChatRole):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return ChatRole(value.strip().lower())
|
||||
except ValueError:
|
||||
logger.warning("[AgentFunctionApp] Invalid role '%s'; defaulting to user", value)
|
||||
return ChatRole.USER
|
||||
|
||||
def _coerce_to_bool(self, value: Any) -> bool:
|
||||
"""Convert various representations into a boolean flag."""
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if value is None:
|
||||
return False
|
||||
if isinstance(value, (int, float)):
|
||||
return bool(value)
|
||||
if isinstance(value, str):
|
||||
return value.strip().lower() in {"true", "1", "yes", "y", "on"}
|
||||
return False
|
||||
@@ -0,0 +1,42 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Callback interfaces for Durable Agent executions.
|
||||
|
||||
This module enables callers of AgentFunctionApp to supply streaming and final-response callbacks that are
|
||||
invoked during durable entity execution.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Protocol
|
||||
|
||||
from agent_framework import AgentRunResponse, AgentRunResponseUpdate
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AgentCallbackContext:
|
||||
"""Context supplied to callback invocations."""
|
||||
|
||||
agent_name: str
|
||||
correlation_id: str
|
||||
conversation_id: str | None = None
|
||||
request_message: str | None = None
|
||||
|
||||
|
||||
class AgentResponseCallbackProtocol(Protocol):
|
||||
"""Protocol describing the callbacks invoked during agent execution."""
|
||||
|
||||
async def on_streaming_response_update(
|
||||
self,
|
||||
update: AgentRunResponseUpdate,
|
||||
context: AgentCallbackContext,
|
||||
) -> None:
|
||||
"""Handle a streaming response update emitted by the agent."""
|
||||
|
||||
async def on_agent_response(
|
||||
self,
|
||||
response: AgentRunResponse,
|
||||
context: AgentCallbackContext,
|
||||
) -> None:
|
||||
"""Handle the final agent response."""
|
||||
@@ -0,0 +1,431 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Durable Entity for Agent Execution.
|
||||
|
||||
This module defines a durable entity that manages agent state and execution.
|
||||
Using entities instead of orchestrations provides better state management and
|
||||
allows for long-running agent conversations.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
from collections.abc import AsyncIterable
|
||||
from typing import Any, cast
|
||||
|
||||
import azure.durable_functions as df
|
||||
from agent_framework import AgentProtocol, AgentRunResponse, AgentRunResponseUpdate, get_logger
|
||||
|
||||
from ._callbacks import AgentCallbackContext, AgentResponseCallbackProtocol
|
||||
from ._models import AgentResponse, ChatRole, RunRequest
|
||||
from ._state import AgentState
|
||||
|
||||
logger = get_logger("agent_framework.azurefunctions.entities")
|
||||
|
||||
|
||||
class AgentEntity:
|
||||
"""Durable entity that manages agent execution and conversation state.
|
||||
|
||||
This entity:
|
||||
- Maintains conversation history
|
||||
- Executes agent with messages
|
||||
- Stores agent responses
|
||||
- Handles tool execution
|
||||
|
||||
Operations:
|
||||
- run_agent: Execute the agent with a message
|
||||
- reset: Clear conversation history
|
||||
|
||||
Attributes:
|
||||
agent: The AgentProtocol instance
|
||||
state: The AgentState managing conversation history
|
||||
"""
|
||||
|
||||
agent: AgentProtocol
|
||||
state: AgentState
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent: AgentProtocol,
|
||||
callback: AgentResponseCallbackProtocol | None = None,
|
||||
):
|
||||
"""Initialize the agent entity.
|
||||
|
||||
Args:
|
||||
agent: The Microsoft Agent Framework agent instance (must implement AgentProtocol)
|
||||
callback: Optional callback invoked during streaming updates and final responses
|
||||
"""
|
||||
self.agent = agent
|
||||
self.state = AgentState()
|
||||
self.callback = callback
|
||||
|
||||
logger.debug(f"[AgentEntity] Initialized with agent type: {type(agent).__name__}")
|
||||
|
||||
async def run_agent(
|
||||
self,
|
||||
context: df.DurableEntityContext,
|
||||
request: RunRequest | dict[str, Any] | str,
|
||||
) -> dict[str, Any]:
|
||||
"""Execute the agent with a message directly in the entity.
|
||||
|
||||
Args:
|
||||
context: Entity context
|
||||
request: RunRequest object, dict, or string message (for backward compatibility)
|
||||
|
||||
Returns:
|
||||
Dict with status information and response (serialized AgentResponse)
|
||||
|
||||
Note:
|
||||
The agent returns an AgentRunResponse object which is stored in state.
|
||||
This method extracts the text/structured response and returns an AgentResponse dict.
|
||||
"""
|
||||
# Convert string or dict to RunRequest
|
||||
if isinstance(request, str):
|
||||
run_request = RunRequest(message=request, role=ChatRole.USER)
|
||||
elif isinstance(request, dict):
|
||||
run_request = RunRequest.from_dict(request)
|
||||
else:
|
||||
run_request = request
|
||||
|
||||
message = run_request.message
|
||||
conversation_id = run_request.conversation_id
|
||||
correlation_id = run_request.correlation_id
|
||||
if not conversation_id:
|
||||
raise ValueError("RunRequest must include a conversation_id")
|
||||
if not correlation_id:
|
||||
raise ValueError("RunRequest must include a correlation_id")
|
||||
role = run_request.role or ChatRole.USER
|
||||
response_format = run_request.response_format
|
||||
enable_tool_calls = run_request.enable_tool_calls
|
||||
|
||||
logger.debug(f"[AgentEntity.run_agent] Received message: {message}")
|
||||
logger.debug(f"[AgentEntity.run_agent] Conversation ID: {conversation_id}")
|
||||
logger.debug(f"[AgentEntity.run_agent] Correlation ID: {correlation_id}")
|
||||
logger.debug(f"[AgentEntity.run_agent] Role: {role.value if isinstance(role, ChatRole) else role}")
|
||||
logger.debug(f"[AgentEntity.run_agent] Enable tool calls: {enable_tool_calls}")
|
||||
logger.debug(f"[AgentEntity.run_agent] Response format: {'provided' if response_format else 'none'}")
|
||||
|
||||
# Store message in history with role
|
||||
role_str = role.value if isinstance(role, ChatRole) else role
|
||||
self.state.add_user_message(message, role=role_str, correlation_id=correlation_id)
|
||||
|
||||
logger.debug("[AgentEntity.run_agent] Executing agent...")
|
||||
|
||||
try:
|
||||
logger.debug("[AgentEntity.run_agent] Starting agent invocation")
|
||||
|
||||
run_kwargs: dict[str, Any] = {"messages": self.state.get_chat_messages()}
|
||||
if not enable_tool_calls:
|
||||
run_kwargs["tools"] = None
|
||||
if response_format:
|
||||
run_kwargs["response_format"] = response_format
|
||||
|
||||
agent_run_response: AgentRunResponse = await self._invoke_agent(
|
||||
run_kwargs=run_kwargs,
|
||||
correlation_id=correlation_id,
|
||||
conversation_id=conversation_id,
|
||||
request_message=message,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"[AgentEntity.run_agent] Agent invocation completed - response type: %s",
|
||||
type(agent_run_response).__name__,
|
||||
)
|
||||
|
||||
response_text = None
|
||||
structured_response = None
|
||||
|
||||
response_str: str | None = None
|
||||
try:
|
||||
if response_format:
|
||||
try:
|
||||
response_str = agent_run_response.text
|
||||
structured_response = json.loads(response_str)
|
||||
logger.debug("Parsed structured JSON response")
|
||||
except json.JSONDecodeError as decode_error:
|
||||
logger.warning(f"Failed to parse JSON response: {decode_error}")
|
||||
response_text = response_str
|
||||
else:
|
||||
raw_text = agent_run_response.text
|
||||
response_text = raw_text if raw_text else "No response"
|
||||
preview = response_text
|
||||
logger.debug(f"Response: {preview[:100]}..." if len(preview) > 100 else f"Response: {preview}")
|
||||
except Exception as extraction_error:
|
||||
logger.error(
|
||||
f"Error extracting response: {extraction_error}",
|
||||
exc_info=True,
|
||||
)
|
||||
response_text = "Error extracting response"
|
||||
|
||||
agent_response = AgentResponse(
|
||||
response=response_text,
|
||||
message=str(message),
|
||||
conversation_id=str(conversation_id),
|
||||
status="success",
|
||||
message_count=self.state.message_count,
|
||||
structured_response=structured_response,
|
||||
)
|
||||
result = agent_response.to_dict()
|
||||
|
||||
content = json.dumps(structured_response) if structured_response else (response_text or "")
|
||||
self.state.add_assistant_message(content, agent_run_response, correlation_id)
|
||||
logger.debug("[AgentEntity.run_agent] AgentRunResponse stored in conversation history")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as exc:
|
||||
import traceback
|
||||
|
||||
error_traceback = traceback.format_exc()
|
||||
logger.error("[AgentEntity.run_agent] Agent execution failed")
|
||||
logger.error(f"Error: {exc!s}")
|
||||
logger.error(f"Error type: {type(exc).__name__}")
|
||||
logger.error(f"Full traceback:\n{error_traceback}")
|
||||
|
||||
error_response = AgentResponse(
|
||||
response=f"Error: {exc!s}",
|
||||
message=str(message),
|
||||
conversation_id=str(conversation_id),
|
||||
status="error",
|
||||
message_count=self.state.message_count,
|
||||
error=str(exc),
|
||||
error_type=type(exc).__name__,
|
||||
)
|
||||
return error_response.to_dict()
|
||||
|
||||
async def _invoke_agent(
|
||||
self,
|
||||
run_kwargs: dict[str, Any],
|
||||
correlation_id: str,
|
||||
conversation_id: str,
|
||||
request_message: str,
|
||||
) -> AgentRunResponse:
|
||||
"""Execute the agent, preferring streaming when available."""
|
||||
callback_context: AgentCallbackContext | None = None
|
||||
if self.callback is not None:
|
||||
callback_context = self._build_callback_context(
|
||||
correlation_id=correlation_id,
|
||||
conversation_id=conversation_id,
|
||||
request_message=request_message,
|
||||
)
|
||||
|
||||
run_stream_callable = getattr(self.agent, "run_stream", None)
|
||||
if callable(run_stream_callable):
|
||||
try:
|
||||
stream_candidate = run_stream_callable(**run_kwargs)
|
||||
if inspect.isawaitable(stream_candidate):
|
||||
stream_candidate = await stream_candidate
|
||||
|
||||
return await self._consume_stream(
|
||||
stream=cast(AsyncIterable[AgentRunResponseUpdate], stream_candidate),
|
||||
callback_context=callback_context,
|
||||
)
|
||||
except TypeError as type_error:
|
||||
if "__aiter__" not in str(type_error):
|
||||
raise
|
||||
logger.debug(
|
||||
"run_stream returned a non-async result; falling back to run(): %s",
|
||||
type_error,
|
||||
)
|
||||
except Exception as stream_error:
|
||||
logger.warning(
|
||||
"run_stream failed; falling back to run(): %s",
|
||||
stream_error,
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
logger.debug("Agent does not expose run_stream; falling back to run().")
|
||||
|
||||
agent_run_response = await self._invoke_non_stream(run_kwargs)
|
||||
await self._notify_final_response(agent_run_response, callback_context)
|
||||
return agent_run_response
|
||||
|
||||
async def _consume_stream(
|
||||
self,
|
||||
stream: AsyncIterable[AgentRunResponseUpdate],
|
||||
callback_context: AgentCallbackContext | None = None,
|
||||
) -> AgentRunResponse:
|
||||
"""Consume streaming responses and build the final AgentRunResponse."""
|
||||
updates: list[AgentRunResponseUpdate] = []
|
||||
|
||||
async for update in stream:
|
||||
updates.append(update)
|
||||
await self._notify_stream_update(update, callback_context)
|
||||
|
||||
if updates:
|
||||
response = AgentRunResponse.from_agent_run_response_updates(updates)
|
||||
else:
|
||||
logger.debug("[AgentEntity] No streaming updates received; creating empty response")
|
||||
response = AgentRunResponse(messages=[])
|
||||
|
||||
await self._notify_final_response(response, callback_context)
|
||||
return response
|
||||
|
||||
async def _invoke_non_stream(self, run_kwargs: dict[str, Any]) -> AgentRunResponse:
|
||||
"""Invoke the agent without streaming support."""
|
||||
run_callable = getattr(self.agent, "run", None)
|
||||
if run_callable is None or not callable(run_callable):
|
||||
raise AttributeError("Agent does not implement run() method")
|
||||
|
||||
result = run_callable(**run_kwargs)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
|
||||
if not isinstance(result, AgentRunResponse):
|
||||
raise TypeError(f"Agent run() must return an AgentRunResponse instance; received {type(result).__name__}")
|
||||
|
||||
return result
|
||||
|
||||
async def _notify_stream_update(
|
||||
self,
|
||||
update: AgentRunResponseUpdate,
|
||||
context: AgentCallbackContext | None,
|
||||
) -> None:
|
||||
"""Invoke the streaming callback if one is registered."""
|
||||
if self.callback is None or context is None:
|
||||
return
|
||||
|
||||
try:
|
||||
callback_result = self.callback.on_streaming_response_update(update, context)
|
||||
if inspect.isawaitable(callback_result):
|
||||
await callback_result
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"[AgentEntity] Streaming callback raised an exception: %s",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
async def _notify_final_response(
|
||||
self,
|
||||
response: AgentRunResponse,
|
||||
context: AgentCallbackContext | None,
|
||||
) -> None:
|
||||
"""Invoke the final response callback if one is registered."""
|
||||
if self.callback is None or context is None:
|
||||
return
|
||||
|
||||
try:
|
||||
callback_result = self.callback.on_agent_response(response, context)
|
||||
if inspect.isawaitable(callback_result):
|
||||
await callback_result
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"[AgentEntity] Response callback raised an exception: %s",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def _build_callback_context(
|
||||
self,
|
||||
correlation_id: str,
|
||||
conversation_id: str,
|
||||
request_message: str,
|
||||
) -> AgentCallbackContext:
|
||||
"""Create the callback context provided to consumers."""
|
||||
agent_name = getattr(self.agent, "name", None) or type(self.agent).__name__
|
||||
return AgentCallbackContext(
|
||||
agent_name=agent_name,
|
||||
correlation_id=correlation_id,
|
||||
conversation_id=conversation_id,
|
||||
request_message=request_message,
|
||||
)
|
||||
|
||||
def reset(self, context: df.DurableEntityContext) -> None:
|
||||
"""Reset the entity state (clear conversation history)."""
|
||||
logger.debug("[AgentEntity.reset] Resetting entity state")
|
||||
self.state.reset()
|
||||
logger.debug("[AgentEntity.reset] State reset complete")
|
||||
|
||||
|
||||
def create_agent_entity(
|
||||
agent: AgentProtocol,
|
||||
callback: AgentResponseCallbackProtocol | None = None,
|
||||
):
|
||||
"""Factory function to create an agent entity class.
|
||||
|
||||
Args:
|
||||
agent: The Microsoft Agent Framework agent instance (must implement AgentProtocol)
|
||||
callback: Optional callback invoked during streaming and final responses
|
||||
|
||||
Returns:
|
||||
Entity function configured with the agent
|
||||
"""
|
||||
|
||||
async def _entity_coroutine(context: df.DurableEntityContext) -> None:
|
||||
"""Async handler that executes the entity operations."""
|
||||
try:
|
||||
logger.debug("[entity_function] Entity triggered")
|
||||
logger.debug(f"[entity_function] Operation: {context.operation_name}")
|
||||
|
||||
current_state = context.get_state(lambda: None)
|
||||
logger.debug("Retrieved state: %s", str(current_state)[:100])
|
||||
entity = AgentEntity(agent, callback)
|
||||
|
||||
if current_state is not None:
|
||||
entity.state.restore_state(current_state)
|
||||
logger.debug(
|
||||
"[entity_function] Restored entity from state (message_count: %s)", entity.state.message_count
|
||||
)
|
||||
else:
|
||||
logger.debug("[entity_function] Created new entity instance")
|
||||
|
||||
operation = context.operation_name
|
||||
|
||||
if operation == "run_agent":
|
||||
input_data: Any = context.get_input()
|
||||
|
||||
# Support both old format (message + conversation_id) and new format (RunRequest dict)
|
||||
# This provides backward compatibility
|
||||
request: str | dict[str, Any]
|
||||
if isinstance(input_data, dict) and "message" in input_data:
|
||||
# Input can be either old format or new RunRequest format
|
||||
request = cast(dict[str, Any], input_data)
|
||||
else:
|
||||
# Fall back to treating input as message string
|
||||
request = "" if input_data is None else str(cast(object, input_data))
|
||||
|
||||
result = await entity.run_agent(context, request)
|
||||
context.set_result(result)
|
||||
|
||||
elif operation == "reset":
|
||||
entity.reset(context)
|
||||
context.set_result({"status": "reset"})
|
||||
|
||||
else:
|
||||
logger.error("[entity_function] Unknown operation: %s", operation)
|
||||
context.set_result({"error": f"Unknown operation: {operation}"})
|
||||
|
||||
context.set_state(entity.state.to_dict())
|
||||
logger.debug(f"[entity_function] Operation {operation} completed successfully")
|
||||
|
||||
except Exception as exc:
|
||||
import traceback
|
||||
|
||||
logger.error("[entity_function] Error in entity: %s", exc)
|
||||
logger.error(f"[entity_function] Traceback:\n{traceback.format_exc()}")
|
||||
context.set_result({"error": str(exc), "status": "error"})
|
||||
|
||||
def entity_function(context: df.DurableEntityContext) -> None:
|
||||
"""Synchronous wrapper invoked by the Durable Functions runtime."""
|
||||
try:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
if loop.is_running():
|
||||
temp_loop = asyncio.new_event_loop()
|
||||
try:
|
||||
temp_loop.run_until_complete(_entity_coroutine(context))
|
||||
finally:
|
||||
temp_loop.close()
|
||||
else:
|
||||
loop.run_until_complete(_entity_coroutine(context))
|
||||
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error("[entity_function] Unexpected error executing entity: %s", exc, exc_info=True)
|
||||
context.set_result({"error": str(exc), "status": "error"})
|
||||
|
||||
return entity_function
|
||||
@@ -0,0 +1,13 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Custom exception types for the durable agent framework."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class IncomingRequestError(ValueError):
|
||||
"""Raised when an incoming HTTP request cannot be parsed or validated."""
|
||||
|
||||
def __init__(self, message: str, status_code: int = 400) -> None:
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
@@ -0,0 +1,381 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Data models for Durable Agent Framework.
|
||||
|
||||
This module defines the request and response models used by the framework.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import uuid
|
||||
from collections.abc import MutableMapping
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from importlib import import_module
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import azure.durable_functions as df
|
||||
from agent_framework import AgentThread
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover - type checking imports only
|
||||
from pydantic import BaseModel
|
||||
|
||||
_PydanticBaseModel: type["BaseModel"] | None
|
||||
try:
|
||||
from pydantic import BaseModel as _RuntimeBaseModel
|
||||
except ImportError: # pragma: no cover - optional dependency
|
||||
_PydanticBaseModel = None
|
||||
else:
|
||||
_PydanticBaseModel = _RuntimeBaseModel
|
||||
|
||||
|
||||
class ChatRole(str, Enum):
|
||||
"""Chat message role enum."""
|
||||
|
||||
USER = "user"
|
||||
SYSTEM = "system"
|
||||
ASSISTANT = "assistant"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentSessionId:
|
||||
"""Represents an agent session ID, which is used to identify a long-running agent session.
|
||||
|
||||
Attributes:
|
||||
name: The name of the agent that owns the session (case-insensitive)
|
||||
key: The unique key of the agent session (case-sensitive)
|
||||
"""
|
||||
|
||||
name: str
|
||||
key: str
|
||||
|
||||
ENTITY_NAME_PREFIX: str = "dafx-"
|
||||
|
||||
@staticmethod
|
||||
def to_entity_name(name: str) -> str:
|
||||
"""Converts an agent name to an entity name by adding the DAFx prefix.
|
||||
|
||||
Args:
|
||||
name: The agent name
|
||||
|
||||
Returns:
|
||||
The entity name with the dafx- prefix
|
||||
"""
|
||||
return f"{AgentSessionId.ENTITY_NAME_PREFIX}{name}"
|
||||
|
||||
@staticmethod
|
||||
def with_random_key(name: str) -> "AgentSessionId":
|
||||
"""Creates a new AgentSessionId with the specified name and a randomly generated key.
|
||||
|
||||
Args:
|
||||
name: The name of the agent that owns the session
|
||||
|
||||
Returns:
|
||||
A new AgentSessionId with the specified name and a random GUID key
|
||||
"""
|
||||
return AgentSessionId(name=name, key=uuid.uuid4().hex)
|
||||
|
||||
def to_entity_id(self) -> df.EntityId:
|
||||
"""Converts this AgentSessionId to a Durable Functions EntityId.
|
||||
|
||||
Returns:
|
||||
EntityId for use with Durable Functions APIs
|
||||
"""
|
||||
return df.EntityId(self.to_entity_name(self.name), self.key)
|
||||
|
||||
@staticmethod
|
||||
def from_entity_id(entity_id: df.EntityId) -> "AgentSessionId":
|
||||
"""Creates an AgentSessionId from a Durable Functions EntityId.
|
||||
|
||||
Args:
|
||||
entity_id: The EntityId to convert
|
||||
|
||||
Returns:
|
||||
AgentSessionId instance
|
||||
|
||||
Raises:
|
||||
ValueError: If the entity ID does not have the expected prefix
|
||||
"""
|
||||
if not entity_id.name.startswith(AgentSessionId.ENTITY_NAME_PREFIX):
|
||||
raise ValueError(
|
||||
f"'{entity_id}' is not a valid agent session ID. "
|
||||
f"Expected entity name to start with '{AgentSessionId.ENTITY_NAME_PREFIX}'"
|
||||
)
|
||||
|
||||
agent_name = entity_id.name[len(AgentSessionId.ENTITY_NAME_PREFIX) :]
|
||||
return AgentSessionId(name=agent_name, key=entity_id.key)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Returns a string representation in the form @name@key."""
|
||||
return f"@{self.name}@{self.key}"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Returns a detailed string representation."""
|
||||
return f"AgentSessionId(name='{self.name}', key='{self.key}')"
|
||||
|
||||
@staticmethod
|
||||
def parse(session_id_string: str) -> "AgentSessionId":
|
||||
"""Parses a string representation of an agent session ID.
|
||||
|
||||
Args:
|
||||
session_id_string: A string in the form @name@key
|
||||
|
||||
Returns:
|
||||
AgentSessionId instance
|
||||
|
||||
Raises:
|
||||
ValueError: If the string format is invalid
|
||||
"""
|
||||
if not session_id_string.startswith("@"):
|
||||
raise ValueError(f"Invalid agent session ID format: {session_id_string}")
|
||||
|
||||
parts = session_id_string[1:].split("@", 1)
|
||||
if len(parts) != 2:
|
||||
raise ValueError(f"Invalid agent session ID format: {session_id_string}")
|
||||
|
||||
return AgentSessionId(name=parts[0], key=parts[1])
|
||||
|
||||
|
||||
class DurableAgentThread(AgentThread):
|
||||
"""Durable agent thread that tracks the owning :class:`AgentSessionId`."""
|
||||
|
||||
_SERIALIZED_SESSION_ID_KEY = "durable_session_id"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
session_id: AgentSessionId | None = None,
|
||||
service_thread_id: str | None = None,
|
||||
message_store: Any = None,
|
||||
context_provider: Any = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
service_thread_id=service_thread_id,
|
||||
message_store=message_store,
|
||||
context_provider=context_provider,
|
||||
)
|
||||
self._session_id: AgentSessionId | None = session_id
|
||||
|
||||
@property
|
||||
def session_id(self) -> AgentSessionId | None:
|
||||
"""Returns the durable agent session identifier for this thread."""
|
||||
return self._session_id
|
||||
|
||||
def attach_session(self, session_id: AgentSessionId) -> None:
|
||||
"""Associates the thread with the provided :class:`AgentSessionId`."""
|
||||
self._session_id = session_id
|
||||
|
||||
@classmethod
|
||||
def from_session_id(
|
||||
cls,
|
||||
session_id: AgentSessionId,
|
||||
*,
|
||||
service_thread_id: str | None = None,
|
||||
message_store: Any = None,
|
||||
context_provider: Any = None,
|
||||
) -> "DurableAgentThread":
|
||||
"""Creates a durable thread pre-associated with the supplied session ID."""
|
||||
return cls(
|
||||
session_id=session_id,
|
||||
service_thread_id=service_thread_id,
|
||||
message_store=message_store,
|
||||
context_provider=context_provider,
|
||||
)
|
||||
|
||||
async def serialize(self, **kwargs: Any) -> dict[str, Any]:
|
||||
"""Serializes thread state including the durable session identifier."""
|
||||
state = await super().serialize(**kwargs)
|
||||
if self._session_id is not None:
|
||||
state[self._SERIALIZED_SESSION_ID_KEY] = str(self._session_id)
|
||||
return state
|
||||
|
||||
@classmethod
|
||||
async def deserialize(
|
||||
cls,
|
||||
serialized_thread_state: MutableMapping[str, Any],
|
||||
*,
|
||||
message_store: Any = None,
|
||||
**kwargs: Any,
|
||||
) -> "DurableAgentThread":
|
||||
"""Restores a durable thread, rehydrating the stored session identifier."""
|
||||
session_id_value = serialized_thread_state.get(cls._SERIALIZED_SESSION_ID_KEY)
|
||||
thread = await super().deserialize(
|
||||
serialized_thread_state,
|
||||
message_store=message_store,
|
||||
**kwargs,
|
||||
)
|
||||
if not isinstance(thread, DurableAgentThread):
|
||||
raise TypeError("Deserialized thread is not a DurableAgentThread instance")
|
||||
|
||||
if session_id_value is None:
|
||||
return thread
|
||||
|
||||
if not isinstance(session_id_value, str):
|
||||
raise ValueError("durable_session_id must be a string when present in serialized state")
|
||||
|
||||
thread.attach_session(AgentSessionId.parse(session_id_value))
|
||||
return thread
|
||||
|
||||
|
||||
def _serialize_response_format(response_format: type["BaseModel"] | None) -> Any:
|
||||
"""Serialize response format for transport across durable function boundaries."""
|
||||
if response_format is None:
|
||||
return None
|
||||
|
||||
if _PydanticBaseModel is None:
|
||||
raise RuntimeError("pydantic is required to use structured response formats")
|
||||
|
||||
if not inspect.isclass(response_format) or not issubclass(response_format, _PydanticBaseModel):
|
||||
raise TypeError("response_format must be a Pydantic BaseModel type")
|
||||
|
||||
return {
|
||||
"__response_schema_type__": "pydantic_model",
|
||||
"module": response_format.__module__,
|
||||
"qualname": response_format.__qualname__,
|
||||
}
|
||||
|
||||
|
||||
def _deserialize_response_format(response_format: Any) -> type["BaseModel"] | None:
|
||||
"""Deserialize response format back into actionable type if possible."""
|
||||
if response_format is None:
|
||||
return None
|
||||
|
||||
if (
|
||||
_PydanticBaseModel is not None
|
||||
and inspect.isclass(response_format)
|
||||
and issubclass(response_format, _PydanticBaseModel)
|
||||
):
|
||||
return response_format
|
||||
|
||||
if not isinstance(response_format, dict):
|
||||
return None
|
||||
|
||||
response_dict = cast(dict[str, Any], response_format)
|
||||
|
||||
if response_dict.get("__response_schema_type__") != "pydantic_model":
|
||||
return None
|
||||
|
||||
module_name = response_dict.get("module")
|
||||
qualname = response_dict.get("qualname")
|
||||
if not module_name or not qualname:
|
||||
return None
|
||||
|
||||
try:
|
||||
module = import_module(module_name)
|
||||
except ImportError: # pragma: no cover - user provided module missing
|
||||
return None
|
||||
|
||||
attr: Any = module
|
||||
for part in qualname.split("."):
|
||||
try:
|
||||
attr = getattr(attr, part)
|
||||
except AttributeError: # pragma: no cover - invalid qualname
|
||||
return None
|
||||
|
||||
if _PydanticBaseModel is not None and inspect.isclass(attr) and issubclass(attr, _PydanticBaseModel):
|
||||
return attr
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RunRequest:
|
||||
"""Represents a request to run an agent with a specific message and configuration.
|
||||
|
||||
Attributes:
|
||||
message: The message to send to the agent
|
||||
role: The role of the message sender (user, system, or assistant)
|
||||
response_format: Optional Pydantic BaseModel type describing the structured response format
|
||||
enable_tool_calls: Whether to enable tool calls for this request
|
||||
conversation_id: Optional conversation/session ID for tracking
|
||||
correlation_id: Optional correlation ID for tracking the response to this specific request
|
||||
"""
|
||||
|
||||
message: str
|
||||
role: ChatRole = ChatRole.USER
|
||||
response_format: type["BaseModel"] | None = None
|
||||
enable_tool_calls: bool = True
|
||||
conversation_id: str | None = None
|
||||
correlation_id: str | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary for JSON serialization."""
|
||||
result = {
|
||||
"message": self.message,
|
||||
"enable_tool_calls": self.enable_tool_calls,
|
||||
"role": self.role.value,
|
||||
}
|
||||
if self.response_format:
|
||||
result["response_format"] = _serialize_response_format(self.response_format)
|
||||
if self.conversation_id:
|
||||
result["conversation_id"] = self.conversation_id
|
||||
if self.correlation_id:
|
||||
result["correlation_id"] = self.correlation_id
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "RunRequest":
|
||||
"""Create RunRequest from dictionary."""
|
||||
role_str = data.get("role")
|
||||
if role_str:
|
||||
try:
|
||||
role = ChatRole(role_str.lower())
|
||||
except ValueError:
|
||||
role = ChatRole.USER # Default to USER if invalid
|
||||
else:
|
||||
role = ChatRole.USER
|
||||
|
||||
return cls(
|
||||
message=data.get("message", ""),
|
||||
role=role,
|
||||
response_format=_deserialize_response_format(data.get("response_format")),
|
||||
enable_tool_calls=data.get("enable_tool_calls", True),
|
||||
conversation_id=data.get("conversation_id"),
|
||||
correlation_id=data.get("correlation_id"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentResponse:
|
||||
"""Response from agent execution.
|
||||
|
||||
Attributes:
|
||||
response: The agent's text response (or None for structured responses)
|
||||
message: The original message sent to the agent
|
||||
conversation_id: The conversation/session ID
|
||||
status: Status of the execution (success, error, etc.)
|
||||
message_count: Number of messages in the conversation
|
||||
error: Error message if status is error
|
||||
error_type: Type of error if status is error
|
||||
structured_response: Structured response if response_format was provided
|
||||
"""
|
||||
|
||||
response: str | None
|
||||
message: str
|
||||
conversation_id: str | None
|
||||
status: str
|
||||
message_count: int = 0
|
||||
error: str | None = None
|
||||
error_type: str | None = None
|
||||
structured_response: dict[str, Any] | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary for JSON serialization."""
|
||||
result = {
|
||||
"message": self.message,
|
||||
"conversation_id": self.conversation_id,
|
||||
"status": self.status,
|
||||
"message_count": self.message_count,
|
||||
}
|
||||
|
||||
# Add response or structured_response based on what's available
|
||||
if self.structured_response is not None:
|
||||
result["structured_response"] = self.structured_response
|
||||
elif self.response is not None:
|
||||
result["response"] = self.response
|
||||
|
||||
if self.error:
|
||||
result["error"] = self.error
|
||||
if self.error_type:
|
||||
result["error_type"] = self.error_type
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,235 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Orchestration Support for Durable Agents.
|
||||
|
||||
This module provides support for using agents inside Durable Function orchestrations.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import TYPE_CHECKING, Any, TypeAlias, cast
|
||||
|
||||
from agent_framework import AgentProtocol, AgentRunResponseUpdate, AgentThread, ChatMessage, get_logger
|
||||
|
||||
from ._models import AgentSessionId, DurableAgentThread, RunRequest
|
||||
|
||||
logger = get_logger("agent_framework.azurefunctions.orchestration")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from azure.durable_functions import DurableOrchestrationContext as _DurableOrchestrationContext
|
||||
|
||||
AgentOrchestrationContextType: TypeAlias = _DurableOrchestrationContext
|
||||
else:
|
||||
AgentOrchestrationContextType = Any
|
||||
|
||||
|
||||
class DurableAIAgent(AgentProtocol):
|
||||
"""A durable agent implementation that uses entity methods to interact with agent entities.
|
||||
|
||||
This class implements AgentProtocol and provides methods to work with Azure Durable Functions
|
||||
orchestrations, which use generators and yield instead of async/await.
|
||||
|
||||
Key methods:
|
||||
- get_new_thread(): Create a new conversation thread
|
||||
- run(): Execute the agent and return a Task for yielding in orchestrations
|
||||
|
||||
Note: The run() method is NOT async. It returns a Task directly that must be
|
||||
yielded in orchestrations to wait for the entity call to complete.
|
||||
|
||||
Example usage in orchestration:
|
||||
writer = get_agent(context, "WriterAgent")
|
||||
thread = writer.get_new_thread() # NOT yielded - returns immediately
|
||||
|
||||
response = yield writer.run( # Yielded - waits for entity call
|
||||
message="Write a haiku about coding",
|
||||
thread=thread
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(self, context: AgentOrchestrationContextType, agent_name: str):
|
||||
"""Initialize the DurableAIAgent.
|
||||
|
||||
Args:
|
||||
context: The orchestration context
|
||||
agent_name: Name of the agent (used to construct entity ID)
|
||||
"""
|
||||
self.context = context
|
||||
self.agent_name = agent_name
|
||||
self._id = str(uuid.uuid4())
|
||||
self._name = agent_name
|
||||
self._display_name = agent_name
|
||||
self._description = f"Durable agent proxy for {agent_name}"
|
||||
logger.debug(f"[DurableAIAgent] Initialized for agent: {agent_name}")
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
"""Get the unique identifier for this agent."""
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def name(self) -> str | None:
|
||||
"""Get the name of the agent."""
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
"""Get the display name of the agent."""
|
||||
return self._display_name
|
||||
|
||||
@property
|
||||
def description(self) -> str | None:
|
||||
"""Get the description of the agent."""
|
||||
return self._description
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
|
||||
*,
|
||||
thread: AgentThread | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any: # TODO(msft-team): Add a wrapper to respond correctly with `AgentRunResponse`
|
||||
"""Execute the agent with messages and return a Task for orchestrations.
|
||||
|
||||
This method implements AgentProtocol and returns a Task that can be yielded
|
||||
in Durable Functions orchestrations.
|
||||
|
||||
Args:
|
||||
messages: The message(s) to send to the agent
|
||||
thread: Optional agent thread for conversation context
|
||||
**kwargs: Additional arguments (enable_tool_calls, response_format, etc.)
|
||||
|
||||
Returns:
|
||||
Task that will resolve to the agent response
|
||||
|
||||
Example:
|
||||
@app.orchestration_trigger(context_name="context")
|
||||
def my_orchestration(context):
|
||||
agent = get_agent(context, "MyAgent")
|
||||
thread = agent.get_new_thread()
|
||||
result = yield agent.run("Hello", thread=thread)
|
||||
"""
|
||||
message_str = self._normalize_messages(messages)
|
||||
|
||||
# Extract optional parameters from kwargs
|
||||
enable_tool_calls = kwargs.get("enable_tool_calls", True)
|
||||
response_format = kwargs.get("response_format")
|
||||
|
||||
# Get the session ID for the entity
|
||||
if isinstance(thread, DurableAgentThread) and thread.session_id is not None:
|
||||
session_id = thread.session_id
|
||||
else:
|
||||
# Create a unique session ID for each call when no thread is provided
|
||||
# This ensures each call gets its own conversation context
|
||||
session_key = str(self.context.new_uuid())
|
||||
session_id = AgentSessionId(name=self.agent_name, key=session_key)
|
||||
logger.warning(f"[DurableAIAgent] No thread provided, created unique session_id: {session_id}")
|
||||
|
||||
# Create entity ID from session ID
|
||||
entity_id = session_id.to_entity_id()
|
||||
|
||||
# Generate a deterministic correlation ID for this call
|
||||
# This is required by the entity and must be unique per call
|
||||
correlation_id = str(self.context.new_uuid())
|
||||
|
||||
# Prepare the request using RunRequest model
|
||||
run_request = RunRequest(
|
||||
message=message_str,
|
||||
enable_tool_calls=enable_tool_calls,
|
||||
correlation_id=correlation_id,
|
||||
conversation_id=session_id.key,
|
||||
response_format=response_format,
|
||||
)
|
||||
|
||||
logger.debug(f"[DurableAIAgent] Calling entity {entity_id} with message: {message_str[:100]}...")
|
||||
|
||||
# Call the entity and return the Task directly
|
||||
# The orchestration will yield this Task
|
||||
return self.context.call_entity(entity_id, "run_agent", run_request.to_dict())
|
||||
|
||||
def run_stream(
|
||||
self,
|
||||
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
|
||||
*,
|
||||
thread: AgentThread | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[AgentRunResponseUpdate]:
|
||||
"""Run the agent with streaming (not supported for durable agents).
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Streaming is not supported for durable agents.
|
||||
"""
|
||||
raise NotImplementedError("Streaming is not supported for durable agents in orchestrations.")
|
||||
|
||||
def get_new_thread(self, **kwargs: Any) -> AgentThread:
|
||||
"""Create a new agent thread for this orchestration instance.
|
||||
|
||||
Each call creates a unique thread with its own conversation context.
|
||||
The session ID is deterministic (uses context.new_uuid()) to ensure
|
||||
orchestration replay works correctly.
|
||||
|
||||
Returns:
|
||||
A new AgentThread instance with a unique session ID
|
||||
"""
|
||||
# Generate a deterministic unique key for this thread
|
||||
# Using context.new_uuid() ensures the same GUID is generated during replay
|
||||
session_key = str(self.context.new_uuid())
|
||||
|
||||
# Create AgentSessionId with agent name and session key
|
||||
session_id = AgentSessionId(name=self.agent_name, key=session_key)
|
||||
|
||||
thread = DurableAgentThread.from_session_id(session_id, **kwargs)
|
||||
|
||||
logger.debug(f"[DurableAIAgent] Created new thread with session_id: {session_id}")
|
||||
return thread
|
||||
|
||||
def _messages_to_string(self, messages: list[ChatMessage]) -> str:
|
||||
"""Convert a list of ChatMessage objects to a single string.
|
||||
|
||||
Args:
|
||||
messages: List of ChatMessage objects
|
||||
|
||||
Returns:
|
||||
Concatenated string of message contents
|
||||
"""
|
||||
return "\n".join([msg.text or "" for msg in messages])
|
||||
|
||||
def _normalize_messages(self, messages: str | ChatMessage | list[str] | list[ChatMessage] | None) -> str:
|
||||
"""Convert supported message inputs to a single string."""
|
||||
if messages is None:
|
||||
return ""
|
||||
if isinstance(messages, str):
|
||||
return messages
|
||||
if isinstance(messages, ChatMessage):
|
||||
return messages.text or ""
|
||||
if isinstance(messages, list):
|
||||
if not messages:
|
||||
return ""
|
||||
first_item = messages[0]
|
||||
if isinstance(first_item, str):
|
||||
return "\n".join(cast(list[str], messages))
|
||||
return self._messages_to_string(cast(list[ChatMessage], messages))
|
||||
return str(messages)
|
||||
|
||||
|
||||
def get_agent(context: AgentOrchestrationContextType, agent_name: str) -> DurableAIAgent:
|
||||
"""Return a :class:`DurableAIAgent` proxy scoped to ``agent_name``.
|
||||
|
||||
Usage::
|
||||
|
||||
from agent_framework.azurefunctions import get_agent
|
||||
|
||||
|
||||
@app.orchestration_trigger(context_name="context")
|
||||
def my_orchestration(context: DurableOrchestrationContext):
|
||||
writer = get_agent(context, "WriterAgent")
|
||||
thread = writer.get_new_thread()
|
||||
response = yield writer.run("Write a haiku", thread=thread)
|
||||
|
||||
Args:
|
||||
context: The orchestration context provided by Durable Functions.
|
||||
agent_name: Name of the durable agent entity to call.
|
||||
|
||||
Returns:
|
||||
DurableAIAgent wrapper for the specified agent.
|
||||
"""
|
||||
return DurableAIAgent(context, agent_name)
|
||||
@@ -0,0 +1,179 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Agent State Management.
|
||||
|
||||
This module defines the AgentState class for managing conversation state and
|
||||
serializing agent framework responses.
|
||||
"""
|
||||
|
||||
from collections.abc import MutableMapping
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from agent_framework import AgentRunResponse, ChatMessage, get_logger
|
||||
|
||||
logger = get_logger("agent_framework.azurefunctions.state")
|
||||
|
||||
|
||||
class AgentState:
|
||||
"""Manages agent conversation state using agent_framework types (ChatMessage, AgentRunResponse).
|
||||
|
||||
This class handles:
|
||||
- Conversation history tracking using ChatMessage objects
|
||||
- Agent response storage using AgentRunResponse objects with correlation IDs
|
||||
- State persistence and restoration
|
||||
- Message counting
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize empty agent state."""
|
||||
self.conversation_history: list[ChatMessage] = []
|
||||
self.last_response: str | None = None
|
||||
self.message_count: int = 0
|
||||
|
||||
def _current_timestamp(self) -> str:
|
||||
"""Return an ISO 8601 UTC timestamp."""
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
def add_user_message(
|
||||
self,
|
||||
content: str,
|
||||
role: Literal["user", "system", "assistant", "tool"] = "user",
|
||||
correlation_id: str | None = None,
|
||||
) -> None:
|
||||
"""Add a user message to the conversation history as a ChatMessage object.
|
||||
|
||||
Args:
|
||||
content: The message content
|
||||
role: The message role (user, system, etc.)
|
||||
correlation_id: Optional correlation identifier associated with the user message
|
||||
"""
|
||||
self.message_count += 1
|
||||
timestamp = self._current_timestamp()
|
||||
additional_props: MutableMapping[str, Any] = {"timestamp": timestamp}
|
||||
if correlation_id is not None:
|
||||
additional_props["correlation_id"] = correlation_id
|
||||
chat_message = ChatMessage(role=role, text=content, additional_properties=additional_props)
|
||||
self.conversation_history.append(chat_message)
|
||||
logger.debug(f"Added {role} ChatMessage to history (message #{self.message_count})")
|
||||
|
||||
def add_assistant_message(
|
||||
self, content: str, agent_response: AgentRunResponse, correlation_id: str | None = None
|
||||
) -> None:
|
||||
"""Add an assistant message to the conversation history with full agent response.
|
||||
|
||||
Args:
|
||||
content: The text content of the response
|
||||
agent_response: The AgentRunResponse object from the agent framework
|
||||
correlation_id: Optional correlation ID for tracking this response
|
||||
"""
|
||||
self.last_response = content
|
||||
timestamp = self._current_timestamp()
|
||||
serialized_response = self.serialize_response(agent_response)
|
||||
|
||||
# Create a ChatMessage for the assistant response
|
||||
# The agent_response already contains messages, but we store it as a custom ChatMessage
|
||||
# with the agent_response stored in additional_properties for full metadata preservation
|
||||
additional_props: dict[str, Any] = {
|
||||
"agent_response": serialized_response,
|
||||
"correlation_id": correlation_id,
|
||||
"timestamp": timestamp,
|
||||
"message_count": self.message_count,
|
||||
}
|
||||
chat_message = ChatMessage(role="assistant", text=content, additional_properties=additional_props)
|
||||
|
||||
self.conversation_history.append(chat_message)
|
||||
|
||||
logger.debug(
|
||||
f"Added assistant ChatMessage to history with AgentRunResponse metadata (correlation_id: {correlation_id})"
|
||||
)
|
||||
|
||||
def get_chat_messages(self) -> list[ChatMessage]:
|
||||
"""Return a copy of the full conversation history."""
|
||||
return list(self.conversation_history)
|
||||
|
||||
def try_get_agent_response(self, correlation_id: str) -> dict[str, Any] | None:
|
||||
"""Get an agent response by correlation ID.
|
||||
|
||||
Args:
|
||||
correlation_id: The correlation ID to look up
|
||||
|
||||
Returns:
|
||||
The agent response data if found, None otherwise
|
||||
"""
|
||||
for message in reversed(self.conversation_history):
|
||||
metadata = getattr(message, "additional_properties", {}) or {}
|
||||
if metadata.get("correlation_id") == correlation_id:
|
||||
return self._build_agent_response_payload(message, metadata)
|
||||
|
||||
return None
|
||||
|
||||
def serialize_response(self, response: AgentRunResponse) -> dict[str, Any]:
|
||||
"""Serialize an ``AgentRunResponse`` to a dictionary.
|
||||
|
||||
Args:
|
||||
response: The agent framework response object
|
||||
|
||||
Returns:
|
||||
Dictionary containing all response fields
|
||||
"""
|
||||
try:
|
||||
return response.to_dict()
|
||||
except Exception as exc: # pragma: no cover - defensive logging path
|
||||
logger.warning(f"Error serializing response: {exc}")
|
||||
return {"response": str(response), "serialization_error": str(exc)}
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Get the current state as a dictionary for persistence.
|
||||
|
||||
Returns:
|
||||
Dictionary containing conversation_history (as serialized ChatMessages),
|
||||
last_response, and message_count
|
||||
"""
|
||||
return {
|
||||
"conversation_history": [msg.to_dict() for msg in self.conversation_history],
|
||||
"last_response": self.last_response,
|
||||
"message_count": self.message_count,
|
||||
}
|
||||
|
||||
def restore_state(self, state: dict[str, Any]) -> None:
|
||||
"""Restore state from a dictionary, reconstructing ChatMessage objects.
|
||||
|
||||
Args:
|
||||
state: Dictionary containing conversation_history, last_response, and message_count
|
||||
"""
|
||||
# Restore conversation history as ChatMessage objects
|
||||
history_data = state.get("conversation_history", [])
|
||||
restored_history: list[ChatMessage] = []
|
||||
for raw_message in history_data:
|
||||
if isinstance(raw_message, dict):
|
||||
restored_history.append(ChatMessage.from_dict(cast(dict[str, Any], raw_message)))
|
||||
else:
|
||||
restored_history.append(cast(ChatMessage, raw_message))
|
||||
|
||||
self.conversation_history = restored_history
|
||||
|
||||
self.last_response = state.get("last_response")
|
||||
self.message_count = state.get("message_count", 0)
|
||||
logger.debug("Restored state: %s ChatMessages in history", len(self.conversation_history))
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the state to empty."""
|
||||
self.conversation_history = []
|
||||
self.last_response = None
|
||||
self.message_count = 0
|
||||
logger.debug("State reset to empty")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation of the state."""
|
||||
return f"AgentState(messages={self.message_count}, history_length={len(self.conversation_history)})"
|
||||
|
||||
def _build_agent_response_payload(self, message: ChatMessage, metadata: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Construct the agent response payload returned to callers."""
|
||||
return {
|
||||
"content": message.text,
|
||||
"agent_response": metadata.get("agent_response"),
|
||||
"message_count": metadata.get("message_count", self.message_count),
|
||||
"timestamp": metadata.get("timestamp"),
|
||||
"correlation_id": metadata.get("correlation_id"),
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import importlib
|
||||
from typing import Any
|
||||
|
||||
PACKAGE_NAME = "agent_framework_azurefunctions"
|
||||
PACKAGE_EXTRA = "azurefunctions"
|
||||
_IMPORTS = [
|
||||
"AgentFunctionApp",
|
||||
"DurableAIAgent",
|
||||
"get_agent",
|
||||
"AgentCallbackContext",
|
||||
"AgentResponseCallbackProtocol",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name in _IMPORTS:
|
||||
try:
|
||||
return getattr(importlib.import_module(PACKAGE_NAME), name)
|
||||
except ModuleNotFoundError as exc:
|
||||
raise ModuleNotFoundError(
|
||||
f"The '{PACKAGE_EXTRA}' extra is not installed, please do `pip install agent-framework-{PACKAGE_EXTRA}`"
|
||||
) from exc
|
||||
raise AttributeError(f"Module {PACKAGE_NAME} has no attribute {name}.")
|
||||
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return _IMPORTS
|
||||
@@ -26,6 +26,7 @@ dependencies = [
|
||||
"agent-framework-a2a",
|
||||
"agent-framework-anthropic",
|
||||
"agent-framework-azure-ai",
|
||||
"agent-framework-azurefunctions",
|
||||
"agent-framework-copilotstudio",
|
||||
"agent-framework-devui",
|
||||
"agent-framework-lab",
|
||||
@@ -89,6 +90,7 @@ agent-framework = { workspace = true }
|
||||
agent-framework-core = { workspace = true }
|
||||
agent-framework-a2a = { workspace = true }
|
||||
agent-framework-azure-ai = { workspace = true }
|
||||
agent-framework-azurefunctions = { workspace = true }
|
||||
agent-framework-copilotstudio = { workspace = true }
|
||||
agent-framework-lab = { workspace = true }
|
||||
agent-framework-mem0 = { workspace = true }
|
||||
@@ -240,6 +242,7 @@ pytest --import-mode=importlib
|
||||
--cov=agent_framework
|
||||
--cov=agent_framework_a2a
|
||||
--cov=agent_framework_azure_ai
|
||||
--cov=agent_framework_azurefunctions
|
||||
--cov=agent_framework_copilotstudio
|
||||
--cov=agent_framework_mem0
|
||||
--cov=agent_framework_redis
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
# Single Agent Sample (Python)
|
||||
|
||||
This sample demonstrates how to use the Durable Extension for Agent Framework to create a simple Azure Functions app that hosts a single AI agent and provides direct HTTP API access for interactive conversations.
|
||||
|
||||
## Key Concepts Demonstrated
|
||||
|
||||
- Defining a simple agent with the Microsoft Agent Framework and wiring it into
|
||||
an Azure Functions app via the Durable Extension for Agent Framework.
|
||||
- Calling the agent through generated HTTP endpoints (`/api/agents/Joker/run`).
|
||||
- Managing conversation state with session identifiers, so multiple clients can
|
||||
interact with the agent concurrently without sharing context.
|
||||
|
||||
## Environment Setup
|
||||
|
||||
### 1. Create and activate a virtual environment
|
||||
|
||||
**Windows (PowerShell):**
|
||||
```powershell
|
||||
python -m venv .venv
|
||||
.venv\Scripts\Activate.ps1
|
||||
```
|
||||
|
||||
**Linux/macOS:**
|
||||
```bash
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
### 2. Install dependencies
|
||||
|
||||
- Azure Functions Core Tools 4.x – install from the official docs so you can run `func start` locally.
|
||||
- Azurite storage emulator – the sample uses `AzureWebJobsStorage=UseDevelopmentStorage=true`; start Azurite before launching the app.
|
||||
- Durable Task local backend – `DURABLE_TASK_SCHEDULER_CONNECTION_STRING` expects the Durable Task scheduler listening on `http://localhost:8080` (start the Durable Functions emulator if it is not already running).
|
||||
- Python dependencies – from this folder, run `pip install -r requirements.txt` (or the equivalent in your active virtual environment).
|
||||
- Environment variables – update `AZURE_OPENAI_ENDPOINT` and `AZURE_OPENAI_CHAT_DEPLOYMENT_NAME` in `local.settings.json` with your Azure OpenAI resource details; keep the other values as provided unless you are using custom infrastructure.
|
||||
|
||||
## Running the Sample
|
||||
|
||||
With the environment configured and the Functions host running, you can interact
|
||||
with the Joker agent using the provided `demo.http` file or any HTTP client. For
|
||||
example:
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:7071/api/agents/Joker/run \
|
||||
-H "Content-Type: text/plain" \
|
||||
-d "Tell me a short joke about cloud computing."
|
||||
```
|
||||
|
||||
The agent responds with a JSON payload that includes the generated joke.
|
||||
@@ -0,0 +1,26 @@
|
||||
### Joker Agent Sample Interactions
|
||||
@baseUrl = http://localhost:7071
|
||||
@agentName = Joker
|
||||
@agentRoute = {{baseUrl}}/api/agents/{{agentName}}
|
||||
@healthRoute = {{baseUrl}}/api/health
|
||||
|
||||
### Health Check
|
||||
GET {{healthRoute}}
|
||||
|
||||
### Ask for a joke (JSON payload)
|
||||
POST {{agentRoute}}/run
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"message": "Add a security element to it.",
|
||||
"sessionId": "session-003",
|
||||
"waitForCompletion": true
|
||||
}
|
||||
|
||||
### Ask for a joke (plain text payload)
|
||||
POST {{agentRoute}}/run
|
||||
|
||||
Give me a programming joke about race conditions.
|
||||
|
||||
### Retrieve conversation state
|
||||
GET {{agentRoute}}/session-001
|
||||
@@ -0,0 +1,61 @@
|
||||
"""Azure Functions single-agent sample showcasing how to host a single Azure OpenAI agent.
|
||||
|
||||
The sample reads the required endpoint and deployment environment variables, configures the Azure OpenAI chat client (using either an API key or Azure CLI credentials), and registers a joke-telling agent with an Azure Functions app that can optionally expose a health check.
|
||||
|
||||
Summary: Demonstrates configuring and deploying a single 'Joker' agent via Azure Functions."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from azure.identity import AzureCliCredential
|
||||
from agent_framework.azure import AzureOpenAIChatClient
|
||||
from agent_framework.azurefunctions import AgentFunctionApp
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
AZURE_OPENAI_ENDPOINT_ENV = "AZURE_OPENAI_ENDPOINT"
|
||||
AZURE_OPENAI_DEPLOYMENT_ENV = "AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"
|
||||
AZURE_OPENAI_API_KEY_ENV = "AZURE_OPENAI_API_KEY"
|
||||
|
||||
|
||||
def _build_client_kwargs() -> dict[str, Any]:
|
||||
"""Construct Azure OpenAI client options."""
|
||||
|
||||
endpoint = os.getenv(AZURE_OPENAI_ENDPOINT_ENV)
|
||||
if not endpoint:
|
||||
raise RuntimeError(f"{AZURE_OPENAI_ENDPOINT_ENV} environment variable is required.")
|
||||
|
||||
deployment = os.getenv(AZURE_OPENAI_DEPLOYMENT_ENV)
|
||||
if not deployment:
|
||||
raise RuntimeError(f"{AZURE_OPENAI_DEPLOYMENT_ENV} environment variable is required.")
|
||||
|
||||
logger.info("[SingleAgent] Using deployment '%s' at '%s'", deployment, endpoint)
|
||||
|
||||
client_kwargs: dict[str, Any] = {
|
||||
"endpoint": endpoint,
|
||||
"deployment_name": deployment,
|
||||
}
|
||||
|
||||
api_key = os.getenv(AZURE_OPENAI_API_KEY_ENV)
|
||||
if api_key:
|
||||
client_kwargs["api_key"] = api_key
|
||||
else:
|
||||
client_kwargs["credential"] = AzureCliCredential()
|
||||
|
||||
return client_kwargs
|
||||
|
||||
|
||||
def _create_agent() -> Any:
|
||||
"""Create the Joker agent."""
|
||||
|
||||
client_kwargs = _build_client_kwargs()
|
||||
return AzureOpenAIChatClient(**client_kwargs).create_agent(
|
||||
name="Joker",
|
||||
instructions="You are good at telling jokes.",
|
||||
)
|
||||
|
||||
|
||||
app = AgentFunctionApp(agents=[_create_agent()], enable_health_check=True)
|
||||
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"version": "2.0",
|
||||
"extensionBundle": {
|
||||
"id": "Microsoft.Azure.Functions.ExtensionBundle",
|
||||
"version": "[4.*, 5.0.0)"
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"IsEncrypted": false,
|
||||
"Values": {
|
||||
"FUNCTIONS_WORKER_RUNTIME": "python",
|
||||
"AzureWebJobsStorage": "UseDevelopmentStorage=true",
|
||||
"DURABLE_TASK_SCHEDULER_CONNECTION_STRING": "Endpoint=http://localhost:8080;TaskHub=default;Authentication=None",
|
||||
"AZURE_OPENAI_ENDPOINT": "<AZURE_OPENAI_ENDPOINT>",
|
||||
"AZURE_OPENAI_CHAT_DEPLOYMENT_NAME": "<AZURE_OPENAI_CHAT_DEPLOYMENT_NAME>"
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,2 @@
|
||||
agent-framework-azurefunctions
|
||||
azure-identity
|
||||
Generated
+3429
-3411
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user