Python: Complete durableagent package (#3058)

* Add worker and clients

* Clean code and refactor common code

* Implement sample

* Add sample

* Update readmes

* Fix tests

* Fix tests

* Update requirements

* Fix typo

* Address comments

* use response.text
This commit is contained in:
Laveesh Rohra
2026-01-07 13:53:21 -08:00
committed by GitHub
Unverified
parent a5b36dc379
commit e3eff65a6b
46 changed files with 4477 additions and 1644 deletions
@@ -2,10 +2,9 @@
import importlib.metadata
from agent_framework_durabletask import AgentCallbackContext, AgentResponseCallbackProtocol
from agent_framework_durabletask import AgentCallbackContext, AgentResponseCallbackProtocol, DurableAIAgent
from ._app import AgentFunctionApp
from ._orchestration import DurableAIAgent
try:
__version__ = importlib.metadata.version(__name__)
@@ -8,6 +8,7 @@ with Azure Durable Entities, enabling stateful and durable AI agent execution.
import json
import re
import uuid
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from datetime import datetime, timezone
@@ -28,14 +29,16 @@ from agent_framework_durabletask import (
WAIT_FOR_RESPONSE_FIELD,
WAIT_FOR_RESPONSE_HEADER,
AgentResponseCallbackProtocol,
AgentSessionId,
ApiResponseFields,
DurableAgentState,
DurableAIAgent,
RunRequest,
)
from ._entities import create_agent_entity
from ._errors import IncomingRequestError
from ._models import AgentSessionId
from ._orchestration import AgentOrchestrationContextType, DurableAIAgent
from ._orchestration import AgentOrchestrationContextType, AgentTask, AzureFunctionsAgentExecutor
logger = get_logger("agent_framework.azurefunctions")
@@ -296,7 +299,7 @@ class AgentFunctionApp(DFAppBase):
self,
context: AgentOrchestrationContextType,
agent_name: str,
) -> DurableAIAgent:
) -> DurableAIAgent[AgentTask]:
"""Return a DurableAIAgent proxy for a registered agent.
Args:
@@ -307,14 +310,15 @@ class AgentFunctionApp(DFAppBase):
ValueError: If the requested agent has not been registered.
Returns:
DurableAIAgent wrapper bound to the orchestration context.
DurableAIAgent[AgentTask] wrapper bound to the orchestration context.
"""
normalized_name = str(agent_name)
if normalized_name not in self._agent_metadata:
raise ValueError(f"Agent '{normalized_name}' is not registered with this app.")
return DurableAIAgent(context, normalized_name)
executor = AzureFunctionsAgentExecutor(context)
return DurableAIAgent(executor, normalized_name)
def _setup_agent_functions(
self,
@@ -377,8 +381,6 @@ class AgentFunctionApp(DFAppBase):
"enable_tool_calls": true|false (optional, default: true)
}
"""
logger.debug(f"[HTTP Trigger] Received request on route: /api/agents/{agent_name}/run")
request_response_format: str = REQUEST_RESPONSE_FORMAT_JSON
thread_id: str | None = None
@@ -387,9 +389,9 @@ class AgentFunctionApp(DFAppBase):
thread_id = self._resolve_thread_id(req=req, req_body=req_body)
wait_for_response = self._should_wait_for_response(req=req, req_body=req_body)
logger.debug(f"[HTTP Trigger] Message: {message}")
logger.debug(f"[HTTP Trigger] Thread ID: {thread_id}")
logger.debug(f"[HTTP Trigger] wait_for_response: {wait_for_response}")
logger.debug(
f"[HTTP Trigger] Message: {message}, Thread ID: {thread_id}, wait_for_response: {wait_for_response}"
)
if not message:
logger.warning("[HTTP Trigger] Request rejected: Missing message")
@@ -403,15 +405,18 @@ class AgentFunctionApp(DFAppBase):
session_id = self._create_session_id(agent_name, thread_id)
correlation_id = self._generate_unique_id()
logger.debug(f"[HTTP Trigger] Using session ID: {session_id}")
logger.debug(f"[HTTP Trigger] Generated correlation ID: {correlation_id}")
logger.debug("[HTTP Trigger] Calling entity to run agent...")
logger.debug(
f"[HTTP Trigger] Calling entity to run agent using session ID: {session_id} "
f"and correlation ID: {correlation_id}"
)
entity_instance_id = session_id.to_entity_id()
entity_instance_id = df.EntityId(
name=session_id.entity_name,
key=session_id.key,
)
run_request = self._build_request_data(
req_body,
message,
thread_id,
correlation_id,
request_response_format,
)
@@ -624,14 +629,16 @@ class AgentFunctionApp(DFAppBase):
session_id = AgentSessionId.with_random_key(agent_name)
# Build entity instance ID
entity_instance_id = session_id.to_entity_id()
entity_instance_id = df.EntityId(
name=session_id.entity_name,
key=session_id.key,
)
# Create run request
correlation_id = self._generate_unique_id()
run_request = self._build_request_data(
req_body={"message": query, "role": "user"},
message=query,
thread_id=str(session_id),
correlation_id=correlation_id,
request_response_format=REQUEST_RESPONSE_FORMAT_TEXT,
)
@@ -783,7 +790,7 @@ class AgentFunctionApp(DFAppBase):
agent_response = state.try_get_agent_response(correlation_id)
if agent_response:
result = self._build_success_result(
response_data=agent_response,
response_message=agent_response.text,
message=message,
thread_id=thread_id,
correlation_id=correlation_id,
@@ -829,23 +836,22 @@ class AgentFunctionApp(DFAppBase):
)
def _build_success_result(
self, response_data: dict[str, Any], message: str, thread_id: str, correlation_id: str, state: DurableAgentState
self, response_message: str, message: str, thread_id: str, correlation_id: str, state: DurableAgentState
) -> dict[str, Any]:
"""Build the success result returned to the HTTP caller."""
return self._build_response_payload(
response=response_data.get("content"),
response=response_message,
message=message,
thread_id=thread_id,
status="success",
correlation_id=correlation_id,
extra_fields={"message_count": response_data.get("message_count", state.message_count)},
extra_fields={ApiResponseFields.MESSAGE_COUNT: state.message_count},
)
def _build_request_data(
self,
req_body: dict[str, Any],
message: str,
thread_id: str,
correlation_id: str,
request_response_format: str,
) -> dict[str, Any]:
@@ -912,15 +918,13 @@ class AgentFunctionApp(DFAppBase):
def _generate_unique_id(self) -> str:
"""Generate a new unique identifier."""
import uuid
return uuid.uuid4().hex
def _create_session_id(self, func_name: str, thread_id: str | None) -> AgentSessionId:
def _create_session_id(self, agent_name: str, thread_id: str | None) -> AgentSessionId:
"""Create a session identifier using the provided thread id or a random value."""
if thread_id:
return AgentSessionId(name=func_name, key=thread_id)
return AgentSessionId.with_random_key(name=func_name)
return AgentSessionId(name=agent_name, key=thread_id)
return AgentSessionId.with_random_key(name=agent_name)
def _resolve_thread_id(self, req: func.HttpRequest, req_body: dict[str, Any]) -> str:
"""Retrieve the thread identifier from request body or query parameters."""
@@ -1,201 +0,0 @@
# Copyright (c) Microsoft. All rights reserved.
"""Azure Functions-specific data models for Durable Agent Framework.
This module contains Azure Functions-specific models:
- AgentSessionId: Entity ID management for Azure Durable Entities
- DurableAgentThread: Thread implementation that tracks AgentSessionId
Common models like RunRequest have been moved to agent-framework-durabletask.
"""
from __future__ import annotations
import uuid
from collections.abc import MutableMapping
from dataclasses import dataclass
from typing import Any
import azure.durable_functions as df
from agent_framework import AgentThread
@dataclass
class AgentSessionId:
"""Represents an agent session ID, which is used to identify a long-running agent session.
Attributes:
name: The name of the agent that owns the session (case-insensitive)
key: The unique key of the agent session (case-sensitive)
"""
name: str
key: str
ENTITY_NAME_PREFIX: str = "dafx-"
@staticmethod
def to_entity_name(name: str) -> str:
"""Converts an agent name to an entity name by adding the DAFx prefix.
Args:
name: The agent name
Returns:
The entity name with the dafx- prefix
"""
return f"{AgentSessionId.ENTITY_NAME_PREFIX}{name}"
@staticmethod
def with_random_key(name: str) -> AgentSessionId:
"""Creates a new AgentSessionId with the specified name and a randomly generated key.
Args:
name: The name of the agent that owns the session
Returns:
A new AgentSessionId with the specified name and a random GUID key
"""
return AgentSessionId(name=name, key=uuid.uuid4().hex)
def to_entity_id(self) -> df.EntityId:
"""Converts this AgentSessionId to a Durable Functions EntityId.
Returns:
EntityId for use with Durable Functions APIs
"""
return df.EntityId(self.to_entity_name(self.name), self.key)
@staticmethod
def from_entity_id(entity_id: df.EntityId) -> AgentSessionId:
"""Creates an AgentSessionId from a Durable Functions EntityId.
Args:
entity_id: The EntityId to convert
Returns:
AgentSessionId instance
Raises:
ValueError: If the entity ID does not have the expected prefix
"""
if not entity_id.name.startswith(AgentSessionId.ENTITY_NAME_PREFIX):
raise ValueError(
f"'{entity_id}' is not a valid agent session ID. "
f"Expected entity name to start with '{AgentSessionId.ENTITY_NAME_PREFIX}'"
)
agent_name = entity_id.name[len(AgentSessionId.ENTITY_NAME_PREFIX) :]
return AgentSessionId(name=agent_name, key=entity_id.key)
def __str__(self) -> str:
"""Returns a string representation in the form @name@key."""
return f"@{self.name}@{self.key}"
def __repr__(self) -> str:
"""Returns a detailed string representation."""
return f"AgentSessionId(name='{self.name}', key='{self.key}')"
@staticmethod
def parse(session_id_string: str) -> AgentSessionId:
"""Parses a string representation of an agent session ID.
Args:
session_id_string: A string in the form @name@key
Returns:
AgentSessionId instance
Raises:
ValueError: If the string format is invalid
"""
if not session_id_string.startswith("@"):
raise ValueError(f"Invalid agent session ID format: {session_id_string}")
parts = session_id_string[1:].split("@", 1)
if len(parts) != 2:
raise ValueError(f"Invalid agent session ID format: {session_id_string}")
return AgentSessionId(name=parts[0], key=parts[1])
class DurableAgentThread(AgentThread):
"""Durable agent thread that tracks the owning :class:`AgentSessionId`."""
_SERIALIZED_SESSION_ID_KEY = "durable_session_id"
def __init__(
self,
*,
session_id: AgentSessionId | None = None,
service_thread_id: str | None = None,
message_store: Any = None,
context_provider: Any = None,
) -> None:
super().__init__(
service_thread_id=service_thread_id,
message_store=message_store,
context_provider=context_provider,
)
self._session_id: AgentSessionId | None = session_id
@property
def session_id(self) -> AgentSessionId | None:
"""Returns the durable agent session identifier for this thread."""
return self._session_id
def attach_session(self, session_id: AgentSessionId) -> None:
"""Associates the thread with the provided :class:`AgentSessionId`."""
self._session_id = session_id
@classmethod
def from_session_id(
cls,
session_id: AgentSessionId,
*,
service_thread_id: str | None = None,
message_store: Any = None,
context_provider: Any = None,
) -> DurableAgentThread:
"""Creates a durable thread pre-associated with the supplied session ID."""
return cls(
session_id=session_id,
service_thread_id=service_thread_id,
message_store=message_store,
context_provider=context_provider,
)
async def serialize(self, **kwargs: Any) -> dict[str, Any]:
"""Serializes thread state including the durable session identifier."""
state = await super().serialize(**kwargs)
if self._session_id is not None:
state[self._SERIALIZED_SESSION_ID_KEY] = str(self._session_id)
return state
@classmethod
async def deserialize(
cls,
serialized_thread_state: MutableMapping[str, Any],
*,
message_store: Any = None,
**kwargs: Any,
) -> DurableAgentThread:
"""Restores a durable thread, rehydrating the stored session identifier."""
state_payload = dict(serialized_thread_state)
session_id_value = state_payload.pop(cls._SERIALIZED_SESSION_ID_KEY, None)
thread = await super().deserialize(
state_payload,
message_store=message_store,
**kwargs,
)
if not isinstance(thread, DurableAgentThread):
raise TypeError("Deserialized thread is not a DurableAgentThread instance")
if session_id_value is None:
return thread
if not isinstance(session_id_value, str):
raise ValueError("durable_session_id must be a string when present in serialized state")
thread.attach_session(AgentSessionId.parse(session_id_value))
return thread
@@ -5,25 +5,21 @@
This module provides support for using agents inside Durable Function orchestrations.
"""
import uuid
from collections.abc import AsyncIterator, Callable
from typing import TYPE_CHECKING, Any, TypeAlias, cast
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, TypeAlias
from agent_framework import (
AgentProtocol,
AgentRunResponse,
AgentRunResponseUpdate,
AgentThread,
ChatMessage,
get_logger,
import azure.durable_functions as df
from agent_framework import AgentThread, get_logger
from agent_framework_durabletask import (
DurableAgentExecutor,
RunRequest,
ensure_response_format,
load_agent_response,
)
from agent_framework_durabletask import RunRequest
from azure.durable_functions.models import TaskBase
from azure.durable_functions.models.Task import CompoundTask, TaskState
from pydantic import BaseModel
from ._models import AgentSessionId, DurableAgentThread
logger = get_logger("agent_framework.azurefunctions.orchestration")
CompoundActionConstructor: TypeAlias = Callable[[list[Any]], Any] | None
@@ -96,10 +92,10 @@ class AgentTask(_TypedCompoundTask):
)
try:
response = self._load_agent_response(raw_result)
response = load_agent_response(raw_result)
if self._response_format is not None:
self._ensure_response_format(
ensure_response_format(
self._response_format,
self._correlation_id,
response,
@@ -119,249 +115,60 @@ class AgentTask(_TypedCompoundTask):
self._first_error = child.result
self.set_value(is_error=True, value=self._first_error)
def _load_agent_response(self, agent_response: AgentRunResponse | dict[str, Any] | None) -> AgentRunResponse:
"""Convert raw payloads into AgentRunResponse instance."""
if agent_response is None:
raise ValueError("agent_response cannot be None")
logger.debug("[load_agent_response] Loading agent response of type: %s", type(agent_response))
class AzureFunctionsAgentExecutor(DurableAgentExecutor[AgentTask]):
"""Executor that executes durable agents inside Azure Functions orchestrations."""
if isinstance(agent_response, AgentRunResponse):
return agent_response
if isinstance(agent_response, dict):
logger.debug("[load_agent_response] Converting dict payload using AgentRunResponse.from_dict")
return AgentRunResponse.from_dict(agent_response)
raise TypeError(f"Unsupported type for agent_response: {type(agent_response)}")
def _ensure_response_format(
self,
response_format: type[BaseModel] | None,
correlation_id: str,
response: AgentRunResponse,
) -> None:
"""Ensure the AgentRunResponse value is parsed into the expected response_format."""
if response_format is not None and not isinstance(response.value, response_format):
response.try_parse_value(response_format)
logger.debug(
"[DurableAIAgent] Loaded AgentRunResponse.value for correlation_id %s with type: %s",
correlation_id,
type(response.value).__name__,
)
class DurableAIAgent(AgentProtocol):
"""A durable agent implementation that uses entity methods to interact with agent entities.
This class implements AgentProtocol and provides methods to work with Azure Durable Functions
orchestrations, which use generators and yield instead of async/await.
Key methods:
- get_new_thread(): Create a new conversation thread
- run(): Execute the agent and return a Task for yielding in orchestrations
Note: The run() method is NOT async. It returns a Task directly that must be
yielded in orchestrations to wait for the entity call to complete.
Example usage in orchestration:
writer = app.get_agent(context, "WriterAgent")
thread = writer.get_new_thread() # NOT yielded - returns immediately
response = yield writer.run( # Yielded - waits for entity call
message="Write a haiku about coding",
thread=thread
)
"""
def __init__(self, context: AgentOrchestrationContextType, agent_name: str):
"""Initialize the DurableAIAgent.
Args:
context: The orchestration context
agent_name: Name of the agent (used to construct entity ID)
"""
def __init__(self, context: AgentOrchestrationContextType):
self.context = context
self.agent_name = agent_name
self._id = str(uuid.uuid4())
self._name = agent_name
self._display_name = agent_name
self._description = f"Durable agent proxy for {agent_name}"
logger.debug("[DurableAIAgent] Initialized for agent: %s", agent_name)
@property
def id(self) -> str:
"""Get the unique identifier for this agent."""
return self._id
def generate_unique_id(self) -> str:
return str(self.context.new_uuid())
@property
def name(self) -> str | None:
"""Get the name of the agent."""
return self._name
@property
def display_name(self) -> str:
"""Get the display name of the agent."""
return self._display_name
@property
def description(self) -> str | None:
"""Get the description of the agent."""
return self._description
# We return an AgentTask here which is a TaskBase subclass.
# This is an intentional deviation from AgentProtocol which defines run() as async.
# The AgentTask can be yielded in Durable Functions orchestrations and will provide
# a typed AgentRunResponse result.
def run( # type: ignore[override]
def get_run_request(
self,
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
*,
thread: AgentThread | None = None,
response_format: type[BaseModel] | None = None,
**kwargs: Any,
) -> AgentTask:
"""Execute the agent with messages and return an AgentTask for orchestrations.
This method implements AgentProtocol and returns an AgentTask (subclass of TaskBase)
that can be yielded in Durable Functions orchestrations. The task's result will be
a typed AgentRunResponse.
Args:
messages: The message(s) to send to the agent
thread: Optional agent thread for conversation context
response_format: Optional Pydantic model for response parsing
**kwargs: Additional arguments (enable_tool_calls)
message: str,
response_format: type[BaseModel] | None,
enable_tool_calls: bool,
) -> RunRequest:
"""Get the current run request from the orchestration context.
Returns:
An AgentTask that resolves to an AgentRunResponse when yielded
Example:
@app.orchestration_trigger(context_name="context")
def my_orchestration(context):
agent = app.get_agent(context, "MyAgent")
thread = agent.get_new_thread()
response = yield agent.run("Hello", thread=thread)
# response is typed as AgentRunResponse
RunRequest: The current run request
"""
message_str = self._normalize_messages(messages)
request = super().get_run_request(
message,
response_format,
enable_tool_calls,
)
request.orchestration_id = self.context.instance_id
return request
# Extract optional parameters from kwargs
enable_tool_calls = kwargs.get("enable_tool_calls", True)
def run_durable_agent(
self,
agent_name: str,
run_request: RunRequest,
thread: AgentThread | None = None,
) -> AgentTask:
# Get the session ID for the entity
if isinstance(thread, DurableAgentThread) and thread.session_id is not None:
session_id = thread.session_id
else:
# Create a unique session ID for each call when no thread is provided
# This ensures each call gets its own conversation context
session_key = str(self.context.new_uuid())
session_id = AgentSessionId(name=self.agent_name, key=session_key)
logger.debug("[DurableAIAgent] No thread provided, created unique session_id: %s", session_id)
# Resolve session
session_id = self._create_session_id(agent_name, thread)
# Create entity ID from session ID
entity_id = session_id.to_entity_id()
entity_id = df.EntityId(
name=session_id.entity_name,
key=session_id.key,
)
# Generate a deterministic correlation ID for this call
# This is required by the entity and must be unique per call
correlation_id = str(self.context.new_uuid())
logger.debug(
"[DurableAIAgent] Using correlation_id: %s for entity_id: %s for session_id: %s",
correlation_id,
"[AzureFunctionsAgentProvider] correlation_id: %s entity_id: %s session_id: %s",
run_request.correlation_id,
entity_id,
session_id,
)
# Prepare the request using RunRequest model
# Include the orchestration's instance_id so it can be stored in the agent's entity state
run_request = RunRequest(
message=message_str,
enable_tool_calls=enable_tool_calls,
correlation_id=correlation_id,
response_format=response_format,
orchestration_id=self.context.instance_id,
created_at=self.context.current_utc_datetime,
)
logger.debug("[DurableAIAgent] Calling entity %s with message: %s", entity_id, message_str[:100])
# Call the entity to get the underlying task
entity_task = self.context.call_entity(entity_id, "run", run_request.to_dict())
# Wrap it in an AgentTask that will convert the result to AgentRunResponse
agent_task = AgentTask(
return AgentTask(
entity_task=entity_task,
response_format=response_format,
correlation_id=correlation_id,
response_format=run_request.response_format,
correlation_id=run_request.correlation_id,
)
logger.debug(
"[DurableAIAgent] Created AgentTask for correlation_id %s",
correlation_id,
)
return agent_task
def run_stream(
self,
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
*,
thread: AgentThread | None = None,
**kwargs: Any,
) -> AsyncIterator[AgentRunResponseUpdate]:
"""Run the agent with streaming (not supported for durable agents).
Raises:
NotImplementedError: Streaming is not supported for durable agents.
"""
raise NotImplementedError("Streaming is not supported for durable agents in orchestrations.")
def get_new_thread(self, **kwargs: Any) -> AgentThread:
"""Create a new agent thread for this orchestration instance.
Each call creates a unique thread with its own conversation context.
The session ID is deterministic (uses context.new_uuid()) to ensure
orchestration replay works correctly.
Returns:
A new AgentThread instance with a unique session ID
"""
# Generate a deterministic unique key for this thread
# Using context.new_uuid() ensures the same GUID is generated during replay
session_key = str(self.context.new_uuid())
# Create AgentSessionId with agent name and session key
session_id = AgentSessionId(name=self.agent_name, key=session_key)
thread = DurableAgentThread.from_session_id(session_id, **kwargs)
logger.debug("[DurableAIAgent] Created new thread with session_id: %s", session_id)
return thread
def _messages_to_string(self, messages: list[ChatMessage]) -> str:
"""Convert a list of ChatMessage objects to a single string.
Args:
messages: List of ChatMessage objects
Returns:
Concatenated string of message contents
"""
return "\n".join([msg.text or "" for msg in messages])
def _normalize_messages(self, messages: str | ChatMessage | list[str] | list[ChatMessage] | None) -> str:
"""Convert supported message inputs to a single string."""
if messages is None:
return ""
if isinstance(messages, str):
return messages
if isinstance(messages, ChatMessage):
return messages.text or ""
if isinstance(messages, list):
if not messages:
return ""
first_item = messages[0]
if isinstance(first_item, str):
return "\n".join(cast(list[str], messages))
return self._messages_to_string(cast(list[ChatMessage], messages))
return str(messages)
@@ -1,402 +0,0 @@
# Copyright (c) Microsoft. All rights reserved.
"""Unit tests for data models (AgentSessionId, RunRequest, AgentResponse)."""
import azure.durable_functions as df
import pytest
from agent_framework import Role
from agent_framework_durabletask import RunRequest
from pydantic import BaseModel
from agent_framework_azurefunctions._models import AgentSessionId
class ModuleStructuredResponse(BaseModel):
value: int
class TestAgentSessionId:
"""Test suite for AgentSessionId."""
def test_init_creates_session_id(self) -> None:
"""Test that AgentSessionId initializes correctly."""
session_id = AgentSessionId(name="AgentEntity", key="test-key-123")
assert session_id.name == "AgentEntity"
assert session_id.key == "test-key-123"
def test_with_random_key_generates_guid(self) -> None:
"""Test that with_random_key generates a GUID."""
session_id = AgentSessionId.with_random_key(name="AgentEntity")
assert session_id.name == "AgentEntity"
assert len(session_id.key) == 32 # UUID hex is 32 chars
# Verify it's a valid hex string
int(session_id.key, 16)
def test_with_random_key_unique_keys(self) -> None:
"""Test that with_random_key generates unique keys."""
session_id1 = AgentSessionId.with_random_key(name="AgentEntity")
session_id2 = AgentSessionId.with_random_key(name="AgentEntity")
assert session_id1.key != session_id2.key
def test_to_entity_id_conversion(self) -> None:
"""Test conversion to EntityId."""
session_id = AgentSessionId(name="AgentEntity", key="test-key")
entity_id = session_id.to_entity_id()
assert isinstance(entity_id, df.EntityId)
assert entity_id.name == "dafx-AgentEntity"
assert entity_id.key == "test-key"
def test_from_entity_id_conversion(self) -> None:
"""Test creation from EntityId."""
entity_id = df.EntityId(name="dafx-AgentEntity", key="test-key")
session_id = AgentSessionId.from_entity_id(entity_id)
assert isinstance(session_id, AgentSessionId)
assert session_id.name == "AgentEntity"
assert session_id.key == "test-key"
def test_round_trip_entity_id_conversion(self) -> None:
"""Test round-trip conversion to and from EntityId."""
original = AgentSessionId(name="AgentEntity", key="test-key")
entity_id = original.to_entity_id()
restored = AgentSessionId.from_entity_id(entity_id)
assert restored.name == original.name
assert restored.key == original.key
def test_str_representation(self) -> None:
"""Test string representation."""
session_id = AgentSessionId(name="AgentEntity", key="test-key-123")
str_repr = str(session_id)
assert str_repr == "@AgentEntity@test-key-123"
def test_repr_representation(self) -> None:
"""Test repr representation."""
session_id = AgentSessionId(name="AgentEntity", key="test-key")
repr_str = repr(session_id)
assert "AgentSessionId" in repr_str
assert "AgentEntity" in repr_str
assert "test-key" in repr_str
def test_parse_valid_session_id(self) -> None:
"""Test parsing valid session ID string."""
session_id = AgentSessionId.parse("@AgentEntity@test-key-123")
assert session_id.name == "AgentEntity"
assert session_id.key == "test-key-123"
def test_parse_invalid_format_no_prefix(self) -> None:
"""Test parsing invalid format without @ prefix."""
with pytest.raises(ValueError) as exc_info:
AgentSessionId.parse("AgentEntity@test-key")
assert "Invalid agent session ID format" in str(exc_info.value)
def test_parse_invalid_format_single_part(self) -> None:
"""Test parsing invalid format with single part."""
with pytest.raises(ValueError) as exc_info:
AgentSessionId.parse("@AgentEntity")
assert "Invalid agent session ID format" in str(exc_info.value)
def test_parse_with_multiple_at_signs_in_key(self) -> None:
"""Test parsing with @ signs in the key."""
session_id = AgentSessionId.parse("@AgentEntity@key-with@symbols")
assert session_id.name == "AgentEntity"
assert session_id.key == "key-with@symbols"
def test_parse_round_trip(self) -> None:
"""Test round-trip parse and string conversion."""
original = AgentSessionId(name="AgentEntity", key="test-key")
str_repr = str(original)
parsed = AgentSessionId.parse(str_repr)
assert parsed.name == original.name
assert parsed.key == original.key
def test_to_entity_name_adds_prefix(self) -> None:
"""Test that to_entity_name adds the dafx- prefix."""
entity_name = AgentSessionId.to_entity_name("TestAgent")
assert entity_name == "dafx-TestAgent"
def test_from_entity_id_strips_prefix(self) -> None:
"""Test that from_entity_id strips the dafx- prefix."""
entity_id = df.EntityId(name="dafx-TestAgent", key="key123")
session_id = AgentSessionId.from_entity_id(entity_id)
assert session_id.name == "TestAgent"
assert session_id.key == "key123"
def test_from_entity_id_raises_without_prefix(self) -> None:
"""Test that from_entity_id raises ValueError when entity name lacks the prefix."""
entity_id = df.EntityId(name="TestAgent", key="key123")
with pytest.raises(ValueError) as exc_info:
AgentSessionId.from_entity_id(entity_id)
assert "not a valid agent session ID" in str(exc_info.value)
assert "dafx-" in str(exc_info.value)
class TestRunRequest:
"""Test suite for RunRequest."""
def test_init_with_defaults(self) -> None:
"""Test RunRequest initialization with defaults."""
request = RunRequest(message="Hello")
assert request.message == "Hello"
assert request.role == Role.USER
assert request.response_format is None
assert request.enable_tool_calls is True
def test_init_with_all_fields(self) -> None:
"""Test RunRequest initialization with all fields."""
schema = ModuleStructuredResponse
request = RunRequest(
message="Hello",
role=Role.SYSTEM,
response_format=schema,
enable_tool_calls=False,
)
assert request.message == "Hello"
assert request.role == Role.SYSTEM
assert request.response_format is schema
assert request.enable_tool_calls is False
def test_init_coerces_string_role(self) -> None:
"""Ensure string role values are coerced into Role instances."""
request = RunRequest(message="Hello", role="system") # type: ignore[arg-type]
assert request.role == Role.SYSTEM
def test_to_dict_with_defaults(self) -> None:
"""Test to_dict with default values."""
request = RunRequest(message="Test message")
data = request.to_dict()
assert data["message"] == "Test message"
assert data["enable_tool_calls"] is True
assert data["role"] == "user"
assert "response_format" not in data or data["response_format"] is None
assert "thread_id" not in data
def test_to_dict_with_all_fields(self) -> None:
"""Test to_dict with all fields."""
schema = ModuleStructuredResponse
request = RunRequest(
message="Hello",
role=Role.ASSISTANT,
response_format=schema,
enable_tool_calls=False,
)
data = request.to_dict()
assert data["message"] == "Hello"
assert data["role"] == "assistant"
assert data["response_format"]["__response_schema_type__"] == "pydantic_model"
assert data["response_format"]["module"] == schema.__module__
assert data["response_format"]["qualname"] == schema.__qualname__
assert data["enable_tool_calls"] is False
assert "thread_id" not in data
def test_from_dict_with_defaults(self) -> None:
"""Test from_dict with minimal data."""
data = {"message": "Hello"}
request = RunRequest.from_dict(data)
assert request.message == "Hello"
assert request.role == Role.USER
assert request.enable_tool_calls is True
def test_from_dict_ignores_thread_id_field(self) -> None:
"""Ensure legacy thread_id input does not break RunRequest parsing."""
request = RunRequest.from_dict({"message": "Hello", "thread_id": "ignored"})
assert request.message == "Hello"
def test_from_dict_with_all_fields(self) -> None:
"""Test from_dict with all fields."""
data = {
"message": "Test",
"role": "system",
"response_format": {
"__response_schema_type__": "pydantic_model",
"module": ModuleStructuredResponse.__module__,
"qualname": ModuleStructuredResponse.__qualname__,
},
"enable_tool_calls": False,
}
request = RunRequest.from_dict(data)
assert request.message == "Test"
assert request.role == Role.SYSTEM
assert request.response_format is ModuleStructuredResponse
assert request.enable_tool_calls is False
def test_from_dict_with_unknown_role_preserves_value(self) -> None:
"""Test from_dict keeps custom roles intact."""
data = {"message": "Test", "role": "reviewer"}
request = RunRequest.from_dict(data)
assert request.role.value == "reviewer"
assert request.role != Role.USER
def test_from_dict_empty_message(self) -> None:
"""Test from_dict with empty message."""
request = RunRequest.from_dict({})
assert request.message == ""
assert request.role == Role.USER
def test_round_trip_dict_conversion(self) -> None:
"""Test round-trip to_dict and from_dict."""
original = RunRequest(
message="Test message",
role=Role.SYSTEM,
response_format=ModuleStructuredResponse,
enable_tool_calls=False,
)
data = original.to_dict()
restored = RunRequest.from_dict(data)
assert restored.message == original.message
assert restored.role == original.role
assert restored.response_format is ModuleStructuredResponse
assert restored.enable_tool_calls == original.enable_tool_calls
def test_round_trip_with_pydantic_response_format(self) -> None:
"""Ensure Pydantic response formats serialize and deserialize properly."""
original = RunRequest(
message="Structured",
response_format=ModuleStructuredResponse,
)
data = original.to_dict()
assert data["response_format"]["__response_schema_type__"] == "pydantic_model"
assert data["response_format"]["module"] == ModuleStructuredResponse.__module__
assert data["response_format"]["qualname"] == ModuleStructuredResponse.__qualname__
restored = RunRequest.from_dict(data)
assert restored.response_format is ModuleStructuredResponse
def test_init_with_correlationId(self) -> None:
"""Test RunRequest initialization with correlationId."""
request = RunRequest(message="Test message", correlation_id="corr-123")
assert request.message == "Test message"
assert request.correlation_id == "corr-123"
def test_to_dict_with_correlationId(self) -> None:
"""Test to_dict includes correlationId."""
request = RunRequest(message="Test", correlation_id="corr-456")
data = request.to_dict()
assert data["message"] == "Test"
assert data["correlationId"] == "corr-456"
def test_from_dict_with_correlationId(self) -> None:
"""Test from_dict with correlationId."""
data = {"message": "Test", "correlationId": "corr-789"}
request = RunRequest.from_dict(data)
assert request.message == "Test"
assert request.correlation_id == "corr-789"
def test_round_trip_with_correlationId(self) -> None:
"""Test round-trip to_dict and from_dict with correlationId."""
original = RunRequest(
message="Test message",
role=Role.SYSTEM,
correlation_id="corr-123",
)
data = original.to_dict()
restored = RunRequest.from_dict(data)
assert restored.message == original.message
assert restored.role == original.role
assert restored.correlation_id == original.correlation_id
def test_init_with_orchestration_id(self) -> None:
"""Test RunRequest initialization with orchestration_id."""
request = RunRequest(
message="Test message",
orchestration_id="orch-123",
)
assert request.message == "Test message"
assert request.orchestration_id == "orch-123"
def test_to_dict_with_orchestration_id(self) -> None:
"""Test to_dict includes orchestrationId."""
request = RunRequest(
message="Test",
orchestration_id="orch-456",
)
data = request.to_dict()
assert data["message"] == "Test"
assert data["orchestrationId"] == "orch-456"
def test_to_dict_excludes_orchestration_id_when_none(self) -> None:
"""Test to_dict excludes orchestrationId when not set."""
request = RunRequest(
message="Test",
)
data = request.to_dict()
assert "orchestrationId" not in data
def test_from_dict_with_orchestration_id(self) -> None:
"""Test from_dict with orchestrationId."""
data = {
"message": "Test",
"orchestrationId": "orch-789",
}
request = RunRequest.from_dict(data)
assert request.message == "Test"
assert request.orchestration_id == "orch-789"
def test_round_trip_with_orchestration_id(self) -> None:
"""Test round-trip to_dict and from_dict with orchestration_id."""
original = RunRequest(
message="Test message",
role=Role.SYSTEM,
correlation_id="corr-123",
orchestration_id="orch-123",
)
data = original.to_dict()
restored = RunRequest.from_dict(data)
assert restored.message == original.message
assert restored.role == original.role
assert restored.correlation_id == original.correlation_id
assert restored.orchestration_id == original.orchestration_id
class TestModelIntegration:
"""Test suite for integration between models."""
def test_run_request_with_session_id_string(self) -> None:
"""AgentSessionId string can still be used by callers, but is not stored on RunRequest."""
session_id = AgentSessionId.with_random_key("AgentEntity")
session_id_str = str(session_id)
assert session_id_str.startswith("@AgentEntity@")
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])
@@ -6,11 +6,11 @@ from typing import Any
from unittest.mock import Mock
import pytest
from agent_framework import AgentRunResponse, AgentThread, ChatMessage
from agent_framework import AgentRunResponse, ChatMessage
from agent_framework_durabletask import DurableAIAgent
from azure.durable_functions.models.Task import TaskBase, TaskState
from agent_framework_azurefunctions import AgentFunctionApp, DurableAIAgent
from agent_framework_azurefunctions._models import AgentSessionId, DurableAgentThread
from agent_framework_azurefunctions import AgentFunctionApp
from agent_framework_azurefunctions._orchestration import AgentTask
@@ -38,46 +38,96 @@ def _create_entity_task(task_id: int = 1) -> TaskBase:
return _FakeTask(task_id)
@pytest.fixture
def mock_context():
"""Create a mock orchestration context with UUID support."""
context = Mock()
context.instance_id = "test-instance"
context.current_utc_datetime = Mock()
return context
@pytest.fixture
def mock_context_with_uuid() -> tuple[Mock, str]:
"""Create a mock context with a single UUID."""
from uuid import UUID
context = Mock()
context.instance_id = "test-instance"
context.current_utc_datetime = Mock()
test_uuid = UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa")
context.new_uuid = Mock(return_value=test_uuid)
return context, test_uuid.hex
@pytest.fixture
def mock_context_with_multiple_uuids() -> tuple[Mock, list[str]]:
"""Create a mock context with multiple UUIDs via side_effect."""
from uuid import UUID
context = Mock()
context.instance_id = "test-instance"
context.current_utc_datetime = Mock()
uuids = [
UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"),
UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"),
UUID("cccccccc-cccc-cccc-cccc-cccccccccccc"),
]
context.new_uuid = Mock(side_effect=uuids)
# Return the hex versions for assertion checking
hex_uuids = [uuid.hex for uuid in uuids]
return context, hex_uuids
@pytest.fixture
def executor_with_uuid() -> tuple[Any, Mock, str]:
"""Create an executor with a mocked generate_unique_id method."""
from agent_framework_azurefunctions._orchestration import AzureFunctionsAgentExecutor
context = Mock()
context.instance_id = "test-instance"
context.current_utc_datetime = Mock()
executor = AzureFunctionsAgentExecutor(context)
test_uuid_hex = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
executor.generate_unique_id = Mock(return_value=test_uuid_hex)
return executor, context, test_uuid_hex
@pytest.fixture
def executor_with_multiple_uuids() -> tuple[Any, Mock, list[str]]:
"""Create an executor with multiple mocked UUIDs."""
from agent_framework_azurefunctions._orchestration import AzureFunctionsAgentExecutor
context = Mock()
context.instance_id = "test-instance"
context.current_utc_datetime = Mock()
executor = AzureFunctionsAgentExecutor(context)
uuid_hexes = [
"aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa",
"bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb",
"cccccccc-cccc-cccc-cccc-cccccccccccc",
"dddddddd-dddd-dddd-dddd-dddddddddddd",
"eeeeeeee-eeee-eeee-eeee-eeeeeeeeeeee",
]
executor.generate_unique_id = Mock(side_effect=uuid_hexes)
return executor, context, uuid_hexes
@pytest.fixture
def executor_with_context(mock_context_with_uuid: tuple[Mock, str]) -> tuple[Any, Mock]:
"""Create an executor with a mocked context."""
from agent_framework_azurefunctions._orchestration import AzureFunctionsAgentExecutor
context, _ = mock_context_with_uuid
return AzureFunctionsAgentExecutor(context), context
class TestAgentResponseHelpers:
"""Tests for helper utilities that prepare AgentRunResponse values."""
@staticmethod
def _create_agent_task() -> AgentTask:
entity_task = _create_entity_task()
return AgentTask(entity_task, None, "correlation-id")
def test_load_agent_response_from_instance(self) -> None:
task = self._create_agent_task()
response = AgentRunResponse(messages=[ChatMessage(role="assistant", text='{"foo": "bar"}')])
loaded = task._load_agent_response(response)
assert loaded is response
assert loaded.value is None
def test_load_agent_response_from_serialized(self) -> None:
task = self._create_agent_task()
serialized = AgentRunResponse(messages=[ChatMessage(role="assistant", text="structured")]).to_dict()
serialized["value"] = {"answer": 42}
loaded = task._load_agent_response(serialized)
assert loaded is not None
assert loaded.value == {"answer": 42}
loaded_dict = loaded.to_dict()
assert loaded_dict["type"] == "agent_run_response"
def test_load_agent_response_rejects_none(self) -> None:
task = self._create_agent_task()
with pytest.raises(ValueError):
task._load_agent_response(None)
def test_load_agent_response_rejects_unsupported_type(self) -> None:
task = self._create_agent_task()
with pytest.raises(TypeError, match="Unsupported type"):
task._load_agent_response(["invalid", "list"]) # type: ignore[arg-type]
"""Tests for response handling through public AgentTask API."""
def test_try_set_value_success(self) -> None:
"""Test try_set_value correctly processes successful task completion."""
@@ -144,335 +194,10 @@ class TestAgentResponseHelpers:
assert isinstance(task.result.value, TestSchema)
assert task.result.value.answer == "42"
def test_ensure_response_format_parses_value(self) -> None:
"""Test _ensure_response_format correctly parses response value."""
from pydantic import BaseModel
class SampleSchema(BaseModel):
name: str
task = self._create_agent_task()
response = AgentRunResponse(messages=[ChatMessage(role="assistant", text='{"name": "test"}')])
# Value should be None initially
assert response.value is None
# Parse the value
task._ensure_response_format(SampleSchema, "test-correlation", response)
# Value should now be parsed
assert isinstance(response.value, SampleSchema)
assert response.value.name == "test"
def test_ensure_response_format_skips_if_already_parsed(self) -> None:
"""Test _ensure_response_format does not re-parse if value already matches format."""
from pydantic import BaseModel
class SampleSchema(BaseModel):
name: str
task = self._create_agent_task()
existing_value = SampleSchema(name="existing")
response = AgentRunResponse(
messages=[ChatMessage(role="assistant", text='{"name": "new"}')],
value=existing_value,
)
# Call _ensure_response_format
task._ensure_response_format(SampleSchema, "test-correlation", response)
# Value should remain unchanged (not re-parsed)
assert response.value is existing_value
assert response.value.name == "existing"
class TestDurableAIAgent:
"""Test suite for DurableAIAgent wrapper."""
def test_init(self) -> None:
"""Test DurableAIAgent initialization."""
mock_context = Mock()
mock_context.instance_id = "test-instance-123"
agent = DurableAIAgent(mock_context, "TestAgent")
assert agent.context == mock_context
assert agent.agent_name == "TestAgent"
def test_implements_agent_protocol(self) -> None:
"""Test that DurableAIAgent implements AgentProtocol."""
from agent_framework import AgentProtocol
mock_context = Mock()
agent = DurableAIAgent(mock_context, "TestAgent")
# Check that agent satisfies AgentProtocol
assert isinstance(agent, AgentProtocol)
def test_has_agent_protocol_properties(self) -> None:
"""Test that DurableAIAgent has AgentProtocol properties."""
mock_context = Mock()
agent = DurableAIAgent(mock_context, "TestAgent")
# AgentProtocol properties
assert hasattr(agent, "id")
assert hasattr(agent, "name")
assert hasattr(agent, "description")
assert hasattr(agent, "display_name")
# Verify values
assert agent.name == "TestAgent"
assert agent.description == "Durable agent proxy for TestAgent"
assert agent.display_name == "TestAgent"
assert agent.id is not None # Auto-generated UUID
def test_get_new_thread(self) -> None:
"""Test creating a new agent thread."""
mock_context = Mock()
mock_context.instance_id = "test-instance-456"
mock_context.new_uuid = Mock(return_value="test-guid-456")
agent = DurableAIAgent(mock_context, "WriterAgent")
thread = agent.get_new_thread()
assert isinstance(thread, DurableAgentThread)
assert thread.session_id is not None
session_id = thread.session_id
assert isinstance(session_id, AgentSessionId)
assert session_id.name == "WriterAgent"
assert session_id.key == "test-guid-456"
mock_context.new_uuid.assert_called_once()
def test_get_new_thread_deterministic(self) -> None:
"""Test that get_new_thread creates deterministic session IDs."""
mock_context = Mock()
mock_context.instance_id = "test-instance-789"
mock_context.new_uuid = Mock(side_effect=["session-guid-1", "session-guid-2"])
agent = DurableAIAgent(mock_context, "EditorAgent")
# Create multiple threads - they should have unique session IDs
thread1 = agent.get_new_thread()
thread2 = agent.get_new_thread()
assert isinstance(thread1, DurableAgentThread)
assert isinstance(thread2, DurableAgentThread)
session_id1 = thread1.session_id
session_id2 = thread2.session_id
assert session_id1 is not None and session_id2 is not None
assert isinstance(session_id1, AgentSessionId)
assert isinstance(session_id2, AgentSessionId)
assert session_id1.name == "EditorAgent"
assert session_id2.name == "EditorAgent"
assert session_id1.key == "session-guid-1"
assert session_id2.key == "session-guid-2"
assert mock_context.new_uuid.call_count == 2
def test_run_creates_entity_call(self) -> None:
"""Test that run() creates proper entity call and returns a Task."""
mock_context = Mock()
mock_context.instance_id = "test-instance-001"
mock_context.new_uuid = Mock(side_effect=["thread-guid", "correlation-guid"])
entity_task = _create_entity_task()
mock_context.call_entity = Mock(return_value=entity_task)
agent = DurableAIAgent(mock_context, "TestAgent")
# Create thread
thread = agent.get_new_thread()
# Call run() - returns AgentTask directly
task = agent.run(messages="Test message", thread=thread, enable_tool_calls=True)
assert isinstance(task, AgentTask)
assert task.children[0] == entity_task
# Verify call_entity was called with correct parameters
assert mock_context.call_entity.called
call_args = mock_context.call_entity.call_args
entity_id, operation, request = call_args[0]
assert operation == "run"
assert request["message"] == "Test message"
assert request["enable_tool_calls"] is True
assert "correlationId" in request
assert request["correlationId"] == "correlation-guid"
assert "thread_id" not in request
# Verify orchestration ID is set from context.instance_id
assert "orchestrationId" in request
assert request["orchestrationId"] == "test-instance-001"
def test_run_sets_orchestration_id(self) -> None:
"""Test that run() sets the orchestration_id from context.instance_id."""
mock_context = Mock()
mock_context.instance_id = "my-orchestration-123"
mock_context.new_uuid = Mock(side_effect=["thread-guid", "correlation-guid"])
entity_task = _create_entity_task()
mock_context.call_entity = Mock(return_value=entity_task)
agent = DurableAIAgent(mock_context, "TestAgent")
thread = agent.get_new_thread()
agent.run(messages="Test", thread=thread)
call_args = mock_context.call_entity.call_args
request = call_args[0][2]
assert request["orchestrationId"] == "my-orchestration-123"
def test_run_without_thread(self) -> None:
"""Test that run() works without explicit thread (creates unique session key)."""
mock_context = Mock()
mock_context.instance_id = "test-instance-002"
mock_context.new_uuid = Mock(side_effect=["auto-generated-guid", "correlation-guid"])
entity_task = _create_entity_task()
mock_context.call_entity = Mock(return_value=entity_task)
agent = DurableAIAgent(mock_context, "TestAgent")
# Call without thread
task = agent.run(messages="Test message")
assert isinstance(task, AgentTask)
assert task.children[0] == entity_task
# Verify the entity ID uses the auto-generated GUID with dafx- prefix
call_args = mock_context.call_entity.call_args
entity_id = call_args[0][0]
assert entity_id.name == "dafx-TestAgent"
assert entity_id.key == "auto-generated-guid"
# Should be called twice: once for session_key, once for correlationId
assert mock_context.new_uuid.call_count == 2
def test_run_with_response_format(self) -> None:
"""Test that run() passes response format correctly."""
mock_context = Mock()
mock_context.instance_id = "test-instance-003"
entity_task = _create_entity_task()
mock_context.call_entity = Mock(return_value=entity_task)
agent = DurableAIAgent(mock_context, "TestAgent")
from pydantic import BaseModel
class SampleSchema(BaseModel):
key: str
# Create thread and call
thread = agent.get_new_thread()
task = agent.run(messages="Test message", thread=thread, response_format=SampleSchema)
assert isinstance(task, AgentTask)
assert task.children[0] == entity_task
# Verify schema was passed in the call_entity arguments
call_args = mock_context.call_entity.call_args
input_data = call_args[0][2] # Third argument is input_data
assert "response_format" in input_data
assert input_data["response_format"]["__response_schema_type__"] == "pydantic_model"
assert input_data["response_format"]["module"] == SampleSchema.__module__
assert input_data["response_format"]["qualname"] == SampleSchema.__qualname__
def test_messages_to_string(self) -> None:
"""Test converting ChatMessage list to string."""
from agent_framework import ChatMessage
mock_context = Mock()
agent = DurableAIAgent(mock_context, "TestAgent")
messages = [
ChatMessage(role="user", text="Hello"),
ChatMessage(role="assistant", text="Hi there"),
ChatMessage(role="user", text="How are you?"),
]
result = agent._messages_to_string(messages)
assert result == "Hello\nHi there\nHow are you?"
def test_run_with_chat_message(self) -> None:
"""Test that run() handles ChatMessage input."""
from agent_framework import ChatMessage
mock_context = Mock()
mock_context.new_uuid = Mock(side_effect=["thread-guid", "correlation-guid"])
entity_task = _create_entity_task()
mock_context.call_entity = Mock(return_value=entity_task)
agent = DurableAIAgent(mock_context, "TestAgent")
thread = agent.get_new_thread()
# Call with ChatMessage
msg = ChatMessage(role="user", text="Hello")
task = agent.run(messages=msg, thread=thread)
assert isinstance(task, AgentTask)
assert task.children[0] == entity_task
# Verify message was converted to string
call_args = mock_context.call_entity.call_args
request = call_args[0][2]
assert request["message"] == "Hello"
def test_run_stream_raises_not_implemented(self) -> None:
"""Test that run_stream() method raises NotImplementedError."""
mock_context = Mock()
agent = DurableAIAgent(mock_context, "TestAgent")
with pytest.raises(NotImplementedError) as exc_info:
agent.run_stream("Test message")
error_msg = str(exc_info.value)
assert "Streaming is not supported" in error_msg
def test_entity_id_format(self) -> None:
"""Test that EntityId is created with correct format (name, key)."""
from azure.durable_functions import EntityId
mock_context = Mock()
mock_context.new_uuid = Mock(return_value="test-guid-789")
mock_context.call_entity = Mock(return_value=_create_entity_task())
agent = DurableAIAgent(mock_context, "WriterAgent")
thread = agent.get_new_thread()
# Call run() to trigger entity ID creation
agent.run("Test", thread=thread)
# Verify call_entity was called with correct EntityId
call_args = mock_context.call_entity.call_args
entity_id = call_args[0][0]
# EntityId should be EntityId(name="dafx-WriterAgent", key="test-guid-789")
# Which formats as "@dafx-writeragent@test-guid-789"
assert isinstance(entity_id, EntityId)
assert entity_id.name == "dafx-WriterAgent"
assert entity_id.key == "test-guid-789"
assert str(entity_id) == "@dafx-writeragent@test-guid-789"
class TestAgentFunctionAppGetAgent:
"""Test suite for AgentFunctionApp.get_agent."""
def test_get_agent_method(self) -> None:
"""Test get_agent method creates DurableAIAgent for registered agent."""
app = _app_with_registered_agents("MyAgent")
mock_context = Mock()
mock_context.instance_id = "test-instance-100"
agent = app.get_agent(mock_context, "MyAgent")
assert isinstance(agent, DurableAIAgent)
assert agent.agent_name == "MyAgent"
assert agent.context == mock_context
def test_get_agent_raises_for_unregistered_agent(self) -> None:
"""Test get_agent raises ValueError when agent is not registered."""
app = _app_with_registered_agents("KnownAgent")
@@ -484,15 +209,9 @@ class TestAgentFunctionAppGetAgent:
class TestOrchestrationIntegration:
"""Integration tests for orchestration scenarios."""
def test_sequential_agent_calls_simulation(self) -> None:
def test_sequential_agent_calls_simulation(self, executor_with_multiple_uuids: tuple[Any, Mock, list[str]]) -> None:
"""Simulate sequential agent calls in an orchestration."""
mock_context = Mock()
mock_context.instance_id = "test-orchestration-001"
# new_uuid will be called 3 times:
# 1. thread creation
# 2. correlationId for first call
# 3. correlationId for second call
mock_context.new_uuid = Mock(side_effect=["deterministic-guid-001", "corr-1", "corr-2"])
executor, context, uuid_hexes = executor_with_multiple_uuids
# Track entity calls
entity_calls: list[dict[str, Any]] = []
@@ -501,10 +220,10 @@ class TestOrchestrationIntegration:
entity_calls.append({"entity_id": str(entity_id), "operation": operation, "input": input_data})
return _create_entity_task()
mock_context.call_entity = Mock(side_effect=mock_call_entity_side_effect)
context.call_entity = Mock(side_effect=mock_call_entity_side_effect)
app = _app_with_registered_agents("WriterAgent")
agent = app.get_agent(mock_context, "WriterAgent")
# Create agent directly with executor (not via app.get_agent)
agent = DurableAIAgent(executor, "WriterAgent")
# Create thread
thread = agent.get_new_thread()
@@ -520,18 +239,15 @@ class TestOrchestrationIntegration:
# Verify both calls used the same entity (same session key)
assert len(entity_calls) == 2
assert entity_calls[0]["entity_id"] == entity_calls[1]["entity_id"]
# EntityId format is @dafx-writeragent@deterministic-guid-001
assert entity_calls[0]["entity_id"] == "@dafx-writeragent@deterministic-guid-001"
# new_uuid called 3 times: thread + 2 correlation IDs
assert mock_context.new_uuid.call_count == 3
# EntityId format is @dafx-writeragent@<uuid_hex>
expected_entity_id = f"@dafx-writeragent@{uuid_hexes[0]}"
assert entity_calls[0]["entity_id"] == expected_entity_id
# generate_unique_id called 3 times: thread + 2 correlation IDs
assert executor.generate_unique_id.call_count == 3
def test_multiple_agents_in_orchestration(self) -> None:
def test_multiple_agents_in_orchestration(self, executor_with_multiple_uuids: tuple[Any, Mock, list[str]]) -> None:
"""Test using multiple different agents in one orchestration."""
mock_context = Mock()
mock_context.instance_id = "test-orchestration-002"
# Mock new_uuid to return different GUIDs for each call
# Order: writer thread, editor thread, writer correlation, editor correlation
mock_context.new_uuid = Mock(side_effect=["writer-guid-001", "editor-guid-002", "writer-corr", "editor-corr"])
executor, context, uuid_hexes = executor_with_multiple_uuids
entity_calls: list[str] = []
@@ -539,11 +255,11 @@ class TestOrchestrationIntegration:
entity_calls.append(str(entity_id))
return _create_entity_task()
mock_context.call_entity = Mock(side_effect=mock_call_entity_side_effect)
context.call_entity = Mock(side_effect=mock_call_entity_side_effect)
app = _app_with_registered_agents("WriterAgent", "EditorAgent")
writer = app.get_agent(mock_context, "WriterAgent")
editor = app.get_agent(mock_context, "EditorAgent")
# Create agents directly with executor (not via app.get_agent)
writer = DurableAIAgent(executor, "WriterAgent")
editor = DurableAIAgent(executor, "EditorAgent")
writer_thread = writer.get_new_thread()
editor_thread = editor.get_new_thread()
@@ -557,62 +273,11 @@ class TestOrchestrationIntegration:
# Verify different entity IDs were used
assert len(entity_calls) == 2
# EntityId format is @dafx-agentname@guid (lowercased agent name with dafx- prefix)
assert entity_calls[0] == "@dafx-writeragent@writer-guid-001"
assert entity_calls[1] == "@dafx-editoragent@editor-guid-002"
class TestAgentThreadSerialization:
"""Test that AgentThread can be serialized for orchestration state."""
async def test_agent_thread_serialize(self) -> None:
"""Test that AgentThread can be serialized."""
thread = AgentThread()
# Serialize
serialized = await thread.serialize()
assert isinstance(serialized, dict)
assert "service_thread_id" in serialized
async def test_agent_thread_deserialize(self) -> None:
"""Test that AgentThread can be deserialized."""
thread = AgentThread()
serialized = await thread.serialize()
# Deserialize
restored = await AgentThread.deserialize(serialized)
assert isinstance(restored, AgentThread)
assert restored.service_thread_id == thread.service_thread_id
async def test_durable_agent_thread_serialization(self) -> None:
"""Test that DurableAgentThread persists session metadata during serialization."""
mock_context = Mock()
mock_context.instance_id = "test-instance-999"
mock_context.new_uuid = Mock(return_value="test-guid-999")
agent = DurableAIAgent(mock_context, "TestAgent")
thread = agent.get_new_thread()
assert isinstance(thread, DurableAgentThread)
# Verify custom attribute and property exist
assert thread.session_id is not None
session_id = thread.session_id
assert isinstance(session_id, AgentSessionId)
assert session_id.name == "TestAgent"
assert session_id.key == "test-guid-999"
# Standard serialization should still work
serialized = await thread.serialize()
assert isinstance(serialized, dict)
assert serialized.get("durable_session_id") == str(session_id)
# After deserialization, we'd need to restore the custom attribute
# This would be handled by the orchestration framework
restored = await DurableAgentThread.deserialize(serialized)
assert isinstance(restored, DurableAgentThread)
assert restored.session_id == session_id
# EntityId format is @dafx-agentname@uuid_hex (lowercased agent name with dafx- prefix)
expected_writer_id = f"@dafx-writeragent@{uuid_hexes[0]}"
expected_editor_id = f"@dafx-editoragent@{uuid_hexes[1]}"
assert entity_calls[0] == expected_writer_id
assert entity_calls[1] == expected_editor_id
if __name__ == "__main__":
-248
View File
@@ -1,248 +0,0 @@
# Design: Durable Task Provider for Agent Framework
## Overview
This package, `agent-framework-durabletask`, provides a durability layer for the Microsoft Agent Framework using the `durabletask` Python SDK. It enables stateful, reliable, and distributed agent execution on any platform (Bring Your Own Platform), decoupling the agent's durability from the Azure Functions platform.
## Design Decision
**Selected Approach: Object-Oriented Wrappers with Symmetric Factory Pattern**
We will use a symmetric Object-Oriented design where both the Client (external) and Orchestrator (internal) expose a consistent interface for retrieving and interacting with durable agents.
## Core Philosophy
* **Native `DurableEntity` Support**: We will leverage the `DurableEntity` support introduced in `durabletask` v1.0.0.
* **Symmetric Factories**: `DurableAIAgentClient` (for external use) and `DurableAIAgentOrchestrator` (for internal use) both provide a `get_agent` method.
* **Unified Interface**: `DurableAIAgent` serves as the common interface for executing agents, regardless of the context (Client vs Orchestration).
* **Consistent Return Type**: `DurableAIAgent.run` always returns a `Task` (or compatible awaitable), ensuring consistent usage patterns.
## Architecture
### 1. Package Structure
```text
packages/durabletask/
├── pyproject.toml
├── README.md
├── agent_framework_durabletask/
│ ├── __init__.py
│ ├── _worker.py # DurableAIAgentWorker
│ ├── _client.py # DurableAIAgentClient
│ ├── _orchestrator.py # DurableAIAgentOrchestrator
│ ├── _entities.py # AgentEntity implementation
│ ├── _models.py # Data models (RunRequest, AgentResponse, etc.)
│ ├── _durable_agent_state.py # State schema (Ported from azurefunctions)
│ ├── _shim.py # DurableAIAgent implementation (will be ported from azurefunctions)
│ └── _utils.py # Mixins and helpers
└── tests/
```
### 2. State Management (`_durable_agent_state.py`)
**Important**: This will be the state maintained in the durable entities for both `durabletask` and `azurefunctions` package.
### 3. The Agent Entity (`_entities.py`)
We will implement a class `AgentEntity` that inherits from `durabletask.entities.DurableEntity`.
**Important**: This will be ported from `azurefunctions` package too but with slight modifications, details TBD.
### 4. The Worker Wrapper (`_worker.py`)
The `DurableAIAgentWorker` wraps an existing `durabletask` worker instance.
```python
class DurableAIAgentWorker:
def __init__(self, worker: TaskHubGrpcWorker):
self._worker = worker
self._registered_agents: dict[str, AgentProtocol] = {}
def add_agent(self, agent: AgentProtocol) -> None:
"""Registers an agent with the worker.
Uses the factory pattern to create an AgentEntity class with the agent
instance injected, then registers it with the durabletask worker.
"""
# Store the agent reference
self._registered_agents[agent.name] = agent
# Create a configured entity class using the factory
entity_class = create_agent_entity(agent)
# Register the entity class with the worker
# The worker.add_entity method takes a class or function
self._worker.add_entity(entity_class)
def start(self):
"""Start the worker to begin processing tasks."""
self._worker.start()
def stop(self):
"""Stop the worker gracefully."""
self._worker.stop()
```
### 5. The Mixin (`_utils.py`)
```python
class GetDurableAgentMixin:
"""Mixin to provide get_agent interface."""
def get_agent(self, agent_name: str) -> 'DurableAIAgent':
raise NotImplementedError
```
### 6. The Client Wrapper (`_client.py`)
The `DurableAIAgentClient` is for external clients (e.g., FastAPI, CLI).
```python
class DurableAIAgentClient(GetDurableAgentMixin):
def __init__(self, client: TaskHubGrpcClient):
self._client = client
async def get_agent(self, agent_name: str) -> 'DurableAIAgent':
"""Retrieves a DurableAIAgent shim.
Validates existence by attempting to fetch entity state/metadata.
"""
# Validation logic using self._client.get_entity(...)
# ...
return DurableAIAgent(self, agent_name)
def run_agent(self, agent_name: str, message: str, **kwargs) -> 'Task':
"""Runs agent via signal + poll and returns a Task wrapper."""
# Returns a ClientTask (wrapper around asyncio.Task)
pass
```
### 7. The Orchestration Context Wrapper (`_orchestration_context.py`)
The `DurableAIAgentOrchestrationContext` is for use *inside* orchestrations to get access to agents that were registered in the workers.
```python
class DurableAIAgentOrchestrationContext(GetDurableAgentMixin):
def __init__(self, context: OrchestrationContext):
self._context = context
def get_agent(self, agent_name: str) -> 'DurableAIAgent':
"""Retrieves a DurableAIAgent shim.
Validation is deferred or performed via call_entity if needed.
"""
return DurableAIAgent(self, agent_name)
def run_agent(self, agent_name: str, message: str, **kwargs) -> 'Task':
"""Runs agent via call_entity and returns the Task."""
# Returns the native durabletask.Task
pass
```
### 8. The Durable Agent Shim (`_shim.py`)
The `DurableAIAgent` implements `AgentProtocol` but delegates execution to the provider. This will be ported from `azurefunctions` package and updated accordingly.
```python
class DurableAIAgent(AgentProtocol):
"""A shim that delegates execution to the provider (Client or Orchestrator)."""
def __init__(self, provider: GetDurableAgentMixin, name: str):
self._provider = provider
self._name = name
@property
def name(self) -> str:
return self._name
def run(self, message: str, **kwargs) -> 'Task':
"""Executes the agent.
Returns:
Task: A yieldable/awaitable task object.
"""
return self._provider.run_agent(
agent_name=self.name,
message=message,
**kwargs
)
```
## Usage Experience
**Scenario A: Worker Side**
```python
# 1. Define your agent
# The agent can be any implementation of AgentProtocol.
# For example, a standard Agent with a model and instructions.
my_agent = Agent(
name="my_agent",
instructions="You are a helpful assistant.",
model=openai_model
)
# 2. Create the worker and the agent worker wrapper
with DurableTaskSchedulerWorker(...) as worker:
agent_worker = DurableAIAgentWorker(worker)
# 3. Register the agent
agent_worker.add_agent(my_agent)
# 4. Start the worker
worker.start()
# ... keep running ...
```
**Scenario B: Client Side**
```python
# 1. Configure the Durable Task client
client = DurableTaskSchedulerClient(...)
# 2. Create the Agent Client wrapper
agent_client = DurableAIAgentClient(client)
# 3. Get a reference to the agent
agent = await agent_client.get_agent("my_agent")
# 4. Run the agent
# The returned object is designed to be compatible with both `await` (Client)
# and `yield` (Orchestrator). Implementation details on this unified return type will follow.
response = await agent.run("Hello")
```
**Scenario C: Orchestration Side**
```python
def orchestrator(context: OrchestrationContext):
# 1. Create the Agent Orchestrator wrapper
agent_orch = DurableAIAgentOrchestrator(context)
# 2. Get a reference to the agent
agent = agent_orch.get_agent("my_agent")
# 3. Run the agent (returns a Task, so we yield it)
result = yield agent.run("Hello")
return result
```
## Additional Styles Considered
### Inheritance Pattern for worker and client (like `DurableAIAgentWorker`, `DurableAIAgentClient`, etc)
We investigated inheriting `DurableAIAgentWorker` directly from `TaskHubGrpcWorker` (or `DurableTaskSchedulerWorker`) to provide a unified API where the agent worker *is* a durable task worker (and similarly the client).
**Why we chose Composition over Inheritance:**
1. **Initialization Divergence:** The `durabletask` package has two distinct worker classes with incompatible `__init__` signatures:
* `TaskHubGrpcWorker`: Requires `host_address`, `metadata`, etc.
* `DurableTaskSchedulerWorker`: Requires `host_address`, `taskhub`, `token_credential`, etc.
To support both via inheritance, we would need to maintain two separate classes (e.g., `DurableAIAgentGrpcWorker` and `DurableAIAgentSchedulerWorker`) or use a complex Mixin approach. This increases the API surface area and maintenance burden.
2. **Encapsulation:** The logic for Azure Managed DTS (authentication, routing) is currently encapsulated in an internal interceptor class within `durabletask`. Without changes to the upstream package to expose this logic, we cannot create a single "Universal" worker class that inherits from the base worker but supports Azure features.
3. **Flexibility:** The Composition pattern allows `DurableAIAgentWorker` to accept *any* instance of a worker that satisfies the required interface. This makes it forward-compatible with future worker implementations or custom subclasses without requiring code changes in our package.
4. **Simplicity:** While Composition requires a two-step setup (instantiate worker, then wrap it), it keeps the `agent-framework-durabletask` package simple, focused, and loosely coupled from the implementation details of the underlying `durabletask` workers.
@@ -3,6 +3,7 @@
"""Durable Task integration for Microsoft Agent Framework."""
from ._callbacks import AgentCallbackContext, AgentResponseCallbackProtocol
from ._client import DurableAIAgentClient
from ._constants import (
DEFAULT_MAX_POLL_RETRIES,
DEFAULT_POLL_INTERVAL_SECONDS,
@@ -41,7 +42,12 @@ from ._durable_agent_state import (
DurableAgentStateUsageContent,
)
from ._entities import AgentEntity, AgentEntityStateProviderMixin
from ._models import RunRequest, serialize_response_format
from ._executors import DurableAgentExecutor
from ._models import AgentSessionId, DurableAgentThread, RunRequest
from ._orchestration_context import DurableAIAgentOrchestrationContext
from ._response_utils import ensure_response_format, load_agent_response
from ._shim import DurableAIAgent
from ._worker import DurableAIAgentWorker
__all__ = [
"DEFAULT_MAX_POLL_RETRIES",
@@ -58,8 +64,14 @@ __all__ = [
"AgentEntity",
"AgentEntityStateProviderMixin",
"AgentResponseCallbackProtocol",
"AgentSessionId",
"ApiResponseFields",
"ContentTypes",
"DurableAIAgent",
"DurableAIAgentClient",
"DurableAIAgentOrchestrationContext",
"DurableAIAgentWorker",
"DurableAgentExecutor",
"DurableAgentState",
"DurableAgentStateContent",
"DurableAgentStateData",
@@ -80,7 +92,10 @@ __all__ = [
"DurableAgentStateUriContent",
"DurableAgentStateUsage",
"DurableAgentStateUsageContent",
"DurableAgentThread",
"DurableAgentThread",
"DurableStateFields",
"RunRequest",
"serialize_response_format",
"ensure_response_format",
"load_agent_response",
]
@@ -0,0 +1,90 @@
# Copyright (c) Microsoft. All rights reserved.
"""Client wrapper for Durable Task Agent Framework.
This module provides the DurableAIAgentClient class for external clients to interact
with durable agents via gRPC.
"""
from __future__ import annotations
from agent_framework import AgentRunResponse, get_logger
from durabletask.client import TaskHubGrpcClient
from ._constants import DEFAULT_MAX_POLL_RETRIES, DEFAULT_POLL_INTERVAL_SECONDS
from ._executors import ClientAgentExecutor
from ._shim import DurableAgentProvider, DurableAIAgent
logger = get_logger("agent_framework.durabletask.client")
class DurableAIAgentClient(DurableAgentProvider[AgentRunResponse]):
"""Client wrapper for interacting with durable agents externally.
This class wraps a durabletask TaskHubGrpcClient and provides a convenient
interface for retrieving and executing durable agents from external contexts.
Example:
```python
from durabletask import TaskHubGrpcClient
from agent_framework_durabletask import DurableAIAgentClient
# Create the underlying client
client = TaskHubGrpcClient(host_address="localhost:4001")
# Wrap it with the agent client
agent_client = DurableAIAgentClient(client)
# Get an agent reference
agent = agent_client.get_agent("assistant")
# Run the agent (synchronous call that waits for completion)
response = agent.run("Hello, how are you?")
print(response.text)
```
"""
def __init__(
self,
client: TaskHubGrpcClient,
max_poll_retries: int = DEFAULT_MAX_POLL_RETRIES,
poll_interval_seconds: float = DEFAULT_POLL_INTERVAL_SECONDS,
):
"""Initialize the client wrapper.
Args:
client: The durabletask client instance to wrap
max_poll_retries: Maximum polling attempts when waiting for responses
poll_interval_seconds: Delay in seconds between polling attempts
"""
self._client = client
# Validate and set polling parameters
self.max_poll_retries = max(1, max_poll_retries)
self.poll_interval_seconds = (
poll_interval_seconds if poll_interval_seconds > 0 else DEFAULT_POLL_INTERVAL_SECONDS
)
self._executor = ClientAgentExecutor(self._client, self.max_poll_retries, self.poll_interval_seconds)
logger.debug("[DurableAIAgentClient] Initialized with client type: %s", type(client).__name__)
def get_agent(self, agent_name: str) -> DurableAIAgent[AgentRunResponse]:
"""Retrieve a DurableAIAgent shim for the specified agent.
This method returns a proxy object that can be used to execute the agent.
The actual agent must be registered on a worker with the same name.
Args:
agent_name: Name of the agent to retrieve (without the dafx- prefix)
Returns:
DurableAIAgent instance that can be used to run the agent
Note:
This method does not validate that the agent exists. Validation
will occur when the agent is executed. If the entity doesn't exist,
the execution will fail with an appropriate error.
"""
logger.debug("[DurableAIAgentClient] Creating agent proxy for: %s", agent_name)
return DurableAIAgent(self._executor, agent_name)
@@ -53,7 +53,7 @@ from agent_framework import (
)
from dateutil import parser as date_parser
from ._constants import ApiResponseFields, ContentTypes, DurableStateFields
from ._constants import ContentTypes, DurableStateFields
from ._models import RunRequest, serialize_response_format
logger = get_logger("agent_framework.durabletask.durable_agent_state")
@@ -452,7 +452,7 @@ class DurableAgentState:
"""Get the count of conversation entries (requests + responses)."""
return len(self.data.conversation_history)
def try_get_agent_response(self, correlation_id: str) -> dict[str, Any] | None:
def try_get_agent_response(self, correlation_id: str) -> AgentRunResponse | None:
"""Try to get an agent response by correlation ID.
This method searches the conversation history for a response entry matching the given
@@ -474,14 +474,8 @@ class DurableAgentState:
for entry in self.data.conversation_history:
if entry.correlation_id == correlation_id and isinstance(entry, DurableAgentStateResponse):
# Found the entry, extract response data
# Get the text content from assistant messages only
content = "\n".join(message.text for message in entry.messages if message.text)
return DurableAgentStateResponse.to_run_response(entry)
return {
ApiResponseFields.CONTENT: content,
ApiResponseFields.MESSAGE_COUNT: self.message_count,
ApiResponseFields.CORRELATION_ID: correlation_id,
}
return None
@@ -705,6 +699,21 @@ class DurableAgentStateResponse(DurableAgentStateEntry):
usage=DurableAgentStateUsage.from_usage(response.usage_details),
)
@staticmethod
def to_run_response(
response_entry: DurableAgentStateResponse,
) -> AgentRunResponse:
"""Converts a DurableAgentStateResponse back to an AgentRunResponse."""
messages = [m.to_chat_message() for m in response_entry.messages]
usage_details = response_entry.usage.to_usage_details() if response_entry.usage is not None else UsageDetails()
return AgentRunResponse(
created_at=response_entry.created_at.isoformat(),
messages=messages,
usage_details=usage_details,
)
class DurableAgentStateMessage:
"""Represents a message within a conversation history entry.
@@ -1214,14 +1223,24 @@ class DurableAgentStateUsage:
input_token_count=usage.input_token_count,
output_token_count=usage.output_token_count,
total_token_count=usage.total_token_count,
extensionData=usage.additional_counts,
)
def to_usage_details(self) -> UsageDetails:
# Convert back to AI SDK UsageDetails
extension_data: dict[str, int] = {}
if self.extensionData is not None:
for k, v in self.extensionData.items():
try:
extension_data[k] = int(v)
except (ValueError, TypeError):
continue
return UsageDetails(
input_token_count=self.input_token_count,
output_token_count=self.output_token_count,
total_token_count=self.total_token_count,
**extension_data,
)
@@ -128,7 +128,7 @@ class AgentEntity:
) -> AgentRunResponse:
"""Execute the agent with a message."""
if isinstance(request, str):
run_request = RunRequest(message=request, role=Role.USER)
run_request = RunRequest.from_json(request)
elif isinstance(request, dict):
run_request = RunRequest.from_dict(request)
else:
@@ -139,8 +139,6 @@ class AgentEntity:
correlation_id = run_request.correlation_id
if not thread_id:
raise ValueError("Entity State Provider must provide a thread_id")
if not correlation_id:
raise ValueError("RunRequest must include a correlation_id")
response_format = run_request.response_format
enable_tool_calls = run_request.enable_tool_calls
@@ -0,0 +1,460 @@
# Copyright (c) Microsoft. All rights reserved.
"""Provider strategies for Durable Agent execution.
These classes are internal execution strategies used by the DurableAIAgent shim.
They are intentionally separate from the public client/orchestration APIs to keep
only `get_agent` exposed to consumers. Executors implement the execution contract
and are injected into the shim.
"""
from __future__ import annotations
import time
import uuid
from abc import ABC, abstractmethod
from datetime import datetime, timezone
from typing import Any, Generic, TypeVar
from agent_framework import AgentRunResponse, AgentThread, ChatMessage, ErrorContent, Role, get_logger
from durabletask.client import TaskHubGrpcClient
from durabletask.entities import EntityInstanceId
from durabletask.task import CompositeTask, OrchestrationContext, Task
from pydantic import BaseModel
from ._constants import DEFAULT_MAX_POLL_RETRIES, DEFAULT_POLL_INTERVAL_SECONDS
from ._durable_agent_state import DurableAgentState
from ._models import AgentSessionId, DurableAgentThread, RunRequest
from ._response_utils import ensure_response_format, load_agent_response
logger = get_logger("agent_framework.durabletask.executors")
# TypeVar for the task type returned by executors
TaskT = TypeVar("TaskT")
class DurableAgentTask(CompositeTask[AgentRunResponse]):
"""A custom Task that wraps entity calls and provides typed AgentRunResponse results.
This task wraps the underlying entity call task and intercepts its completion
to convert the raw result into a typed AgentRunResponse object.
"""
def __init__(
self,
entity_task: Task[Any],
response_format: type[BaseModel] | None,
correlation_id: str,
):
"""Initialize the DurableAgentTask.
Args:
entity_task: The underlying entity call task
response_format: Optional Pydantic model for response parsing
correlation_id: Correlation ID for logging
"""
self._response_format = response_format
self._correlation_id = correlation_id
super().__init__([entity_task]) # type: ignore[misc]
def on_child_completed(self, task: Task[Any]) -> None:
"""Handle completion of the underlying entity task.
Parameters
----------
task : Task
The entity call task that just completed
"""
if self.is_complete:
return
if task.is_failed:
# Propagate the failure
self._exception = task.get_exception()
self._is_complete = True
if self._parent is not None:
self._parent.on_child_completed(self)
return
# Task succeeded - transform the raw result
raw_result = task.get_result()
logger.debug(
"[DurableAgentTask] Converting raw result for correlation_id %s",
self._correlation_id,
)
try:
response = load_agent_response(raw_result)
if self._response_format is not None:
ensure_response_format(
self._response_format,
self._correlation_id,
response,
)
# Set the typed AgentRunResponse as this task's result
self._result = response
self._is_complete = True
if self._parent is not None:
self._parent.on_child_completed(self)
except Exception:
logger.exception(
"[DurableAgentTask] Failed to convert result for correlation_id: %s",
self._correlation_id,
)
raise
class DurableAgentExecutor(ABC, Generic[TaskT]):
"""Abstract base class for durable agent execution strategies.
Type Parameters:
TaskT: The task type returned by this executor
"""
@abstractmethod
def run_durable_agent(
self,
agent_name: str,
run_request: RunRequest,
thread: AgentThread | None = None,
) -> TaskT:
"""Execute the durable agent.
Returns:
TaskT: The task type specific to this executor implementation
"""
raise NotImplementedError
def get_new_thread(self, agent_name: str, **kwargs: Any) -> DurableAgentThread:
"""Create a new DurableAgentThread with random session ID."""
session_id = self._create_session_id(agent_name)
return DurableAgentThread.from_session_id(session_id, **kwargs)
def _create_session_id(
self,
agent_name: str,
thread: AgentThread | None = None,
) -> AgentSessionId:
"""Create the AgentSessionId for the execution."""
if isinstance(thread, DurableAgentThread) and thread.session_id is not None:
return thread.session_id
# Create new session ID - either no thread provided or it's a regular AgentThread
key = self.generate_unique_id()
return AgentSessionId(name=agent_name, key=key)
def generate_unique_id(self) -> str:
"""Generate a new Unique ID."""
return uuid.uuid4().hex
def get_run_request(
self,
message: str,
response_format: type[BaseModel] | None,
enable_tool_calls: bool,
) -> RunRequest:
"""Create a RunRequest for the given parameters."""
correlation_id = self.generate_unique_id()
return RunRequest(
message=message,
response_format=response_format,
enable_tool_calls=enable_tool_calls,
correlation_id=correlation_id,
)
class ClientAgentExecutor(DurableAgentExecutor[AgentRunResponse]):
"""Execution strategy for external clients.
Note: Returns AgentRunResponse directly since the execution
is blocking until response is available via polling
as per the design of TaskHubGrpcClient.
"""
def __init__(
self,
client: TaskHubGrpcClient,
max_poll_retries: int = DEFAULT_MAX_POLL_RETRIES,
poll_interval_seconds: float = DEFAULT_POLL_INTERVAL_SECONDS,
):
self._client = client
self.max_poll_retries = max_poll_retries
self.poll_interval_seconds = poll_interval_seconds
def run_durable_agent(
self,
agent_name: str,
run_request: RunRequest,
thread: AgentThread | None = None,
) -> AgentRunResponse:
"""Execute the agent via the durabletask client.
Signals the agent entity with a message request, then polls the entity
state to retrieve the response once processing is complete.
Note: This is a blocking/synchronous operation (in line with how
TaskHubGrpcClient works) that polls until a response is available or
timeout occurs.
Args:
agent_name: Name of the agent to execute
run_request: The run request containing message and optional response format
thread: Optional conversation thread (creates new if not provided)
Returns:
AgentRunResponse: The agent's response after execution completes
"""
# Signal the entity with the request
entity_id = self._signal_agent_entity(agent_name, run_request, thread)
# Poll for the response
agent_response = self._poll_for_agent_response(entity_id, run_request.correlation_id)
# Handle and return the result
return self._handle_agent_response(agent_response, run_request.response_format, run_request.correlation_id)
def _signal_agent_entity(
self,
agent_name: str,
run_request: RunRequest,
thread: AgentThread | None,
) -> EntityInstanceId:
"""Signal the agent entity with a run request.
Args:
agent_name: Name of the agent to execute
run_request: The run request containing message and optional response format
thread: Optional conversation thread
Returns:
entity_id
"""
# Get or create session ID
session_id = self._create_session_id(agent_name, thread)
# Create the entity ID
entity_id = EntityInstanceId(
entity=session_id.entity_name,
key=session_id.key,
)
logger.debug(
"[ClientAgentExecutor] Signaling entity '%s' (session: %s, correlation: %s)",
agent_name,
session_id,
run_request.correlation_id,
)
self._client.signal_entity(entity_id, "run", run_request.to_dict())
return entity_id
def _poll_for_agent_response(
self,
entity_id: EntityInstanceId,
correlation_id: str,
) -> AgentRunResponse | None:
"""Poll the entity for a response with retries.
Args:
entity_id: Entity instance identifier
correlation_id: Correlation ID to track the request
Returns:
The agent response if found, None if timeout occurs
"""
agent_response = None
for attempt in range(1, self.max_poll_retries + 1):
time.sleep(self.poll_interval_seconds)
agent_response = self._poll_entity_for_response(entity_id, correlation_id)
if agent_response is not None:
logger.info(
"[ClientAgentExecutor] Found response (attempt %d/%d, correlation: %s)",
attempt,
self.max_poll_retries,
correlation_id,
)
break
logger.debug(
"[ClientAgentExecutor] Response not ready (attempt %d/%d)",
attempt,
self.max_poll_retries,
)
return agent_response
def _handle_agent_response(
self,
agent_response: AgentRunResponse | None,
response_format: type[BaseModel] | None,
correlation_id: str,
) -> AgentRunResponse:
"""Handle the agent response or create an error response.
Args:
agent_response: The response from polling, or None if timeout
response_format: Optional response format for validation
correlation_id: Correlation ID for logging
Returns:
AgentRunResponse with either the agent's response or an error message
"""
if agent_response is not None:
try:
# Validate response format if specified
if response_format is not None:
ensure_response_format(
response_format,
correlation_id,
agent_response,
)
return agent_response
except Exception as e:
logger.exception(
"[ClientAgentExecutor] Error converting response for correlation: %s",
correlation_id,
)
error_message = ChatMessage(
role=Role.SYSTEM,
contents=[
ErrorContent(
message=f"Error processing agent response: {e}",
error_code="response_processing_error",
)
],
)
else:
logger.warning(
"[ClientAgentExecutor] Timeout after %d attempts (correlation: %s)",
self.max_poll_retries,
correlation_id,
)
error_message = ChatMessage(
role=Role.SYSTEM,
contents=[
ErrorContent(
message=f"Timeout waiting for agent response after {self.max_poll_retries} attempts",
error_code="response_timeout",
)
],
)
return AgentRunResponse(
messages=[error_message],
created_at=datetime.now(timezone.utc).isoformat(),
)
def _poll_entity_for_response(
self,
entity_id: EntityInstanceId,
correlation_id: str,
) -> AgentRunResponse | None:
"""Poll the entity state for a response matching the correlation ID.
Args:
entity_id: Entity instance identifier
correlation_id: Correlation ID to search for
Returns:
Response AgentRunResponse, None otherwise
"""
try:
entity_metadata = self._client.get_entity(entity_id, include_state=True)
if entity_metadata is None:
return None
state_json = entity_metadata.get_state()
if not state_json:
return None
state = DurableAgentState.from_json(state_json)
# Use the helper method to get response by correlation ID
return state.try_get_agent_response(correlation_id)
except Exception as e:
logger.warning(
"[ClientAgentExecutor] Error reading entity state: %s",
e,
)
return None
class OrchestrationAgentExecutor(DurableAgentExecutor[DurableAgentTask]):
"""Execution strategy for orchestrations (sync/yield)."""
def __init__(self, context: OrchestrationContext):
self._context = context
logger.debug("[OrchestrationAgentExecutor] Initialized")
def get_run_request(
self,
message: str,
response_format: type[BaseModel] | None,
enable_tool_calls: bool,
) -> RunRequest:
"""Get the current run request from the orchestration context.
Returns:
RunRequest: The current run request
"""
request = super().get_run_request(
message,
response_format,
enable_tool_calls,
)
request.orchestration_id = self._context.instance_id
return request
def run_durable_agent(
self,
agent_name: str,
run_request: RunRequest,
thread: AgentThread | None = None,
) -> DurableAgentTask:
"""Execute the agent via orchestration context.
Calls the agent entity and returns a DurableAgentTask that can be yielded
in orchestrations to wait for the entity's response.
Args:
agent_name: Name of the agent to execute
run_request: The run request containing message and optional response format
thread: Optional conversation thread (creates new if not provided)
Returns:
DurableAgentTask: A task wrapping the entity call that yields AgentRunResponse
"""
# Resolve session
session_id = self._create_session_id(agent_name, thread)
# Create the entity ID
entity_id = EntityInstanceId(
entity=session_id.entity_name,
key=session_id.key,
)
logger.debug(
"[OrchestrationAgentExecutor] correlation_id: %s entity_id: %s session_id: %s",
run_request.correlation_id,
entity_id,
session_id,
)
# Call the entity and get the underlying task
entity_task: Task[Any] = self._context.call_entity(entity_id, "run", run_request.to_dict()) # type: ignore
# Wrap in DurableAgentTask for response transformation
return DurableAgentTask(
entity_task=entity_task,
response_format=run_request.response_format,
correlation_id=run_request.correlation_id,
)
@@ -8,12 +8,15 @@ This module defines the request and response models used by the framework.
from __future__ import annotations
import inspect
import json
import uuid
from collections.abc import MutableMapping
from dataclasses import dataclass
from datetime import datetime
from datetime import datetime, timezone
from importlib import import_module
from typing import TYPE_CHECKING, Any, cast
from agent_framework import Role
from agent_framework import AgentThread, Role
from ._constants import REQUEST_RESPONSE_FORMAT_TEXT
@@ -101,38 +104,38 @@ class RunRequest:
role: The role of the message sender (user, system, or assistant)
response_format: Optional Pydantic BaseModel type describing the structured response format
enable_tool_calls: Whether to enable tool calls for this request
correlation_id: Optional correlation ID for tracking the response to this specific request
correlation_id: Correlation ID for tracking the response to this specific request
created_at: Optional timestamp when the request was created
orchestration_id: Optional ID of the orchestration that initiated this request
"""
message: str
request_response_format: str
correlation_id: str
role: Role = Role.USER
response_format: type[BaseModel] | None = None
enable_tool_calls: bool = True
correlation_id: str | None = None
created_at: datetime | None = None
orchestration_id: str | None = None
def __init__(
self,
message: str,
correlation_id: str,
request_response_format: str = REQUEST_RESPONSE_FORMAT_TEXT,
role: Role | str | None = Role.USER,
response_format: type[BaseModel] | None = None,
enable_tool_calls: bool = True,
correlation_id: str | None = None,
created_at: datetime | None = None,
orchestration_id: str | None = None,
) -> None:
self.message = message
self.correlation_id = correlation_id
self.role = self.coerce_role(role)
self.response_format = response_format
self.request_response_format = request_response_format
self.enable_tool_calls = enable_tool_calls
self.correlation_id = correlation_id
self.created_at = created_at
self.created_at = created_at if created_at is not None else datetime.now(tz=timezone.utc)
self.orchestration_id = orchestration_id
@staticmethod
@@ -154,11 +157,10 @@ class RunRequest:
"enable_tool_calls": self.enable_tool_calls,
"role": self.role.value,
"request_response_format": self.request_response_format,
"correlationId": self.correlation_id,
}
if self.response_format:
result["response_format"] = serialize_response_format(self.response_format)
if self.correlation_id:
result["correlationId"] = self.correlation_id
if self.created_at:
result["created_at"] = self.created_at.isoformat()
if self.orchestration_id:
@@ -166,6 +168,16 @@ class RunRequest:
return result
@classmethod
def from_json(cls, data: str) -> RunRequest:
"""Create RunRequest from JSON string."""
try:
dict_data = json.loads(data)
except json.JSONDecodeError as e:
raise ValueError("The durable agent state is not valid JSON.") from e
return cls.from_dict(dict_data)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> RunRequest:
"""Create RunRequest from dictionary."""
@@ -176,13 +188,120 @@ class RunRequest:
except ValueError:
created_at = None
correlation_id = data.get("correlationId")
if not correlation_id:
raise ValueError("correlationId is required in RunRequest data")
return cls(
message=data.get("message", ""),
correlation_id=correlation_id,
request_response_format=data.get("request_response_format", REQUEST_RESPONSE_FORMAT_TEXT),
role=cls.coerce_role(data.get("role")),
response_format=_deserialize_response_format(data.get("response_format")),
enable_tool_calls=data.get("enable_tool_calls", True),
correlation_id=data.get("correlationId"),
created_at=created_at,
orchestration_id=data.get("orchestrationId"),
)
@dataclass
class AgentSessionId:
"""Represents an agent session identifier (name + key)."""
name: str
key: str
ENTITY_NAME_PREFIX: str = "dafx-"
@staticmethod
def to_entity_name(name: str) -> str:
return f"{AgentSessionId.ENTITY_NAME_PREFIX}{name}"
@staticmethod
def with_random_key(name: str) -> AgentSessionId:
return AgentSessionId(name=name, key=uuid.uuid4().hex)
@property
def entity_name(self) -> str:
return self.to_entity_name(self.name)
def __str__(self) -> str:
return f"@{self.name}@{self.key}"
def __repr__(self) -> str:
return f"AgentSessionId(name='{self.name}', key='{self.key}')"
@staticmethod
def parse(session_id_string: str) -> AgentSessionId:
if not session_id_string.startswith("@"):
raise ValueError(f"Invalid agent session ID format: {session_id_string}")
parts = session_id_string[1:].split("@", 1)
if len(parts) != 2:
raise ValueError(f"Invalid agent session ID format: {session_id_string}")
return AgentSessionId(name=parts[0], key=parts[1])
class DurableAgentThread(AgentThread):
"""Durable agent thread that tracks the owning :class:`AgentSessionId`."""
_SERIALIZED_SESSION_ID_KEY = "durable_session_id"
def __init__(
self,
*,
session_id: AgentSessionId | None = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self._session_id: AgentSessionId | None = session_id
@property
def session_id(self) -> AgentSessionId | None:
return self._session_id
@session_id.setter
def session_id(self, value: AgentSessionId | None) -> None:
self._session_id = value
@classmethod
def from_session_id(
cls,
session_id: AgentSessionId,
**kwargs: Any,
) -> DurableAgentThread:
return cls(session_id=session_id, **kwargs)
async def serialize(self, **kwargs: Any) -> dict[str, Any]:
state = await super().serialize(**kwargs)
if self._session_id is not None:
state[self._SERIALIZED_SESSION_ID_KEY] = str(self._session_id)
return state
@classmethod
async def deserialize(
cls,
serialized_thread_state: MutableMapping[str, Any],
*,
message_store: Any = None,
**kwargs: Any,
) -> DurableAgentThread:
state_payload = dict(serialized_thread_state)
session_id_value = state_payload.pop(cls._SERIALIZED_SESSION_ID_KEY, None)
thread = await super().deserialize(
state_payload,
message_store=message_store,
**kwargs,
)
if not isinstance(thread, DurableAgentThread):
raise TypeError("Deserialized thread is not a DurableAgentThread instance")
if session_id_value is None:
return thread
if not isinstance(session_id_value, str):
raise ValueError("durable_session_id must be a string when present in serialized state")
thread.session_id = AgentSessionId.parse(session_id_value)
return thread
@@ -0,0 +1,75 @@
# Copyright (c) Microsoft. All rights reserved.
"""Orchestration context wrapper for Durable Task Agent Framework.
This module provides the DurableAIAgentOrchestrationContext class for use inside
orchestration functions to interact with durable agents.
"""
from __future__ import annotations
from agent_framework import get_logger
from durabletask.task import OrchestrationContext
from ._executors import DurableAgentTask, OrchestrationAgentExecutor
from ._shim import DurableAgentProvider, DurableAIAgent
logger = get_logger("agent_framework.durabletask.orchestration_context")
class DurableAIAgentOrchestrationContext(DurableAgentProvider[DurableAgentTask]):
"""Orchestration context wrapper for interacting with durable agents internally.
This class wraps a durabletask OrchestrationContext and provides a convenient
interface for retrieving and executing durable agents from within orchestration
functions.
Example:
```python
from durabletask import Orchestration
from agent_framework_durabletask import DurableAIAgentOrchestrationContext
def my_orchestration(context: OrchestrationContext):
# Wrap the context
agent_context = DurableAIAgentOrchestrationContext(context)
# Get an agent reference
agent = agent_context.get_agent("assistant")
# Run the agent (returns a Task to be yielded)
result = yield agent.run("Hello, how are you?")
return result.text
```
"""
def __init__(self, context: OrchestrationContext):
"""Initialize the orchestration context wrapper.
Args:
context: The durabletask orchestration context to wrap
"""
self._context = context
self._executor = OrchestrationAgentExecutor(self._context)
logger.debug("[DurableAIAgentOrchestrationContext] Initialized")
def get_agent(self, agent_name: str) -> DurableAIAgent[DurableAgentTask]:
"""Retrieve a DurableAIAgent shim for the specified agent.
This method returns a proxy object that can be used to execute the agent
within an orchestration. The agent's run() method will return a Task that
must be yielded.
Args:
agent_name: Name of the agent to retrieve (without the dafx- prefix)
Returns:
DurableAIAgent instance that can be used to run the agent
Note:
Validation is deferred to execution time. The entity must be registered
on a worker with the name f"dafx-{agent_name}".
"""
logger.debug("[DurableAIAgentOrchestrationContext] Creating agent proxy for: %s", agent_name)
return DurableAIAgent(self._executor, agent_name)
@@ -0,0 +1,62 @@
# Copyright (c) Microsoft. All rights reserved.
"""Shared utilities for handling AgentRunResponse parsing and validation."""
from typing import Any
from agent_framework import AgentRunResponse, get_logger
from pydantic import BaseModel
logger = get_logger("agent_framework.durabletask.response_utils")
def load_agent_response(agent_response: AgentRunResponse | dict[str, Any] | None) -> AgentRunResponse:
"""Convert raw payloads into AgentRunResponse instance.
Args:
agent_response: The response to convert, can be an AgentRunResponse, dict, or None
Returns:
AgentRunResponse: The converted response object
Raises:
ValueError: If agent_response is None
TypeError: If agent_response is an unsupported type
"""
if agent_response is None:
raise ValueError("agent_response cannot be None")
logger.debug("[load_agent_response] Loading agent response of type: %s", type(agent_response))
if isinstance(agent_response, AgentRunResponse):
return agent_response
if isinstance(agent_response, dict):
logger.debug("[load_agent_response] Converting dict payload using AgentRunResponse.from_dict")
return AgentRunResponse.from_dict(agent_response)
raise TypeError(f"Unsupported type for agent_response: {type(agent_response)}")
def ensure_response_format(
response_format: type[BaseModel] | None,
correlation_id: str,
response: AgentRunResponse,
) -> None:
"""Ensure the AgentRunResponse value is parsed into the expected response_format.
This function modifies the response in-place by parsing its value attribute
into the specified Pydantic model format.
Args:
response_format: Optional Pydantic model class to parse the response value into
correlation_id: Correlation ID for logging purposes
response: The AgentRunResponse object to validate and parse
"""
if response_format is not None and not isinstance(response.value, response_format):
response.try_parse_value(response_format)
logger.debug(
"[ensure_response_format] Loaded AgentRunResponse.value for correlation_id %s with type: %s",
correlation_id,
type(response.value).__name__,
)
@@ -0,0 +1,185 @@
# Copyright (c) Microsoft. All rights reserved.
"""Durable Agent Shim for Durable Task Framework.
This module provides the DurableAIAgent shim that implements AgentProtocol
and provides a consistent interface for both Client and Orchestration contexts.
The actual execution is delegated to the context-specific providers.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from typing import Any, Generic, TypeVar
from agent_framework import AgentProtocol, AgentRunResponseUpdate, AgentThread, ChatMessage
from pydantic import BaseModel
from ._executors import DurableAgentExecutor
from ._models import DurableAgentThread
# TypeVar for the task type returned by executors
# Covariant because TaskT only appears in return positions (output)
TaskT = TypeVar("TaskT", covariant=True)
class DurableAgentProvider(ABC, Generic[TaskT]):
"""Abstract provider for constructing durable agent proxies.
Implemented by context-specific wrappers (client/orchestration) to return a
`DurableAIAgent` shim backed by their respective `DurableAgentExecutor`
implementation, ensuring a consistent `get_agent` entry point regardless of
execution context.
"""
@abstractmethod
def get_agent(self, agent_name: str) -> DurableAIAgent[TaskT]:
"""Retrieve a DurableAIAgent shim for the specified agent.
Args:
agent_name: Name of the agent to retrieve
Returns:
DurableAIAgent instance that can be used to run the agent
Raises:
NotImplementedError: Must be implemented by subclasses
"""
raise NotImplementedError("Subclasses must implement get_agent()")
class DurableAIAgent(AgentProtocol, Generic[TaskT]):
"""A durable agent proxy that delegates execution to the provider.
This class implements AgentProtocol but with one critical difference:
- AgentProtocol.run() returns a Coroutine (async, must await)
- DurableAIAgent.run() returns TaskT (sync Task object - must yield
or the AgentRunResponse directly in the case of TaskHubGrpcClient)
This represents fundamentally different execution models but maintains the same
interface contract for all other properties and methods.
The underlying provider determines how execution occurs (entity calls, HTTP requests, etc.)
and what type of Task object is returned.
Type Parameters:
TaskT: The task type returned by this agent (e.g., AgentRunResponse, DurableAgentTask, AgentTask)
"""
def __init__(self, executor: DurableAgentExecutor[TaskT], name: str, *, agent_id: str | None = None):
"""Initialize the shim with a provider and agent name.
Args:
executor: The execution provider (Client or OrchestrationContext)
name: The name of the agent to execute
agent_id: Optional unique identifier for the agent (defaults to name)
"""
self._executor = executor
self._name = name
self._id = agent_id if agent_id is not None else name
self._display_name = name
self._description = f"Durable agent proxy for {name}"
@property
def id(self) -> str:
"""Get the unique identifier for this agent."""
return self._id
@property
def name(self) -> str | None:
"""Get the name of the agent."""
return self._name
@property
def display_name(self) -> str:
"""Get the display name of the agent."""
return self._display_name
@property
def description(self) -> str | None:
"""Get the description of the agent."""
return self._description
def run( # pyright: ignore[reportIncompatibleMethodOverride]
self,
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
*,
thread: AgentThread | None = None,
response_format: type[BaseModel] | None = None,
enable_tool_calls: bool = True,
) -> TaskT:
"""Execute the agent via the injected provider.
Note:
This method overrides AgentProtocol.run() with a different return type:
- AgentProtocol.run() returns Coroutine[Any, Any, AgentRunResponse] (async)
- DurableAIAgent.run() returns TaskT (Task object for yielding)
This is intentional to support orchestration contexts that use yield patterns
instead of async/await patterns.
Returns:
TaskT: The task type specific to the executor
"""
message_str = self._normalize_messages(messages)
run_request = self._executor.get_run_request(
message=message_str,
response_format=response_format,
enable_tool_calls=enable_tool_calls,
)
return self._executor.run_durable_agent(
agent_name=self._name,
run_request=run_request,
thread=thread,
)
def run_stream(
self,
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
*,
thread: AgentThread | None = None,
**kwargs: Any,
) -> AsyncIterator[AgentRunResponseUpdate]:
"""Run the agent with streaming (not supported for durable agents).
Args:
messages: The message(s) to send to the agent
thread: Optional agent thread for conversation context
**kwargs: Additional arguments
Raises:
NotImplementedError: Streaming is not supported for durable agents
"""
raise NotImplementedError("Streaming is not supported for durable agents")
def get_new_thread(self, **kwargs: Any) -> DurableAgentThread:
"""Create a new agent thread via the provider."""
return self._executor.get_new_thread(self._name, **kwargs)
def _normalize_messages(self, messages: str | ChatMessage | list[str] | list[ChatMessage] | None) -> str:
"""Convert supported message inputs to a single string.
Args:
messages: The messages to normalize
Returns:
A single string representation of the messages
"""
if messages is None:
return ""
if isinstance(messages, str):
return messages
if isinstance(messages, ChatMessage):
return messages.text or ""
if isinstance(messages, list):
if not messages:
return ""
first_item = messages[0]
if isinstance(first_item, str):
return "\n".join(messages) # type: ignore[arg-type]
# List of ChatMessage
return "\n".join([msg.text or "" for msg in messages]) # type: ignore[union-attr]
return ""
@@ -0,0 +1,218 @@
# Copyright (c) Microsoft. All rights reserved.
"""Worker wrapper for Durable Task Agent Framework.
This module provides the DurableAIAgentWorker class that wraps a durabletask worker
and enables registration of agents as durable entities.
"""
from __future__ import annotations
import asyncio
from typing import Any
from agent_framework import AgentProtocol, get_logger
from durabletask.worker import TaskHubGrpcWorker
from ._callbacks import AgentResponseCallbackProtocol
from ._entities import AgentEntity, DurableTaskEntityStateProvider
logger = get_logger("agent_framework.durabletask.worker")
class DurableAIAgentWorker:
"""Wrapper for durabletask worker that enables agent registration.
This class wraps an existing TaskHubGrpcWorker instance and provides
a convenient interface for registering agents as durable entities.
Example:
```python
from durabletask import TaskHubGrpcWorker
from agent_framework import ChatAgent
from agent_framework_durabletask import DurableAIAgentWorker
# Create the underlying worker
worker = TaskHubGrpcWorker(host_address="localhost:4001")
# Wrap it with the agent worker
agent_worker = DurableAIAgentWorker(worker)
# Register agents
my_agent = ChatAgent(chat_client=client, name="assistant")
agent_worker.add_agent(my_agent)
# Start the worker
worker.start()
```
"""
def __init__(
self,
worker: TaskHubGrpcWorker,
callback: AgentResponseCallbackProtocol | None = None,
):
"""Initialize the worker wrapper.
Args:
worker: The durabletask worker instance to wrap
callback: Optional callback for agent response notifications
"""
self._worker = worker
self._callback = callback
self._registered_agents: dict[str, AgentProtocol] = {}
logger.debug("[DurableAIAgentWorker] Initialized with worker type: %s", type(worker).__name__)
def add_agent(
self,
agent: AgentProtocol,
callback: AgentResponseCallbackProtocol | None = None,
) -> None:
"""Register an agent with the worker.
This method creates a durable entity class for the agent and registers
it with the underlying durabletask worker. The entity will be accessible
by the name "dafx-{agent_name}".
Args:
agent: The agent to register (must have a name)
callback: Optional callback for this specific agent (overrides worker-level callback)
Raises:
ValueError: If the agent doesn't have a name or is already registered
"""
agent_name = agent.name
if not agent_name:
raise ValueError("Agent must have a name to be registered")
if agent_name in self._registered_agents:
raise ValueError(f"Agent '{agent_name}' is already registered")
logger.info("[DurableAIAgentWorker] Registering agent: %s as entity: dafx-%s", agent_name, agent_name)
# Store the agent reference
self._registered_agents[agent_name] = agent
# Use agent-specific callback if provided, otherwise use worker-level callback
effective_callback = callback or self._callback
# Create a configured entity class using the factory
entity_class = self.__create_agent_entity(agent, effective_callback)
# Register the entity class with the worker
# The worker.add_entity method takes a class
entity_registered: str = self._worker.add_entity(entity_class) # pyright: ignore[reportUnknownMemberType]
logger.debug(
"[DurableAIAgentWorker] Successfully registered entity class %s for agent: %s",
entity_registered,
agent_name,
)
def start(self) -> None:
"""Start the worker to begin processing tasks.
Note:
This method delegates to the underlying worker's start method.
The worker will block until stopped.
"""
logger.info("[DurableAIAgentWorker] Starting worker with %d registered agents", len(self._registered_agents))
self._worker.start()
def stop(self) -> None:
"""Stop the worker gracefully.
Note:
This method delegates to the underlying worker's stop method.
"""
logger.info("[DurableAIAgentWorker] Stopping worker")
self._worker.stop()
@property
def registered_agent_names(self) -> list[str]:
"""Get the names of all registered agents.
Returns:
List of agent names (without the dafx- prefix)
"""
return list(self._registered_agents.keys())
def __create_agent_entity(
self,
agent: AgentProtocol,
callback: AgentResponseCallbackProtocol | None = None,
) -> type[DurableTaskEntityStateProvider]:
"""Factory function to create a DurableEntity class configured with an agent.
This factory creates a new class that combines the entity state provider
with the agent execution logic. Each agent gets its own entity class.
Args:
agent: The agent instance to wrap
callback: Optional callback for agent responses
Returns:
A new DurableEntity subclass configured for this agent
"""
agent_name = agent.name or type(agent).__name__
entity_name = f"dafx-{agent_name}"
class ConfiguredAgentEntity(DurableTaskEntityStateProvider):
"""Durable entity configured with a specific agent instance."""
def __init__(self) -> None:
super().__init__()
# Create the AgentEntity with this state provider
self._agent_entity = AgentEntity(
agent=agent,
callback=callback,
state_provider=self,
)
logger.debug(
"[ConfiguredAgentEntity] Initialized entity for agent: %s (entity name: %s)",
agent_name,
entity_name,
)
def run(self, request: Any) -> Any:
"""Handle run requests from clients or orchestrations.
Args:
request: RunRequest as dict or string
Returns:
AgentRunResponse as dict
"""
logger.debug("[ConfiguredAgentEntity.run] Executing agent: %s", agent_name)
# Get or create event loop for async execution
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Run the async agent execution synchronously
if loop.is_running():
# If loop is already running (shouldn't happen in entity context),
# create a temporary loop
temp_loop = asyncio.new_event_loop()
try:
response = temp_loop.run_until_complete(self._agent_entity.run(request))
finally:
temp_loop.close()
else:
response = loop.run_until_complete(self._agent_entity.run(request))
return response.to_dict()
def reset(self) -> None:
"""Reset the agent's conversation history."""
logger.debug("[ConfiguredAgentEntity.reset] Resetting agent: %s", agent_name)
self._agent_entity.reset()
# Set the entity name to match the prefixed agent name
# This is used by durabletask to register the entity
ConfiguredAgentEntity.__name__ = entity_name
ConfiguredAgentEntity.__qualname__ = entity_name
return ConfiguredAgentEntity
@@ -0,0 +1,272 @@
# Copyright (c) Microsoft. All rights reserved.
"""Unit tests for AgentSessionId and DurableAgentThread."""
import pytest
from agent_framework import AgentThread
from agent_framework_durabletask._models import AgentSessionId, DurableAgentThread
class TestAgentSessionId:
"""Test suite for AgentSessionId."""
def test_init_creates_session_id(self) -> None:
"""Test that AgentSessionId initializes correctly."""
session_id = AgentSessionId(name="AgentEntity", key="test-key-123")
assert session_id.name == "AgentEntity"
assert session_id.key == "test-key-123"
def test_with_random_key_generates_guid(self) -> None:
"""Test that with_random_key generates a GUID."""
session_id = AgentSessionId.with_random_key(name="AgentEntity")
assert session_id.name == "AgentEntity"
assert len(session_id.key) == 32 # UUID hex is 32 chars
# Verify it's a valid hex string
int(session_id.key, 16)
def test_with_random_key_unique_keys(self) -> None:
"""Test that with_random_key generates unique keys."""
session_id1 = AgentSessionId.with_random_key(name="AgentEntity")
session_id2 = AgentSessionId.with_random_key(name="AgentEntity")
assert session_id1.key != session_id2.key
def test_str_representation(self) -> None:
"""Test string representation."""
session_id = AgentSessionId(name="AgentEntity", key="test-key-123")
str_repr = str(session_id)
assert str_repr == "@AgentEntity@test-key-123"
def test_repr_representation(self) -> None:
"""Test repr representation."""
session_id = AgentSessionId(name="AgentEntity", key="test-key")
repr_str = repr(session_id)
assert "AgentSessionId" in repr_str
assert "AgentEntity" in repr_str
assert "test-key" in repr_str
def test_parse_valid_session_id(self) -> None:
"""Test parsing valid session ID string."""
session_id = AgentSessionId.parse("@AgentEntity@test-key-123")
assert session_id.name == "AgentEntity"
assert session_id.key == "test-key-123"
def test_parse_invalid_format_no_prefix(self) -> None:
"""Test parsing invalid format without @ prefix."""
with pytest.raises(ValueError) as exc_info:
AgentSessionId.parse("AgentEntity@test-key")
assert "Invalid agent session ID format" in str(exc_info.value)
def test_parse_invalid_format_single_part(self) -> None:
"""Test parsing invalid format with single part."""
with pytest.raises(ValueError) as exc_info:
AgentSessionId.parse("@AgentEntity")
assert "Invalid agent session ID format" in str(exc_info.value)
def test_parse_with_multiple_at_signs_in_key(self) -> None:
"""Test parsing with @ signs in the key."""
session_id = AgentSessionId.parse("@AgentEntity@key-with@symbols")
assert session_id.name == "AgentEntity"
assert session_id.key == "key-with@symbols"
def test_parse_round_trip(self) -> None:
"""Test round-trip parse and string conversion."""
original = AgentSessionId(name="AgentEntity", key="test-key")
str_repr = str(original)
parsed = AgentSessionId.parse(str_repr)
assert parsed.name == original.name
assert parsed.key == original.key
def test_to_entity_name_adds_prefix(self) -> None:
"""Test that to_entity_name adds the dafx- prefix."""
entity_name = AgentSessionId.to_entity_name("TestAgent")
assert entity_name == "dafx-TestAgent"
class TestDurableAgentThread:
"""Test suite for DurableAgentThread."""
def test_init_with_session_id(self) -> None:
"""Test DurableAgentThread initialization with session ID."""
session_id = AgentSessionId(name="TestAgent", key="test-key")
thread = DurableAgentThread(session_id=session_id)
assert thread.session_id is not None
assert thread.session_id == session_id
def test_init_without_session_id(self) -> None:
"""Test DurableAgentThread initialization without session ID."""
thread = DurableAgentThread()
assert thread.session_id is None
def test_session_id_setter(self) -> None:
"""Test setting a session ID to an existing thread."""
thread = DurableAgentThread()
assert thread.session_id is None
session_id = AgentSessionId(name="TestAgent", key="test-key")
thread.session_id = session_id
assert thread.session_id is not None
assert thread.session_id == session_id
assert thread.session_id.name == "TestAgent"
def test_from_session_id(self) -> None:
"""Test creating DurableAgentThread from session ID."""
session_id = AgentSessionId(name="TestAgent", key="test-key")
thread = DurableAgentThread.from_session_id(session_id)
assert isinstance(thread, DurableAgentThread)
assert thread.session_id is not None
assert thread.session_id == session_id
assert thread.session_id.name == "TestAgent"
assert thread.session_id.key == "test-key"
def test_from_session_id_with_service_thread_id(self) -> None:
"""Test creating DurableAgentThread with service thread ID."""
session_id = AgentSessionId(name="TestAgent", key="test-key")
thread = DurableAgentThread.from_session_id(session_id, service_thread_id="service-123")
assert thread.session_id is not None
assert thread.session_id == session_id
assert thread.service_thread_id == "service-123"
async def test_serialize_with_session_id(self) -> None:
"""Test serialization includes session ID."""
session_id = AgentSessionId(name="TestAgent", key="test-key")
thread = DurableAgentThread(session_id=session_id)
serialized = await thread.serialize()
assert isinstance(serialized, dict)
assert "durable_session_id" in serialized
assert serialized["durable_session_id"] == "@TestAgent@test-key"
async def test_serialize_without_session_id(self) -> None:
"""Test serialization without session ID."""
thread = DurableAgentThread()
serialized = await thread.serialize()
assert isinstance(serialized, dict)
assert "durable_session_id" not in serialized
async def test_deserialize_with_session_id(self) -> None:
"""Test deserialization restores session ID."""
serialized = {
"service_thread_id": "thread-123",
"durable_session_id": "@TestAgent@test-key",
}
thread = await DurableAgentThread.deserialize(serialized)
assert isinstance(thread, DurableAgentThread)
assert thread.session_id is not None
assert thread.session_id.name == "TestAgent"
assert thread.session_id.key == "test-key"
assert thread.service_thread_id == "thread-123"
async def test_deserialize_without_session_id(self) -> None:
"""Test deserialization without session ID."""
serialized = {
"service_thread_id": "thread-456",
}
thread = await DurableAgentThread.deserialize(serialized)
assert isinstance(thread, DurableAgentThread)
assert thread.session_id is None
assert thread.service_thread_id == "thread-456"
async def test_round_trip_serialization(self) -> None:
"""Test round-trip serialization preserves session ID."""
session_id = AgentSessionId(name="TestAgent", key="test-key-789")
original = DurableAgentThread(session_id=session_id)
serialized = await original.serialize()
restored = await DurableAgentThread.deserialize(serialized)
assert isinstance(restored, DurableAgentThread)
assert restored.session_id is not None
assert restored.session_id.name == session_id.name
assert restored.session_id.key == session_id.key
async def test_deserialize_invalid_session_id_type(self) -> None:
"""Test deserialization with invalid session ID type raises error."""
serialized = {
"service_thread_id": "thread-123",
"durable_session_id": 12345, # Invalid type
}
with pytest.raises(ValueError, match="durable_session_id must be a string"):
await DurableAgentThread.deserialize(serialized)
class TestAgentThreadCompatibility:
"""Test suite for compatibility between AgentThread and DurableAgentThread."""
async def test_agent_thread_serialize(self) -> None:
"""Test that base AgentThread can be serialized."""
thread = AgentThread()
serialized = await thread.serialize()
assert isinstance(serialized, dict)
assert "service_thread_id" in serialized
async def test_agent_thread_deserialize(self) -> None:
"""Test that base AgentThread can be deserialized."""
thread = AgentThread()
serialized = await thread.serialize()
restored = await AgentThread.deserialize(serialized)
assert isinstance(restored, AgentThread)
assert restored.service_thread_id == thread.service_thread_id
async def test_durable_thread_is_agent_thread(self) -> None:
"""Test that DurableAgentThread is an AgentThread."""
thread = DurableAgentThread()
assert isinstance(thread, AgentThread)
assert isinstance(thread, DurableAgentThread)
class TestModelIntegration:
"""Test suite for integration between models."""
def test_session_id_string_format(self) -> None:
"""Test that AgentSessionId string format is consistent."""
session_id = AgentSessionId.with_random_key("AgentEntity")
session_id_str = str(session_id)
assert session_id_str.startswith("@AgentEntity@")
async def test_thread_with_session_preserves_on_serialization(self) -> None:
"""Test that thread with session ID preserves it through serialization."""
session_id = AgentSessionId(name="TestAgent", key="preserved-key")
thread = DurableAgentThread.from_session_id(session_id)
# Serialize and deserialize
serialized = await thread.serialize()
restored = await DurableAgentThread.deserialize(serialized)
# Session ID should be preserved
assert restored.session_id is not None
assert restored.session_id.name == "TestAgent"
assert restored.session_id.key == "preserved-key"
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])
@@ -0,0 +1,142 @@
# Copyright (c) Microsoft. All rights reserved.
"""Unit tests for DurableAIAgentClient.
Focuses on critical client workflows: agent retrieval, protocol compliance, and integration.
Run with: pytest tests/test_client.py -v
"""
from unittest.mock import Mock
import pytest
from agent_framework import AgentProtocol
from agent_framework_durabletask import DurableAgentThread, DurableAIAgentClient
from agent_framework_durabletask._constants import DEFAULT_MAX_POLL_RETRIES, DEFAULT_POLL_INTERVAL_SECONDS
from agent_framework_durabletask._shim import DurableAIAgent
@pytest.fixture
def mock_grpc_client() -> Mock:
"""Create a mock TaskHubGrpcClient for testing."""
return Mock()
@pytest.fixture
def agent_client(mock_grpc_client: Mock) -> DurableAIAgentClient:
"""Create a DurableAIAgentClient with mock gRPC client."""
return DurableAIAgentClient(mock_grpc_client)
@pytest.fixture
def agent_client_with_custom_polling(mock_grpc_client: Mock) -> DurableAIAgentClient:
"""Create a DurableAIAgentClient with custom polling parameters."""
return DurableAIAgentClient(
mock_grpc_client,
max_poll_retries=15,
poll_interval_seconds=0.5,
)
class TestDurableAIAgentClientGetAgent:
"""Test core workflow: retrieving agents from the client."""
def test_get_agent_returns_durable_agent_shim(self, agent_client: DurableAIAgentClient) -> None:
"""Verify get_agent returns a DurableAIAgent instance."""
agent = agent_client.get_agent("assistant")
assert isinstance(agent, DurableAIAgent)
assert isinstance(agent, AgentProtocol)
def test_get_agent_shim_has_correct_name(self, agent_client: DurableAIAgentClient) -> None:
"""Verify retrieved agent has the correct name."""
agent = agent_client.get_agent("my_agent")
assert agent.name == "my_agent"
def test_get_agent_multiple_times_returns_new_instances(self, agent_client: DurableAIAgentClient) -> None:
"""Verify multiple get_agent calls return independent instances."""
agent1 = agent_client.get_agent("assistant")
agent2 = agent_client.get_agent("assistant")
assert agent1 is not agent2 # Different object instances
def test_get_agent_different_agents(self, agent_client: DurableAIAgentClient) -> None:
"""Verify client can retrieve multiple different agents."""
agent1 = agent_client.get_agent("agent1")
agent2 = agent_client.get_agent("agent2")
assert agent1.name == "agent1"
assert agent2.name == "agent2"
class TestDurableAIAgentClientIntegration:
"""Test integration scenarios between client and agent shim."""
def test_client_agent_has_working_run_method(self, agent_client: DurableAIAgentClient) -> None:
"""Verify agent from client has callable run method (even if not yet implemented)."""
agent = agent_client.get_agent("assistant")
assert hasattr(agent, "run")
assert callable(agent.run)
def test_client_agent_can_create_threads(self, agent_client: DurableAIAgentClient) -> None:
"""Verify agent from client can create DurableAgentThread instances."""
agent = agent_client.get_agent("assistant")
thread = agent.get_new_thread()
assert isinstance(thread, DurableAgentThread)
def test_client_agent_thread_with_parameters(self, agent_client: DurableAIAgentClient) -> None:
"""Verify agent can create threads with custom parameters."""
agent = agent_client.get_agent("assistant")
thread = agent.get_new_thread(service_thread_id="client-session-123")
assert isinstance(thread, DurableAgentThread)
assert thread.service_thread_id == "client-session-123"
class TestDurableAIAgentClientPollingConfiguration:
"""Test polling configuration parameters for DurableAIAgentClient."""
def test_client_uses_default_polling_parameters(self, agent_client: DurableAIAgentClient) -> None:
"""Verify client initializes with default polling parameters."""
assert agent_client.max_poll_retries == DEFAULT_MAX_POLL_RETRIES
assert agent_client.poll_interval_seconds == DEFAULT_POLL_INTERVAL_SECONDS
def test_client_accepts_custom_polling_parameters(
self, agent_client_with_custom_polling: DurableAIAgentClient
) -> None:
"""Verify client accepts and stores custom polling parameters."""
assert agent_client_with_custom_polling.max_poll_retries == 15
assert agent_client_with_custom_polling.poll_interval_seconds == 0.5
def test_client_validates_max_poll_retries(self, mock_grpc_client: Mock) -> None:
"""Verify client validates and normalizes max_poll_retries."""
# Test with zero - should enforce minimum of 1
client = DurableAIAgentClient(mock_grpc_client, max_poll_retries=0)
assert client.max_poll_retries == 1
# Test with negative - should enforce minimum of 1
client = DurableAIAgentClient(mock_grpc_client, max_poll_retries=-5)
assert client.max_poll_retries == 1
def test_client_validates_poll_interval_seconds(self, mock_grpc_client: Mock) -> None:
"""Verify client validates and normalizes poll_interval_seconds."""
# Test with zero - should use default
client = DurableAIAgentClient(mock_grpc_client, poll_interval_seconds=0)
assert client.poll_interval_seconds == DEFAULT_POLL_INTERVAL_SECONDS
# Test with negative - should use default
client = DurableAIAgentClient(mock_grpc_client, poll_interval_seconds=-0.5)
assert client.poll_interval_seconds == DEFAULT_POLL_INTERVAL_SECONDS
# Test with valid float
client = DurableAIAgentClient(mock_grpc_client, poll_interval_seconds=2.5)
assert client.poll_interval_seconds == 2.5
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])
@@ -112,20 +112,21 @@ class TestDurableAgentStateMessageCreatedAt:
"""Test suite for DurableAgentStateMessage created_at field handling."""
def test_message_from_run_request_without_created_at_preserves_none(self) -> None:
"""Test from_run_request preserves None created_at instead of defaulting to current time.
"""Test from_run_request handles auto-populated created_at from RunRequest.
When a RunRequest has no created_at value, the resulting DurableAgentStateMessage
should also have None for created_at, not default to current UTC time.
When a RunRequest is created with None for created_at, RunRequest defaults it to
current UTC time. The resulting DurableAgentStateMessage should have this timestamp.
"""
run_request = RunRequest(
message="test message",
correlation_id="corr-run",
created_at=None, # Explicitly None
created_at=None, # RunRequest will default this to current time
)
durable_message = DurableAgentStateMessage.from_run_request(run_request)
assert durable_message.created_at is None
# RunRequest auto-populates created_at, so it should not be None
assert durable_message.created_at is not None
def test_message_from_run_request_with_created_at_parses_correctly(self) -> None:
"""Test from_run_request correctly parses a valid created_at timestamp."""
@@ -0,0 +1,320 @@
# Copyright (c) Microsoft. All rights reserved.
"""Unit tests for DurableAgentExecutor implementations.
Focuses on critical behavioral flows for executor strategies.
Run with: pytest tests/test_executors.py -v
"""
import time
from typing import Any
from unittest.mock import Mock
import pytest
from agent_framework import AgentRunResponse, Role
from durabletask.entities import EntityInstanceId
from durabletask.task import Task
from pydantic import BaseModel
from agent_framework_durabletask import DurableAgentThread
from agent_framework_durabletask._constants import DEFAULT_MAX_POLL_RETRIES, DEFAULT_POLL_INTERVAL_SECONDS
from agent_framework_durabletask._executors import (
ClientAgentExecutor,
DurableAgentTask,
OrchestrationAgentExecutor,
)
from agent_framework_durabletask._models import AgentSessionId, RunRequest
# Fixtures
@pytest.fixture
def mock_client() -> Mock:
"""Provide a mock client for ClientAgentExecutor tests."""
client = Mock()
client.signal_entity = Mock()
client.get_entity = Mock(return_value=None)
return client
@pytest.fixture
def mock_entity_task() -> Mock:
"""Provide a mock entity task."""
return Mock(spec=Task)
@pytest.fixture
def mock_orchestration_context(mock_entity_task: Mock) -> Mock:
"""Provide a mock orchestration context with call_entity configured."""
context = Mock()
context.call_entity = Mock(return_value=mock_entity_task)
return context
@pytest.fixture
def sample_run_request() -> RunRequest:
"""Provide a sample RunRequest for tests."""
return RunRequest(message="test message", correlation_id="test-123")
@pytest.fixture
def client_executor(mock_client: Mock) -> ClientAgentExecutor:
"""Provide a ClientAgentExecutor with minimal polling for fast tests."""
return ClientAgentExecutor(mock_client, max_poll_retries=1, poll_interval_seconds=0.01)
@pytest.fixture
def orchestration_executor(mock_orchestration_context: Mock) -> OrchestrationAgentExecutor:
"""Provide an OrchestrationAgentExecutor."""
return OrchestrationAgentExecutor(mock_orchestration_context)
@pytest.fixture
def successful_agent_response() -> dict[str, Any]:
"""Provide a successful agent response dictionary."""
return {
"messages": [{"role": "assistant", "contents": [{"type": "text", "text": "Hello!"}]}],
"created_at": "2025-12-30T10:00:00Z",
}
class TestExecutorThreadCreation:
"""Test that executors properly create DurableAgentThread with parameters."""
def test_client_executor_creates_durable_thread(self, mock_client: Mock) -> None:
"""Verify ClientAgentExecutor creates DurableAgentThread instances."""
executor = ClientAgentExecutor(mock_client)
thread = executor.get_new_thread("test_agent")
assert isinstance(thread, DurableAgentThread)
def test_client_executor_forwards_kwargs_to_thread(self, mock_client: Mock) -> None:
"""Verify ClientAgentExecutor forwards kwargs to DurableAgentThread creation."""
executor = ClientAgentExecutor(mock_client)
thread = executor.get_new_thread("test_agent", service_thread_id="client-123")
assert isinstance(thread, DurableAgentThread)
assert thread.service_thread_id == "client-123"
def test_orchestration_executor_creates_durable_thread(
self, orchestration_executor: OrchestrationAgentExecutor
) -> None:
"""Verify OrchestrationAgentExecutor creates DurableAgentThread instances."""
thread = orchestration_executor.get_new_thread("test_agent")
assert isinstance(thread, DurableAgentThread)
def test_orchestration_executor_forwards_kwargs_to_thread(
self, orchestration_executor: OrchestrationAgentExecutor
) -> None:
"""Verify OrchestrationAgentExecutor forwards kwargs to DurableAgentThread creation."""
thread = orchestration_executor.get_new_thread("test_agent", service_thread_id="orch-456")
assert isinstance(thread, DurableAgentThread)
assert thread.service_thread_id == "orch-456"
class TestClientAgentExecutorRun:
"""Test that ClientAgentExecutor.run_durable_agent works as implemented."""
def test_client_executor_run_returns_response(
self, client_executor: ClientAgentExecutor, sample_run_request: RunRequest
) -> None:
"""Verify ClientAgentExecutor.run_durable_agent returns AgentRunResponse (synchronous)."""
result = client_executor.run_durable_agent("test_agent", sample_run_request)
# Verify it returns an AgentRunResponse (synchronous, not a coroutine)
assert isinstance(result, AgentRunResponse)
assert result is not None
class TestClientAgentExecutorPollingConfiguration:
"""Test polling configuration parameters for ClientAgentExecutor."""
def test_executor_uses_default_polling_parameters(self, mock_client: Mock) -> None:
"""Verify executor initializes with default polling parameters."""
executor = ClientAgentExecutor(mock_client)
assert executor.max_poll_retries == DEFAULT_MAX_POLL_RETRIES
assert executor.poll_interval_seconds == DEFAULT_POLL_INTERVAL_SECONDS
def test_executor_accepts_custom_polling_parameters(self, mock_client: Mock) -> None:
"""Verify executor accepts and stores custom polling parameters."""
executor = ClientAgentExecutor(mock_client, max_poll_retries=20, poll_interval_seconds=0.5)
assert executor.max_poll_retries == 20
assert executor.poll_interval_seconds == 0.5
def test_executor_respects_custom_max_poll_retries(self, mock_client: Mock, sample_run_request: RunRequest) -> None:
"""Verify executor respects custom max_poll_retries during polling."""
# Create executor with only 2 retries
executor = ClientAgentExecutor(mock_client, max_poll_retries=2, poll_interval_seconds=0.01)
# Run the agent
result = executor.run_durable_agent("test_agent", sample_run_request)
# Verify it returns AgentRunResponse (should timeout after 2 attempts)
assert isinstance(result, AgentRunResponse)
# Verify get_entity was called 2 times (max_poll_retries)
assert mock_client.get_entity.call_count == 2
def test_executor_respects_custom_poll_interval(self, mock_client: Mock, sample_run_request: RunRequest) -> None:
"""Verify executor respects custom poll_interval_seconds during polling."""
# Create executor with very short interval
executor = ClientAgentExecutor(mock_client, max_poll_retries=3, poll_interval_seconds=0.01)
# Measure time taken
start = time.time()
result = executor.run_durable_agent("test_agent", sample_run_request)
elapsed = time.time() - start
# Should take roughly 3 * 0.01 = 0.03 seconds (plus overhead)
# Be generous with timing to avoid flakiness
assert elapsed < 0.2 # Should be quick with 0.01 interval
assert isinstance(result, AgentRunResponse)
class TestOrchestrationAgentExecutorRun:
"""Test OrchestrationAgentExecutor.run_durable_agent implementation."""
def test_orchestration_executor_run_returns_durable_agent_task(
self, orchestration_executor: OrchestrationAgentExecutor, sample_run_request: RunRequest
) -> None:
"""Verify OrchestrationAgentExecutor.run_durable_agent returns DurableAgentTask."""
result = orchestration_executor.run_durable_agent("test_agent", sample_run_request)
assert isinstance(result, DurableAgentTask)
def test_orchestration_executor_calls_entity_with_correct_parameters(
self,
mock_orchestration_context: Mock,
orchestration_executor: OrchestrationAgentExecutor,
sample_run_request: RunRequest,
) -> None:
"""Verify call_entity is invoked with correct entity ID and request."""
orchestration_executor.run_durable_agent("test_agent", sample_run_request)
# Verify call_entity was called once
assert mock_orchestration_context.call_entity.call_count == 1
# Get the call arguments
call_args = mock_orchestration_context.call_entity.call_args
entity_id_arg = call_args[0][0]
operation_arg = call_args[0][1]
request_dict_arg = call_args[0][2]
# Verify entity ID
assert isinstance(entity_id_arg, EntityInstanceId)
assert entity_id_arg.entity == "dafx-test_agent"
# Verify operation name
assert operation_arg == "run"
# Verify request dict
assert request_dict_arg == sample_run_request.to_dict()
def test_orchestration_executor_uses_thread_session_id(
self,
mock_orchestration_context: Mock,
orchestration_executor: OrchestrationAgentExecutor,
sample_run_request: RunRequest,
) -> None:
"""Verify executor uses thread's session ID when provided."""
# Create thread with specific session ID
session_id = AgentSessionId(name="test_agent", key="specific-key-123")
thread = DurableAgentThread.from_session_id(session_id)
result = orchestration_executor.run_durable_agent("test_agent", sample_run_request, thread=thread)
# Verify call_entity was called with the specific key
call_args = mock_orchestration_context.call_entity.call_args
entity_id_arg = call_args[0][0]
assert entity_id_arg.key == "specific-key-123"
assert isinstance(result, DurableAgentTask)
class TestDurableAgentTask:
"""Test DurableAgentTask completion and response transformation."""
def test_durable_agent_task_transforms_successful_result(
self, mock_entity_task: Mock, successful_agent_response: dict[str, Any]
) -> None:
"""Verify DurableAgentTask converts successful entity result to AgentRunResponse."""
mock_entity_task.is_failed = False
mock_entity_task.get_result = Mock(return_value=successful_agent_response)
task = DurableAgentTask(entity_task=mock_entity_task, response_format=None, correlation_id="test-123")
# Simulate child task completion
task.on_child_completed(mock_entity_task)
assert task.is_complete
result = task.get_result()
assert isinstance(result, AgentRunResponse)
assert len(result.messages) == 1
assert result.messages[0].role == Role.ASSISTANT
def test_durable_agent_task_propagates_failure(self, mock_entity_task: Mock) -> None:
"""Verify DurableAgentTask propagates task failures."""
mock_entity_task.is_failed = True
mock_entity_task.get_exception = Mock(return_value=ValueError("Entity error"))
task = DurableAgentTask(entity_task=mock_entity_task, response_format=None, correlation_id="test-123")
# Simulate child task completion with failure
task.on_child_completed(mock_entity_task)
assert task.is_complete
assert task.is_failed
exception = task.get_exception()
assert isinstance(exception, ValueError)
assert str(exception) == "Entity error"
def test_durable_agent_task_validates_response_format(self, mock_entity_task: Mock) -> None:
"""Verify DurableAgentTask validates response format when provided."""
mock_entity_task.is_failed = False
mock_entity_task.get_result = Mock(
return_value={
"messages": [{"role": "assistant", "contents": [{"type": "text", "text": '{"answer": "42"}'}]}],
"created_at": "2025-12-30T10:00:00Z",
}
)
class TestResponse(BaseModel):
answer: str
task = DurableAgentTask(entity_task=mock_entity_task, response_format=TestResponse, correlation_id="test-123")
# Simulate child task completion
task.on_child_completed(mock_entity_task)
assert task.is_complete
result = task.get_result()
assert isinstance(result, AgentRunResponse)
def test_durable_agent_task_ignores_duplicate_completion(
self, mock_entity_task: Mock, successful_agent_response: dict[str, Any]
) -> None:
"""Verify DurableAgentTask ignores duplicate completion calls."""
mock_entity_task.is_failed = False
mock_entity_task.get_result = Mock(return_value=successful_agent_response)
task = DurableAgentTask(entity_task=mock_entity_task, response_format=None, correlation_id="test-123")
# Simulate child task completion twice
task.on_child_completed(mock_entity_task)
first_result = task.get_result()
task.on_child_completed(mock_entity_task)
second_result = task.get_result()
# Should be the same result, get_result should only be called once
assert first_result is second_result
assert mock_entity_task.get_result.call_count == 1
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])
@@ -18,9 +18,10 @@ class TestRunRequest:
def test_init_with_defaults(self) -> None:
"""Test RunRequest initialization with defaults."""
request = RunRequest(message="Hello")
request = RunRequest(message="Hello", correlation_id="corr-001")
assert request.message == "Hello"
assert request.correlation_id == "corr-001"
assert request.role == Role.USER
assert request.response_format is None
assert request.enable_tool_calls is True
@@ -30,30 +31,33 @@ class TestRunRequest:
schema = ModuleStructuredResponse
request = RunRequest(
message="Hello",
correlation_id="corr-002",
role=Role.SYSTEM,
response_format=schema,
enable_tool_calls=False,
)
assert request.message == "Hello"
assert request.correlation_id == "corr-002"
assert request.role == Role.SYSTEM
assert request.response_format is schema
assert request.enable_tool_calls is False
def test_init_coerces_string_role(self) -> None:
"""Ensure string role values are coerced into Role instances."""
request = RunRequest(message="Hello", role="system") # type: ignore[arg-type]
request = RunRequest(message="Hello", correlation_id="corr-003", role="system") # type: ignore[arg-type]
assert request.role == Role.SYSTEM
def test_to_dict_with_defaults(self) -> None:
"""Test to_dict with default values."""
request = RunRequest(message="Test message")
request = RunRequest(message="Test message", correlation_id="corr-004")
data = request.to_dict()
assert data["message"] == "Test message"
assert data["enable_tool_calls"] is True
assert data["role"] == "user"
assert data["correlationId"] == "corr-004"
assert "response_format" not in data or data["response_format"] is None
assert "thread_id" not in data
@@ -62,6 +66,7 @@ class TestRunRequest:
schema = ModuleStructuredResponse
request = RunRequest(
message="Hello",
correlation_id="corr-005",
role=Role.ASSISTANT,
response_format=schema,
enable_tool_calls=False,
@@ -69,6 +74,7 @@ class TestRunRequest:
data = request.to_dict()
assert data["message"] == "Hello"
assert data["correlationId"] == "corr-005"
assert data["role"] == "assistant"
assert data["response_format"]["__response_schema_type__"] == "pydantic_model"
assert data["response_format"]["module"] == schema.__module__
@@ -78,16 +84,17 @@ class TestRunRequest:
def test_from_dict_with_defaults(self) -> None:
"""Test from_dict with minimal data."""
data = {"message": "Hello"}
data = {"message": "Hello", "correlationId": "corr-006"}
request = RunRequest.from_dict(data)
assert request.message == "Hello"
assert request.correlation_id == "corr-006"
assert request.role == Role.USER
assert request.enable_tool_calls is True
def test_from_dict_ignores_thread_id_field(self) -> None:
"""Ensure legacy thread_id input does not break RunRequest parsing."""
request = RunRequest.from_dict({"message": "Hello", "thread_id": "ignored"})
request = RunRequest.from_dict({"message": "Hello", "correlationId": "corr-007", "thread_id": "ignored"})
assert request.message == "Hello"
@@ -95,6 +102,7 @@ class TestRunRequest:
"""Test from_dict with all fields."""
data = {
"message": "Test",
"correlationId": "corr-008",
"role": "system",
"response_format": {
"__response_schema_type__": "pydantic_model",
@@ -106,13 +114,14 @@ class TestRunRequest:
request = RunRequest.from_dict(data)
assert request.message == "Test"
assert request.correlation_id == "corr-008"
assert request.role == Role.SYSTEM
assert request.response_format is ModuleStructuredResponse
assert request.enable_tool_calls is False
def test_from_dict_with_unknown_role_preserves_value(self) -> None:
def test_from_dict_unknown_role_preserves_value(self) -> None:
"""Test from_dict keeps custom roles intact."""
data = {"message": "Test", "role": "reviewer"}
data = {"message": "Test", "correlationId": "corr-009", "role": "reviewer"}
request = RunRequest.from_dict(data)
assert request.role.value == "reviewer"
@@ -120,15 +129,22 @@ class TestRunRequest:
def test_from_dict_empty_message(self) -> None:
"""Test from_dict with empty message."""
request = RunRequest.from_dict({})
request = RunRequest.from_dict({"correlationId": "corr-010"})
assert request.message == ""
assert request.correlation_id == "corr-010"
assert request.role == Role.USER
def test_from_dict_missing_correlation_id_raises(self) -> None:
"""Test from_dict raises when correlationId is missing."""
with pytest.raises(ValueError, match="correlationId is required"):
RunRequest.from_dict({"message": "Test"})
def test_round_trip_dict_conversion(self) -> None:
"""Test round-trip to_dict and from_dict."""
original = RunRequest(
message="Test message",
correlation_id="corr-011",
role=Role.SYSTEM,
response_format=ModuleStructuredResponse,
enable_tool_calls=False,
@@ -138,6 +154,7 @@ class TestRunRequest:
restored = RunRequest.from_dict(data)
assert restored.message == original.message
assert restored.correlation_id == original.correlation_id
assert restored.role == original.role
assert restored.response_format is ModuleStructuredResponse
assert restored.enable_tool_calls == original.enable_tool_calls
@@ -146,6 +163,7 @@ class TestRunRequest:
"""Ensure Pydantic response formats serialize and deserialize properly."""
original = RunRequest(
message="Structured",
correlation_id="corr-012",
response_format=ModuleStructuredResponse,
)
@@ -186,7 +204,7 @@ class TestRunRequest:
original = RunRequest(
message="Test message",
role=Role.SYSTEM,
correlation_id="corr-123",
correlation_id="corr-124",
)
data = original.to_dict()
@@ -200,6 +218,7 @@ class TestRunRequest:
"""Test RunRequest initialization with orchestration_id."""
request = RunRequest(
message="Test message",
correlation_id="corr-125",
orchestration_id="orch-123",
)
@@ -210,6 +229,7 @@ class TestRunRequest:
"""Test to_dict includes orchestrationId."""
request = RunRequest(
message="Test",
correlation_id="corr-126",
orchestration_id="orch-456",
)
data = request.to_dict()
@@ -221,6 +241,7 @@ class TestRunRequest:
"""Test to_dict excludes orchestrationId when not set."""
request = RunRequest(
message="Test",
correlation_id="corr-127",
)
data = request.to_dict()
@@ -230,6 +251,7 @@ class TestRunRequest:
"""Test from_dict with orchestrationId."""
data = {
"message": "Test",
"correlationId": "corr-128",
"orchestrationId": "orch-789",
}
request = RunRequest.from_dict(data)
@@ -242,7 +264,7 @@ class TestRunRequest:
original = RunRequest(
message="Test message",
role=Role.SYSTEM,
correlation_id="corr-123",
correlation_id="corr-129",
orchestration_id="orch-123",
)
@@ -0,0 +1,98 @@
# Copyright (c) Microsoft. All rights reserved.
"""Unit tests for DurableAIAgentOrchestrationContext.
Focuses on critical orchestration workflows: agent retrieval and integration.
Run with: pytest tests/test_orchestration_context.py -v
"""
from unittest.mock import Mock
import pytest
from agent_framework import AgentProtocol
from agent_framework_durabletask import DurableAgentThread
from agent_framework_durabletask._orchestration_context import DurableAIAgentOrchestrationContext
from agent_framework_durabletask._shim import DurableAIAgent
@pytest.fixture
def mock_orchestration_context() -> Mock:
"""Create a mock OrchestrationContext for testing."""
return Mock()
@pytest.fixture
def agent_context(mock_orchestration_context: Mock) -> DurableAIAgentOrchestrationContext:
"""Create a DurableAIAgentOrchestrationContext with mock context."""
return DurableAIAgentOrchestrationContext(mock_orchestration_context)
class TestDurableAIAgentOrchestrationContextGetAgent:
"""Test core workflow: retrieving agents from orchestration context."""
def test_get_agent_returns_durable_agent_shim(self, agent_context: DurableAIAgentOrchestrationContext) -> None:
"""Verify get_agent returns a DurableAIAgent instance."""
agent = agent_context.get_agent("assistant")
assert isinstance(agent, DurableAIAgent)
assert isinstance(agent, AgentProtocol)
def test_get_agent_shim_has_correct_name(self, agent_context: DurableAIAgentOrchestrationContext) -> None:
"""Verify retrieved agent has the correct name."""
agent = agent_context.get_agent("my_agent")
assert agent.name == "my_agent"
def test_get_agent_multiple_times_returns_new_instances(
self, agent_context: DurableAIAgentOrchestrationContext
) -> None:
"""Verify multiple get_agent calls return independent instances."""
agent1 = agent_context.get_agent("assistant")
agent2 = agent_context.get_agent("assistant")
assert agent1 is not agent2 # Different object instances
def test_get_agent_different_agents(self, agent_context: DurableAIAgentOrchestrationContext) -> None:
"""Verify context can retrieve multiple different agents."""
agent1 = agent_context.get_agent("agent1")
agent2 = agent_context.get_agent("agent2")
assert agent1.name == "agent1"
assert agent2.name == "agent2"
class TestDurableAIAgentOrchestrationContextIntegration:
"""Test integration scenarios between orchestration context and agent shim."""
def test_orchestration_agent_has_working_run_method(
self, agent_context: DurableAIAgentOrchestrationContext
) -> None:
"""Verify agent from context has callable run method (even if not yet implemented)."""
agent = agent_context.get_agent("assistant")
assert hasattr(agent, "run")
assert callable(agent.run)
def test_orchestration_agent_can_create_threads(self, agent_context: DurableAIAgentOrchestrationContext) -> None:
"""Verify agent from context can create DurableAgentThread instances."""
agent = agent_context.get_agent("assistant")
thread = agent.get_new_thread()
assert isinstance(thread, DurableAgentThread)
def test_orchestration_agent_thread_with_parameters(
self, agent_context: DurableAIAgentOrchestrationContext
) -> None:
"""Verify agent can create threads with custom parameters."""
agent = agent_context.get_agent("assistant")
thread = agent.get_new_thread(service_thread_id="orch-session-456")
assert isinstance(thread, DurableAgentThread)
assert thread.service_thread_id == "orch-session-456"
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])
@@ -0,0 +1,206 @@
# Copyright (c) Microsoft. All rights reserved.
"""Unit tests for DurableAIAgent shim and DurableAgentProvider.
Focuses on critical message normalization, delegation, and protocol compliance.
Run with: pytest tests/test_shim.py -v
"""
from typing import Any
from unittest.mock import Mock
import pytest
from agent_framework import AgentProtocol, ChatMessage
from pydantic import BaseModel
from agent_framework_durabletask import DurableAgentThread
from agent_framework_durabletask._executors import DurableAgentExecutor
from agent_framework_durabletask._models import RunRequest
from agent_framework_durabletask._shim import DurableAgentProvider, DurableAIAgent
class ResponseFormatModel(BaseModel):
"""Test Pydantic model for response format testing."""
result: str
@pytest.fixture
def mock_executor() -> Mock:
"""Create a mock executor for testing."""
mock = Mock(spec=DurableAgentExecutor)
mock.run_durable_agent = Mock(return_value=None)
mock.get_new_thread = Mock(return_value=DurableAgentThread())
# Mock get_run_request to create actual RunRequest objects
def create_run_request(
message: str, response_format: type[BaseModel] | None = None, enable_tool_calls: bool = True
) -> RunRequest:
import uuid
return RunRequest(
message=message,
correlation_id=str(uuid.uuid4()),
response_format=response_format,
enable_tool_calls=enable_tool_calls,
)
mock.get_run_request = Mock(side_effect=create_run_request)
return mock
@pytest.fixture
def test_agent(mock_executor: Mock) -> DurableAIAgent[Any]:
"""Create a test agent with mock executor."""
return DurableAIAgent(mock_executor, "test_agent")
class TestDurableAIAgentMessageNormalization:
"""Test that DurableAIAgent properly normalizes various message input types."""
def test_run_accepts_string_message(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None:
"""Verify run accepts and normalizes string messages."""
test_agent.run("Hello, world!")
mock_executor.run_durable_agent.assert_called_once()
# Verify agent_name and run_request were passed correctly as kwargs
_, kwargs = mock_executor.run_durable_agent.call_args
assert kwargs["agent_name"] == "test_agent"
assert kwargs["run_request"].message == "Hello, world!"
def test_run_accepts_chat_message(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None:
"""Verify run accepts and normalizes ChatMessage objects."""
chat_msg = ChatMessage(role="user", text="Test message")
test_agent.run(chat_msg)
mock_executor.run_durable_agent.assert_called_once()
_, kwargs = mock_executor.run_durable_agent.call_args
assert kwargs["run_request"].message == "Test message"
def test_run_accepts_list_of_strings(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None:
"""Verify run accepts and joins list of strings."""
test_agent.run(["First message", "Second message"])
mock_executor.run_durable_agent.assert_called_once()
_, kwargs = mock_executor.run_durable_agent.call_args
assert kwargs["run_request"].message == "First message\nSecond message"
def test_run_accepts_list_of_chat_messages(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None:
"""Verify run accepts and joins list of ChatMessage objects."""
messages = [
ChatMessage(role="user", text="Message 1"),
ChatMessage(role="assistant", text="Message 2"),
]
test_agent.run(messages)
mock_executor.run_durable_agent.assert_called_once()
_, kwargs = mock_executor.run_durable_agent.call_args
assert kwargs["run_request"].message == "Message 1\nMessage 2"
def test_run_handles_none_message(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None:
"""Verify run handles None message gracefully."""
test_agent.run(None)
mock_executor.run_durable_agent.assert_called_once()
_, kwargs = mock_executor.run_durable_agent.call_args
assert kwargs["run_request"].message == ""
def test_run_handles_empty_list(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None:
"""Verify run handles empty list gracefully."""
test_agent.run([])
mock_executor.run_durable_agent.assert_called_once()
_, kwargs = mock_executor.run_durable_agent.call_args
assert kwargs["run_request"].message == ""
class TestDurableAIAgentParameterFlow:
"""Test that parameters flow correctly through the shim to executor."""
def test_run_forwards_thread_parameter(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None:
"""Verify run forwards thread parameter to executor."""
thread = DurableAgentThread(service_thread_id="test-thread")
test_agent.run("message", thread=thread)
mock_executor.run_durable_agent.assert_called_once()
_, kwargs = mock_executor.run_durable_agent.call_args
assert kwargs["thread"] == thread
def test_run_forwards_response_format(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None:
"""Verify run forwards response_format parameter to executor."""
test_agent.run("message", response_format=ResponseFormatModel)
mock_executor.run_durable_agent.assert_called_once()
_, kwargs = mock_executor.run_durable_agent.call_args
assert kwargs["run_request"].response_format == ResponseFormatModel
class TestDurableAIAgentProtocolCompliance:
"""Test that DurableAIAgent implements AgentProtocol correctly."""
def test_agent_implements_protocol(self, test_agent: DurableAIAgent[Any]) -> None:
"""Verify DurableAIAgent implements AgentProtocol."""
assert isinstance(test_agent, AgentProtocol)
def test_agent_has_required_properties(self, test_agent: DurableAIAgent[Any]) -> None:
"""Verify DurableAIAgent has all required AgentProtocol properties."""
assert hasattr(test_agent, "id")
assert hasattr(test_agent, "name")
assert hasattr(test_agent, "display_name")
assert hasattr(test_agent, "description")
def test_agent_id_defaults_to_name(self, mock_executor: Mock) -> None:
"""Verify agent id defaults to name when not provided."""
agent: DurableAIAgent[Any] = DurableAIAgent(mock_executor, "my_agent")
assert agent.id == "my_agent"
assert agent.name == "my_agent"
def test_agent_id_can_be_customized(self, mock_executor: Mock) -> None:
"""Verify agent id can be set independently from name."""
agent: DurableAIAgent[Any] = DurableAIAgent(mock_executor, "my_agent", agent_id="custom-id")
assert agent.id == "custom-id"
assert agent.name == "my_agent"
class TestDurableAIAgentThreadManagement:
"""Test thread creation and management."""
def test_get_new_thread_delegates_to_executor(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None:
"""Verify get_new_thread delegates to executor."""
mock_thread = DurableAgentThread()
mock_executor.get_new_thread.return_value = mock_thread
thread = test_agent.get_new_thread()
mock_executor.get_new_thread.assert_called_once_with("test_agent")
assert thread == mock_thread
def test_get_new_thread_forwards_kwargs(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None:
"""Verify get_new_thread forwards kwargs to executor."""
mock_thread = DurableAgentThread(service_thread_id="thread-123")
mock_executor.get_new_thread.return_value = mock_thread
test_agent.get_new_thread(service_thread_id="thread-123")
mock_executor.get_new_thread.assert_called_once()
_, kwargs = mock_executor.get_new_thread.call_args
assert kwargs["service_thread_id"] == "thread-123"
class TestDurableAgentProviderInterface:
"""Test that DurableAgentProvider defines the correct interface."""
def test_provider_cannot_be_instantiated(self) -> None:
"""Verify DurableAgentProvider is abstract and cannot be instantiated."""
with pytest.raises(TypeError):
DurableAgentProvider() # type: ignore[abstract]
def test_provider_defines_get_agent_method(self) -> None:
"""Verify DurableAgentProvider defines get_agent abstract method."""
assert hasattr(DurableAgentProvider, "get_agent")
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])
@@ -0,0 +1,168 @@
# Copyright (c) Microsoft. All rights reserved.
"""Unit tests for DurableAIAgentWorker.
Focuses on critical worker flows: agent registration, validation, callbacks, and lifecycle.
"""
from unittest.mock import Mock
import pytest
from agent_framework_durabletask import DurableAIAgentWorker
@pytest.fixture
def mock_grpc_worker() -> Mock:
"""Create a mock TaskHubGrpcWorker for testing."""
mock = Mock()
mock.add_entity = Mock(return_value="dafx-test_agent")
mock.start = Mock()
mock.stop = Mock()
return mock
@pytest.fixture
def mock_agent() -> Mock:
"""Create a mock agent for testing."""
agent = Mock()
agent.name = "test_agent"
return agent
@pytest.fixture
def agent_worker(mock_grpc_worker: Mock) -> DurableAIAgentWorker:
"""Create a DurableAIAgentWorker with mock worker."""
return DurableAIAgentWorker(mock_grpc_worker)
class TestDurableAIAgentWorkerRegistration:
"""Test agent registration behavior."""
def test_add_agent_accepts_agent_with_name(
self, agent_worker: DurableAIAgentWorker, mock_agent: Mock, mock_grpc_worker: Mock
) -> None:
"""Verify that agents with names can be registered."""
agent_worker.add_agent(mock_agent)
# Verify entity was registered with underlying worker
mock_grpc_worker.add_entity.assert_called_once()
# Verify agent name is tracked
assert "test_agent" in agent_worker.registered_agent_names
def test_add_agent_rejects_agent_without_name(self, agent_worker: DurableAIAgentWorker) -> None:
"""Verify that agents without names are rejected."""
agent_no_name = Mock()
agent_no_name.name = None
with pytest.raises(ValueError, match="Agent must have a name"):
agent_worker.add_agent(agent_no_name)
def test_add_agent_rejects_empty_name(self, agent_worker: DurableAIAgentWorker) -> None:
"""Verify that agents with empty names are rejected."""
agent_empty_name = Mock()
agent_empty_name.name = ""
with pytest.raises(ValueError, match="Agent must have a name"):
agent_worker.add_agent(agent_empty_name)
def test_add_agent_rejects_duplicate_names(self, agent_worker: DurableAIAgentWorker, mock_agent: Mock) -> None:
"""Verify duplicate agent names are not allowed."""
agent_worker.add_agent(mock_agent)
# Try to register another agent with the same name
duplicate_agent = Mock()
duplicate_agent.name = "test_agent"
with pytest.raises(ValueError, match="already registered"):
agent_worker.add_agent(duplicate_agent)
def test_registered_agent_names_tracks_multiple_agents(self, agent_worker: DurableAIAgentWorker) -> None:
"""Verify registered_agent_names tracks all registered agents."""
agent1 = Mock()
agent1.name = "agent1"
agent2 = Mock()
agent2.name = "agent2"
agent3 = Mock()
agent3.name = "agent3"
agent_worker.add_agent(agent1)
agent_worker.add_agent(agent2)
agent_worker.add_agent(agent3)
registered = agent_worker.registered_agent_names
assert "agent1" in registered
assert "agent2" in registered
assert "agent3" in registered
assert len(registered) == 3
class TestDurableAIAgentWorkerCallbacks:
"""Test callback configuration behavior."""
def test_worker_level_callback_accepted(self, mock_grpc_worker: Mock) -> None:
"""Verify worker-level callback can be set."""
mock_callback = Mock()
agent_worker = DurableAIAgentWorker(mock_grpc_worker, callback=mock_callback)
assert agent_worker is not None
def test_agent_level_callback_accepted(self, agent_worker: DurableAIAgentWorker, mock_agent: Mock) -> None:
"""Verify agent-level callback can be set during registration."""
mock_callback = Mock()
# Should not raise exception
agent_worker.add_agent(mock_agent, callback=mock_callback)
assert "test_agent" in agent_worker.registered_agent_names
def test_none_callback_accepted(self, mock_grpc_worker: Mock, mock_agent: Mock) -> None:
"""Verify None callback is valid (no callbacks required)."""
agent_worker = DurableAIAgentWorker(mock_grpc_worker, callback=None)
agent_worker.add_agent(mock_agent, callback=None)
assert "test_agent" in agent_worker.registered_agent_names
class TestDurableAIAgentWorkerLifecycle:
"""Test worker lifecycle behavior."""
def test_start_delegates_to_underlying_worker(
self, agent_worker: DurableAIAgentWorker, mock_grpc_worker: Mock
) -> None:
"""Verify start() delegates to wrapped worker."""
agent_worker.start()
mock_grpc_worker.start.assert_called_once()
def test_stop_delegates_to_underlying_worker(
self, agent_worker: DurableAIAgentWorker, mock_grpc_worker: Mock
) -> None:
"""Verify stop() delegates to wrapped worker."""
agent_worker.stop()
mock_grpc_worker.stop.assert_called_once()
def test_start_works_with_no_agents(self, agent_worker: DurableAIAgentWorker, mock_grpc_worker: Mock) -> None:
"""Verify worker can start even with no agents registered."""
agent_worker.start()
mock_grpc_worker.start.assert_called_once()
def test_start_works_with_multiple_agents(self, agent_worker: DurableAIAgentWorker, mock_grpc_worker: Mock) -> None:
"""Verify worker can start with multiple agents registered."""
agent1 = Mock()
agent1.name = "agent1"
agent2 = Mock()
agent2.name = "agent2"
agent_worker.add_agent(agent1)
agent_worker.add_agent(agent2)
agent_worker.start()
mock_grpc_worker.start.assert_called_once()
assert len(agent_worker.registered_agent_names) == 2
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])
@@ -10,6 +10,7 @@ Prerequisites: configure `AZURE_OPENAI_ENDPOINT`, `AZURE_OPENAI_CHAT_DEPLOYMENT_
import json
import logging
from collections.abc import Generator
from typing import Any
import azure.functions as func
@@ -44,7 +45,7 @@ app = AgentFunctionApp(agents=[_create_writer_agent()], enable_health_check=True
# 4. Orchestration that runs the agent sequentially on a shared thread for chaining behaviour.
@app.orchestration_trigger(context_name="context")
def single_agent_orchestration(context: DurableOrchestrationContext):
def single_agent_orchestration(context: DurableOrchestrationContext) -> Generator[Any, Any, str]:
"""Run the writer agent twice on the same thread to mirror chaining behaviour."""
writer = app.get_agent(context, WRITER_AGENT_NAME)
@@ -116,12 +117,6 @@ async def get_orchestration_status(
)
status = await client.get_status(instance_id)
if status is None:
return func.HttpResponse(
body=json.dumps({"error": "Instance not found"}),
status_code=404,
mimetype="application/json",
)
response_data: dict[str, Any] = {
"instanceId": status.instance_id,
@@ -10,6 +10,7 @@ Prerequisites: configure `AZURE_OPENAI_ENDPOINT`, `AZURE_OPENAI_CHAT_DEPLOYMENT_
import json
import logging
from collections.abc import Generator
from typing import Any, cast
from agent_framework import AgentRunResponse
@@ -51,7 +52,7 @@ app.add_agent(agents[1])
# 4. Durable Functions orchestration that runs both agents in parallel.
@app.orchestration_trigger(context_name="context")
def multi_agent_concurrent_orchestration(context: DurableOrchestrationContext):
def multi_agent_concurrent_orchestration(context: DurableOrchestrationContext) -> Generator[Any, Any, dict[str, str]]:
"""Fan out to two domain-specific agents and aggregate their responses."""
prompt = context.get_input()
@@ -137,12 +138,6 @@ async def get_orchestration_status(
)
status = await client.get_status(instance_id)
if status is None:
return func.HttpResponse(
body=json.dumps({"error": "Instance not found"}),
status_code=404,
mimetype="application/json",
)
response_data: dict[str, Any] = {
"instanceId": status.instance_id,
@@ -11,7 +11,7 @@ Functions host."""
import json
import logging
from collections.abc import Mapping
from collections.abc import Generator, Mapping
from typing import Any, cast
import azure.functions as func
@@ -74,7 +74,7 @@ def send_email(message: str) -> str:
# 4. Orchestration validates input, runs agents, and branches on spam results.
@app.orchestration_trigger(context_name="context")
def spam_detection_orchestration(context: DurableOrchestrationContext):
def spam_detection_orchestration(context: DurableOrchestrationContext) -> Generator[Any, Any, str]:
payload_raw = context.get_input()
if not isinstance(payload_raw, Mapping):
raise ValueError("Email data is required")
@@ -105,7 +105,7 @@ def spam_detection_orchestration(context: DurableOrchestrationContext):
spam_result = cast(SpamDetectionResult, spam_result_raw.value)
if spam_result.is_spam:
result = yield context.call_activity("handle_spam_email", spam_result.reason)
result = yield context.call_activity("handle_spam_email", spam_result.reason) # type: ignore[misc]
return result
email_thread = email_agent.get_new_thread()
@@ -125,7 +125,7 @@ def spam_detection_orchestration(context: DurableOrchestrationContext):
email_result = cast(EmailResponse, email_result_raw.value)
result = yield context.call_activity("send_email", email_result.response)
result = yield context.call_activity("send_email", email_result.response) # type: ignore[misc]
return result
@@ -196,12 +196,6 @@ async def get_orchestration_status(
)
status = await client.get_status(instance_id)
if status is None:
return func.HttpResponse(
body=json.dumps({"error": "Instance not found"}),
status_code=404,
mimetype="application/json",
)
response_data: dict[str, Any] = {
"instanceId": status.instance_id,
@@ -10,7 +10,7 @@ either `AZURE_OPENAI_API_KEY` or sign in with Azure CLI before running `func sta
import json
import logging
from collections.abc import Mapping
from collections.abc import Generator, Mapping
from datetime import timedelta
from typing import Any
@@ -62,7 +62,7 @@ app = AgentFunctionApp(agents=[_create_writer_agent()], enable_health_check=True
# 3. Activities encapsulate external work for review notifications and publishing.
@app.activity_trigger(input_name="content")
def notify_user_for_approval(content: dict) -> None:
def notify_user_for_approval(content: Any) -> None:
model = GeneratedContent.model_validate(content)
logger.info("NOTIFICATION: Please review the following content for approval:")
logger.info("Title: %s", model.title or "(untitled)")
@@ -71,7 +71,7 @@ def notify_user_for_approval(content: dict) -> None:
@app.activity_trigger(input_name="content")
def publish_content(content: dict) -> None:
def publish_content(content: Any) -> None:
model = GeneratedContent.model_validate(content)
logger.info("PUBLISHING: Content has been published successfully:")
logger.info("Title: %s", model.title or "(untitled)")
@@ -80,7 +80,7 @@ def publish_content(content: dict) -> None:
# 4. Orchestration loops until the human approves, times out, or attempts are exhausted.
@app.orchestration_trigger(context_name="context")
def content_generation_hitl_orchestration(context: DurableOrchestrationContext):
def content_generation_hitl_orchestration(context: DurableOrchestrationContext) -> Generator[Any, Any, dict[str, str]]:
payload_raw = context.get_input()
if not isinstance(payload_raw, Mapping):
raise ValueError("Content generation input is required")
@@ -102,7 +102,7 @@ def content_generation_hitl_orchestration(context: DurableOrchestrationContext):
)
content = initial_raw.value
logger.info("Type of content after extraction: %s", type(content))
logger.info("Type of content after extraction: %s", type(content)) # type: ignore[misc]
if content is None or not isinstance(content, GeneratedContent):
raise ValueError("Agent returned no content after extraction.")
@@ -114,7 +114,7 @@ def content_generation_hitl_orchestration(context: DurableOrchestrationContext):
f"Requesting human feedback. Iteration #{attempt}. Timeout: {payload.approval_timeout_hours} hour(s)."
)
yield context.call_activity("notify_user_for_approval", content.model_dump())
yield context.call_activity("notify_user_for_approval", content.model_dump()) # type: ignore[misc]
approval_task = context.wait_for_external_event(HUMAN_APPROVAL_EVENT)
timeout_task = context.create_timer(
@@ -129,7 +129,7 @@ def content_generation_hitl_orchestration(context: DurableOrchestrationContext):
if approval_payload.approved:
context.set_custom_status("Content approved by human reviewer. Publishing content...")
yield context.call_activity("publish_content", content.model_dump())
yield context.call_activity("publish_content", content.model_dump()) # type: ignore[misc]
context.set_custom_status(
f"Content published successfully at {context.current_utc_datetime:%Y-%m-%dT%H:%M:%S}"
)
@@ -287,7 +287,7 @@ async def get_orchestration_status(
)
# Check if status is None or if the instance doesn't exist (runtime_status is None)
if status is None or getattr(status, "runtime_status", None) is None:
if getattr(status, "runtime_status", None) is None:
return func.HttpResponse(
body=json.dumps({"error": "Instance not found."}),
status_code=404,
@@ -0,0 +1,66 @@
# Single Agent Sample
This sample demonstrates how to use the durable agents extension to create a worker-client setup that hosts a single AI agent and provides interactive conversation via the Durable Task Scheduler.
## Key Concepts Demonstrated
- Using the Microsoft Agent Framework to define a simple AI agent with a name and instructions.
- Registering durable agents with the worker and interacting with them via a client.
- Conversation management (via threads) for isolated interactions.
- Worker-client architecture for distributed agent execution.
## Environment Setup
See the [README.md](../README.md) file in the parent directory for more information on how to configure the environment, including how to install and run common sample dependencies.
## Running the Sample
With the environment setup, you can run the sample using separate worker and client processes:
**Start the worker:**
```bash
cd samples/getting_started/durabletask/01_single_agent
python worker.py
```
The worker will register the Joker agent and listen for requests.
**In a new terminal, run the client:**
```bash
python client.py
```
The client will interact with the Joker agent:
```
Starting Durable Task Agent Client...
Using taskhub: default
Using endpoint: http://localhost:8080
Getting reference to Joker agent...
Created conversation thread: a1b2c3d4-e5f6-7890-abcd-ef1234567890
User: Tell me a short joke about cloud computing.
Joker: Why did the cloud break up with the server?
Because it found someone more "uplifting"!
User: Now tell me one about Python programming.
Joker: Why do Python programmers prefer dark mode?
Because light attracts bugs!
```
## Viewing Agent State
You can view the state of the agent in the Durable Task Scheduler dashboard:
1. Open your browser and navigate to `http://localhost:8082`
2. In the dashboard, you can view the state of the Joker agent, including its conversation history and current state
The agent maintains conversation state across multiple interactions, and you can inspect this state in the dashboard to understand how the durable agents extension manages conversation context.
@@ -0,0 +1,92 @@
"""Client application for interacting with a Durable Task hosted agent.
This client connects to the Durable Task Scheduler and sends requests to
registered agents, demonstrating how to interact with agents from external processes.
Prerequisites:
- The worker must be running with the agent registered
- Set AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_CHAT_DEPLOYMENT_NAME
(plus AZURE_OPENAI_API_KEY or Azure CLI authentication)
- Durable Task Scheduler must be running
"""
import asyncio
import logging
import os
from agent_framework_durabletask import DurableAIAgentClient
from azure.identity import DefaultAzureCredential
from durabletask.azuremanaged.client import DurableTaskSchedulerClient
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
async def main() -> None:
"""Main entry point for the client application."""
logger.info("Starting Durable Task Agent Client...")
# Get environment variables for taskhub and endpoint with defaults
taskhub_name = os.getenv("TASKHUB", "default")
endpoint = os.getenv("ENDPOINT", "http://localhost:8080")
logger.info(f"Using taskhub: {taskhub_name}")
logger.info(f"Using endpoint: {endpoint}")
logger.info("")
# Set credential to None for emulator, or DefaultAzureCredential for Azure
credential = None if endpoint == "http://localhost:8080" else DefaultAzureCredential()
# Create a client using Azure Managed Durable Task
client = DurableTaskSchedulerClient(
host_address=endpoint,
secure_channel=endpoint != "http://localhost:8080",
taskhub=taskhub_name,
token_credential=credential
)
# Wrap it with the agent client
agent_client = DurableAIAgentClient(client)
# Get a reference to the Joker agent
logger.info("Getting reference to Joker agent...")
joker = agent_client.get_agent("Joker")
# Create a new thread for the conversation
thread = joker.get_new_thread()
logger.info(f"Created conversation thread: {thread.session_id}")
logger.info("")
try:
# First message
message1 = "Tell me a short joke about cloud computing."
logger.info(f"User: {message1}")
logger.info("")
# Run the agent - this blocks until the response is ready
response1 = joker.run(message1, thread=thread)
logger.info(f"Agent: {response1.text}")
logger.info("")
# Second message - continuing the conversation
message2 = "Now tell me one about Python programming."
logger.info(f"User: {message2}")
logger.info("")
response2 = joker.run(message2, thread=thread)
logger.info(f"Agent: {response2.text}")
logger.info("")
logger.info(f"Conversation completed successfully!")
logger.info(f"Thread ID: {thread.session_id}")
except Exception as e:
logger.exception(f"Error during agent interaction: {e}")
finally:
logger.info("Client shutting down")
if __name__ == "__main__":
asyncio.run(main())
@@ -0,0 +1,6 @@
# Agent Framework packages (installing from local package until a package is published)
-e ../../../../
-e ../../../../packages/durabletask
# Azure authentication
azure-identity
@@ -0,0 +1,137 @@
"""Single Agent Sample - Durable Task Integration (Combined Worker + Client)
This sample demonstrates running both the worker and client in a single process.
The worker is started first to register the agent, then client operations are
performed against the running worker.
Prerequisites:
- Set AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_CHAT_DEPLOYMENT_NAME
(plus AZURE_OPENAI_API_KEY or Azure CLI authentication)
- Durable Task Scheduler must be running (e.g., using Docker)
To run this sample:
python sample.py
"""
import logging
import os
from agent_framework.azure import AzureOpenAIChatClient
from agent_framework_durabletask import DurableAIAgentClient, DurableAIAgentWorker
from azure.identity import AzureCliCredential, DefaultAzureCredential
from dotenv import load_dotenv
from durabletask.azuremanaged.client import DurableTaskSchedulerClient
from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def create_joker_agent():
"""Create the Joker agent using Azure OpenAI.
Returns:
AgentProtocol: The configured Joker agent
"""
return AzureOpenAIChatClient(credential=AzureCliCredential()).create_agent(
name="Joker",
instructions="You are good at telling jokes.",
)
def main():
"""Main entry point - runs both worker and client in single process."""
logger.info("Starting Durable Task Agent Sample (Combined Worker + Client)...")
# Get environment variables for taskhub and endpoint with defaults
taskhub_name = os.getenv("TASKHUB", "default")
endpoint = os.getenv("ENDPOINT", "http://localhost:8080")
logger.info(f"Using taskhub: {taskhub_name}")
logger.info(f"Using endpoint: {endpoint}")
logger.info("")
# Set credential to None for emulator, or DefaultAzureCredential for Azure
credential = None if endpoint == "http://localhost:8080" else DefaultAzureCredential()
secure_channel = endpoint != "http://localhost:8080"
# Create and start the worker using a context manager
with DurableTaskSchedulerWorker(
host_address=endpoint,
secure_channel=secure_channel,
taskhub=taskhub_name,
token_credential=credential
) as worker:
# Wrap with the agent worker
agent_worker = DurableAIAgentWorker(worker)
# Create and register the Joker agent
logger.info("Creating and registering Joker agent...")
joker_agent = create_joker_agent()
agent_worker.add_agent(joker_agent)
logger.info(f"✓ Registered agent: {joker_agent.name}")
logger.info(f" Entity name: dafx-{joker_agent.name}")
logger.info("")
# Start the worker
worker.start()
logger.info("Worker started and listening for requests...")
logger.info("")
# Create the client
client = DurableTaskSchedulerClient(
host_address=endpoint,
secure_channel=secure_channel,
taskhub=taskhub_name,
token_credential=credential
)
# Wrap it with the agent client
agent_client = DurableAIAgentClient(client)
# Get a reference to the Joker agent
logger.info("Getting reference to Joker agent...")
joker = agent_client.get_agent("Joker")
# Create a new thread for the conversation
thread = joker.get_new_thread()
logger.info(f"Created conversation thread: {thread.session_id}")
logger.info("")
try:
# First message
message1 = "Tell me a short joke about cloud computing."
logger.info(f"User: {message1}")
logger.info("")
# Run the agent - this blocks until the response is ready
response1 = joker.run(message1, thread=thread)
logger.info(f"Agent: {response1.text}; {response1}")
logger.info("")
# Second message - continuing the conversation
message2 = "Now tell me one about Python programming."
logger.info(f"User: {message2}")
logger.info("")
response2 = joker.run(message2, thread=thread)
logger.info(f"Agent: {response2.text}; {response2}")
logger.info("")
logger.info(f"Conversation completed successfully!")
logger.info(f"Thread ID: {thread.session_id}")
except Exception as e:
logger.exception(f"Error during agent interaction: {e}")
logger.info("")
logger.info("Sample completed. Worker shutting down...")
if __name__ == "__main__":
load_dotenv()
main()
@@ -0,0 +1,89 @@
"""Worker process for hosting a single Azure OpenAI-powered agent using Durable Task.
This worker registers agents as durable entities and continuously listens for requests.
The worker should run as a background service, processing incoming agent requests.
Prerequisites:
- Set AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_CHAT_DEPLOYMENT_NAME
(plus AZURE_OPENAI_API_KEY or Azure CLI authentication)
- Start a Durable Task Scheduler (e.g., using Docker)
"""
import asyncio
import logging
import os
from agent_framework.azure import AzureOpenAIChatClient
from agent_framework_durabletask import DurableAIAgentWorker
from azure.identity import AzureCliCredential, DefaultAzureCredential
from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def create_joker_agent():
"""Create the Joker agent using Azure OpenAI.
Returns:
AgentProtocol: The configured Joker agent
"""
return AzureOpenAIChatClient(credential=AzureCliCredential()).create_agent(
name="Joker",
instructions="You are good at telling jokes.",
)
async def main():
"""Main entry point for the worker process."""
logger.info("Starting Durable Task Agent Worker...")
# Get environment variables for taskhub and endpoint with defaults
taskhub_name = os.getenv("TASKHUB", "default")
endpoint = os.getenv("ENDPOINT", "http://localhost:8080")
logger.info(f"Using taskhub: {taskhub_name}")
logger.info(f"Using endpoint: {endpoint}")
# Set credential to None for emulator, or DefaultAzureCredential for Azure
credential = None if endpoint == "http://localhost:8080" else DefaultAzureCredential()
# Create a worker using Azure Managed Durable Task
worker = DurableTaskSchedulerWorker(
host_address=endpoint,
secure_channel=endpoint != "http://localhost:8080",
taskhub=taskhub_name,
token_credential=credential
)
# Wrap it with the agent worker
agent_worker = DurableAIAgentWorker(worker)
# Create and register the Joker agent
logger.info("Creating and registering Joker agent...")
joker_agent = create_joker_agent()
agent_worker.add_agent(joker_agent)
logger.info(f"✓ Registered agent: {joker_agent.name}")
logger.info(f" Entity name: dafx-{joker_agent.name}")
logger.info("")
logger.info("Worker is ready and listening for requests...")
logger.info("Press Ctrl+C to stop.")
logger.info("")
try:
# Start the worker (this blocks until stopped)
worker.start()
# Keep the worker running
while True:
await asyncio.sleep(1)
except KeyboardInterrupt:
logger.info("Worker shutdown initiated")
logger.info("Worker stopped")
if __name__ == "__main__":
asyncio.run(main())
@@ -0,0 +1,72 @@
# Single Agent Orchestration Chaining Sample
This sample demonstrates how to chain multiple invocations of the same agent using a durable orchestration while preserving conversation state between runs.
## Key Concepts Demonstrated
- Using durable orchestrations to coordinate sequential agent invocations.
- Chaining agent calls where the output of one run becomes input to the next.
- Maintaining conversation context across sequential runs using a shared thread.
- Using `DurableAIAgentOrchestrationContext` to access agents within orchestrations.
## Environment Setup
See the [README.md](../README.md) file in the parent directory for more information on how to configure the environment, including how to install and run common sample dependencies.
## Running the Sample
With the environment setup, you can run the sample using one of two approaches:
### Option 1: Combined Worker + Client (Quick Start)
```bash
cd samples/getting_started/durabletask/04_single_agent_orchestration_chaining
python sample.py
```
This runs both worker and client in a single process.
### Option 2: Separate Worker and Client
**Start the worker in one terminal:**
```bash
python worker.py
```
**In a new terminal, run the client:**
```bash
python client.py
```
The orchestration will execute the writer agent twice sequentially, and you'll see output like:
```
[Orchestration] Starting single agent chaining...
[Orchestration] Created thread: abc-123
[Orchestration] First agent run: Generating initial sentence...
[Orchestration] Initial response: Every small step forward is progress toward mastery.
[Orchestration] Second agent run: Refining the sentence...
[Orchestration] Refined response: Each small step forward brings you closer to mastery and growth.
[Orchestration] Chaining complete
================================================================================
Orchestration Result
================================================================================
Each small step forward brings you closer to mastery and growth.
```
## Viewing Orchestration State
You can view the state of the orchestration in the Durable Task Scheduler dashboard:
1. Open your browser and navigate to `http://localhost:8082`
2. In the dashboard, you can view the orchestration instance, including:
- The sequential execution of both agent runs
- The conversation thread shared between runs
- Input and output at each step
- Overall orchestration state and history
The orchestration maintains the conversation context across both agent invocations, demonstrating how durable orchestrations can coordinate multi-step agent workflows.
@@ -0,0 +1,104 @@
"""Client application for starting a single agent chaining orchestration.
This client connects to the Durable Task Scheduler and starts an orchestration
that runs a writer agent twice sequentially on the same thread, demonstrating
how conversation context is maintained across multiple agent invocations.
Prerequisites:
- The worker must be running with the writer agent and orchestration registered
- Set AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_CHAT_DEPLOYMENT_NAME
(plus AZURE_OPENAI_API_KEY or Azure CLI authentication)
- Durable Task Scheduler must be running
"""
import asyncio
import json
import logging
import os
from azure.identity import DefaultAzureCredential
from durabletask.azuremanaged.client import DurableTaskSchedulerClient
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
async def main() -> None:
"""Main entry point for the client application."""
logger.info("Starting Durable Task Single Agent Chaining Orchestration Client...")
# Get environment variables for taskhub and endpoint with defaults
taskhub_name = os.getenv("TASKHUB", "default")
endpoint = os.getenv("ENDPOINT", "http://localhost:8080")
logger.info(f"Using taskhub: {taskhub_name}")
logger.info(f"Using endpoint: {endpoint}")
logger.info("")
# Set credential to None for emulator, or DefaultAzureCredential for Azure
credential = None if endpoint == "http://localhost:8080" else DefaultAzureCredential()
# Create a client using Azure Managed Durable Task
client = DurableTaskSchedulerClient(
host_address=endpoint,
secure_channel=endpoint != "http://localhost:8080",
taskhub=taskhub_name,
token_credential=credential
)
logger.info("Starting single agent chaining orchestration...")
logger.info("")
try:
# Start the orchestration
instance_id = client.schedule_new_orchestration(
orchestrator="single_agent_chaining_orchestration",
input="",
)
logger.info(f"Orchestration started with instance ID: {instance_id}")
logger.info("Waiting for orchestration to complete...")
logger.info("")
# Retrieve the final state
metadata = client.wait_for_orchestration_completion(
instance_id=instance_id,
timeout=300
)
if metadata and metadata.runtime_status.name == "COMPLETED":
result = metadata.serialized_output
logger.info("=" * 80)
logger.info("Orchestration completed successfully!")
logger.info("=" * 80)
logger.info("")
logger.info("Results:")
logger.info("")
# Parse and display the result
if result:
final_text = json.loads(result)
logger.info("Final refined sentence:")
logger.info(f" {final_text}")
logger.info("")
logger.info("=" * 80)
elif metadata:
logger.error(f"Orchestration ended with status: {metadata.runtime_status.name}")
if metadata.serialized_output:
logger.error(f"Output: {metadata.serialized_output}")
else:
logger.error("Orchestration did not complete within the timeout period")
except Exception as e:
logger.exception(f"Error during orchestration: {e}")
finally:
logger.info("")
logger.info("Client shutting down")
if __name__ == "__main__":
asyncio.run(main())
@@ -0,0 +1,6 @@
# Agent Framework packages (installing from local package until a package is published)
-e ../../../../
-e ../../../../packages/durabletask
# Azure authentication
azure-identity
@@ -0,0 +1,255 @@
"""Single Agent Orchestration Chaining Sample - Durable Task Integration
This sample demonstrates chaining two invocations of the same agent inside a Durable Task
orchestration while preserving the conversation state between runs. The orchestration
runs the writer agent sequentially on a shared thread to refine text iteratively.
Components used:
- AzureOpenAIChatClient to construct the writer agent
- DurableTaskSchedulerWorker and DurableAIAgentWorker for agent hosting
- DurableTaskSchedulerClient and orchestration for sequential agent invocations
- Thread management to maintain conversation context across invocations
Prerequisites:
- Set AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_CHAT_DEPLOYMENT_NAME
(plus AZURE_OPENAI_API_KEY or Azure CLI authentication)
- Durable Task Scheduler must be running (e.g., using Docker emulator)
To run this sample:
python sample.py
"""
import asyncio
import json
import logging
import os
from collections.abc import Generator
from typing import Any
from agent_framework import AgentRunResponse
from agent_framework.azure import AzureOpenAIChatClient
from agent_framework_durabletask import DurableAIAgentOrchestrationContext, DurableAIAgentWorker
from azure.identity import AzureCliCredential, DefaultAzureCredential
from dotenv import load_dotenv
from durabletask.task import OrchestrationContext, Task
from durabletask.azuremanaged.client import DurableTaskSchedulerClient
from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Agent name
WRITER_AGENT_NAME = "WriterAgent"
def create_writer_agent():
"""Create the Writer agent using Azure OpenAI.
This agent refines short pieces of text, enhancing initial sentences
and polishing improved versions further.
Returns:
AgentProtocol: The configured Writer agent
"""
instructions = (
"You refine short pieces of text. When given an initial sentence you enhance it;\n"
"when given an improved sentence you polish it further."
)
return AzureOpenAIChatClient(credential=AzureCliCredential()).create_agent(
name=WRITER_AGENT_NAME,
instructions=instructions,
)
def single_agent_chaining_orchestration(
context: OrchestrationContext, _: str
) -> Generator[Task[Any], Any, str]:
"""Orchestration that runs the writer agent twice on the same thread.
This demonstrates chaining behavior where the output of the first agent run
becomes part of the input for the second run, all while maintaining the
conversation context through a shared thread.
Args:
context: The orchestration context
_: Input parameter (unused)
Returns:
str: The final refined text from the second agent run
"""
logger.info("[Orchestration] Starting single agent chaining...")
# Wrap the orchestration context to access agents
agent_context = DurableAIAgentOrchestrationContext(context)
# Get the writer agent using the agent context
writer = agent_context.get_agent(WRITER_AGENT_NAME)
# Create a new thread for the conversation - this will be shared across both runs
writer_thread = writer.get_new_thread()
logger.info(f"[Orchestration] Created thread: {writer_thread.session_id}")
# First run: Generate an initial inspirational sentence
logger.info("[Orchestration] First agent run: Generating initial sentence...")
initial_response: AgentRunResponse = yield writer.run(
messages="Write a concise inspirational sentence about learning.",
thread=writer_thread,
)
logger.info(f"[Orchestration] Initial response: {initial_response.text}")
# Second run: Refine the initial response on the same thread
improved_prompt = (
f"Improve this further while keeping it under 25 words: "
f"{initial_response.text}"
)
logger.info("[Orchestration] Second agent run: Refining the sentence...")
refined_response: AgentRunResponse = yield writer.run(
messages=improved_prompt,
thread=writer_thread,
)
logger.info(f"[Orchestration] Refined response: {refined_response.text}")
logger.info("[Orchestration] Chaining complete")
return refined_response.text
async def run_client(
endpoint: str, taskhub_name: str, credential: DefaultAzureCredential | None
):
"""Run the client to start and monitor the orchestration.
Args:
endpoint: The durable task scheduler endpoint
taskhub_name: The task hub name
credential: The credential for authentication
"""
logger.info("")
logger.info("=" * 80)
logger.info("CLIENT: Starting orchestration...")
logger.info("=" * 80)
logger.info("")
# Create a client
client = DurableTaskSchedulerClient(
host_address=endpoint,
secure_channel=endpoint != "http://localhost:8080",
taskhub=taskhub_name,
token_credential=credential
)
try:
# Start the orchestration
instance_id = client.schedule_new_orchestration(
single_agent_chaining_orchestration
)
logger.info(f"Orchestration started with instance ID: {instance_id}")
logger.info("Waiting for orchestration to complete...")
logger.info("")
# Retrieve the final state
metadata = client.wait_for_orchestration_completion(
instance_id=instance_id,
timeout=300
)
if metadata and metadata.runtime_status.name == "COMPLETED":
result = metadata.serialized_output
logger.info("")
logger.info("=" * 80)
logger.info("ORCHESTRATION COMPLETED SUCCESSFULLY!")
logger.info("=" * 80)
logger.info("")
# Parse and display the result
if result:
final_text = json.loads(result)
logger.info("Final refined sentence:")
logger.info(f" {final_text}")
else:
logger.warning("No output returned from orchestration")
elif metadata:
logger.error(f"Orchestration did not complete successfully: {metadata.runtime_status.name}")
if metadata.serialized_output:
logger.error(f"Output: {metadata.serialized_output}")
else:
logger.error("Could not retrieve orchestration metadata")
except Exception as e:
logger.exception(f"Error during orchestration: {e}")
logger.info("")
logger.info("Client shutting down")
def main():
"""Main entry point - runs both worker and client in single process."""
logger.info("Starting Single Agent Orchestration Chaining Sample...")
logger.info("")
# Load environment variables
load_dotenv()
# Get environment variables for taskhub and endpoint with defaults
taskhub_name = os.getenv("TASKHUB", "default")
endpoint = os.getenv("ENDPOINT", "http://localhost:8080")
logger.info(f"Using taskhub: {taskhub_name}")
logger.info(f"Using endpoint: {endpoint}")
logger.info("")
# Set credential to None for emulator, or DefaultAzureCredential for Azure
credential = None if endpoint == "http://localhost:8080" else DefaultAzureCredential()
secure_channel = endpoint != "http://localhost:8080"
# Create and start the worker using a context manager
with DurableTaskSchedulerWorker(
host_address=endpoint,
secure_channel=secure_channel,
taskhub=taskhub_name,
token_credential=credential
) as worker:
# Wrap with the agent worker
agent_worker = DurableAIAgentWorker(worker)
# Create and register the Writer agent
logger.info("Creating and registering Writer agent...")
writer_agent = create_writer_agent()
agent_worker.add_agent(writer_agent)
logger.info(f"✓ Registered agent: {writer_agent.name}")
logger.info(f" Entity name: dafx-{writer_agent.name}")
# Register the orchestration function
logger.info("Registering orchestration function...")
worker.add_orchestrator(single_agent_chaining_orchestration)
logger.info("✓ Registered orchestration: single_agent_chaining_orchestration")
logger.info("")
# Start the worker
worker.start()
logger.info("Worker started and listening for requests...")
logger.info("")
# Run the client in the same process
try:
asyncio.run(run_client(endpoint, taskhub_name, credential))
except KeyboardInterrupt:
logger.info("Sample interrupted by user")
finally:
logger.info("Worker stopping...")
logger.info("Sample completed")
if __name__ == "__main__":
load_dotenv()
main()
@@ -0,0 +1,167 @@
"""Worker process for hosting a single agent with chaining orchestration using Durable Task.
This worker registers a writer agent and an orchestration function that demonstrates
chaining behavior by running the agent twice sequentially on the same thread,
preserving conversation context between invocations.
Prerequisites:
- Set AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_CHAT_DEPLOYMENT_NAME
(plus AZURE_OPENAI_API_KEY or Azure CLI authentication)
- Start a Durable Task Scheduler (e.g., using Docker)
"""
import asyncio
from collections.abc import Generator
import logging
import os
from typing import Any
from agent_framework import AgentRunResponse
from agent_framework.azure import AzureOpenAIChatClient
from agent_framework_durabletask import DurableAIAgentOrchestrationContext, DurableAIAgentWorker
from azure.identity import AzureCliCredential, DefaultAzureCredential
from durabletask.task import OrchestrationContext, Task
from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Agent name
WRITER_AGENT_NAME = "WriterAgent"
def create_writer_agent():
"""Create the Writer agent using Azure OpenAI.
This agent refines short pieces of text, enhancing initial sentences
and polishing improved versions further.
Returns:
AgentProtocol: The configured Writer agent
"""
instructions = (
"You refine short pieces of text. When given an initial sentence you enhance it;\n"
"when given an improved sentence you polish it further."
)
return AzureOpenAIChatClient(credential=AzureCliCredential()).create_agent(
name=WRITER_AGENT_NAME,
instructions=instructions,
)
def single_agent_chaining_orchestration(
context: OrchestrationContext, _: str
) -> Generator[Task[Any], Any, str]:
"""Orchestration that runs the writer agent twice on the same thread.
This demonstrates chaining behavior where the output of the first agent run
becomes part of the input for the second run, all while maintaining the
conversation context through a shared thread.
Args:
context: The orchestration context
_: Input parameter (unused)
Returns:
str: The final refined text from the second agent run
"""
logger.info("[Orchestration] Starting single agent chaining...")
# Wrap the orchestration context to access agents
agent_context = DurableAIAgentOrchestrationContext(context)
# Get the writer agent using the agent context
writer = agent_context.get_agent(WRITER_AGENT_NAME)
# Create a new thread for the conversation - this will be shared across both runs
writer_thread = writer.get_new_thread()
logger.info(f"[Orchestration] Created thread: {writer_thread.session_id}")
# First run: Generate an initial inspirational sentence
logger.info("[Orchestration] First agent run: Generating initial sentence...")
initial_response: AgentRunResponse = yield writer.run(
messages="Write a concise inspirational sentence about learning.",
thread=writer_thread,
)
logger.info(f"[Orchestration] Initial response: {initial_response.text}")
# Second run: Refine the initial response on the same thread
improved_prompt = (
f"Improve this further while keeping it under 25 words: "
f"{initial_response.text}"
)
logger.info("[Orchestration] Second agent run: Refining the sentence...")
refined_response: AgentRunResponse = yield writer.run(
messages=improved_prompt,
thread=writer_thread,
)
logger.info(f"[Orchestration] Refined response: {refined_response.text}")
logger.info("[Orchestration] Chaining complete")
return refined_response.text
async def main():
"""Main entry point for the worker process."""
logger.info("Starting Durable Task Single Agent Chaining Worker with Orchestration...")
# Get environment variables for taskhub and endpoint with defaults
taskhub_name = os.getenv("TASKHUB", "default")
endpoint = os.getenv("ENDPOINT", "http://localhost:8080")
logger.info(f"Using taskhub: {taskhub_name}")
logger.info(f"Using endpoint: {endpoint}")
# Set credential to None for emulator, or DefaultAzureCredential for Azure
credential = None if endpoint == "http://localhost:8080" else DefaultAzureCredential()
# Create a worker using Azure Managed Durable Task
worker = DurableTaskSchedulerWorker(
host_address=endpoint,
secure_channel=endpoint != "http://localhost:8080",
taskhub=taskhub_name,
token_credential=credential
)
# Wrap it with the agent worker
agent_worker = DurableAIAgentWorker(worker)
# Create and register the Writer agent
logger.info("Creating and registering Writer agent...")
writer_agent = create_writer_agent()
agent_worker.add_agent(writer_agent)
logger.info(f"✓ Registered agent: {writer_agent.name}")
logger.info(f" Entity name: dafx-{writer_agent.name}")
logger.info("")
# Register the orchestration function
logger.info("Registering orchestration function...")
worker.add_orchestrator(single_agent_chaining_orchestration)
logger.info(f"✓ Registered orchestration: {single_agent_chaining_orchestration.__name__}")
logger.info("")
logger.info("Worker is ready and listening for requests...")
logger.info("Press Ctrl+C to stop.")
logger.info("")
try:
# Start the worker (this blocks until stopped)
worker.start()
# Keep the worker running
while True:
await asyncio.sleep(1)
except KeyboardInterrupt:
logger.info("Worker shutdown initiated")
logger.info("Worker stopped")
if __name__ == "__main__":
asyncio.run(main())
@@ -0,0 +1,75 @@
# Multi-Agent Orchestration with Concurrency Sample
This sample demonstrates how to host multiple agents and run them concurrently using a durable orchestration, aggregating their responses into a single result.
## Key Concepts Demonstrated
- Running multiple specialized agents in parallel within an orchestration.
- Using `OrchestrationAgentExecutor` to get `DurableAgentTask` objects for concurrent execution.
- Aggregating results from multiple agents using `task.when_all()`.
- Creating separate conversation threads for independent agent contexts.
## Environment Setup
See the [README.md](../README.md) file in the parent directory for more information on how to configure the environment, including how to install and run common sample dependencies.
## Running the Sample
With the environment setup, you can run the sample using one of two approaches:
### Option 1: Combined Worker + Client (Quick Start)
```bash
cd samples/getting_started/durabletask/05_multi_agent_orchestration_concurrency
python sample.py
```
This runs both worker and client in a single process.
### Option 2: Separate Worker and Client
**Start the worker in one terminal:**
```bash
python worker.py
```
**In a new terminal, run the client:**
```bash
python client.py
```
The orchestration will execute both agents concurrently, and you'll see output like:
```
Prompt: What is temperature?
Starting multi-agent concurrent orchestration...
Orchestration started with instance ID: abc123...
Orchestration status: COMPLETED
Results:
Physicist's response:
Temperature measures the average kinetic energy of particles in a system...
Chemist's response:
Temperature reflects how molecular motion influences reaction rates...
```
## Viewing Orchestration State
You can view the state of the orchestration in the Durable Task Scheduler dashboard:
1. Open your browser and navigate to `http://localhost:8082`
2. In the dashboard, you can view the orchestration instance, including:
- The concurrent execution of both agents (Physicist and Chemist)
- Separate conversation threads for each agent
- Parallel task execution and completion timing
- Aggregated results from both agents
- Overall orchestration state and history
The orchestration demonstrates how multiple agents can be executed in parallel, with results collected and aggregated once all agents complete.
@@ -0,0 +1,114 @@
"""Client application for starting a multi-agent concurrent orchestration.
This client connects to the Durable Task Scheduler and starts an orchestration
that runs two agents (physicist and chemist) concurrently, then retrieves and
displays the aggregated results.
Prerequisites:
- The worker must be running with both agents and orchestration registered
- Set AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_CHAT_DEPLOYMENT_NAME
(plus AZURE_OPENAI_API_KEY or Azure CLI authentication)
- Durable Task Scheduler must be running
"""
import asyncio
import json
import logging
import os
from azure.identity import DefaultAzureCredential
from durabletask.azuremanaged.client import DurableTaskSchedulerClient
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
async def main() -> None:
"""Main entry point for the client application."""
logger.info("Starting Durable Task Multi-Agent Orchestration Client...")
# Get environment variables for taskhub and endpoint with defaults
taskhub_name = os.getenv("TASKHUB", "default")
endpoint = os.getenv("ENDPOINT", "http://localhost:8080")
logger.info(f"Using taskhub: {taskhub_name}")
logger.info(f"Using endpoint: {endpoint}")
logger.info("")
# Set credential to None for emulator, or DefaultAzureCredential for Azure
credential = None if endpoint == "http://localhost:8080" else DefaultAzureCredential()
# Create a client using Azure Managed Durable Task
client = DurableTaskSchedulerClient(
host_address=endpoint,
secure_channel=endpoint != "http://localhost:8080",
taskhub=taskhub_name,
token_credential=credential
)
# Define the prompt to send to both agents
prompt = "What is temperature?"
logger.info(f"Prompt: {prompt}")
logger.info("")
logger.info("Starting multi-agent concurrent orchestration...")
try:
# Start the orchestration with the prompt as input
instance_id = client.schedule_new_orchestration(
orchestrator="multi_agent_concurrent_orchestration",
input=prompt,
)
logger.info(f"Orchestration started with instance ID: {instance_id}")
logger.info("Waiting for orchestration to complete...")
logger.info("")
# Retrieve the final state
metadata = client.wait_for_orchestration_completion(
instance_id=instance_id,
)
if metadata and metadata.runtime_status.name == "COMPLETED":
result = metadata.serialized_output
logger.info("=" * 80)
logger.info("Orchestration completed successfully!")
logger.info("=" * 80)
logger.info("")
logger.info(f"Prompt: {prompt}")
logger.info("")
logger.info("Results:")
logger.info("")
# Parse and display the result
if result:
result_dict = json.loads(result)
logger.info("Physicist's response:")
logger.info(f" {result_dict.get('physicist', 'N/A')}")
logger.info("")
logger.info("Chemist's response:")
logger.info(f" {result_dict.get('chemist', 'N/A')}")
logger.info("")
logger.info("=" * 80)
elif metadata:
logger.error(f"Orchestration ended with status: {metadata.runtime_status.name}")
if metadata.serialized_output:
logger.error(f"Output: {metadata.serialized_output}")
else:
logger.error("Orchestration did not complete within the timeout period")
except Exception as e:
logger.exception(f"Error during orchestration: {e}")
finally:
logger.info("")
logger.info("Client shutting down")
if __name__ == "__main__":
asyncio.run(main())
@@ -0,0 +1,6 @@
# Agent Framework packages (installing from local package until a package is published)
-e ../../../../
-e ../../../../packages/durabletask
# Azure authentication
azure-identity
@@ -0,0 +1,266 @@
"""Multi-Agent Orchestration Sample - Durable Task Integration (Combined Worker + Client)
This sample demonstrates running both the worker and client in a single process for
concurrent multi-agent orchestration. The worker registers two domain-specific agents
(physicist and chemist) and an orchestration function that runs them in parallel.
The orchestration uses OrchestrationAgentExecutor to execute agents concurrently
and aggregate their responses.
Prerequisites:
- Set AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_CHAT_DEPLOYMENT_NAME
(plus AZURE_OPENAI_API_KEY or Azure CLI authentication)
- Durable Task Scheduler must be running (e.g., using Docker)
To run this sample:
python sample.py
"""
import asyncio
import json
import logging
import os
from collections.abc import Generator
from typing import Any
from agent_framework import AgentRunResponse
from agent_framework.azure import AzureOpenAIChatClient
from agent_framework_durabletask import DurableAIAgentOrchestrationContext, DurableAIAgentWorker
from azure.identity import AzureCliCredential, DefaultAzureCredential
from dotenv import load_dotenv
from durabletask.task import OrchestrationContext, when_all, Task
from durabletask.azuremanaged.client import DurableTaskSchedulerClient
from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Agent names
PHYSICIST_AGENT_NAME = "PhysicistAgent"
CHEMIST_AGENT_NAME = "ChemistAgent"
def create_physicist_agent():
"""Create the Physicist agent using Azure OpenAI.
Returns:
AgentProtocol: The configured Physicist agent
"""
return AzureOpenAIChatClient(credential=AzureCliCredential()).create_agent(
name=PHYSICIST_AGENT_NAME,
instructions="You are an expert in physics. You answer questions from a physics perspective.",
)
def create_chemist_agent():
"""Create the Chemist agent using Azure OpenAI.
Returns:
AgentProtocol: The configured Chemist agent
"""
return AzureOpenAIChatClient(credential=AzureCliCredential()).create_agent(
name=CHEMIST_AGENT_NAME,
instructions="You are an expert in chemistry. You answer questions from a chemistry perspective.",
)
def multi_agent_concurrent_orchestration(context: OrchestrationContext, prompt: str) -> Generator[Task[Any], Any, dict[str, str]]:
"""Orchestration that runs both agents in parallel and aggregates results.
Uses DurableAIAgentOrchestrationContext to wrap the orchestration context and
access agents via the OrchestrationAgentExecutor.
Args:
context: The orchestration context
Returns:
dict: Dictionary with 'physicist' and 'chemist' response texts
"""
logger.info(f"[Orchestration] Starting concurrent execution for prompt: {prompt}")
# Wrap the orchestration context to access agents
agent_context = DurableAIAgentOrchestrationContext(context)
# Get agents using the agent context (returns DurableAIAgent proxies)
physicist = agent_context.get_agent(PHYSICIST_AGENT_NAME)
chemist = agent_context.get_agent(CHEMIST_AGENT_NAME)
# Create separate threads for each agent
physicist_thread = physicist.get_new_thread()
chemist_thread = chemist.get_new_thread()
logger.info(f"[Orchestration] Created threads - Physicist: {physicist_thread.session_id}, Chemist: {chemist_thread.session_id}")
# Create tasks from agent.run() calls - these return DurableAgentTask instances
physicist_task = physicist.run(messages=str(prompt), thread=physicist_thread)
chemist_task = chemist.run(messages=str(prompt), thread=chemist_thread)
logger.info("[Orchestration] Created agent tasks, executing concurrently...")
# Execute both tasks concurrently using task.when_all
# The DurableAgentTask instances wrap the underlying entity calls
task_results = yield when_all([physicist_task, chemist_task])
logger.info("[Orchestration] Both agents completed")
# Extract results from the tasks - DurableAgentTask yields AgentRunResponse
physicist_result: AgentRunResponse = task_results[0]
chemist_result: AgentRunResponse = task_results[1]
result = {
"physicist": physicist_result.text,
"chemist": chemist_result.text,
}
logger.info(f"[Orchestration] Aggregated results ready")
return result
async def run_client(endpoint: str, taskhub_name: str, credential: DefaultAzureCredential | None, prompt: str):
"""Run the client to start and monitor the orchestration.
Args:
endpoint: The durable task scheduler endpoint
taskhub_name: The task hub name
credential: The credential for authentication
prompt: The prompt to send to both agents
"""
logger.info("")
logger.info("=" * 80)
logger.info("CLIENT: Starting orchestration...")
logger.info("=" * 80)
# Create a client
client = DurableTaskSchedulerClient(
host_address=endpoint,
secure_channel=endpoint != "http://localhost:8080",
taskhub=taskhub_name,
token_credential=credential
)
logger.info(f"Prompt: {prompt}")
logger.info("")
try:
# Start the orchestration with the prompt as input
instance_id = client.schedule_new_orchestration(
multi_agent_concurrent_orchestration,
input=prompt,
)
logger.info(f"Orchestration started with instance ID: {instance_id}")
logger.info("Waiting for orchestration to complete...")
logger.info("")
# Retrieve the final state
metadata = client.wait_for_orchestration_completion(
instance_id=instance_id
)
if metadata and metadata.runtime_status.name == "COMPLETED":
result = metadata.serialized_output
logger.info("")
logger.info("=" * 80)
logger.info("ORCHESTRATION COMPLETED SUCCESSFULLY!")
logger.info("=" * 80)
logger.info("")
logger.info(f"Prompt: {prompt}")
logger.info("")
logger.info("Results:")
logger.info("")
# Parse and display the result
if result:
result_dict = json.loads(result)
logger.info("Physicist's response:")
logger.info(f" {result_dict.get('physicist', 'N/A')}")
logger.info("")
logger.info("Chemist's response:")
logger.info(f" {result_dict.get('chemist', 'N/A')}")
logger.info("")
logger.info("=" * 80)
elif metadata:
logger.error(f"Orchestration ended with status: {metadata.runtime_status.name}")
if metadata.serialized_output:
logger.error(f"Output: {metadata.serialized_output}")
else:
logger.error("Orchestration did not complete within the timeout period")
except Exception as e:
logger.exception(f"Error during orchestration: {e}")
def main():
"""Main entry point - runs both worker and client in single process."""
logger.info("Starting Durable Task Multi-Agent Orchestration Sample (Combined Worker + Client)...")
# Get environment variables for taskhub and endpoint with defaults
taskhub_name = os.getenv("TASKHUB", "default")
endpoint = os.getenv("ENDPOINT", "http://localhost:8080")
logger.info(f"Using taskhub: {taskhub_name}")
logger.info(f"Using endpoint: {endpoint}")
logger.info("")
# Set credential to None for emulator, or DefaultAzureCredential for Azure
credential = None if endpoint == "http://localhost:8080" else DefaultAzureCredential()
secure_channel = endpoint != "http://localhost:8080"
# Create and start the worker using a context manager
with DurableTaskSchedulerWorker(
host_address=endpoint,
secure_channel=secure_channel,
taskhub=taskhub_name,
token_credential=credential
) as worker:
# Wrap with the agent worker
agent_worker = DurableAIAgentWorker(worker)
# Create and register both agents
logger.info("Creating and registering agents...")
physicist_agent = create_physicist_agent()
chemist_agent = create_chemist_agent()
agent_worker.add_agent(physicist_agent)
agent_worker.add_agent(chemist_agent)
logger.info(f"✓ Registered agent: {physicist_agent.name}")
logger.info(f" Entity name: dafx-{physicist_agent.name}")
logger.info(f"✓ Registered agent: {chemist_agent.name}")
logger.info(f" Entity name: dafx-{chemist_agent.name}")
logger.info("")
# Register the orchestration function
logger.info("Registering orchestration function...")
worker.add_orchestrator(multi_agent_concurrent_orchestration)
logger.info(f"✓ Registered orchestration: {multi_agent_concurrent_orchestration.__name__}")
logger.info("")
# Start the worker
worker.start()
logger.info("Worker started and listening for requests...")
# Define the prompt
prompt = "What is temperature?"
try:
# Run the client to start the orchestration
asyncio.run(run_client(endpoint, taskhub_name, credential, prompt))
except Exception as e:
logger.exception(f"Error during sample execution: {e}")
logger.info("")
logger.info("Sample completed. Worker shutting down...")
if __name__ == "__main__":
load_dotenv()
main()
@@ -0,0 +1,175 @@
"""Worker process for hosting multiple agents with orchestration using Durable Task.
This worker registers two domain-specific agents (physicist and chemist) and an orchestration
function that runs them concurrently. The orchestration uses OrchestrationAgentExecutor
to execute agents in parallel and aggregate their responses.
Prerequisites:
- Set AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_CHAT_DEPLOYMENT_NAME
(plus AZURE_OPENAI_API_KEY or Azure CLI authentication)
- Start a Durable Task Scheduler (e.g., using Docker)
"""
import asyncio
from collections.abc import Generator
import logging
import os
from typing import Any
from agent_framework import AgentRunResponse
from agent_framework.azure import AzureOpenAIChatClient
from agent_framework_durabletask import DurableAIAgentOrchestrationContext, DurableAIAgentWorker
from azure.identity import AzureCliCredential, DefaultAzureCredential
from durabletask.task import OrchestrationContext, when_all, Task
from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Agent names
PHYSICIST_AGENT_NAME = "PhysicistAgent"
CHEMIST_AGENT_NAME = "ChemistAgent"
def create_physicist_agent():
"""Create the Physicist agent using Azure OpenAI.
Returns:
AgentProtocol: The configured Physicist agent
"""
return AzureOpenAIChatClient(credential=AzureCliCredential()).create_agent(
name=PHYSICIST_AGENT_NAME,
instructions="You are an expert in physics. You answer questions from a physics perspective.",
)
def create_chemist_agent():
"""Create the Chemist agent using Azure OpenAI.
Returns:
AgentProtocol: The configured Chemist agent
"""
return AzureOpenAIChatClient(credential=AzureCliCredential()).create_agent(
name=CHEMIST_AGENT_NAME,
instructions="You are an expert in chemistry. You answer questions from a chemistry perspective.",
)
def multi_agent_concurrent_orchestration(context: OrchestrationContext, prompt: str) -> Generator[Task[Any], Any, dict[str, str]]:
"""Orchestration that runs both agents in parallel and aggregates results.
Uses DurableAIAgentOrchestrationContext to wrap the orchestration context and
access agents via the OrchestrationAgentExecutor.
Args:
context: The orchestration context
Returns:
dict: Dictionary with 'physicist' and 'chemist' response texts
"""
logger.info(f"[Orchestration] Starting concurrent execution for prompt: {prompt}")
# Wrap the orchestration context to access agents
agent_context = DurableAIAgentOrchestrationContext(context)
# Get agents using the agent context (returns DurableAIAgent proxies)
physicist = agent_context.get_agent(PHYSICIST_AGENT_NAME)
chemist = agent_context.get_agent(CHEMIST_AGENT_NAME)
# Create separate threads for each agent
physicist_thread = physicist.get_new_thread()
chemist_thread = chemist.get_new_thread()
logger.info(f"[Orchestration] Created threads - Physicist: {physicist_thread.session_id}, Chemist: {chemist_thread.session_id}")
# Create tasks from agent.run() calls - these return DurableAgentTask instances
physicist_task = physicist.run(messages=str(prompt), thread=physicist_thread)
chemist_task = chemist.run(messages=str(prompt), thread=chemist_thread)
logger.info("[Orchestration] Created agent tasks, executing concurrently...")
# Execute both tasks concurrently using when_all
# The DurableAgentTask instances wrap the underlying entity calls
task_results = yield when_all([physicist_task, chemist_task])
logger.info("[Orchestration] Both agents completed")
# Extract results from the tasks - DurableAgentTask yields AgentRunResponse
physicist_result: AgentRunResponse = task_results[0]
chemist_result: AgentRunResponse = task_results[1]
result = {
"physicist": physicist_result.text,
"chemist": chemist_result.text,
}
logger.info(f"[Orchestration] Aggregated results ready")
return result
async def main():
"""Main entry point for the worker process."""
logger.info("Starting Durable Task Multi-Agent Worker with Orchestration...")
# Get environment variables for taskhub and endpoint with defaults
taskhub_name = os.getenv("TASKHUB", "default")
endpoint = os.getenv("ENDPOINT", "http://localhost:8080")
logger.info(f"Using taskhub: {taskhub_name}")
logger.info(f"Using endpoint: {endpoint}")
# Set credential to None for emulator, or DefaultAzureCredential for Azure
credential = None if endpoint == "http://localhost:8080" else DefaultAzureCredential()
# Create a worker using Azure Managed Durable Task
worker = DurableTaskSchedulerWorker(
host_address=endpoint,
secure_channel=endpoint != "http://localhost:8080",
taskhub=taskhub_name,
token_credential=credential
)
# Wrap it with the agent worker
agent_worker = DurableAIAgentWorker(worker)
# Create and register both agents
logger.info("Creating and registering agents...")
physicist_agent = create_physicist_agent()
chemist_agent = create_chemist_agent()
agent_worker.add_agent(physicist_agent)
agent_worker.add_agent(chemist_agent)
logger.info(f"✓ Registered agent: {physicist_agent.name}")
logger.info(f" Entity name: dafx-{physicist_agent.name}")
logger.info(f"✓ Registered agent: {chemist_agent.name}")
logger.info(f" Entity name: dafx-{chemist_agent.name}")
logger.info("")
# Register the orchestration function
logger.info("Registering orchestration function...")
worker.add_orchestrator(multi_agent_concurrent_orchestration)
logger.info(f"✓ Registered orchestration: {multi_agent_concurrent_orchestration.__name__}")
logger.info("")
logger.info("Worker is ready and listening for requests...")
logger.info("Press Ctrl+C to stop.")
logger.info("")
try:
# Start the worker (this blocks until stopped)
worker.start()
# Keep the worker running
while True:
await asyncio.sleep(1)
except KeyboardInterrupt:
logger.info("Worker shutdown initiated")
logger.info("Worker stopped")
if __name__ == "__main__":
asyncio.run(main())
@@ -0,0 +1,124 @@
# Durable Task Samples
This directory contains samples for durable agent hosting using the Durable Task Scheduler. These samples demonstrate the worker-client architecture pattern, enabling distributed agent execution with persistent conversation state.
- **[01_single_agent](01_single_agent/)**: A sample that demonstrates how to host a single conversational agent using the Durable Task Scheduler and interact with it via a client.
- **[04_single_agent_orchestration_chaining](04_single_agent_orchestration_chaining/)**: A sample that demonstrates how to chain multiple invocations of the same agent using a durable orchestration.
- **[05_multi_agent_orchestration_concurrency](05_multi_agent_orchestration_concurrency/)**: A sample that demonstrates how to host multiple agents and run them concurrently using a durable orchestration.
## Running the Samples
These samples are designed to be run locally in a cloned repository.
### Prerequisites
The following prerequisites are required to run the samples:
- [Python 3.9 or later](https://www.python.org/downloads/)
- [Azure CLI](https://learn.microsoft.com/cli/azure/install-azure-cli) installed and authenticated (`az login`) or an API key for the Azure OpenAI service
- [Azure OpenAI Service](https://learn.microsoft.com/azure/ai-services/openai/how-to/create-resource) with a deployed model (gpt-4o-mini or better is recommended)
- [Durable Task Scheduler](https://learn.microsoft.com/azure/azure-functions/durable/durable-task-scheduler/develop-with-durable-task-scheduler) (local emulator or Azure-hosted)
- [Docker](https://docs.docker.com/get-docker/) installed if running the Durable Task Scheduler emulator locally
### Configuring RBAC Permissions for Azure OpenAI
These samples are configured to use the Azure OpenAI service with RBAC permissions to access the model. You'll need to configure the RBAC permissions for the Azure OpenAI service to allow the Python app to access the model.
Below is an example of how to configure the RBAC permissions for the Azure OpenAI service to allow the current user to access the model.
Bash (Linux/macOS/WSL):
```bash
az role assignment create \
--assignee "yourname@contoso.com" \
--role "Cognitive Services OpenAI User" \
--scope /subscriptions/<your-subscription-id>/resourceGroups/<your-resource-group-name>/providers/Microsoft.CognitiveServices/accounts/<your-openai-resource-name>
```
PowerShell:
```powershell
az role assignment create `
--assignee "yourname@contoso.com" `
--role "Cognitive Services OpenAI User" `
--scope /subscriptions/<your-subscription-id>/resourceGroups/<your-resource-group-name>/providers/Microsoft.CognitiveServices/accounts/<your-openai-resource-name>
```
More information on how to configure RBAC permissions for Azure OpenAI can be found in the [Azure OpenAI documentation](https://learn.microsoft.com/azure/ai-services/openai/how-to/create-resource?pivots=cli).
### Setting an API key for the Azure OpenAI service
As an alternative to configuring Azure RBAC permissions, you can set an API key for the Azure OpenAI service by setting the `AZURE_OPENAI_API_KEY` environment variable.
Bash (Linux/macOS/WSL):
```bash
export AZURE_OPENAI_API_KEY="your-api-key"
```
PowerShell:
```powershell
$env:AZURE_OPENAI_API_KEY="your-api-key"
```
### Start Durable Task Scheduler
Most samples use the Durable Task Scheduler (DTS) to support hosted agents and durable orchestrations. DTS also allows you to view the status of orchestrations and their inputs and outputs from a web UI.
To run the Durable Task Scheduler locally, you can use the following `docker` command:
```bash
docker run -d --name dts-emulator -p 8080:8080 -p 8082:8082 mcr.microsoft.com/dts/dts-emulator:latest
```
The DTS dashboard will be available at `http://localhost:8082`.
### Environment Configuration
Each sample reads configuration from environment variables. You'll need to set the following environment variables:
```bash
export AZURE_OPENAI_ENDPOINT="https://your-resource.openai.azure.com/"
export AZURE_OPENAI_CHAT_DEPLOYMENT_NAME="your-deployment-name"
```
### Installing Dependencies
Navigate to the sample directory and install dependencies:
```bash
cd samples/getting_started/durabletask/01_single_agent
pip install -r requirements.txt
```
### Running the Samples
Each sample follows a worker-client architecture. Most samples provide separate `worker.py` and `client.py` files, though some include a combined `sample.py` for convenience.
**Running with separate worker and client:**
In one terminal, start the worker:
```bash
python worker.py
```
In another terminal, run the client:
```bash
python client.py
```
**Running with combined sample:**
```bash
python sample.py
```
### Viewing the Sample Output
The sample output is displayed directly in the terminal where you ran the Python script. Agent responses are printed to stdout with log formatting for better readability.
You can also see the state of agents and orchestrations in the Durable Task Scheduler dashboard at `http://localhost:8082`.