Python: [Breaking] Python: Respond with AgentRunResponse with serialized structured output (#2285)

* Respond with AgentRunResponse

* Fix response_Format type

* Address comments

* Fix tests

* Fix log

* Addressed comments

* Code cleanup

* Use AgentTask vs Generator

* Address comments

* use lazy logging

* fix mypy errors
This commit is contained in:
Laveesh Rohra
2025-11-26 10:44:28 -08:00
committed by GitHub
Unverified
parent 0f2c5e6cb8
commit 306c81aef8
12 changed files with 484 additions and 432 deletions
@@ -53,7 +53,7 @@ from agent_framework import (
)
from dateutil import parser as date_parser
from ._models import RunRequest, _serialize_response_format
from ._models import RunRequest, serialize_response_format
logger = get_logger("agent_framework.azurefunctions.durable_agent_state")
@@ -494,7 +494,7 @@ class DurableAgentStateRequest(DurableAgentStateEntry):
messages=[DurableAgentStateMessage.from_run_request(request)],
created_at=datetime.now(tz=timezone.utc),
response_type=request.request_response_format,
response_schema=_serialize_response_format(request.response_format),
response_schema=serialize_response_format(request.response_format),
)
@@ -9,9 +9,7 @@ allows for long-running agent conversations.
import asyncio
import inspect
import json
from collections.abc import AsyncIterable, Callable
from datetime import datetime, timezone
from typing import Any, cast
import azure.durable_functions as df
@@ -30,11 +28,10 @@ from ._durable_agent_state import (
DurableAgentState,
DurableAgentStateData,
DurableAgentStateEntry,
DurableAgentStateMessage,
DurableAgentStateRequest,
DurableAgentStateResponse,
)
from ._models import AgentResponse, RunRequest
from ._models import RunRequest
logger = get_logger("agent_framework.azurefunctions.entities")
@@ -97,7 +94,7 @@ class AgentEntity:
self,
context: df.DurableEntityContext,
request: RunRequest | dict[str, Any] | str,
) -> dict[str, Any]:
) -> AgentRunResponse:
"""Execute the agent with a message directly in the entity.
Args:
@@ -105,13 +102,8 @@ class AgentEntity:
request: RunRequest object, dict, or string message (for backward compatibility)
Returns:
Dict with status information and response (serialized AgentResponse)
Note:
The agent returns an AgentRunResponse object which is stored in state.
This method extracts the text/structured response and returns an AgentResponse dict.
AgentRunResponse enriched with execution metadata.
"""
# Convert string or dict to RunRequest
if isinstance(request, str):
run_request = RunRequest(message=request, role=Role.USER)
elif isinstance(request, dict):
@@ -135,8 +127,6 @@ class AgentEntity:
logger.debug(f"[AgentEntity.run_agent] Received Message: {state_request}")
try:
logger.debug("[AgentEntity.run_agent] Starting agent invocation")
# Build messages from conversation history, excluding error responses
# Error responses are kept in history for tracking but not sent to the agent
chat_messages: list[ChatMessage] = [
@@ -164,83 +154,39 @@ class AgentEntity:
type(agent_run_response).__name__,
)
response_text = None
structured_response = None
response_str: str | None = None
try:
if response_format:
try:
response_str = agent_run_response.text
structured_response = json.loads(response_str)
logger.debug("Parsed structured JSON response")
except json.JSONDecodeError as decode_error:
logger.warning(f"Failed to parse JSON response: {decode_error}")
response_text = response_str
else:
raw_text = agent_run_response.text
response_text = raw_text if raw_text else "No response"
preview = response_text
logger.debug(f"Response: {preview[:100]}..." if len(preview) > 100 else f"Response: {preview}")
response_text = agent_run_response.text if agent_run_response.text else "No response"
logger.debug(f"Response: {response_text[:100]}...")
except Exception as extraction_error:
logger.error(
f"Error extracting response: {extraction_error}",
"Error extracting response text: %s",
extraction_error,
exc_info=True,
)
response_text = "Error extracting response"
state_response = DurableAgentStateResponse.from_run_response(correlation_id, agent_run_response)
self.state.data.conversation_history.append(state_response)
agent_response = AgentResponse(
response=response_text,
message=str(message),
thread_id=str(thread_id),
status="success",
message_count=len(self.state.data.conversation_history),
structured_response=structured_response,
)
result = agent_response.to_dict()
logger.debug("[AgentEntity.run_agent] AgentRunResponse stored in conversation history")
return result
return agent_run_response
except Exception as exc:
import traceback
error_traceback = traceback.format_exc()
logger.error("[AgentEntity.run_agent] Agent execution failed")
logger.error(f"Error: {exc!s}")
logger.error(f"Error type: {type(exc).__name__}")
logger.error(f"Full traceback:\n{error_traceback}")
logger.exception("[AgentEntity.run_agent] Agent execution failed.")
# Create error message
error_message = DurableAgentStateMessage.from_chat_message(
ChatMessage(
role=Role.ASSISTANT, contents=[ErrorContent(message=str(exc), error_code=type(exc).__name__)]
)
error_message = ChatMessage(
role=Role.ASSISTANT, contents=[ErrorContent(message=str(exc), error_code=type(exc).__name__)]
)
error_response = AgentRunResponse(messages=[error_message])
# Create and store error response in conversation history
error_state_response = DurableAgentStateResponse(
correlation_id=correlation_id,
created_at=datetime.now(tz=timezone.utc),
messages=[error_message],
is_error=True,
)
error_state_response = DurableAgentStateResponse.from_run_response(correlation_id, error_response)
error_state_response.is_error = True
self.state.data.conversation_history.append(error_state_response)
error_response = AgentResponse(
response=f"Error: {exc!s}",
message=str(message),
thread_id=str(thread_id),
status="error",
message_count=len(self.state.data.conversation_history),
error=str(exc),
error_type=type(exc).__name__,
)
return error_response.to_dict()
return error_response
async def _invoke_agent(
self,
@@ -432,7 +378,7 @@ def create_agent_entity(
request = "" if input_data is None else str(cast(object, input_data))
result = await entity.run_agent(context, request)
context.set_result(result)
context.set_result(result.to_dict())
elif operation == "reset":
entity.reset(context)
@@ -442,15 +388,13 @@ def create_agent_entity(
logger.error("[entity_function] Unknown operation: %s", operation)
context.set_result({"error": f"Unknown operation: {operation}"})
logger.debug("State dict: %s", entity.state.to_dict())
context.set_state(entity.state.to_dict())
serialized_state = entity.state.to_dict()
logger.debug("State dict: %s", serialized_state)
context.set_state(serialized_state)
logger.info(f"[entity_function] Operation {operation} completed successfully")
except Exception as exc:
import traceback
logger.error("[entity_function] Error in entity: %s", exc)
logger.error(f"[entity_function] Traceback:\n{traceback.format_exc()}")
logger.exception("[entity_function] Error executing entity operation %s", exc)
context.set_result({"error": str(exc), "status": "error"})
def entity_function(context: df.DurableEntityContext) -> None:
@@ -213,7 +213,7 @@ class DurableAgentThread(AgentThread):
return thread
def _serialize_response_format(response_format: type[BaseModel] | None) -> Any:
def serialize_response_format(response_format: type[BaseModel] | None) -> Any:
"""Serialize response format for transport across durable function boundaries."""
if response_format is None:
return None
@@ -339,7 +339,7 @@ class RunRequest:
"request_response_format": self.request_response_format,
}
if self.response_format:
result["response_format"] = _serialize_response_format(self.response_format)
result["response_format"] = serialize_response_format(self.response_format)
if self.thread_id:
result["thread_id"] = self.thread_id
if self.correlation_id:
@@ -362,50 +362,3 @@ class RunRequest:
correlation_id=data.get("correlationId"),
created_at=data.get("created_at"),
)
@dataclass
class AgentResponse:
"""Response from agent execution.
Attributes:
response: The agent's text response (or None for structured responses)
message: The original message sent to the agent
thread_id: The thread identifier
status: Status of the execution (success, error, etc.)
message_count: Number of messages in the conversation
error: Error message if status is error
error_type: Type of error if status is error
structured_response: Structured response if response_format was provided
"""
response: str | None
message: str
thread_id: str | None
status: str
message_count: int = 0
error: str | None = None
error_type: str | None = None
structured_response: dict[str, Any] | None = None
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for JSON serialization."""
result: dict[str, Any] = {
"message": self.message,
"thread_id": self.thread_id,
"status": self.status,
"message_count": self.message_count,
}
# Add response or structured_response based on what's available
if self.structured_response is not None:
result["structured_response"] = self.structured_response
elif self.response is not None:
result["response"] = self.response
if self.error:
result["error"] = self.error
if self.error_type:
result["error_type"] = self.error_type
return result
@@ -6,21 +6,148 @@ This module provides support for using agents inside Durable Function orchestrat
"""
import uuid
from collections.abc import AsyncIterator
from collections.abc import AsyncIterator, Callable
from typing import TYPE_CHECKING, Any, TypeAlias, cast
from agent_framework import AgentProtocol, AgentRunResponseUpdate, AgentThread, ChatMessage, get_logger
from agent_framework import (
AgentProtocol,
AgentRunResponse,
AgentRunResponseUpdate,
AgentThread,
ChatMessage,
get_logger,
)
from azure.durable_functions.models import TaskBase
from azure.durable_functions.models.Task import CompoundTask, TaskState
from pydantic import BaseModel
from ._models import AgentSessionId, DurableAgentThread, RunRequest
logger = get_logger("agent_framework.azurefunctions.orchestration")
if TYPE_CHECKING:
from azure.durable_functions import DurableOrchestrationContext as _DurableOrchestrationContext
CompoundActionConstructor: TypeAlias = Callable[[list[Any]], Any] | None
AgentOrchestrationContextType: TypeAlias = _DurableOrchestrationContext
if TYPE_CHECKING:
from azure.durable_functions import DurableOrchestrationContext
class _TypedCompoundTask(CompoundTask): # type: ignore[misc]
_first_error: Any
def __init__(
self,
tasks: list[TaskBase],
compound_action_constructor: CompoundActionConstructor = None,
) -> None: ...
AgentOrchestrationContextType: TypeAlias = DurableOrchestrationContext
else:
AgentOrchestrationContextType = Any
_TypedCompoundTask = CompoundTask
class AgentTask(_TypedCompoundTask):
"""A custom Task that wraps entity calls and provides typed AgentRunResponse results.
This task wraps the underlying entity call task and intercepts its completion
to convert the raw result into a typed AgentRunResponse object.
"""
def __init__(
self,
entity_task: TaskBase,
response_format: type[BaseModel] | None,
correlation_id: str,
):
"""Initialize the AgentTask.
Args:
entity_task: The underlying entity call task
response_format: Optional Pydantic model for response parsing
correlation_id: Correlation ID for logging
"""
super().__init__([entity_task])
self._response_format = response_format
self._correlation_id = correlation_id
# Override action_repr to expose the inner task's action directly
# This ensures compatibility with ReplaySchema V3 which expects Action objects.
self.action_repr = entity_task.action_repr
# Also copy the task ID to match the entity task's identity
self.id = entity_task.id
def try_set_value(self, child: TaskBase) -> None:
"""Transition the AgentTask to a terminal state and set its value to `AgentRunResponse`.
Parameters
----------
child : TaskBase
The entity call task that just completed
"""
if child.state is TaskState.SUCCEEDED:
# Delegate to parent class for standard completion logic
if len(self.pending_tasks) == 0:
# Transform the raw result before setting it
raw_result = child.result
logger.debug(
"[AgentTask] Converting raw result for correlation_id %s",
self._correlation_id,
)
try:
response = self._load_agent_response(raw_result)
if self._response_format is not None:
self._ensure_response_format(
self._response_format,
self._correlation_id,
response,
)
# Set the typed AgentRunResponse as this task's result
self.set_value(is_error=False, value=response)
except Exception as e:
logger.exception(
"[AgentTask] Failed to convert result for correlation_id: %s",
self._correlation_id,
)
self.set_value(is_error=True, value=e)
else:
# If error not handled by the parent, set it explicitly.
if self._first_error is None:
self._first_error = child.result
self.set_value(is_error=True, value=self._first_error)
def _load_agent_response(self, agent_response: AgentRunResponse | dict[str, Any] | None) -> AgentRunResponse:
"""Convert raw payloads into AgentRunResponse instance."""
if agent_response is None:
raise ValueError("agent_response cannot be None")
logger.debug("[load_agent_response] Loading agent response of type: %s", type(agent_response))
if isinstance(agent_response, AgentRunResponse):
return agent_response
if isinstance(agent_response, dict):
logger.debug("[load_agent_response] Converting dict payload using AgentRunResponse.from_dict")
return AgentRunResponse.from_dict(agent_response)
raise TypeError(f"Unsupported type for agent_response: {type(agent_response)}")
def _ensure_response_format(
self,
response_format: type[BaseModel] | None,
correlation_id: str,
response: AgentRunResponse,
) -> None:
"""Ensure the AgentRunResponse value is parsed into the expected response_format."""
if response_format is not None and not isinstance(response.value, response_format):
response.try_parse_value(response_format)
logger.debug(
"[DurableAIAgent] Loaded AgentRunResponse.value for correlation_id %s with type: %s",
correlation_id,
type(response.value).__name__,
)
class DurableAIAgent(AgentProtocol):
@@ -59,7 +186,7 @@ class DurableAIAgent(AgentProtocol):
self._name = agent_name
self._display_name = agent_name
self._description = f"Durable agent proxy for {agent_name}"
logger.debug(f"[DurableAIAgent] Initialized for agent: {agent_name}")
logger.debug("[DurableAIAgent] Initialized for agent: %s", agent_name)
@property
def id(self) -> str:
@@ -81,38 +208,45 @@ class DurableAIAgent(AgentProtocol):
"""Get the description of the agent."""
return self._description
def run(
# We return an AgentTask here which is a TaskBase subclass.
# This is an intentional deviation from AgentProtocol which defines run() as async.
# The AgentTask can be yielded in Durable Functions orchestrations and will provide
# a typed AgentRunResponse result.
def run( # type: ignore[override]
self,
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
*,
thread: AgentThread | None = None,
response_format: type[BaseModel] | None = None,
**kwargs: Any,
) -> Any: # TODO(msft-team): Add a wrapper to respond correctly with `AgentRunResponse`
"""Execute the agent with messages and return a Task for orchestrations.
) -> AgentTask:
"""Execute the agent with messages and return an AgentTask for orchestrations.
This method implements AgentProtocol and returns a Task that can be yielded
in Durable Functions orchestrations.
This method implements AgentProtocol and returns an AgentTask (subclass of TaskBase)
that can be yielded in Durable Functions orchestrations. The task's result will be
a typed AgentRunResponse.
Args:
messages: The message(s) to send to the agent
thread: Optional agent thread for conversation context
**kwargs: Additional arguments (enable_tool_calls, response_format, etc.)
response_format: Optional Pydantic model for response parsing
**kwargs: Additional arguments (enable_tool_calls)
Returns:
Task that will resolve to the agent response
An AgentTask that resolves to an AgentRunResponse when yielded
Example:
@app.orchestration_trigger(context_name="context")
def my_orchestration(context):
agent = app.get_agent(context, "MyAgent")
thread = agent.get_new_thread()
result = yield agent.run("Hello", thread=thread)
response = yield agent.run("Hello", thread=thread)
# response is typed as AgentRunResponse
"""
message_str = self._normalize_messages(messages)
# Extract optional parameters from kwargs
enable_tool_calls = kwargs.get("enable_tool_calls", True)
response_format = kwargs.get("response_format")
# Get the session ID for the entity
if isinstance(thread, DurableAgentThread) and thread.session_id is not None:
@@ -122,7 +256,7 @@ class DurableAIAgent(AgentProtocol):
# This ensures each call gets its own conversation context
session_key = str(self.context.new_uuid())
session_id = AgentSessionId(name=self.agent_name, key=session_key)
logger.warning(f"[DurableAIAgent] No thread provided, created unique session_id: {session_id}")
logger.warning("[DurableAIAgent] No thread provided, created unique session_id: %s", session_id)
# Create entity ID from session ID
entity_id = session_id.to_entity_id()
@@ -130,6 +264,12 @@ class DurableAIAgent(AgentProtocol):
# Generate a deterministic correlation ID for this call
# This is required by the entity and must be unique per call
correlation_id = str(self.context.new_uuid())
logger.debug(
"[DurableAIAgent] Using correlation_id: %s for entity_id: %s for session_id: %s",
correlation_id,
entity_id,
session_id,
)
# Prepare the request using RunRequest model
run_request = RunRequest(
@@ -140,11 +280,24 @@ class DurableAIAgent(AgentProtocol):
response_format=response_format,
)
logger.debug(f"[DurableAIAgent] Calling entity {entity_id} with message: {message_str[:100]}...")
logger.debug("[DurableAIAgent] Calling entity %s with message: %s", entity_id, message_str[:100])
# Call the entity and return the Task directly
# The orchestration will yield this Task
return self.context.call_entity(entity_id, "run_agent", run_request.to_dict())
# Call the entity to get the underlying task
entity_task = self.context.call_entity(entity_id, "run_agent", run_request.to_dict())
# Wrap it in an AgentTask that will convert the result to AgentRunResponse
agent_task = AgentTask(
entity_task=entity_task,
response_format=response_format,
correlation_id=correlation_id,
)
logger.debug(
"[DurableAIAgent] Created AgentTask for correlation_id %s",
correlation_id,
)
return agent_task
def run_stream(
self,
@@ -179,7 +332,7 @@ class DurableAIAgent(AgentProtocol):
thread = DurableAgentThread.from_session_id(session_id, **kwargs)
logger.debug(f"[DurableAIAgent] Created new thread with session_id: {session_id}")
logger.debug("[DurableAIAgent] Created new thread with session_id: %s", session_id)
return thread
def _messages_to_string(self, messages: list[ChatMessage]) -> str:
@@ -10,7 +10,7 @@ from unittest.mock import ANY, AsyncMock, Mock, patch
import azure.durable_functions as df
import azure.functions as func
import pytest
from agent_framework import AgentRunResponse, ChatMessage
from agent_framework import AgentRunResponse, ChatMessage, ErrorContent
from agent_framework_azurefunctions import AgentFunctionApp
from agent_framework_azurefunctions._app import WAIT_FOR_RESPONSE_FIELD, WAIT_FOR_RESPONSE_HEADER
@@ -343,10 +343,8 @@ class TestAgentEntityOperations:
{"message": "Test message", "thread_id": "test-conv-123", "correlationId": "corr-app-entity-1"},
)
assert result["status"] == "success"
assert result["response"] == "Test response"
assert result["message"] == "Test message"
assert result["thread_id"] == "test-conv-123"
assert isinstance(result, AgentRunResponse)
assert result.text == "Test response"
assert entity.state.message_count == 2
async def test_entity_stores_conversation_history(self) -> None:
@@ -591,10 +589,12 @@ class TestErrorHandling:
mock_context, {"message": "Test message", "thread_id": "conv-1", "correlationId": "corr-app-error-1"}
)
assert result["status"] == "error"
assert "error" in result
assert "Agent error" in result["error"]
assert result["error_type"] == "Exception"
assert isinstance(result, AgentRunResponse)
assert len(result.messages) == 1
content = result.messages[0].contents[0]
assert isinstance(content, ErrorContent)
assert "Agent error" in (content.message or "")
assert content.error_code == "Exception"
def test_entity_function_handles_exception(self) -> None:
"""Test that the entity function handles exceptions gracefully."""
@@ -12,7 +12,7 @@ from typing import Any, TypeVar
from unittest.mock import AsyncMock, Mock, patch
import pytest
from agent_framework import AgentRunResponse, AgentRunResponseUpdate, ChatMessage, Role
from agent_framework import AgentRunResponse, AgentRunResponseUpdate, ChatMessage, ErrorContent, Role
from pydantic import BaseModel
from agent_framework_azurefunctions._durable_agent_state import (
@@ -133,10 +133,8 @@ class TestAgentEntityRunAgent:
assert getattr(sent_message.role, "value", sent_message.role) == "user"
# Verify result
assert result["status"] == "success"
assert result["response"] == "Test response"
assert result["message"] == "Test message"
assert result["thread_id"] == "conv-123"
assert isinstance(result, AgentRunResponse)
assert result.text == "Test response"
async def test_run_agent_streaming_callbacks_invoked(self) -> None:
"""Ensure streaming updates trigger callbacks and run() is not used."""
@@ -168,8 +166,8 @@ class TestAgentEntityRunAgent:
},
)
assert result["status"] == "success"
assert "Hello" in result.get("response", "")
assert isinstance(result, AgentRunResponse)
assert "Hello" in result.text
assert callback.stream_mock.await_count == len(updates)
assert callback.response_mock.await_count == 1
mock_agent.run.assert_not_called()
@@ -215,8 +213,8 @@ class TestAgentEntityRunAgent:
},
)
assert result["status"] == "success"
assert result.get("response") == "Final response"
assert isinstance(result, AgentRunResponse)
assert result.text == "Final response"
assert callback.stream_mock.await_count == 0
assert callback.response_mock.await_count == 1
@@ -294,44 +292,6 @@ class TestAgentEntityRunAgent:
mock_context, {"message": "Message", "thread_id": None, "correlationId": "corr-entity-5"}
)
async def test_run_agent_handles_response_without_text_attribute(self) -> None:
"""Test that run_agent handles responses without a text attribute."""
mock_agent = Mock()
class NoTextResponse(AgentRunResponse):
@property
def text(self) -> str: # type: ignore[override]
raise AttributeError("text attribute missing")
mock_response = NoTextResponse(messages=[ChatMessage(role="assistant", text="ignored")])
mock_agent.run = AsyncMock(return_value=mock_response)
entity = AgentEntity(mock_agent)
mock_context = Mock()
result = await entity.run_agent(
mock_context, {"message": "Message", "thread_id": "conv-1", "correlationId": "corr-entity-6"}
)
# Should handle gracefully
assert result["status"] == "success"
assert result["response"] == "Error extracting response"
async def test_run_agent_handles_none_response_text(self) -> None:
"""Test that run_agent handles responses with None text."""
mock_agent = Mock()
mock_agent.run = AsyncMock(return_value=_agent_response(None))
entity = AgentEntity(mock_agent)
mock_context = Mock()
result = await entity.run_agent(
mock_context, {"message": "Message", "thread_id": "conv-1", "correlationId": "corr-entity-7"}
)
assert result["status"] == "success"
assert result["response"] == "No response"
async def test_run_agent_multiple_conversations(self) -> None:
"""Test that run_agent maintains history across multiple messages."""
mock_agent = Mock()
@@ -621,10 +581,12 @@ class TestErrorHandling:
mock_context, {"message": "Message", "thread_id": "conv-1", "correlationId": "corr-entity-error-1"}
)
assert result["status"] == "error"
assert "error" in result
assert "Agent failed" in result["error"]
assert result["error_type"] == "Exception"
assert isinstance(result, AgentRunResponse)
assert len(result.messages) == 1
content = result.messages[0].contents[0]
assert isinstance(content, ErrorContent)
assert "Agent failed" in (content.message or "")
assert content.error_code == "Exception"
async def test_run_agent_handles_value_error(self) -> None:
"""Test that run_agent handles ValueError instances."""
@@ -638,9 +600,12 @@ class TestErrorHandling:
mock_context, {"message": "Message", "thread_id": "conv-1", "correlationId": "corr-entity-error-2"}
)
assert result["status"] == "error"
assert result["error_type"] == "ValueError"
assert "Invalid input" in result["error"]
assert isinstance(result, AgentRunResponse)
assert len(result.messages) == 1
content = result.messages[0].contents[0]
assert isinstance(content, ErrorContent)
assert content.error_code == "ValueError"
assert "Invalid input" in str(content.message)
async def test_run_agent_handles_timeout_error(self) -> None:
"""Test that run_agent handles TimeoutError instances."""
@@ -654,8 +619,11 @@ class TestErrorHandling:
mock_context, {"message": "Message", "thread_id": "conv-1", "correlationId": "corr-entity-error-3"}
)
assert result["status"] == "error"
assert result["error_type"] == "TimeoutError"
assert isinstance(result, AgentRunResponse)
assert len(result.messages) == 1
content = result.messages[0].contents[0]
assert isinstance(content, ErrorContent)
assert content.error_code == "TimeoutError"
def test_entity_function_handles_exception_in_operation(self) -> None:
"""Test that the entity function handles exceptions gracefully."""
@@ -690,9 +658,10 @@ class TestErrorHandling:
)
# Even on error, message info should be preserved
assert result["message"] == "Test message"
assert result["thread_id"] == "conv-123"
assert result["status"] == "error"
assert isinstance(result, AgentRunResponse)
assert len(result.messages) == 1
content = result.messages[0].contents[0]
assert isinstance(content, ErrorContent)
class TestConversationHistory:
@@ -800,10 +769,8 @@ class TestRunRequestSupport:
result = await entity.run_agent(mock_context, request)
assert result["status"] == "success"
assert result["response"] == "Response"
assert result["message"] == "Test message"
assert result["thread_id"] == "conv-123"
assert isinstance(result, AgentRunResponse)
assert result.text == "Response"
async def test_run_agent_with_dict_request(self) -> None:
"""Test run_agent with a dictionary request."""
@@ -823,9 +790,8 @@ class TestRunRequestSupport:
result = await entity.run_agent(mock_context, request_dict)
assert result["status"] == "success"
assert result["message"] == "Test message"
assert result["thread_id"] == "conv-456"
assert isinstance(result, AgentRunResponse)
assert result.text == "Response"
async def test_run_agent_with_string_raises_without_correlation(self) -> None:
"""Test that run_agent rejects legacy string input without correlation ID."""
@@ -879,10 +845,9 @@ class TestRunRequestSupport:
result = await entity.run_agent(mock_context, request)
assert result["status"] == "success"
# Should have structured_response
if "structured_response" in result:
assert result["structured_response"]["answer"] == 42
assert isinstance(result, AgentRunResponse)
assert result.text == '{"answer": 42}'
assert result.value is None
async def test_run_agent_disable_tool_calls(self) -> None:
"""Test run_agent with tool calls disabled."""
@@ -898,7 +863,7 @@ class TestRunRequestSupport:
result = await entity.run_agent(mock_context, request)
assert result["status"] == "success"
assert isinstance(result, AgentRunResponse)
# Agent should have been called (tool disabling is framework-dependent)
mock_agent.run.assert_called_once()
@@ -925,8 +890,24 @@ class TestRunRequestSupport:
# Verify result was set
assert mock_context.set_result.called
result = mock_context.set_result.call_args[0][0]
assert result["status"] == "success"
assert result["message"] == "Test message"
assert isinstance(result, dict)
# Check if messages are present
assert "messages" in result
assert len(result["messages"]) > 0
message = result["messages"][0]
# Check for text in various possible locations
text_found = False
if "text" in message and message["text"] == "Response":
text_found = True
elif "contents" in message:
for content in message["contents"]:
if isinstance(content, dict) and content.get("text") == "Response":
text_found = True
break
assert text_found, f"Response text not found in message: {message}"
if __name__ == "__main__":
@@ -7,7 +7,7 @@ import pytest
from agent_framework import Role
from pydantic import BaseModel
from agent_framework_azurefunctions._models import AgentResponse, AgentSessionId, RunRequest
from agent_framework_azurefunctions._models import AgentSessionId, RunRequest
class ModuleStructuredResponse(BaseModel):
@@ -337,107 +337,6 @@ class TestRunRequest:
assert restored.thread_id == original.thread_id
class TestAgentResponse:
"""Test suite for AgentResponse."""
def test_init_with_required_fields(self) -> None:
"""Test AgentResponse initialization with required fields."""
response = AgentResponse(
response="Test response", message="Test message", thread_id="thread-123", status="success"
)
assert response.response == "Test response"
assert response.message == "Test message"
assert response.thread_id == "thread-123"
assert response.status == "success"
assert response.message_count == 0
assert response.error is None
assert response.error_type is None
assert response.structured_response is None
def test_init_with_all_fields(self) -> None:
"""Test AgentResponse initialization with all fields."""
structured = {"answer": "42"}
response = AgentResponse(
response=None,
message="What is the answer?",
thread_id="thread-456",
status="success",
message_count=5,
error=None,
error_type=None,
structured_response=structured,
)
assert response.response is None
assert response.structured_response == structured
assert response.message_count == 5
def test_to_dict_with_text_response(self) -> None:
"""Test to_dict with text response."""
response = AgentResponse(
response="Text response", message="Message", thread_id="thread-1", status="success", message_count=3
)
data = response.to_dict()
assert data["response"] == "Text response"
assert data["message"] == "Message"
assert data["thread_id"] == "thread-1"
assert data["status"] == "success"
assert data["message_count"] == 3
assert "structured_response" not in data
assert "error" not in data
assert "error_type" not in data
def test_to_dict_with_structured_response(self) -> None:
"""Test to_dict with structured response."""
structured = {"answer": 42, "confidence": 0.95}
response = AgentResponse(
response=None,
message="Question",
thread_id="thread-2",
status="success",
structured_response=structured,
)
data = response.to_dict()
assert data["structured_response"] == structured
assert "response" not in data
def test_to_dict_with_error(self) -> None:
"""Test to_dict with error."""
response = AgentResponse(
response=None,
message="Failed message",
thread_id="thread-3",
status="error",
error="Something went wrong",
error_type="ValueError",
)
data = response.to_dict()
assert data["status"] == "error"
assert data["error"] == "Something went wrong"
assert data["error_type"] == "ValueError"
def test_to_dict_prefers_structured_over_text(self) -> None:
"""Test to_dict prefers structured_response over response."""
structured = {"result": "structured"}
response = AgentResponse(
response="Text response",
message="Message",
thread_id="thread-4",
status="success",
structured_response=structured,
)
data = response.to_dict()
assert "structured_response" in data
assert data["structured_response"] == structured
# Text response should not be included when structured is present
assert "response" not in data
class TestModelIntegration:
"""Test suite for integration between models."""
@@ -450,21 +349,6 @@ class TestModelIntegration:
assert request.thread_id == str(session_id)
assert request.thread_id.startswith("@AgentEntity@")
def test_response_from_run_request(self) -> None:
"""Test creating AgentResponse from RunRequest."""
request = RunRequest(message="What is 2+2?", thread_id="thread-123", role=Role.USER)
response = AgentResponse(
response="4",
message=request.message,
thread_id=request.thread_id,
status="success",
message_count=1,
)
assert response.message == request.message
assert response.thread_id == request.thread_id
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])
@@ -6,10 +6,12 @@ from typing import Any
from unittest.mock import Mock
import pytest
from agent_framework import AgentThread
from agent_framework import AgentRunResponse, AgentThread, ChatMessage
from azure.durable_functions.models.Task import TaskBase, TaskState
from agent_framework_azurefunctions import AgentFunctionApp, DurableAIAgent
from agent_framework_azurefunctions._models import AgentSessionId, DurableAgentThread
from agent_framework_azurefunctions._orchestration import AgentTask
def _app_with_registered_agents(*agent_names: str) -> AgentFunctionApp:
@@ -21,6 +23,169 @@ def _app_with_registered_agents(*agent_names: str) -> AgentFunctionApp:
return app
class _FakeTask(TaskBase):
"""Concrete TaskBase for testing AgentTask wiring."""
def __init__(self, task_id: int = 1):
super().__init__(task_id, [])
self._set_is_scheduled(False)
self.action_repr = []
self.state = TaskState.RUNNING
def _create_entity_task(task_id: int = 1) -> TaskBase:
"""Create a minimal TaskBase instance for AgentTask tests."""
return _FakeTask(task_id)
class TestAgentResponseHelpers:
"""Tests for helper utilities that prepare AgentRunResponse values."""
@staticmethod
def _create_agent_task() -> AgentTask:
entity_task = _create_entity_task()
return AgentTask(entity_task, None, "correlation-id")
def test_load_agent_response_from_instance(self) -> None:
task = self._create_agent_task()
response = AgentRunResponse(messages=[ChatMessage(role="assistant", text='{"foo": "bar"}')])
loaded = task._load_agent_response(response)
assert loaded is response
assert loaded.value is None
def test_load_agent_response_from_serialized(self) -> None:
task = self._create_agent_task()
serialized = AgentRunResponse(messages=[ChatMessage(role="assistant", text="structured")]).to_dict()
serialized["value"] = {"answer": 42}
loaded = task._load_agent_response(serialized)
assert loaded is not None
assert loaded.value == {"answer": 42}
loaded_dict = loaded.to_dict()
assert loaded_dict["type"] == "agent_run_response"
def test_load_agent_response_rejects_none(self) -> None:
task = self._create_agent_task()
with pytest.raises(ValueError):
task._load_agent_response(None)
def test_load_agent_response_rejects_unsupported_type(self) -> None:
task = self._create_agent_task()
with pytest.raises(TypeError, match="Unsupported type"):
task._load_agent_response(["invalid", "list"]) # type: ignore[arg-type]
def test_try_set_value_success(self) -> None:
"""Test try_set_value correctly processes successful task completion."""
entity_task = _create_entity_task()
task = AgentTask(entity_task, None, "correlation-id")
# Simulate successful entity task completion
entity_task.state = TaskState.SUCCEEDED
entity_task.result = AgentRunResponse(messages=[ChatMessage(role="assistant", text="Test response")]).to_dict()
# Clear pending_tasks to simulate that parent has processed the child
task.pending_tasks.clear()
# Call try_set_value
task.try_set_value(entity_task)
# Verify task completed successfully with AgentRunResponse
assert task.state == TaskState.SUCCEEDED
assert isinstance(task.result, AgentRunResponse)
assert task.result.text == "Test response"
def test_try_set_value_failure(self) -> None:
"""Test try_set_value correctly handles failed task completion."""
entity_task = _create_entity_task()
task = AgentTask(entity_task, None, "correlation-id")
# Simulate failed entity task
entity_task.state = TaskState.FAILED
entity_task.result = Exception("Entity call failed")
# Call try_set_value
task.try_set_value(entity_task)
# Verify task failed with the error
assert task.state == TaskState.FAILED
assert isinstance(task.result, Exception)
assert str(task.result) == "Entity call failed"
def test_try_set_value_with_response_format(self) -> None:
"""Test try_set_value parses structured output when response_format is provided."""
from pydantic import BaseModel
class TestSchema(BaseModel):
answer: str
entity_task = _create_entity_task()
task = AgentTask(entity_task, TestSchema, "correlation-id")
# Simulate successful entity task with JSON response
entity_task.state = TaskState.SUCCEEDED
entity_task.result = AgentRunResponse(
messages=[ChatMessage(role="assistant", text='{"answer": "42"}')]
).to_dict()
# Clear pending_tasks to simulate that parent has processed the child
task.pending_tasks.clear()
# Call try_set_value
task.try_set_value(entity_task)
# Verify task completed and value was parsed
assert task.state == TaskState.SUCCEEDED
assert isinstance(task.result, AgentRunResponse)
assert isinstance(task.result.value, TestSchema)
assert task.result.value.answer == "42"
def test_ensure_response_format_parses_value(self) -> None:
"""Test _ensure_response_format correctly parses response value."""
from pydantic import BaseModel
class SampleSchema(BaseModel):
name: str
task = self._create_agent_task()
response = AgentRunResponse(messages=[ChatMessage(role="assistant", text='{"name": "test"}')])
# Value should be None initially
assert response.value is None
# Parse the value
task._ensure_response_format(SampleSchema, "test-correlation", response)
# Value should now be parsed
assert isinstance(response.value, SampleSchema)
assert response.value.name == "test"
def test_ensure_response_format_skips_if_already_parsed(self) -> None:
"""Test _ensure_response_format does not re-parse if value already matches format."""
from pydantic import BaseModel
class SampleSchema(BaseModel):
name: str
task = self._create_agent_task()
existing_value = SampleSchema(name="existing")
response = AgentRunResponse(
messages=[ChatMessage(role="assistant", text='{"name": "new"}')],
value=existing_value,
)
# Call _ensure_response_format
task._ensure_response_format(SampleSchema, "test-correlation", response)
# Value should remain unchanged (not re-parsed)
assert response.value is existing_value
assert response.value.name == "existing"
class TestDurableAIAgent:
"""Test suite for DurableAIAgent wrapper."""
@@ -111,22 +276,19 @@ class TestDurableAIAgent:
mock_context.instance_id = "test-instance-001"
mock_context.new_uuid = Mock(side_effect=["thread-guid", "correlation-guid"])
# Mock call_entity to return a Task-like object
mock_task = Mock()
mock_task._is_scheduled = False # Task attribute that orchestration checks
mock_context.call_entity = Mock(return_value=mock_task)
entity_task = _create_entity_task()
mock_context.call_entity = Mock(return_value=entity_task)
agent = DurableAIAgent(mock_context, "TestAgent")
# Create thread
thread = agent.get_new_thread()
# Call run() - it should return the Task directly
# Call run() - returns AgentTask directly
task = agent.run(messages="Test message", thread=thread, enable_tool_calls=True)
# Verify run() returns the Task from call_entity
assert task == mock_task
assert isinstance(task, AgentTask)
assert task.children[0] == entity_task
# Verify call_entity was called with correct parameters
assert mock_context.call_entity.called
@@ -145,19 +307,18 @@ class TestDurableAIAgent:
"""Test that run() works without explicit thread (creates unique session key)."""
mock_context = Mock()
mock_context.instance_id = "test-instance-002"
# Two calls to new_uuid: one for session_key, one for correlationId
mock_context.new_uuid = Mock(side_effect=["auto-generated-guid", "correlation-guid"])
mock_task = Mock()
mock_task._is_scheduled = False
mock_context.call_entity = Mock(return_value=mock_task)
entity_task = _create_entity_task()
mock_context.call_entity = Mock(return_value=entity_task)
agent = DurableAIAgent(mock_context, "TestAgent")
# Call without thread
task = agent.run(messages="Test message")
assert task == mock_task
assert isinstance(task, AgentTask)
assert task.children[0] == entity_task
# Verify the entity ID uses the auto-generated GUID with dafx- prefix
call_args = mock_context.call_entity.call_args
@@ -172,9 +333,8 @@ class TestDurableAIAgent:
mock_context = Mock()
mock_context.instance_id = "test-instance-003"
mock_task = Mock()
mock_task._is_scheduled = False
mock_context.call_entity = Mock(return_value=mock_task)
entity_task = _create_entity_task()
mock_context.call_entity = Mock(return_value=entity_task)
agent = DurableAIAgent(mock_context, "TestAgent")
@@ -188,7 +348,8 @@ class TestDurableAIAgent:
task = agent.run(messages="Test message", thread=thread, response_format=SampleSchema)
assert task == mock_task
assert isinstance(task, AgentTask)
assert task.children[0] == entity_task
# Verify schema was passed in the call_entity arguments
call_args = mock_context.call_entity.call_args
@@ -221,8 +382,8 @@ class TestDurableAIAgent:
mock_context = Mock()
mock_context.new_uuid = Mock(side_effect=["thread-guid", "correlation-guid"])
mock_task = Mock()
mock_context.call_entity = Mock(return_value=mock_task)
entity_task = _create_entity_task()
mock_context.call_entity = Mock(return_value=entity_task)
agent = DurableAIAgent(mock_context, "TestAgent")
thread = agent.get_new_thread()
@@ -231,7 +392,8 @@ class TestDurableAIAgent:
msg = ChatMessage(role="user", text="Hello")
task = agent.run(messages=msg, thread=thread)
assert task == mock_task
assert isinstance(task, AgentTask)
assert task.children[0] == entity_task
# Verify message was converted to string
call_args = mock_context.call_entity.call_args
@@ -255,7 +417,7 @@ class TestDurableAIAgent:
mock_context = Mock()
mock_context.new_uuid = Mock(return_value="test-guid-789")
mock_context.call_entity = Mock(return_value=Mock())
mock_context.call_entity = Mock(return_value=_create_entity_task())
agent = DurableAIAgent(mock_context, "WriterAgent")
thread = agent.get_new_thread()
@@ -314,13 +476,9 @@ class TestOrchestrationIntegration:
# Track entity calls
entity_calls: list[dict[str, Any]] = []
def mock_call_entity_side_effect(entity_id: Any, operation: str, input_data: dict[str, Any]) -> Mock:
def mock_call_entity_side_effect(entity_id: Any, operation: str, input_data: dict[str, Any]) -> TaskBase:
entity_calls.append({"entity_id": str(entity_id), "operation": operation, "input": input_data})
# Return a mock Task
mock_task = Mock()
mock_task._is_scheduled = False
return mock_task
return _create_entity_task()
mock_context.call_entity = Mock(side_effect=mock_call_entity_side_effect)
@@ -330,13 +488,13 @@ class TestOrchestrationIntegration:
# Create thread
thread = agent.get_new_thread()
# First call - returns Task
# First call - returns AgentTask
task1 = agent.run("Write something", thread=thread)
assert hasattr(task1, "_is_scheduled")
assert isinstance(task1, AgentTask)
# Second call - returns Task
# Second call - returns AgentTask
task2 = agent.run("Improve: something", thread=thread)
assert hasattr(task2, "_is_scheduled")
assert isinstance(task2, AgentTask)
# Verify both calls used the same entity (same session key)
assert len(entity_calls) == 2
@@ -356,11 +514,9 @@ class TestOrchestrationIntegration:
entity_calls: list[str] = []
def mock_call_entity_side_effect(entity_id: Any, operation: str, input_data: dict[str, Any]) -> Mock:
def mock_call_entity_side_effect(entity_id: Any, operation: str, input_data: dict[str, Any]) -> TaskBase:
entity_calls.append(str(entity_id))
mock_task = Mock()
mock_task._is_scheduled = False
return mock_task
return _create_entity_task()
mock_context.call_entity = Mock(side_effect=mock_call_entity_side_effect)
@@ -371,12 +527,12 @@ class TestOrchestrationIntegration:
writer_thread = writer.get_new_thread()
editor_thread = editor.get_new_thread()
# Call both agents - returns Tasks
# Call both agents - returns AgentTasks
writer_task = writer.run("Write", thread=writer_thread)
editor_task = editor.run("Edit", thread=editor_thread)
assert hasattr(writer_task, "_is_scheduled")
assert hasattr(editor_task, "_is_scheduled")
assert isinstance(writer_task, AgentTask)
assert isinstance(editor_task, AgentTask)
# Verify different entity IDs were used
assert len(entity_calls) == 2
@@ -57,7 +57,7 @@ def single_agent_orchestration(context: DurableOrchestrationContext):
improved_prompt = (
"Improve this further while keeping it under 25 words: "
f"{initial.get('response', '').strip()}"
f"{initial.text}"
)
refined = yield writer.run(
@@ -65,7 +65,7 @@ def single_agent_orchestration(context: DurableOrchestrationContext):
thread=writer_thread,
)
return refined.get("response", "")
return refined.text
# 5. HTTP endpoint to kick off the orchestration and return the status query URI.
@@ -10,8 +10,9 @@ Prerequisites: configure `AZURE_OPENAI_ENDPOINT`, `AZURE_OPENAI_CHAT_DEPLOYMENT_
import json
import logging
from typing import Any
from typing import Any, cast
from agent_framework import AgentRunResponse
import azure.functions as func
from agent_framework.azure import AgentFunctionApp, AzureOpenAIChatClient
from azure.durable_functions import DurableOrchestrationClient, DurableOrchestrationContext
@@ -63,14 +64,19 @@ def multi_agent_concurrent_orchestration(context: DurableOrchestrationContext):
physicist_thread = physicist.get_new_thread()
chemist_thread = chemist.get_new_thread()
# Create tasks from agent.run() calls
physicist_task = physicist.run(messages=str(prompt), thread=physicist_thread)
chemist_task = chemist.run(messages=str(prompt), thread=chemist_thread)
results = yield context.task_all([physicist_task, chemist_task])
# Execute both tasks concurrently using task_all
task_results = yield context.task_all([physicist_task, chemist_task])
physicist_result = cast(AgentRunResponse, task_results[0])
chemist_result = cast(AgentRunResponse, task_results[1])
return {
"physicist": results[0].get("response", ""),
"chemist": results[1].get("response", ""),
"physicist": physicist_result.text,
"chemist": chemist_result.text,
}
@@ -102,7 +102,7 @@ def spam_detection_orchestration(context: DurableOrchestrationContext):
response_format=SpamDetectionResult,
)
spam_result = cast(SpamDetectionResult, _coerce_structured(spam_result_raw, SpamDetectionResult))
spam_result = cast(SpamDetectionResult, spam_result_raw.value)
if spam_result.is_spam:
result = yield context.call_activity("handle_spam_email", spam_result.reason)
@@ -123,7 +123,7 @@ def spam_detection_orchestration(context: DurableOrchestrationContext):
response_format=EmailResponse,
)
email_result = cast(EmailResponse, _coerce_structured(email_result_raw, EmailResponse))
email_result = cast(EmailResponse, email_result_raw.value)
result = yield context.call_activity("send_email", email_result.response)
return result
@@ -231,24 +231,6 @@ def _build_status_url(request_url: str, instance_id: str, *, route: str) -> str:
return f"{base_url}/api/{route}/status/{instance_id}"
def _coerce_structured(result: Mapping[str, Any], model: type[BaseModel]) -> BaseModel:
structured = result.get("structured_response") if isinstance(result, Mapping) else None
if structured is not None:
return model.model_validate(structured)
response_text = result.get("response") if isinstance(result, Mapping) else None
if isinstance(response_text, str) and response_text.strip():
try:
parsed = json.loads(response_text)
if isinstance(parsed, Mapping):
return model.model_validate(parsed)
except json.JSONDecodeError:
logger.warning("[ConditionalOrchestration] Failed to parse agent JSON response; raising error.")
# If parsing failed, raise to surface the issue to the caller.
raise ValueError(f"Agent response could not be parsed as {model.__name__}.")
"""
Expected response from `POST /api/spamdetection/run`:
@@ -100,7 +100,12 @@ def content_generation_hitl_orchestration(context: DurableOrchestrationContext):
thread=writer_thread,
response_format=GeneratedContent,
)
content = _coerce_generated_content(initial_raw)
content = initial_raw.value
logger.info("Type of content after extraction: %s", type(content))
if content is None or not isinstance(content, GeneratedContent):
raise ValueError("Agent returned no content after extraction.")
attempt = 0
while attempt < payload.max_review_attempts:
@@ -142,7 +147,12 @@ def content_generation_hitl_orchestration(context: DurableOrchestrationContext):
thread=writer_thread,
response_format=GeneratedContent,
)
content = _coerce_generated_content(rewritten_raw)
rewritten_value = rewritten_raw.value
if rewritten_value is None or not isinstance(rewritten_value, GeneratedContent):
raise ValueError("Agent returned no content after rewrite.")
content = rewritten_value
else:
context.set_custom_status(
f"Human approval timed out after {payload.approval_timeout_hours} hour(s). Treating as rejection."
@@ -317,23 +327,6 @@ def _build_status_url(request_url: str, instance_id: str, *, route: str) -> str:
return f"{base_url}/api/{route}/status/{instance_id}"
def _coerce_generated_content(result: Mapping[str, Any]) -> GeneratedContent:
structured = result.get("structured_response") if isinstance(result, Mapping) else None
if structured is not None:
return GeneratedContent.model_validate(structured)
response_text = result.get("response") if isinstance(result, Mapping) else None
if isinstance(response_text, str) and response_text.strip():
try:
parsed = json.loads(response_text)
if isinstance(parsed, Mapping):
return GeneratedContent.model_validate(parsed)
except json.JSONDecodeError:
logger.warning("[HITL] Failed to parse agent JSON response; falling back to defaults.")
raise ValueError("Agent response could not be parsed as GeneratedContent.")
def _parse_human_approval(raw: Any) -> HumanApproval:
if isinstance(raw, Mapping):
return HumanApproval.model_validate(raw)