mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: [Breaking] Python: Respond with AgentRunResponse with serialized structured output (#2285)
* Respond with AgentRunResponse * Fix response_Format type * Address comments * Fix tests * Fix log * Addressed comments * Code cleanup * Use AgentTask vs Generator * Address comments * use lazy logging * fix mypy errors
This commit is contained in:
committed by
GitHub
Unverified
parent
0f2c5e6cb8
commit
306c81aef8
+2
-2
@@ -53,7 +53,7 @@ from agent_framework import (
|
||||
)
|
||||
from dateutil import parser as date_parser
|
||||
|
||||
from ._models import RunRequest, _serialize_response_format
|
||||
from ._models import RunRequest, serialize_response_format
|
||||
|
||||
logger = get_logger("agent_framework.azurefunctions.durable_agent_state")
|
||||
|
||||
@@ -494,7 +494,7 @@ class DurableAgentStateRequest(DurableAgentStateEntry):
|
||||
messages=[DurableAgentStateMessage.from_run_request(request)],
|
||||
created_at=datetime.now(tz=timezone.utc),
|
||||
response_type=request.request_response_format,
|
||||
response_schema=_serialize_response_format(request.response_format),
|
||||
response_schema=serialize_response_format(request.response_format),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -9,9 +9,7 @@ allows for long-running agent conversations.
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
from collections.abc import AsyncIterable, Callable
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, cast
|
||||
|
||||
import azure.durable_functions as df
|
||||
@@ -30,11 +28,10 @@ from ._durable_agent_state import (
|
||||
DurableAgentState,
|
||||
DurableAgentStateData,
|
||||
DurableAgentStateEntry,
|
||||
DurableAgentStateMessage,
|
||||
DurableAgentStateRequest,
|
||||
DurableAgentStateResponse,
|
||||
)
|
||||
from ._models import AgentResponse, RunRequest
|
||||
from ._models import RunRequest
|
||||
|
||||
logger = get_logger("agent_framework.azurefunctions.entities")
|
||||
|
||||
@@ -97,7 +94,7 @@ class AgentEntity:
|
||||
self,
|
||||
context: df.DurableEntityContext,
|
||||
request: RunRequest | dict[str, Any] | str,
|
||||
) -> dict[str, Any]:
|
||||
) -> AgentRunResponse:
|
||||
"""Execute the agent with a message directly in the entity.
|
||||
|
||||
Args:
|
||||
@@ -105,13 +102,8 @@ class AgentEntity:
|
||||
request: RunRequest object, dict, or string message (for backward compatibility)
|
||||
|
||||
Returns:
|
||||
Dict with status information and response (serialized AgentResponse)
|
||||
|
||||
Note:
|
||||
The agent returns an AgentRunResponse object which is stored in state.
|
||||
This method extracts the text/structured response and returns an AgentResponse dict.
|
||||
AgentRunResponse enriched with execution metadata.
|
||||
"""
|
||||
# Convert string or dict to RunRequest
|
||||
if isinstance(request, str):
|
||||
run_request = RunRequest(message=request, role=Role.USER)
|
||||
elif isinstance(request, dict):
|
||||
@@ -135,8 +127,6 @@ class AgentEntity:
|
||||
logger.debug(f"[AgentEntity.run_agent] Received Message: {state_request}")
|
||||
|
||||
try:
|
||||
logger.debug("[AgentEntity.run_agent] Starting agent invocation")
|
||||
|
||||
# Build messages from conversation history, excluding error responses
|
||||
# Error responses are kept in history for tracking but not sent to the agent
|
||||
chat_messages: list[ChatMessage] = [
|
||||
@@ -164,83 +154,39 @@ class AgentEntity:
|
||||
type(agent_run_response).__name__,
|
||||
)
|
||||
|
||||
response_text = None
|
||||
structured_response = None
|
||||
response_str: str | None = None
|
||||
|
||||
try:
|
||||
if response_format:
|
||||
try:
|
||||
response_str = agent_run_response.text
|
||||
structured_response = json.loads(response_str)
|
||||
logger.debug("Parsed structured JSON response")
|
||||
except json.JSONDecodeError as decode_error:
|
||||
logger.warning(f"Failed to parse JSON response: {decode_error}")
|
||||
response_text = response_str
|
||||
else:
|
||||
raw_text = agent_run_response.text
|
||||
response_text = raw_text if raw_text else "No response"
|
||||
preview = response_text
|
||||
logger.debug(f"Response: {preview[:100]}..." if len(preview) > 100 else f"Response: {preview}")
|
||||
response_text = agent_run_response.text if agent_run_response.text else "No response"
|
||||
logger.debug(f"Response: {response_text[:100]}...")
|
||||
except Exception as extraction_error:
|
||||
logger.error(
|
||||
f"Error extracting response: {extraction_error}",
|
||||
"Error extracting response text: %s",
|
||||
extraction_error,
|
||||
exc_info=True,
|
||||
)
|
||||
response_text = "Error extracting response"
|
||||
|
||||
state_response = DurableAgentStateResponse.from_run_response(correlation_id, agent_run_response)
|
||||
self.state.data.conversation_history.append(state_response)
|
||||
|
||||
agent_response = AgentResponse(
|
||||
response=response_text,
|
||||
message=str(message),
|
||||
thread_id=str(thread_id),
|
||||
status="success",
|
||||
message_count=len(self.state.data.conversation_history),
|
||||
structured_response=structured_response,
|
||||
)
|
||||
result = agent_response.to_dict()
|
||||
|
||||
logger.debug("[AgentEntity.run_agent] AgentRunResponse stored in conversation history")
|
||||
|
||||
return result
|
||||
return agent_run_response
|
||||
|
||||
except Exception as exc:
|
||||
import traceback
|
||||
|
||||
error_traceback = traceback.format_exc()
|
||||
logger.error("[AgentEntity.run_agent] Agent execution failed")
|
||||
logger.error(f"Error: {exc!s}")
|
||||
logger.error(f"Error type: {type(exc).__name__}")
|
||||
logger.error(f"Full traceback:\n{error_traceback}")
|
||||
logger.exception("[AgentEntity.run_agent] Agent execution failed.")
|
||||
|
||||
# Create error message
|
||||
error_message = DurableAgentStateMessage.from_chat_message(
|
||||
ChatMessage(
|
||||
role=Role.ASSISTANT, contents=[ErrorContent(message=str(exc), error_code=type(exc).__name__)]
|
||||
)
|
||||
error_message = ChatMessage(
|
||||
role=Role.ASSISTANT, contents=[ErrorContent(message=str(exc), error_code=type(exc).__name__)]
|
||||
)
|
||||
|
||||
error_response = AgentRunResponse(messages=[error_message])
|
||||
|
||||
# Create and store error response in conversation history
|
||||
error_state_response = DurableAgentStateResponse(
|
||||
correlation_id=correlation_id,
|
||||
created_at=datetime.now(tz=timezone.utc),
|
||||
messages=[error_message],
|
||||
is_error=True,
|
||||
)
|
||||
error_state_response = DurableAgentStateResponse.from_run_response(correlation_id, error_response)
|
||||
error_state_response.is_error = True
|
||||
self.state.data.conversation_history.append(error_state_response)
|
||||
|
||||
error_response = AgentResponse(
|
||||
response=f"Error: {exc!s}",
|
||||
message=str(message),
|
||||
thread_id=str(thread_id),
|
||||
status="error",
|
||||
message_count=len(self.state.data.conversation_history),
|
||||
error=str(exc),
|
||||
error_type=type(exc).__name__,
|
||||
)
|
||||
return error_response.to_dict()
|
||||
return error_response
|
||||
|
||||
async def _invoke_agent(
|
||||
self,
|
||||
@@ -432,7 +378,7 @@ def create_agent_entity(
|
||||
request = "" if input_data is None else str(cast(object, input_data))
|
||||
|
||||
result = await entity.run_agent(context, request)
|
||||
context.set_result(result)
|
||||
context.set_result(result.to_dict())
|
||||
|
||||
elif operation == "reset":
|
||||
entity.reset(context)
|
||||
@@ -442,15 +388,13 @@ def create_agent_entity(
|
||||
logger.error("[entity_function] Unknown operation: %s", operation)
|
||||
context.set_result({"error": f"Unknown operation: {operation}"})
|
||||
|
||||
logger.debug("State dict: %s", entity.state.to_dict())
|
||||
context.set_state(entity.state.to_dict())
|
||||
serialized_state = entity.state.to_dict()
|
||||
logger.debug("State dict: %s", serialized_state)
|
||||
context.set_state(serialized_state)
|
||||
logger.info(f"[entity_function] Operation {operation} completed successfully")
|
||||
|
||||
except Exception as exc:
|
||||
import traceback
|
||||
|
||||
logger.error("[entity_function] Error in entity: %s", exc)
|
||||
logger.error(f"[entity_function] Traceback:\n{traceback.format_exc()}")
|
||||
logger.exception("[entity_function] Error executing entity operation %s", exc)
|
||||
context.set_result({"error": str(exc), "status": "error"})
|
||||
|
||||
def entity_function(context: df.DurableEntityContext) -> None:
|
||||
|
||||
@@ -213,7 +213,7 @@ class DurableAgentThread(AgentThread):
|
||||
return thread
|
||||
|
||||
|
||||
def _serialize_response_format(response_format: type[BaseModel] | None) -> Any:
|
||||
def serialize_response_format(response_format: type[BaseModel] | None) -> Any:
|
||||
"""Serialize response format for transport across durable function boundaries."""
|
||||
if response_format is None:
|
||||
return None
|
||||
@@ -339,7 +339,7 @@ class RunRequest:
|
||||
"request_response_format": self.request_response_format,
|
||||
}
|
||||
if self.response_format:
|
||||
result["response_format"] = _serialize_response_format(self.response_format)
|
||||
result["response_format"] = serialize_response_format(self.response_format)
|
||||
if self.thread_id:
|
||||
result["thread_id"] = self.thread_id
|
||||
if self.correlation_id:
|
||||
@@ -362,50 +362,3 @@ class RunRequest:
|
||||
correlation_id=data.get("correlationId"),
|
||||
created_at=data.get("created_at"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentResponse:
|
||||
"""Response from agent execution.
|
||||
|
||||
Attributes:
|
||||
response: The agent's text response (or None for structured responses)
|
||||
message: The original message sent to the agent
|
||||
thread_id: The thread identifier
|
||||
status: Status of the execution (success, error, etc.)
|
||||
message_count: Number of messages in the conversation
|
||||
error: Error message if status is error
|
||||
error_type: Type of error if status is error
|
||||
structured_response: Structured response if response_format was provided
|
||||
"""
|
||||
|
||||
response: str | None
|
||||
message: str
|
||||
thread_id: str | None
|
||||
status: str
|
||||
message_count: int = 0
|
||||
error: str | None = None
|
||||
error_type: str | None = None
|
||||
structured_response: dict[str, Any] | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary for JSON serialization."""
|
||||
result: dict[str, Any] = {
|
||||
"message": self.message,
|
||||
"thread_id": self.thread_id,
|
||||
"status": self.status,
|
||||
"message_count": self.message_count,
|
||||
}
|
||||
|
||||
# Add response or structured_response based on what's available
|
||||
if self.structured_response is not None:
|
||||
result["structured_response"] = self.structured_response
|
||||
elif self.response is not None:
|
||||
result["response"] = self.response
|
||||
|
||||
if self.error:
|
||||
result["error"] = self.error
|
||||
if self.error_type:
|
||||
result["error_type"] = self.error_type
|
||||
|
||||
return result
|
||||
|
||||
@@ -6,21 +6,148 @@ This module provides support for using agents inside Durable Function orchestrat
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from typing import TYPE_CHECKING, Any, TypeAlias, cast
|
||||
|
||||
from agent_framework import AgentProtocol, AgentRunResponseUpdate, AgentThread, ChatMessage, get_logger
|
||||
from agent_framework import (
|
||||
AgentProtocol,
|
||||
AgentRunResponse,
|
||||
AgentRunResponseUpdate,
|
||||
AgentThread,
|
||||
ChatMessage,
|
||||
get_logger,
|
||||
)
|
||||
from azure.durable_functions.models import TaskBase
|
||||
from azure.durable_functions.models.Task import CompoundTask, TaskState
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._models import AgentSessionId, DurableAgentThread, RunRequest
|
||||
|
||||
logger = get_logger("agent_framework.azurefunctions.orchestration")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from azure.durable_functions import DurableOrchestrationContext as _DurableOrchestrationContext
|
||||
CompoundActionConstructor: TypeAlias = Callable[[list[Any]], Any] | None
|
||||
|
||||
AgentOrchestrationContextType: TypeAlias = _DurableOrchestrationContext
|
||||
if TYPE_CHECKING:
|
||||
from azure.durable_functions import DurableOrchestrationContext
|
||||
|
||||
class _TypedCompoundTask(CompoundTask): # type: ignore[misc]
|
||||
_first_error: Any
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tasks: list[TaskBase],
|
||||
compound_action_constructor: CompoundActionConstructor = None,
|
||||
) -> None: ...
|
||||
|
||||
AgentOrchestrationContextType: TypeAlias = DurableOrchestrationContext
|
||||
else:
|
||||
AgentOrchestrationContextType = Any
|
||||
_TypedCompoundTask = CompoundTask
|
||||
|
||||
|
||||
class AgentTask(_TypedCompoundTask):
|
||||
"""A custom Task that wraps entity calls and provides typed AgentRunResponse results.
|
||||
|
||||
This task wraps the underlying entity call task and intercepts its completion
|
||||
to convert the raw result into a typed AgentRunResponse object.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
entity_task: TaskBase,
|
||||
response_format: type[BaseModel] | None,
|
||||
correlation_id: str,
|
||||
):
|
||||
"""Initialize the AgentTask.
|
||||
|
||||
Args:
|
||||
entity_task: The underlying entity call task
|
||||
response_format: Optional Pydantic model for response parsing
|
||||
correlation_id: Correlation ID for logging
|
||||
"""
|
||||
super().__init__([entity_task])
|
||||
self._response_format = response_format
|
||||
self._correlation_id = correlation_id
|
||||
|
||||
# Override action_repr to expose the inner task's action directly
|
||||
# This ensures compatibility with ReplaySchema V3 which expects Action objects.
|
||||
self.action_repr = entity_task.action_repr
|
||||
|
||||
# Also copy the task ID to match the entity task's identity
|
||||
self.id = entity_task.id
|
||||
|
||||
def try_set_value(self, child: TaskBase) -> None:
|
||||
"""Transition the AgentTask to a terminal state and set its value to `AgentRunResponse`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
child : TaskBase
|
||||
The entity call task that just completed
|
||||
"""
|
||||
if child.state is TaskState.SUCCEEDED:
|
||||
# Delegate to parent class for standard completion logic
|
||||
if len(self.pending_tasks) == 0:
|
||||
# Transform the raw result before setting it
|
||||
raw_result = child.result
|
||||
logger.debug(
|
||||
"[AgentTask] Converting raw result for correlation_id %s",
|
||||
self._correlation_id,
|
||||
)
|
||||
|
||||
try:
|
||||
response = self._load_agent_response(raw_result)
|
||||
|
||||
if self._response_format is not None:
|
||||
self._ensure_response_format(
|
||||
self._response_format,
|
||||
self._correlation_id,
|
||||
response,
|
||||
)
|
||||
|
||||
# Set the typed AgentRunResponse as this task's result
|
||||
self.set_value(is_error=False, value=response)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"[AgentTask] Failed to convert result for correlation_id: %s",
|
||||
self._correlation_id,
|
||||
)
|
||||
self.set_value(is_error=True, value=e)
|
||||
else:
|
||||
# If error not handled by the parent, set it explicitly.
|
||||
if self._first_error is None:
|
||||
self._first_error = child.result
|
||||
self.set_value(is_error=True, value=self._first_error)
|
||||
|
||||
def _load_agent_response(self, agent_response: AgentRunResponse | dict[str, Any] | None) -> AgentRunResponse:
|
||||
"""Convert raw payloads into AgentRunResponse instance."""
|
||||
if agent_response is None:
|
||||
raise ValueError("agent_response cannot be None")
|
||||
|
||||
logger.debug("[load_agent_response] Loading agent response of type: %s", type(agent_response))
|
||||
|
||||
if isinstance(agent_response, AgentRunResponse):
|
||||
return agent_response
|
||||
if isinstance(agent_response, dict):
|
||||
logger.debug("[load_agent_response] Converting dict payload using AgentRunResponse.from_dict")
|
||||
return AgentRunResponse.from_dict(agent_response)
|
||||
|
||||
raise TypeError(f"Unsupported type for agent_response: {type(agent_response)}")
|
||||
|
||||
def _ensure_response_format(
|
||||
self,
|
||||
response_format: type[BaseModel] | None,
|
||||
correlation_id: str,
|
||||
response: AgentRunResponse,
|
||||
) -> None:
|
||||
"""Ensure the AgentRunResponse value is parsed into the expected response_format."""
|
||||
if response_format is not None and not isinstance(response.value, response_format):
|
||||
response.try_parse_value(response_format)
|
||||
|
||||
logger.debug(
|
||||
"[DurableAIAgent] Loaded AgentRunResponse.value for correlation_id %s with type: %s",
|
||||
correlation_id,
|
||||
type(response.value).__name__,
|
||||
)
|
||||
|
||||
|
||||
class DurableAIAgent(AgentProtocol):
|
||||
@@ -59,7 +186,7 @@ class DurableAIAgent(AgentProtocol):
|
||||
self._name = agent_name
|
||||
self._display_name = agent_name
|
||||
self._description = f"Durable agent proxy for {agent_name}"
|
||||
logger.debug(f"[DurableAIAgent] Initialized for agent: {agent_name}")
|
||||
logger.debug("[DurableAIAgent] Initialized for agent: %s", agent_name)
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
@@ -81,38 +208,45 @@ class DurableAIAgent(AgentProtocol):
|
||||
"""Get the description of the agent."""
|
||||
return self._description
|
||||
|
||||
def run(
|
||||
# We return an AgentTask here which is a TaskBase subclass.
|
||||
# This is an intentional deviation from AgentProtocol which defines run() as async.
|
||||
# The AgentTask can be yielded in Durable Functions orchestrations and will provide
|
||||
# a typed AgentRunResponse result.
|
||||
def run( # type: ignore[override]
|
||||
self,
|
||||
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
|
||||
*,
|
||||
thread: AgentThread | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any: # TODO(msft-team): Add a wrapper to respond correctly with `AgentRunResponse`
|
||||
"""Execute the agent with messages and return a Task for orchestrations.
|
||||
) -> AgentTask:
|
||||
"""Execute the agent with messages and return an AgentTask for orchestrations.
|
||||
|
||||
This method implements AgentProtocol and returns a Task that can be yielded
|
||||
in Durable Functions orchestrations.
|
||||
This method implements AgentProtocol and returns an AgentTask (subclass of TaskBase)
|
||||
that can be yielded in Durable Functions orchestrations. The task's result will be
|
||||
a typed AgentRunResponse.
|
||||
|
||||
Args:
|
||||
messages: The message(s) to send to the agent
|
||||
thread: Optional agent thread for conversation context
|
||||
**kwargs: Additional arguments (enable_tool_calls, response_format, etc.)
|
||||
response_format: Optional Pydantic model for response parsing
|
||||
**kwargs: Additional arguments (enable_tool_calls)
|
||||
|
||||
Returns:
|
||||
Task that will resolve to the agent response
|
||||
An AgentTask that resolves to an AgentRunResponse when yielded
|
||||
|
||||
Example:
|
||||
@app.orchestration_trigger(context_name="context")
|
||||
def my_orchestration(context):
|
||||
agent = app.get_agent(context, "MyAgent")
|
||||
thread = agent.get_new_thread()
|
||||
result = yield agent.run("Hello", thread=thread)
|
||||
response = yield agent.run("Hello", thread=thread)
|
||||
# response is typed as AgentRunResponse
|
||||
"""
|
||||
message_str = self._normalize_messages(messages)
|
||||
|
||||
# Extract optional parameters from kwargs
|
||||
enable_tool_calls = kwargs.get("enable_tool_calls", True)
|
||||
response_format = kwargs.get("response_format")
|
||||
|
||||
# Get the session ID for the entity
|
||||
if isinstance(thread, DurableAgentThread) and thread.session_id is not None:
|
||||
@@ -122,7 +256,7 @@ class DurableAIAgent(AgentProtocol):
|
||||
# This ensures each call gets its own conversation context
|
||||
session_key = str(self.context.new_uuid())
|
||||
session_id = AgentSessionId(name=self.agent_name, key=session_key)
|
||||
logger.warning(f"[DurableAIAgent] No thread provided, created unique session_id: {session_id}")
|
||||
logger.warning("[DurableAIAgent] No thread provided, created unique session_id: %s", session_id)
|
||||
|
||||
# Create entity ID from session ID
|
||||
entity_id = session_id.to_entity_id()
|
||||
@@ -130,6 +264,12 @@ class DurableAIAgent(AgentProtocol):
|
||||
# Generate a deterministic correlation ID for this call
|
||||
# This is required by the entity and must be unique per call
|
||||
correlation_id = str(self.context.new_uuid())
|
||||
logger.debug(
|
||||
"[DurableAIAgent] Using correlation_id: %s for entity_id: %s for session_id: %s",
|
||||
correlation_id,
|
||||
entity_id,
|
||||
session_id,
|
||||
)
|
||||
|
||||
# Prepare the request using RunRequest model
|
||||
run_request = RunRequest(
|
||||
@@ -140,11 +280,24 @@ class DurableAIAgent(AgentProtocol):
|
||||
response_format=response_format,
|
||||
)
|
||||
|
||||
logger.debug(f"[DurableAIAgent] Calling entity {entity_id} with message: {message_str[:100]}...")
|
||||
logger.debug("[DurableAIAgent] Calling entity %s with message: %s", entity_id, message_str[:100])
|
||||
|
||||
# Call the entity and return the Task directly
|
||||
# The orchestration will yield this Task
|
||||
return self.context.call_entity(entity_id, "run_agent", run_request.to_dict())
|
||||
# Call the entity to get the underlying task
|
||||
entity_task = self.context.call_entity(entity_id, "run_agent", run_request.to_dict())
|
||||
|
||||
# Wrap it in an AgentTask that will convert the result to AgentRunResponse
|
||||
agent_task = AgentTask(
|
||||
entity_task=entity_task,
|
||||
response_format=response_format,
|
||||
correlation_id=correlation_id,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"[DurableAIAgent] Created AgentTask for correlation_id %s",
|
||||
correlation_id,
|
||||
)
|
||||
|
||||
return agent_task
|
||||
|
||||
def run_stream(
|
||||
self,
|
||||
@@ -179,7 +332,7 @@ class DurableAIAgent(AgentProtocol):
|
||||
|
||||
thread = DurableAgentThread.from_session_id(session_id, **kwargs)
|
||||
|
||||
logger.debug(f"[DurableAIAgent] Created new thread with session_id: {session_id}")
|
||||
logger.debug("[DurableAIAgent] Created new thread with session_id: %s", session_id)
|
||||
return thread
|
||||
|
||||
def _messages_to_string(self, messages: list[ChatMessage]) -> str:
|
||||
|
||||
@@ -10,7 +10,7 @@ from unittest.mock import ANY, AsyncMock, Mock, patch
|
||||
import azure.durable_functions as df
|
||||
import azure.functions as func
|
||||
import pytest
|
||||
from agent_framework import AgentRunResponse, ChatMessage
|
||||
from agent_framework import AgentRunResponse, ChatMessage, ErrorContent
|
||||
|
||||
from agent_framework_azurefunctions import AgentFunctionApp
|
||||
from agent_framework_azurefunctions._app import WAIT_FOR_RESPONSE_FIELD, WAIT_FOR_RESPONSE_HEADER
|
||||
@@ -343,10 +343,8 @@ class TestAgentEntityOperations:
|
||||
{"message": "Test message", "thread_id": "test-conv-123", "correlationId": "corr-app-entity-1"},
|
||||
)
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert result["response"] == "Test response"
|
||||
assert result["message"] == "Test message"
|
||||
assert result["thread_id"] == "test-conv-123"
|
||||
assert isinstance(result, AgentRunResponse)
|
||||
assert result.text == "Test response"
|
||||
assert entity.state.message_count == 2
|
||||
|
||||
async def test_entity_stores_conversation_history(self) -> None:
|
||||
@@ -591,10 +589,12 @@ class TestErrorHandling:
|
||||
mock_context, {"message": "Test message", "thread_id": "conv-1", "correlationId": "corr-app-error-1"}
|
||||
)
|
||||
|
||||
assert result["status"] == "error"
|
||||
assert "error" in result
|
||||
assert "Agent error" in result["error"]
|
||||
assert result["error_type"] == "Exception"
|
||||
assert isinstance(result, AgentRunResponse)
|
||||
assert len(result.messages) == 1
|
||||
content = result.messages[0].contents[0]
|
||||
assert isinstance(content, ErrorContent)
|
||||
assert "Agent error" in (content.message or "")
|
||||
assert content.error_code == "Exception"
|
||||
|
||||
def test_entity_function_handles_exception(self) -> None:
|
||||
"""Test that the entity function handles exceptions gracefully."""
|
||||
|
||||
@@ -12,7 +12,7 @@ from typing import Any, TypeVar
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from agent_framework import AgentRunResponse, AgentRunResponseUpdate, ChatMessage, Role
|
||||
from agent_framework import AgentRunResponse, AgentRunResponseUpdate, ChatMessage, ErrorContent, Role
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agent_framework_azurefunctions._durable_agent_state import (
|
||||
@@ -133,10 +133,8 @@ class TestAgentEntityRunAgent:
|
||||
assert getattr(sent_message.role, "value", sent_message.role) == "user"
|
||||
|
||||
# Verify result
|
||||
assert result["status"] == "success"
|
||||
assert result["response"] == "Test response"
|
||||
assert result["message"] == "Test message"
|
||||
assert result["thread_id"] == "conv-123"
|
||||
assert isinstance(result, AgentRunResponse)
|
||||
assert result.text == "Test response"
|
||||
|
||||
async def test_run_agent_streaming_callbacks_invoked(self) -> None:
|
||||
"""Ensure streaming updates trigger callbacks and run() is not used."""
|
||||
@@ -168,8 +166,8 @@ class TestAgentEntityRunAgent:
|
||||
},
|
||||
)
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert "Hello" in result.get("response", "")
|
||||
assert isinstance(result, AgentRunResponse)
|
||||
assert "Hello" in result.text
|
||||
assert callback.stream_mock.await_count == len(updates)
|
||||
assert callback.response_mock.await_count == 1
|
||||
mock_agent.run.assert_not_called()
|
||||
@@ -215,8 +213,8 @@ class TestAgentEntityRunAgent:
|
||||
},
|
||||
)
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert result.get("response") == "Final response"
|
||||
assert isinstance(result, AgentRunResponse)
|
||||
assert result.text == "Final response"
|
||||
assert callback.stream_mock.await_count == 0
|
||||
assert callback.response_mock.await_count == 1
|
||||
|
||||
@@ -294,44 +292,6 @@ class TestAgentEntityRunAgent:
|
||||
mock_context, {"message": "Message", "thread_id": None, "correlationId": "corr-entity-5"}
|
||||
)
|
||||
|
||||
async def test_run_agent_handles_response_without_text_attribute(self) -> None:
|
||||
"""Test that run_agent handles responses without a text attribute."""
|
||||
mock_agent = Mock()
|
||||
|
||||
class NoTextResponse(AgentRunResponse):
|
||||
@property
|
||||
def text(self) -> str: # type: ignore[override]
|
||||
raise AttributeError("text attribute missing")
|
||||
|
||||
mock_response = NoTextResponse(messages=[ChatMessage(role="assistant", text="ignored")])
|
||||
mock_agent.run = AsyncMock(return_value=mock_response)
|
||||
|
||||
entity = AgentEntity(mock_agent)
|
||||
mock_context = Mock()
|
||||
|
||||
result = await entity.run_agent(
|
||||
mock_context, {"message": "Message", "thread_id": "conv-1", "correlationId": "corr-entity-6"}
|
||||
)
|
||||
|
||||
# Should handle gracefully
|
||||
assert result["status"] == "success"
|
||||
assert result["response"] == "Error extracting response"
|
||||
|
||||
async def test_run_agent_handles_none_response_text(self) -> None:
|
||||
"""Test that run_agent handles responses with None text."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response(None))
|
||||
|
||||
entity = AgentEntity(mock_agent)
|
||||
mock_context = Mock()
|
||||
|
||||
result = await entity.run_agent(
|
||||
mock_context, {"message": "Message", "thread_id": "conv-1", "correlationId": "corr-entity-7"}
|
||||
)
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert result["response"] == "No response"
|
||||
|
||||
async def test_run_agent_multiple_conversations(self) -> None:
|
||||
"""Test that run_agent maintains history across multiple messages."""
|
||||
mock_agent = Mock()
|
||||
@@ -621,10 +581,12 @@ class TestErrorHandling:
|
||||
mock_context, {"message": "Message", "thread_id": "conv-1", "correlationId": "corr-entity-error-1"}
|
||||
)
|
||||
|
||||
assert result["status"] == "error"
|
||||
assert "error" in result
|
||||
assert "Agent failed" in result["error"]
|
||||
assert result["error_type"] == "Exception"
|
||||
assert isinstance(result, AgentRunResponse)
|
||||
assert len(result.messages) == 1
|
||||
content = result.messages[0].contents[0]
|
||||
assert isinstance(content, ErrorContent)
|
||||
assert "Agent failed" in (content.message or "")
|
||||
assert content.error_code == "Exception"
|
||||
|
||||
async def test_run_agent_handles_value_error(self) -> None:
|
||||
"""Test that run_agent handles ValueError instances."""
|
||||
@@ -638,9 +600,12 @@ class TestErrorHandling:
|
||||
mock_context, {"message": "Message", "thread_id": "conv-1", "correlationId": "corr-entity-error-2"}
|
||||
)
|
||||
|
||||
assert result["status"] == "error"
|
||||
assert result["error_type"] == "ValueError"
|
||||
assert "Invalid input" in result["error"]
|
||||
assert isinstance(result, AgentRunResponse)
|
||||
assert len(result.messages) == 1
|
||||
content = result.messages[0].contents[0]
|
||||
assert isinstance(content, ErrorContent)
|
||||
assert content.error_code == "ValueError"
|
||||
assert "Invalid input" in str(content.message)
|
||||
|
||||
async def test_run_agent_handles_timeout_error(self) -> None:
|
||||
"""Test that run_agent handles TimeoutError instances."""
|
||||
@@ -654,8 +619,11 @@ class TestErrorHandling:
|
||||
mock_context, {"message": "Message", "thread_id": "conv-1", "correlationId": "corr-entity-error-3"}
|
||||
)
|
||||
|
||||
assert result["status"] == "error"
|
||||
assert result["error_type"] == "TimeoutError"
|
||||
assert isinstance(result, AgentRunResponse)
|
||||
assert len(result.messages) == 1
|
||||
content = result.messages[0].contents[0]
|
||||
assert isinstance(content, ErrorContent)
|
||||
assert content.error_code == "TimeoutError"
|
||||
|
||||
def test_entity_function_handles_exception_in_operation(self) -> None:
|
||||
"""Test that the entity function handles exceptions gracefully."""
|
||||
@@ -690,9 +658,10 @@ class TestErrorHandling:
|
||||
)
|
||||
|
||||
# Even on error, message info should be preserved
|
||||
assert result["message"] == "Test message"
|
||||
assert result["thread_id"] == "conv-123"
|
||||
assert result["status"] == "error"
|
||||
assert isinstance(result, AgentRunResponse)
|
||||
assert len(result.messages) == 1
|
||||
content = result.messages[0].contents[0]
|
||||
assert isinstance(content, ErrorContent)
|
||||
|
||||
|
||||
class TestConversationHistory:
|
||||
@@ -800,10 +769,8 @@ class TestRunRequestSupport:
|
||||
|
||||
result = await entity.run_agent(mock_context, request)
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert result["response"] == "Response"
|
||||
assert result["message"] == "Test message"
|
||||
assert result["thread_id"] == "conv-123"
|
||||
assert isinstance(result, AgentRunResponse)
|
||||
assert result.text == "Response"
|
||||
|
||||
async def test_run_agent_with_dict_request(self) -> None:
|
||||
"""Test run_agent with a dictionary request."""
|
||||
@@ -823,9 +790,8 @@ class TestRunRequestSupport:
|
||||
|
||||
result = await entity.run_agent(mock_context, request_dict)
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert result["message"] == "Test message"
|
||||
assert result["thread_id"] == "conv-456"
|
||||
assert isinstance(result, AgentRunResponse)
|
||||
assert result.text == "Response"
|
||||
|
||||
async def test_run_agent_with_string_raises_without_correlation(self) -> None:
|
||||
"""Test that run_agent rejects legacy string input without correlation ID."""
|
||||
@@ -879,10 +845,9 @@ class TestRunRequestSupport:
|
||||
|
||||
result = await entity.run_agent(mock_context, request)
|
||||
|
||||
assert result["status"] == "success"
|
||||
# Should have structured_response
|
||||
if "structured_response" in result:
|
||||
assert result["structured_response"]["answer"] == 42
|
||||
assert isinstance(result, AgentRunResponse)
|
||||
assert result.text == '{"answer": 42}'
|
||||
assert result.value is None
|
||||
|
||||
async def test_run_agent_disable_tool_calls(self) -> None:
|
||||
"""Test run_agent with tool calls disabled."""
|
||||
@@ -898,7 +863,7 @@ class TestRunRequestSupport:
|
||||
|
||||
result = await entity.run_agent(mock_context, request)
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert isinstance(result, AgentRunResponse)
|
||||
# Agent should have been called (tool disabling is framework-dependent)
|
||||
mock_agent.run.assert_called_once()
|
||||
|
||||
@@ -925,8 +890,24 @@ class TestRunRequestSupport:
|
||||
# Verify result was set
|
||||
assert mock_context.set_result.called
|
||||
result = mock_context.set_result.call_args[0][0]
|
||||
assert result["status"] == "success"
|
||||
assert result["message"] == "Test message"
|
||||
assert isinstance(result, dict)
|
||||
|
||||
# Check if messages are present
|
||||
assert "messages" in result
|
||||
assert len(result["messages"]) > 0
|
||||
message = result["messages"][0]
|
||||
|
||||
# Check for text in various possible locations
|
||||
text_found = False
|
||||
if "text" in message and message["text"] == "Response":
|
||||
text_found = True
|
||||
elif "contents" in message:
|
||||
for content in message["contents"]:
|
||||
if isinstance(content, dict) and content.get("text") == "Response":
|
||||
text_found = True
|
||||
break
|
||||
|
||||
assert text_found, f"Response text not found in message: {message}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -7,7 +7,7 @@ import pytest
|
||||
from agent_framework import Role
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agent_framework_azurefunctions._models import AgentResponse, AgentSessionId, RunRequest
|
||||
from agent_framework_azurefunctions._models import AgentSessionId, RunRequest
|
||||
|
||||
|
||||
class ModuleStructuredResponse(BaseModel):
|
||||
@@ -337,107 +337,6 @@ class TestRunRequest:
|
||||
assert restored.thread_id == original.thread_id
|
||||
|
||||
|
||||
class TestAgentResponse:
|
||||
"""Test suite for AgentResponse."""
|
||||
|
||||
def test_init_with_required_fields(self) -> None:
|
||||
"""Test AgentResponse initialization with required fields."""
|
||||
response = AgentResponse(
|
||||
response="Test response", message="Test message", thread_id="thread-123", status="success"
|
||||
)
|
||||
|
||||
assert response.response == "Test response"
|
||||
assert response.message == "Test message"
|
||||
assert response.thread_id == "thread-123"
|
||||
assert response.status == "success"
|
||||
assert response.message_count == 0
|
||||
assert response.error is None
|
||||
assert response.error_type is None
|
||||
assert response.structured_response is None
|
||||
|
||||
def test_init_with_all_fields(self) -> None:
|
||||
"""Test AgentResponse initialization with all fields."""
|
||||
structured = {"answer": "42"}
|
||||
response = AgentResponse(
|
||||
response=None,
|
||||
message="What is the answer?",
|
||||
thread_id="thread-456",
|
||||
status="success",
|
||||
message_count=5,
|
||||
error=None,
|
||||
error_type=None,
|
||||
structured_response=structured,
|
||||
)
|
||||
|
||||
assert response.response is None
|
||||
assert response.structured_response == structured
|
||||
assert response.message_count == 5
|
||||
|
||||
def test_to_dict_with_text_response(self) -> None:
|
||||
"""Test to_dict with text response."""
|
||||
response = AgentResponse(
|
||||
response="Text response", message="Message", thread_id="thread-1", status="success", message_count=3
|
||||
)
|
||||
data = response.to_dict()
|
||||
|
||||
assert data["response"] == "Text response"
|
||||
assert data["message"] == "Message"
|
||||
assert data["thread_id"] == "thread-1"
|
||||
assert data["status"] == "success"
|
||||
assert data["message_count"] == 3
|
||||
assert "structured_response" not in data
|
||||
assert "error" not in data
|
||||
assert "error_type" not in data
|
||||
|
||||
def test_to_dict_with_structured_response(self) -> None:
|
||||
"""Test to_dict with structured response."""
|
||||
structured = {"answer": 42, "confidence": 0.95}
|
||||
response = AgentResponse(
|
||||
response=None,
|
||||
message="Question",
|
||||
thread_id="thread-2",
|
||||
status="success",
|
||||
structured_response=structured,
|
||||
)
|
||||
data = response.to_dict()
|
||||
|
||||
assert data["structured_response"] == structured
|
||||
assert "response" not in data
|
||||
|
||||
def test_to_dict_with_error(self) -> None:
|
||||
"""Test to_dict with error."""
|
||||
response = AgentResponse(
|
||||
response=None,
|
||||
message="Failed message",
|
||||
thread_id="thread-3",
|
||||
status="error",
|
||||
error="Something went wrong",
|
||||
error_type="ValueError",
|
||||
)
|
||||
data = response.to_dict()
|
||||
|
||||
assert data["status"] == "error"
|
||||
assert data["error"] == "Something went wrong"
|
||||
assert data["error_type"] == "ValueError"
|
||||
|
||||
def test_to_dict_prefers_structured_over_text(self) -> None:
|
||||
"""Test to_dict prefers structured_response over response."""
|
||||
structured = {"result": "structured"}
|
||||
response = AgentResponse(
|
||||
response="Text response",
|
||||
message="Message",
|
||||
thread_id="thread-4",
|
||||
status="success",
|
||||
structured_response=structured,
|
||||
)
|
||||
data = response.to_dict()
|
||||
|
||||
assert "structured_response" in data
|
||||
assert data["structured_response"] == structured
|
||||
# Text response should not be included when structured is present
|
||||
assert "response" not in data
|
||||
|
||||
|
||||
class TestModelIntegration:
|
||||
"""Test suite for integration between models."""
|
||||
|
||||
@@ -450,21 +349,6 @@ class TestModelIntegration:
|
||||
assert request.thread_id == str(session_id)
|
||||
assert request.thread_id.startswith("@AgentEntity@")
|
||||
|
||||
def test_response_from_run_request(self) -> None:
|
||||
"""Test creating AgentResponse from RunRequest."""
|
||||
request = RunRequest(message="What is 2+2?", thread_id="thread-123", role=Role.USER)
|
||||
|
||||
response = AgentResponse(
|
||||
response="4",
|
||||
message=request.message,
|
||||
thread_id=request.thread_id,
|
||||
status="success",
|
||||
message_count=1,
|
||||
)
|
||||
|
||||
assert response.message == request.message
|
||||
assert response.thread_id == request.thread_id
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "--tb=short"])
|
||||
|
||||
@@ -6,10 +6,12 @@ from typing import Any
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from agent_framework import AgentThread
|
||||
from agent_framework import AgentRunResponse, AgentThread, ChatMessage
|
||||
from azure.durable_functions.models.Task import TaskBase, TaskState
|
||||
|
||||
from agent_framework_azurefunctions import AgentFunctionApp, DurableAIAgent
|
||||
from agent_framework_azurefunctions._models import AgentSessionId, DurableAgentThread
|
||||
from agent_framework_azurefunctions._orchestration import AgentTask
|
||||
|
||||
|
||||
def _app_with_registered_agents(*agent_names: str) -> AgentFunctionApp:
|
||||
@@ -21,6 +23,169 @@ def _app_with_registered_agents(*agent_names: str) -> AgentFunctionApp:
|
||||
return app
|
||||
|
||||
|
||||
class _FakeTask(TaskBase):
|
||||
"""Concrete TaskBase for testing AgentTask wiring."""
|
||||
|
||||
def __init__(self, task_id: int = 1):
|
||||
super().__init__(task_id, [])
|
||||
self._set_is_scheduled(False)
|
||||
self.action_repr = []
|
||||
self.state = TaskState.RUNNING
|
||||
|
||||
|
||||
def _create_entity_task(task_id: int = 1) -> TaskBase:
|
||||
"""Create a minimal TaskBase instance for AgentTask tests."""
|
||||
return _FakeTask(task_id)
|
||||
|
||||
|
||||
class TestAgentResponseHelpers:
|
||||
"""Tests for helper utilities that prepare AgentRunResponse values."""
|
||||
|
||||
@staticmethod
|
||||
def _create_agent_task() -> AgentTask:
|
||||
entity_task = _create_entity_task()
|
||||
return AgentTask(entity_task, None, "correlation-id")
|
||||
|
||||
def test_load_agent_response_from_instance(self) -> None:
|
||||
task = self._create_agent_task()
|
||||
response = AgentRunResponse(messages=[ChatMessage(role="assistant", text='{"foo": "bar"}')])
|
||||
|
||||
loaded = task._load_agent_response(response)
|
||||
|
||||
assert loaded is response
|
||||
assert loaded.value is None
|
||||
|
||||
def test_load_agent_response_from_serialized(self) -> None:
|
||||
task = self._create_agent_task()
|
||||
serialized = AgentRunResponse(messages=[ChatMessage(role="assistant", text="structured")]).to_dict()
|
||||
serialized["value"] = {"answer": 42}
|
||||
|
||||
loaded = task._load_agent_response(serialized)
|
||||
|
||||
assert loaded is not None
|
||||
assert loaded.value == {"answer": 42}
|
||||
loaded_dict = loaded.to_dict()
|
||||
assert loaded_dict["type"] == "agent_run_response"
|
||||
|
||||
def test_load_agent_response_rejects_none(self) -> None:
|
||||
task = self._create_agent_task()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
task._load_agent_response(None)
|
||||
|
||||
def test_load_agent_response_rejects_unsupported_type(self) -> None:
|
||||
task = self._create_agent_task()
|
||||
|
||||
with pytest.raises(TypeError, match="Unsupported type"):
|
||||
task._load_agent_response(["invalid", "list"]) # type: ignore[arg-type]
|
||||
|
||||
def test_try_set_value_success(self) -> None:
|
||||
"""Test try_set_value correctly processes successful task completion."""
|
||||
entity_task = _create_entity_task()
|
||||
task = AgentTask(entity_task, None, "correlation-id")
|
||||
|
||||
# Simulate successful entity task completion
|
||||
entity_task.state = TaskState.SUCCEEDED
|
||||
entity_task.result = AgentRunResponse(messages=[ChatMessage(role="assistant", text="Test response")]).to_dict()
|
||||
|
||||
# Clear pending_tasks to simulate that parent has processed the child
|
||||
task.pending_tasks.clear()
|
||||
|
||||
# Call try_set_value
|
||||
task.try_set_value(entity_task)
|
||||
|
||||
# Verify task completed successfully with AgentRunResponse
|
||||
assert task.state == TaskState.SUCCEEDED
|
||||
assert isinstance(task.result, AgentRunResponse)
|
||||
assert task.result.text == "Test response"
|
||||
|
||||
def test_try_set_value_failure(self) -> None:
|
||||
"""Test try_set_value correctly handles failed task completion."""
|
||||
entity_task = _create_entity_task()
|
||||
task = AgentTask(entity_task, None, "correlation-id")
|
||||
|
||||
# Simulate failed entity task
|
||||
entity_task.state = TaskState.FAILED
|
||||
entity_task.result = Exception("Entity call failed")
|
||||
|
||||
# Call try_set_value
|
||||
task.try_set_value(entity_task)
|
||||
|
||||
# Verify task failed with the error
|
||||
assert task.state == TaskState.FAILED
|
||||
assert isinstance(task.result, Exception)
|
||||
assert str(task.result) == "Entity call failed"
|
||||
|
||||
def test_try_set_value_with_response_format(self) -> None:
|
||||
"""Test try_set_value parses structured output when response_format is provided."""
|
||||
from pydantic import BaseModel
|
||||
|
||||
class TestSchema(BaseModel):
|
||||
answer: str
|
||||
|
||||
entity_task = _create_entity_task()
|
||||
task = AgentTask(entity_task, TestSchema, "correlation-id")
|
||||
|
||||
# Simulate successful entity task with JSON response
|
||||
entity_task.state = TaskState.SUCCEEDED
|
||||
entity_task.result = AgentRunResponse(
|
||||
messages=[ChatMessage(role="assistant", text='{"answer": "42"}')]
|
||||
).to_dict()
|
||||
|
||||
# Clear pending_tasks to simulate that parent has processed the child
|
||||
task.pending_tasks.clear()
|
||||
|
||||
# Call try_set_value
|
||||
task.try_set_value(entity_task)
|
||||
|
||||
# Verify task completed and value was parsed
|
||||
assert task.state == TaskState.SUCCEEDED
|
||||
assert isinstance(task.result, AgentRunResponse)
|
||||
assert isinstance(task.result.value, TestSchema)
|
||||
assert task.result.value.answer == "42"
|
||||
|
||||
def test_ensure_response_format_parses_value(self) -> None:
|
||||
"""Test _ensure_response_format correctly parses response value."""
|
||||
from pydantic import BaseModel
|
||||
|
||||
class SampleSchema(BaseModel):
|
||||
name: str
|
||||
|
||||
task = self._create_agent_task()
|
||||
response = AgentRunResponse(messages=[ChatMessage(role="assistant", text='{"name": "test"}')])
|
||||
|
||||
# Value should be None initially
|
||||
assert response.value is None
|
||||
|
||||
# Parse the value
|
||||
task._ensure_response_format(SampleSchema, "test-correlation", response)
|
||||
|
||||
# Value should now be parsed
|
||||
assert isinstance(response.value, SampleSchema)
|
||||
assert response.value.name == "test"
|
||||
|
||||
def test_ensure_response_format_skips_if_already_parsed(self) -> None:
|
||||
"""Test _ensure_response_format does not re-parse if value already matches format."""
|
||||
from pydantic import BaseModel
|
||||
|
||||
class SampleSchema(BaseModel):
|
||||
name: str
|
||||
|
||||
task = self._create_agent_task()
|
||||
existing_value = SampleSchema(name="existing")
|
||||
response = AgentRunResponse(
|
||||
messages=[ChatMessage(role="assistant", text='{"name": "new"}')],
|
||||
value=existing_value,
|
||||
)
|
||||
|
||||
# Call _ensure_response_format
|
||||
task._ensure_response_format(SampleSchema, "test-correlation", response)
|
||||
|
||||
# Value should remain unchanged (not re-parsed)
|
||||
assert response.value is existing_value
|
||||
assert response.value.name == "existing"
|
||||
|
||||
|
||||
class TestDurableAIAgent:
|
||||
"""Test suite for DurableAIAgent wrapper."""
|
||||
|
||||
@@ -111,22 +276,19 @@ class TestDurableAIAgent:
|
||||
mock_context.instance_id = "test-instance-001"
|
||||
mock_context.new_uuid = Mock(side_effect=["thread-guid", "correlation-guid"])
|
||||
|
||||
# Mock call_entity to return a Task-like object
|
||||
mock_task = Mock()
|
||||
mock_task._is_scheduled = False # Task attribute that orchestration checks
|
||||
|
||||
mock_context.call_entity = Mock(return_value=mock_task)
|
||||
entity_task = _create_entity_task()
|
||||
mock_context.call_entity = Mock(return_value=entity_task)
|
||||
|
||||
agent = DurableAIAgent(mock_context, "TestAgent")
|
||||
|
||||
# Create thread
|
||||
thread = agent.get_new_thread()
|
||||
|
||||
# Call run() - it should return the Task directly
|
||||
# Call run() - returns AgentTask directly
|
||||
task = agent.run(messages="Test message", thread=thread, enable_tool_calls=True)
|
||||
|
||||
# Verify run() returns the Task from call_entity
|
||||
assert task == mock_task
|
||||
assert isinstance(task, AgentTask)
|
||||
assert task.children[0] == entity_task
|
||||
|
||||
# Verify call_entity was called with correct parameters
|
||||
assert mock_context.call_entity.called
|
||||
@@ -145,19 +307,18 @@ class TestDurableAIAgent:
|
||||
"""Test that run() works without explicit thread (creates unique session key)."""
|
||||
mock_context = Mock()
|
||||
mock_context.instance_id = "test-instance-002"
|
||||
# Two calls to new_uuid: one for session_key, one for correlationId
|
||||
mock_context.new_uuid = Mock(side_effect=["auto-generated-guid", "correlation-guid"])
|
||||
|
||||
mock_task = Mock()
|
||||
mock_task._is_scheduled = False
|
||||
mock_context.call_entity = Mock(return_value=mock_task)
|
||||
entity_task = _create_entity_task()
|
||||
mock_context.call_entity = Mock(return_value=entity_task)
|
||||
|
||||
agent = DurableAIAgent(mock_context, "TestAgent")
|
||||
|
||||
# Call without thread
|
||||
task = agent.run(messages="Test message")
|
||||
|
||||
assert task == mock_task
|
||||
assert isinstance(task, AgentTask)
|
||||
assert task.children[0] == entity_task
|
||||
|
||||
# Verify the entity ID uses the auto-generated GUID with dafx- prefix
|
||||
call_args = mock_context.call_entity.call_args
|
||||
@@ -172,9 +333,8 @@ class TestDurableAIAgent:
|
||||
mock_context = Mock()
|
||||
mock_context.instance_id = "test-instance-003"
|
||||
|
||||
mock_task = Mock()
|
||||
mock_task._is_scheduled = False
|
||||
mock_context.call_entity = Mock(return_value=mock_task)
|
||||
entity_task = _create_entity_task()
|
||||
mock_context.call_entity = Mock(return_value=entity_task)
|
||||
|
||||
agent = DurableAIAgent(mock_context, "TestAgent")
|
||||
|
||||
@@ -188,7 +348,8 @@ class TestDurableAIAgent:
|
||||
|
||||
task = agent.run(messages="Test message", thread=thread, response_format=SampleSchema)
|
||||
|
||||
assert task == mock_task
|
||||
assert isinstance(task, AgentTask)
|
||||
assert task.children[0] == entity_task
|
||||
|
||||
# Verify schema was passed in the call_entity arguments
|
||||
call_args = mock_context.call_entity.call_args
|
||||
@@ -221,8 +382,8 @@ class TestDurableAIAgent:
|
||||
|
||||
mock_context = Mock()
|
||||
mock_context.new_uuid = Mock(side_effect=["thread-guid", "correlation-guid"])
|
||||
mock_task = Mock()
|
||||
mock_context.call_entity = Mock(return_value=mock_task)
|
||||
entity_task = _create_entity_task()
|
||||
mock_context.call_entity = Mock(return_value=entity_task)
|
||||
|
||||
agent = DurableAIAgent(mock_context, "TestAgent")
|
||||
thread = agent.get_new_thread()
|
||||
@@ -231,7 +392,8 @@ class TestDurableAIAgent:
|
||||
msg = ChatMessage(role="user", text="Hello")
|
||||
task = agent.run(messages=msg, thread=thread)
|
||||
|
||||
assert task == mock_task
|
||||
assert isinstance(task, AgentTask)
|
||||
assert task.children[0] == entity_task
|
||||
|
||||
# Verify message was converted to string
|
||||
call_args = mock_context.call_entity.call_args
|
||||
@@ -255,7 +417,7 @@ class TestDurableAIAgent:
|
||||
|
||||
mock_context = Mock()
|
||||
mock_context.new_uuid = Mock(return_value="test-guid-789")
|
||||
mock_context.call_entity = Mock(return_value=Mock())
|
||||
mock_context.call_entity = Mock(return_value=_create_entity_task())
|
||||
|
||||
agent = DurableAIAgent(mock_context, "WriterAgent")
|
||||
thread = agent.get_new_thread()
|
||||
@@ -314,13 +476,9 @@ class TestOrchestrationIntegration:
|
||||
# Track entity calls
|
||||
entity_calls: list[dict[str, Any]] = []
|
||||
|
||||
def mock_call_entity_side_effect(entity_id: Any, operation: str, input_data: dict[str, Any]) -> Mock:
|
||||
def mock_call_entity_side_effect(entity_id: Any, operation: str, input_data: dict[str, Any]) -> TaskBase:
|
||||
entity_calls.append({"entity_id": str(entity_id), "operation": operation, "input": input_data})
|
||||
|
||||
# Return a mock Task
|
||||
mock_task = Mock()
|
||||
mock_task._is_scheduled = False
|
||||
return mock_task
|
||||
return _create_entity_task()
|
||||
|
||||
mock_context.call_entity = Mock(side_effect=mock_call_entity_side_effect)
|
||||
|
||||
@@ -330,13 +488,13 @@ class TestOrchestrationIntegration:
|
||||
# Create thread
|
||||
thread = agent.get_new_thread()
|
||||
|
||||
# First call - returns Task
|
||||
# First call - returns AgentTask
|
||||
task1 = agent.run("Write something", thread=thread)
|
||||
assert hasattr(task1, "_is_scheduled")
|
||||
assert isinstance(task1, AgentTask)
|
||||
|
||||
# Second call - returns Task
|
||||
# Second call - returns AgentTask
|
||||
task2 = agent.run("Improve: something", thread=thread)
|
||||
assert hasattr(task2, "_is_scheduled")
|
||||
assert isinstance(task2, AgentTask)
|
||||
|
||||
# Verify both calls used the same entity (same session key)
|
||||
assert len(entity_calls) == 2
|
||||
@@ -356,11 +514,9 @@ class TestOrchestrationIntegration:
|
||||
|
||||
entity_calls: list[str] = []
|
||||
|
||||
def mock_call_entity_side_effect(entity_id: Any, operation: str, input_data: dict[str, Any]) -> Mock:
|
||||
def mock_call_entity_side_effect(entity_id: Any, operation: str, input_data: dict[str, Any]) -> TaskBase:
|
||||
entity_calls.append(str(entity_id))
|
||||
mock_task = Mock()
|
||||
mock_task._is_scheduled = False
|
||||
return mock_task
|
||||
return _create_entity_task()
|
||||
|
||||
mock_context.call_entity = Mock(side_effect=mock_call_entity_side_effect)
|
||||
|
||||
@@ -371,12 +527,12 @@ class TestOrchestrationIntegration:
|
||||
writer_thread = writer.get_new_thread()
|
||||
editor_thread = editor.get_new_thread()
|
||||
|
||||
# Call both agents - returns Tasks
|
||||
# Call both agents - returns AgentTasks
|
||||
writer_task = writer.run("Write", thread=writer_thread)
|
||||
editor_task = editor.run("Edit", thread=editor_thread)
|
||||
|
||||
assert hasattr(writer_task, "_is_scheduled")
|
||||
assert hasattr(editor_task, "_is_scheduled")
|
||||
assert isinstance(writer_task, AgentTask)
|
||||
assert isinstance(editor_task, AgentTask)
|
||||
|
||||
# Verify different entity IDs were used
|
||||
assert len(entity_calls) == 2
|
||||
|
||||
+2
-2
@@ -57,7 +57,7 @@ def single_agent_orchestration(context: DurableOrchestrationContext):
|
||||
|
||||
improved_prompt = (
|
||||
"Improve this further while keeping it under 25 words: "
|
||||
f"{initial.get('response', '').strip()}"
|
||||
f"{initial.text}"
|
||||
)
|
||||
|
||||
refined = yield writer.run(
|
||||
@@ -65,7 +65,7 @@ def single_agent_orchestration(context: DurableOrchestrationContext):
|
||||
thread=writer_thread,
|
||||
)
|
||||
|
||||
return refined.get("response", "")
|
||||
return refined.text
|
||||
|
||||
|
||||
# 5. HTTP endpoint to kick off the orchestration and return the status query URI.
|
||||
|
||||
+10
-4
@@ -10,8 +10,9 @@ Prerequisites: configure `AZURE_OPENAI_ENDPOINT`, `AZURE_OPENAI_CHAT_DEPLOYMENT_
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from agent_framework import AgentRunResponse
|
||||
import azure.functions as func
|
||||
from agent_framework.azure import AgentFunctionApp, AzureOpenAIChatClient
|
||||
from azure.durable_functions import DurableOrchestrationClient, DurableOrchestrationContext
|
||||
@@ -63,14 +64,19 @@ def multi_agent_concurrent_orchestration(context: DurableOrchestrationContext):
|
||||
physicist_thread = physicist.get_new_thread()
|
||||
chemist_thread = chemist.get_new_thread()
|
||||
|
||||
# Create tasks from agent.run() calls
|
||||
physicist_task = physicist.run(messages=str(prompt), thread=physicist_thread)
|
||||
chemist_task = chemist.run(messages=str(prompt), thread=chemist_thread)
|
||||
|
||||
results = yield context.task_all([physicist_task, chemist_task])
|
||||
# Execute both tasks concurrently using task_all
|
||||
task_results = yield context.task_all([physicist_task, chemist_task])
|
||||
|
||||
physicist_result = cast(AgentRunResponse, task_results[0])
|
||||
chemist_result = cast(AgentRunResponse, task_results[1])
|
||||
|
||||
return {
|
||||
"physicist": results[0].get("response", ""),
|
||||
"chemist": results[1].get("response", ""),
|
||||
"physicist": physicist_result.text,
|
||||
"chemist": chemist_result.text,
|
||||
}
|
||||
|
||||
|
||||
|
||||
+2
-20
@@ -102,7 +102,7 @@ def spam_detection_orchestration(context: DurableOrchestrationContext):
|
||||
response_format=SpamDetectionResult,
|
||||
)
|
||||
|
||||
spam_result = cast(SpamDetectionResult, _coerce_structured(spam_result_raw, SpamDetectionResult))
|
||||
spam_result = cast(SpamDetectionResult, spam_result_raw.value)
|
||||
|
||||
if spam_result.is_spam:
|
||||
result = yield context.call_activity("handle_spam_email", spam_result.reason)
|
||||
@@ -123,7 +123,7 @@ def spam_detection_orchestration(context: DurableOrchestrationContext):
|
||||
response_format=EmailResponse,
|
||||
)
|
||||
|
||||
email_result = cast(EmailResponse, _coerce_structured(email_result_raw, EmailResponse))
|
||||
email_result = cast(EmailResponse, email_result_raw.value)
|
||||
|
||||
result = yield context.call_activity("send_email", email_result.response)
|
||||
return result
|
||||
@@ -231,24 +231,6 @@ def _build_status_url(request_url: str, instance_id: str, *, route: str) -> str:
|
||||
return f"{base_url}/api/{route}/status/{instance_id}"
|
||||
|
||||
|
||||
def _coerce_structured(result: Mapping[str, Any], model: type[BaseModel]) -> BaseModel:
|
||||
structured = result.get("structured_response") if isinstance(result, Mapping) else None
|
||||
if structured is not None:
|
||||
return model.model_validate(structured)
|
||||
|
||||
response_text = result.get("response") if isinstance(result, Mapping) else None
|
||||
if isinstance(response_text, str) and response_text.strip():
|
||||
try:
|
||||
parsed = json.loads(response_text)
|
||||
if isinstance(parsed, Mapping):
|
||||
return model.model_validate(parsed)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("[ConditionalOrchestration] Failed to parse agent JSON response; raising error.")
|
||||
|
||||
# If parsing failed, raise to surface the issue to the caller.
|
||||
raise ValueError(f"Agent response could not be parsed as {model.__name__}.")
|
||||
|
||||
|
||||
"""
|
||||
Expected response from `POST /api/spamdetection/run`:
|
||||
|
||||
|
||||
+12
-19
@@ -100,7 +100,12 @@ def content_generation_hitl_orchestration(context: DurableOrchestrationContext):
|
||||
thread=writer_thread,
|
||||
response_format=GeneratedContent,
|
||||
)
|
||||
content = _coerce_generated_content(initial_raw)
|
||||
|
||||
content = initial_raw.value
|
||||
logger.info("Type of content after extraction: %s", type(content))
|
||||
|
||||
if content is None or not isinstance(content, GeneratedContent):
|
||||
raise ValueError("Agent returned no content after extraction.")
|
||||
|
||||
attempt = 0
|
||||
while attempt < payload.max_review_attempts:
|
||||
@@ -142,7 +147,12 @@ def content_generation_hitl_orchestration(context: DurableOrchestrationContext):
|
||||
thread=writer_thread,
|
||||
response_format=GeneratedContent,
|
||||
)
|
||||
content = _coerce_generated_content(rewritten_raw)
|
||||
|
||||
rewritten_value = rewritten_raw.value
|
||||
if rewritten_value is None or not isinstance(rewritten_value, GeneratedContent):
|
||||
raise ValueError("Agent returned no content after rewrite.")
|
||||
|
||||
content = rewritten_value
|
||||
else:
|
||||
context.set_custom_status(
|
||||
f"Human approval timed out after {payload.approval_timeout_hours} hour(s). Treating as rejection."
|
||||
@@ -317,23 +327,6 @@ def _build_status_url(request_url: str, instance_id: str, *, route: str) -> str:
|
||||
return f"{base_url}/api/{route}/status/{instance_id}"
|
||||
|
||||
|
||||
def _coerce_generated_content(result: Mapping[str, Any]) -> GeneratedContent:
|
||||
structured = result.get("structured_response") if isinstance(result, Mapping) else None
|
||||
if structured is not None:
|
||||
return GeneratedContent.model_validate(structured)
|
||||
|
||||
response_text = result.get("response") if isinstance(result, Mapping) else None
|
||||
if isinstance(response_text, str) and response_text.strip():
|
||||
try:
|
||||
parsed = json.loads(response_text)
|
||||
if isinstance(parsed, Mapping):
|
||||
return GeneratedContent.model_validate(parsed)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("[HITL] Failed to parse agent JSON response; falling back to defaults.")
|
||||
|
||||
raise ValueError("Agent response could not be parsed as GeneratedContent.")
|
||||
|
||||
|
||||
def _parse_human_approval(raw: Any) -> HumanApproval:
|
||||
if isinstance(raw, Mapping):
|
||||
return HumanApproval.model_validate(raw)
|
||||
|
||||
Reference in New Issue
Block a user