mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: [BREAKING] Moved to a single get_response and run API (#3379)
* WIP * big update to new ResponseStream model * fixed tests and typing * fixed tests and typing * fixed tools typevar import * fix * mypy fix * mypy fixes and some cleanup * fix missing quoted names * and client * fix imports agui * fix anthropic override * fix agui * fix ag ui * fix import * fix anthropic types * fix mypy * refactoring * updated typing * fix 3.11 * fixes * redid layering of chat clients and agents * redid layering of chat clients and agents * Fix lint, type, and test issues after rebase - Add @overload decorators to AgentProtocol.run() for type compatibility - Add missing docstring params (middleware, function_invocation_configuration) - Fix TODO format (TD002) by adding author tags - Fix broken observability tests from upstream: - Replace non-existent use_instrumentation with direct instantiation - Replace non-existent use_agent_instrumentation with AgentTelemetryLayer mixin - Fix get_streaming_response to use get_response(stream=True) - Add AgentInitializationError import - Update streaming exception tests to match actual behavior * Fix AgentExecutionException import error in test_agents.py - Replace non-existent AgentExecutionException with AgentRunException * Fix test import and asyncio deprecation issues - Add 'tests' to pythonpath in ag-ui pyproject.toml for utils_test_ag_ui import - Replace deprecated asyncio.get_event_loop().run_until_complete with asyncio.run * Fix azure-ai test failures - Update _prepare_options patching to use correct class path - Fix test_to_azure_ai_agent_tools_web_search_missing_connection to clear env vars * Convert ag-ui utils_test_ag_ui.py to conftest.py - Move test utilities to conftest.py for proper pytest discovery - Update all test imports to use conftest instead of utils_test_ag_ui - Remove old utils_test_ag_ui.py file - Revert pythonpath change in pyproject.toml * fix: use relative imports for ag-ui test utilities * fix agui * Rename Bare*Client to Raw*Client and BaseChatClient - Renamed BareChatClient to BaseChatClient (abstract base class) - Renamed BareOpenAIChatClient to RawOpenAIChatClient - Renamed BareOpenAIResponsesClient to RawOpenAIResponsesClient - Renamed BareAzureAIClient to RawAzureAIClient - Added warning docstrings to Raw* classes about layer ordering - Updated README in samples/getting_started/agents/custom with layer docs - Added test for span ordering with function calling * Fix layer ordering: FunctionInvocationLayer before ChatTelemetryLayer This ensures each inner LLM call gets its own telemetry span, resulting in the correct span sequence: chat -> execute_tool -> chat Updated all production clients and test mocks to use correct ordering: - ChatMiddlewareLayer (first) - FunctionInvocationLayer (second) - ChatTelemetryLayer (third) - BaseChatClient/Raw...Client (fourth) * Remove run_stream usage * Fix conversation_id propagation * Python: Add BaseAgent implementation for Claude Agent SDK (#3509) * Added ClaudeAgent implementation * Updated streaming logic * Small updates * Small update * Fixes * Small fix * Naming improvements * Updated imports * Addressed comments * Updated package versions * Update Claude agent connector layering * fix test and plugin * Store function middleware in invocation layer * Fix telemetry streaming and ag-ui tests * Remove legacy ag-ui tests folder * updates * Remove terminate flag from FunctionInvocationContext, use MiddlewareTermination instead - Remove terminate attribute from FunctionInvocationContext - Add result attribute to MiddlewareTermination to carry function results - FunctionMiddlewarePipeline.execute() now lets MiddlewareTermination propagate - _auto_invoke_function captures context.result in exception before re-raising - _try_execute_function_calls catches MiddlewareTermination and sets should_terminate - Fix handoff middleware to append to chat_client.function_middleware directly - Update tests to use raise MiddlewareTermination instead of context.terminate - Add middleware flow documentation in samples/concepts/tools/README.md - Fix ag-ui to use FunctionMiddlewarePipeline instead of removed create_function_middleware_pipeline * fix: remove references to removed terminate flag in purview tests, add type ignore * fix: move _test_utils.py from package to test folder * fix: call get_final_response() to trigger context provider notification in streaming test * fix: correct broken links in tools README * docs: clarify default middleware behavior in summary table * fix: ensure inner stream result hooks are called when using map()/from_awaitable() * Fix mypy type errors * Address PR review comments on observability.py - Remove TODO comment about unconsumed streams, add explanatory note instead - Remove redundant _close_span cleanup hook (already called in _finalize_stream) - Clarify behavior: cleanup hooks run after stream iteration, if stream is not consumed the span remains open until garbage collected * Remove gen_ai.client.operation.duration from span attributes Duration is a metrics-only attribute per OpenTelemetry semantic conventions. It should be recorded to the histogram but not set as a span attribute. * Remove duration from _get_response_attributes, pass directly to _capture_response Duration is a metrics-only attribute. It's now passed directly to _capture_response instead of being included in the attributes dict that gets set on the span. * Remove redundant _close_span cleanup hook in AgentTelemetryLayer _finalize_stream already calls _close_span() in its finally block, so adding it as a separate cleanup hook is redundant. * Use weakref.finalize to close span when stream is garbage collected If a user creates a streaming response but never consumes it, the cleanup hooks won't run. Now we register a weak reference finalizer that will close the span when the stream object is garbage collected, ensuring spans don't leak in this scenario. * Fix _get_finalizers_from_stream to use _result_hooks attribute Renamed function to _get_result_hooks_from_stream and fixed it to look for the _result_hooks attribute which is the correct name in ResponseStream class. * Add missing asyncio import in test_request_info_mixin.py * Fix leftover merge conflict marker in image_generation sample * Update integration tests * Fix integration tests: increase max_iterations from 1 to 2 Tests with tool_choice options require at least 2 iterations: 1. First iteration to get function call and execute the tool 2. Second iteration to get the final text response With max_iterations=1, streaming tests would return early with only the function call/result but no final text content. * Fix duplicate function call error in conversation-based APIs When using conversation_id (for Responses/Assistants APIs), the server already has the function call message from the previous response. We should only send the new function result message, not all messages including the function call which would cause a duplicate ID error. Fix: When conversation_id is set, only send the last message (the tool result) instead of all response.messages. * Add regression test for conversation_id propagation between tool iterations Port test from PR #3664 with updates for new streaming API pattern. Tests that conversation_id is properly updated in options dict during function invocation loop iterations. * Fix tool_choice=required to return after tool execution When tool_choice is 'required', the user's intent is to force exactly one tool call. After the tool executes, return immediately with the function call and result - don't continue to call the model again. This fixes integration tests that were failing with empty text responses because with tool_choice=required, the model would keep returning function calls instead of text. Also adds regression tests for: - conversation_id propagation between tool iterations (from PR #3664) - tool_choice=required returns after tool execution * Document tool_choice behavior in tools README - Add table explaining tool_choice values (auto, none, required) - Explain why tool_choice=required returns immediately after tool execution - Add code example showing the difference between required and auto - Update flow diagram to show the early return path for tool_choice=required * Fix tool_choice=None behavior - don't default to 'auto' Remove the hardcoded default of 'auto' for tool_choice in ChatAgent init. When tool_choice is not specified (None), it will now not be sent to the API, allowing the API's default behavior to be used. Users who want tool_choice='auto' can still explicitly set it either in default_options or at runtime. Fixes #3585 * Fix tool_choice=none should not remove tools In OpenAI Assistants client, tools were not being sent when tool_choice='none'. This was incorrect - tool_choice='none' means the model won't call tools, but tools should still be available in the request (they may be used later in the conversation). Fixes #3585 * Add test for tool_choice=none preserving tools Adds a regression test to ensure that when tool_choice='none' is set but tools are provided, the tools are still sent to the API. This verifies the fix for #3585. * Fix tool_choice=none should not remove tools in all clients Apply the same fix to OpenAI Responses client and Azure AI client: - OpenAI Responses: Remove else block that popped tool_choice/parallel_tool_calls - Azure AI: Remove tool_choice != 'none' check when adding tools When tool_choice='none', the model won't call tools, but tools should still be sent to the API so they're available for future turns. Also update README to clarify tool_choice=required supports multiple tools. Fixes #3585 * Keep tool_choice even when tools is None Move tool_choice processing outside of the 'if tools' block in OpenAI Responses client so tool_choice is sent to the API even when no tools are provided. * Update test to match new parallel_tool_calls behavior Changed test_prepare_options_removes_parallel_tool_calls_when_no_tools to test_prepare_options_preserves_parallel_tool_calls_when_no_tools to reflect that parallel_tool_calls is now preserved even when no tools are present, consistent with the tool_choice behavior. * Fix ChatMessage API and Role enum usage after rebase - Update ChatMessage instantiation to use keyword args (role=, text=, contents=) - Fix Role enum comparisons to use .value for string comparison - Add created_at to AgentResponse in error handling - Fix AgentResponse.from_updates -> from_agent_run_response_updates - Fix DurableAgentStateMessage.from_chat_message to convert Role enum to string - Add Role import where needed * Fix additional ChatMessage API and method name changes - Fix ChatMessage usage in workflow files (use text= instead of contents= for strings) - Fix AgentResponse.from_updates -> from_agent_run_response_updates in workflow files - Fix test files for ChatMessage and Role enum usage * Fix remaining ChatMessage API usage in test files * Fix more ChatMessage and Role API changes in source and test files - Fix ChatMessage in _magentic.py replan method - Fix Role enum comparison in test assertions - Fix remaining test files with old ChatMessage syntax * Fix ChatMessage and Role API changes across packages - Add Role import where missing - Fix ChatMessage signature: positional args to keyword args (role=, text=, contents=) - Fix Role enum comparisons: .role.value instead of .role string - Fix FinishReason enum usage in ag-ui event converters - Rename AgentResponse.from_updates to from_agent_run_response_updates in ag-ui Fixes API compatibility after Types API Review improvements merge * Fix ChatMessage and Role API changes in github_copilot tests * Fix ChatMessage and Role API changes in redis and github_copilot packages - Fix redis provider: Role enum comparison using .value - Fix redis tests: ChatMessage signature and Role comparisons - Fix github_copilot tests: ChatMessage signature and Role comparisons - Update docstring examples in redis chat message store * Fix ChatMessage and Role API changes in devui package - Fix executor: ChatMessage signature change - Fix conversations: Role enum to string conversion in two places - Fix tests: ChatMessage signatures and Role comparisons * Fix ChatMessage and Role API changes in a2a and lab packages - Fix a2a tests: Role comparisons and ChatMessage signatures - Fix lab tau2 source: Role enum comparison in flip_messages, log_messages, sliding_window - Fix lab tau2 tests: ChatMessage signatures and Role comparisons * Remove duplicate test files from ag-ui/tests (tests are in ag_ui_tests) * Fix ChatMessage and Role API changes across packages After rebasing on upstream/main which merged PR #3647 (Types API Review improvements), fix all packages to use the new API: - ChatMessage: Use keyword args (role=, text=, contents=) instead of positional args - Role: Compare using .value attribute since it's now an enum Packages fixed: - ag-ui: Fixed Role value extraction bugs in _message_adapters.py - anthropic: Fixed ChatMessage and Role comparisons in tests - azure-ai: Fixed Role comparison in _client.py - azure-ai-search: Fixed ChatMessage and Role in source/tests - bedrock: Fixed ChatMessage signatures in tests - chatkit: Fixed ChatMessage and Role in source/tests - copilotstudio: Fixed ChatMessage and Role in tests - declarative: Fixed ChatMessage in _executors_agents.py - mem0: Fixed ChatMessage and Role in source/tests - purview: Fixed ChatMessage in source/tests * Fix mypy errors for ChatMessage and Role API changes - durabletask: Use str() fallback in role value extraction - core: Fix ChatMessage in _orchestrator_helpers.py to use keyword args - core: Add type ignore for _conversation_state.py contents deserialization - ag-ui: Fix type ignore comments (call-overload instead of arg-type) - azure-ai-search: Fix get_role_value type hint to accept Any - lab: Move get_role_value to module level with Any type hint * Improve CI test timeout configuration - Increase job timeout from 10 to 15 minutes - Reduce per-test timeout to 60s (was 900s/300s) - Add --timeout_method thread for better timeout handling - Add --timeout-verbose to see which tests are slow - Reduce retries from 3 to 2 and delay from 10s to 5s This ensures individual test timeouts are shorter than the job timeout, providing better visibility when tests hang. With 60s timeout and 2 retries, worst case per test is ~180s. * Fix ChatMessage API usage in docstrings and source - Fix ChatMessage positional args in docstrings: _serialization.py, _threads.py, _middleware.py - Fix ChatMessage in tau2 runner.py - Fix role comparison in _orchestrator_helpers.py to use .value - Fix role comparison in _group_chat.py docstring example - Fix role assertions in test_durable_entities.py to use .value * Revert tool_choice/parallel_tool_calls changes - must be removed when no tools OpenAI API requires tool_choice and parallel_tool_calls to only be present when tools are specified. Restored the logic that removes these options when there are no tools. - Restored check in _chat_client.py to remove tool_choice and parallel_tool_calls when no tools present - Restored same logic in _responses_client.py - Reverted test to expect the correct behavior * fixed issue in tests * fix: resolve merge conflict markers in ag-ui tests * fix: restructure ag-ui tests and fix Role/FinishReason to use string types * fix: streaming function invocation and middleware termination - Refactor streaming function invocation to use get_final_response() on inner streams - Fix MiddlewareTermination to accept result parameter for passing results - Fix _AutoHandoffMiddleware to use MiddlewareTermination instead of context.terminate - Fix AgentMiddlewareLayer.run() to properly forward function/chat middleware - Remove duplicate middleware registration in AgentMiddlewareLayer.__init__ - Fix exception handling in _auto_invoke_function to properly capture termination - Fix mypy errors in core package - Update tests to use stream=True parameter for unified run API * fix all tests command * Refactor integration tests to use pytest fixtures - Merge testutils.py into conftest.py for azurefunctions integration tests - Merge dt_testutils.py into conftest.py for durabletask integration tests - Convert all integration tests to use fixtures instead of direct imports (fixes ModuleNotFoundError with --import-mode=importlib) - Add sample_helper fixture for azurefunctions tests - Add agent_client_factory and orchestration_helper fixtures for durabletask - Integration tests now skip with descriptive messages when services unavailable - Restructure devui tests into tests/devui/ with proper conftest.py - Add test organization guidelines to CODING_STANDARD.md - Remove __init__.py from test directories per pytest best practices * Fix pytest_collection_modifyitems to only skip integration tests The hook was skipping all tests in the test session, not just integration tests. Now it only skips items in the integration_tests directory. * Fix mem0 tests failing on Python 3.13 Use patch.object on the imported module instead of @patch with string path to ensure the mock takes effect regardless of import timing. * fix mem0 * another attempt for mem0 * fix for mem0 * fix mem0 * Increase worker initialization wait time in durabletask tests Increase from 2 to 8 seconds to allow time for: - Python startup and module imports - Azure OpenAI client creation - Agent registration with DTS worker - Worker connection to DTS This helps prevent test failures in CI where the first tests may run before the worker is fully ready to process requests. * Fix streaming test to use ResponseStream with finalizer The _consume_stream method now expects a ResponseStream that can provide a final AgentResponse via get_final_response(). Update the test to use ResponseStream with AgentResponse.from_updates as the finalizer. * Fix MockToolCallingAgent to use new ResponseStream API and update samples * small updates to run_stream to run * fix sub workflow * temp fix for az func test --------- Co-authored-by: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
d1205896a1
commit
3dc59c83b5
@@ -817,7 +817,7 @@ class DurableAgentStateMessage:
|
||||
]
|
||||
|
||||
return DurableAgentStateMessage(
|
||||
role=chat_message.role,
|
||||
role=chat_message.role if hasattr(chat_message.role, "value") else str(chat_message.role),
|
||||
contents=contents_list,
|
||||
author_name=chat_message.author_name,
|
||||
extension_data=dict(chat_message.additional_properties) if chat_message.additional_properties else None,
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from collections.abc import AsyncIterable
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, cast
|
||||
|
||||
from agent_framework import (
|
||||
@@ -14,6 +14,7 @@ from agent_framework import (
|
||||
AgentResponseUpdate,
|
||||
ChatMessage,
|
||||
Content,
|
||||
ResponseStream,
|
||||
get_logger,
|
||||
)
|
||||
from durabletask.entities import DurableEntity
|
||||
@@ -177,7 +178,10 @@ class AgentEntity:
|
||||
error_message = ChatMessage(
|
||||
role="assistant", contents=[Content.from_error(message=str(exc), error_code=type(exc).__name__)]
|
||||
)
|
||||
error_response = AgentResponse(messages=[error_message])
|
||||
error_response = AgentResponse(
|
||||
messages=[error_message],
|
||||
created_at=datetime.now(tz=timezone.utc).isoformat(),
|
||||
)
|
||||
|
||||
error_state_response = DurableAgentStateResponse.from_run_response(correlation_id, error_response)
|
||||
error_state_response.is_error = True
|
||||
@@ -202,40 +206,47 @@ class AgentEntity:
|
||||
request_message=request_message,
|
||||
)
|
||||
|
||||
run_stream_callable = getattr(self.agent, "run_stream", None)
|
||||
if callable(run_stream_callable):
|
||||
try:
|
||||
stream_candidate = run_stream_callable(**run_kwargs)
|
||||
if inspect.isawaitable(stream_candidate):
|
||||
stream_candidate = await stream_candidate
|
||||
run_callable = getattr(self.agent, "run", None)
|
||||
if run_callable is None or not callable(run_callable):
|
||||
raise AttributeError("Agent does not implement run() method")
|
||||
|
||||
return await self._consume_stream(
|
||||
stream=cast(AsyncIterable[AgentResponseUpdate], stream_candidate),
|
||||
callback_context=callback_context,
|
||||
)
|
||||
except TypeError as type_error:
|
||||
if "__aiter__" not in str(type_error):
|
||||
raise
|
||||
logger.debug(
|
||||
"run_stream returned a non-async result; falling back to run(): %s",
|
||||
type_error,
|
||||
)
|
||||
except Exception as stream_error:
|
||||
logger.warning(
|
||||
"run_stream failed; falling back to run(): %s",
|
||||
stream_error,
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
logger.debug("Agent does not expose run_stream; falling back to run().")
|
||||
# Try streaming first with run(stream=True)
|
||||
try:
|
||||
stream_candidate = run_callable(stream=True, **run_kwargs)
|
||||
if inspect.isawaitable(stream_candidate):
|
||||
stream_candidate = await stream_candidate
|
||||
|
||||
agent_run_response = await self._invoke_non_stream(run_kwargs)
|
||||
return await self._consume_stream(
|
||||
stream=stream_candidate, # type: ignore[arg-type]
|
||||
callback_context=callback_context,
|
||||
)
|
||||
except TypeError as type_error:
|
||||
if "__aiter__" not in str(type_error) and "stream" not in str(type_error):
|
||||
raise
|
||||
logger.debug(
|
||||
"run(stream=True) returned a non-async result; falling back to run(): %s",
|
||||
type_error,
|
||||
)
|
||||
except Exception as stream_error:
|
||||
logger.warning(
|
||||
"run(stream=True) failed; falling back to run(): %s",
|
||||
stream_error,
|
||||
exc_info=True,
|
||||
)
|
||||
agent_run_response = run_callable(**run_kwargs)
|
||||
if inspect.isawaitable(agent_run_response):
|
||||
agent_run_response = await agent_run_response
|
||||
|
||||
if not isinstance(agent_run_response, AgentResponse):
|
||||
raise TypeError(
|
||||
f"Agent run() must return an AgentResponse instance; received {type(agent_run_response).__name__}"
|
||||
)
|
||||
await self._notify_final_response(agent_run_response, callback_context)
|
||||
return agent_run_response
|
||||
|
||||
async def _consume_stream(
|
||||
self,
|
||||
stream: AsyncIterable[AgentResponseUpdate],
|
||||
stream: ResponseStream[AgentResponseUpdate, AgentResponse],
|
||||
callback_context: AgentCallbackContext | None = None,
|
||||
) -> AgentResponse:
|
||||
"""Consume streaming responses and build the final AgentResponse."""
|
||||
@@ -245,30 +256,11 @@ class AgentEntity:
|
||||
updates.append(update)
|
||||
await self._notify_stream_update(update, callback_context)
|
||||
|
||||
if updates:
|
||||
response = AgentResponse.from_updates(updates)
|
||||
else:
|
||||
logger.debug("[AgentEntity] No streaming updates received; creating empty response")
|
||||
response = AgentResponse(messages=[])
|
||||
response = await stream.get_final_response()
|
||||
|
||||
await self._notify_final_response(response, callback_context)
|
||||
return response
|
||||
|
||||
async def _invoke_non_stream(self, run_kwargs: dict[str, Any]) -> AgentResponse:
|
||||
"""Invoke the agent without streaming support."""
|
||||
run_callable = getattr(self.agent, "run", None)
|
||||
if run_callable is None or not callable(run_callable):
|
||||
raise AttributeError("Agent does not implement run() method")
|
||||
|
||||
result = run_callable(**run_kwargs)
|
||||
if inspect.isawaitable(result):
|
||||
result = await result
|
||||
|
||||
if not isinstance(result, AgentResponse):
|
||||
raise TypeError(f"Agent run() must return an AgentResponse instance; received {type(result).__name__}")
|
||||
|
||||
return result
|
||||
|
||||
async def _notify_stream_update(
|
||||
self,
|
||||
update: AgentResponseUpdate,
|
||||
|
||||
@@ -10,10 +10,9 @@ The actual execution is delegated to the context-specific providers.
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any, Generic, TypeVar
|
||||
from typing import Any, Generic, Literal, TypeVar
|
||||
|
||||
from agent_framework import AgentProtocol, AgentResponseUpdate, AgentThread, ChatMessage
|
||||
from agent_framework import AgentProtocol, AgentThread, ChatMessage
|
||||
|
||||
from ._executors import DurableAgentExecutor
|
||||
from ._models import DurableAgentThread
|
||||
@@ -89,6 +88,7 @@ class DurableAIAgent(AgentProtocol, Generic[TaskT]):
|
||||
self,
|
||||
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
|
||||
*,
|
||||
stream: Literal[False] = False,
|
||||
thread: AgentThread | None = None,
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> TaskT:
|
||||
@@ -96,6 +96,8 @@ class DurableAIAgent(AgentProtocol, Generic[TaskT]):
|
||||
|
||||
Args:
|
||||
messages: The message(s) to send to the agent
|
||||
stream: Whether to use streaming for the response (must be False)
|
||||
DurableAgents do not support streaming mode.
|
||||
thread: Optional agent thread for conversation context
|
||||
options: Optional options dictionary. Supported keys include
|
||||
``response_format``, ``enable_tool_calls``, and ``wait_for_response``.
|
||||
@@ -115,6 +117,8 @@ class DurableAIAgent(AgentProtocol, Generic[TaskT]):
|
||||
Raises:
|
||||
ValueError: If wait_for_response=False is used in an unsupported context
|
||||
"""
|
||||
if stream is not False:
|
||||
raise ValueError("DurableAIAgent does not support streaming mode (stream must be False)")
|
||||
message_str = self._normalize_messages(messages)
|
||||
|
||||
run_request = self._executor.get_run_request(
|
||||
@@ -128,25 +132,6 @@ class DurableAIAgent(AgentProtocol, Generic[TaskT]):
|
||||
thread=thread,
|
||||
)
|
||||
|
||||
def run_stream( # type: ignore[override]
|
||||
self,
|
||||
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
|
||||
*,
|
||||
thread: AgentThread | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[AgentResponseUpdate]:
|
||||
"""Run the agent with streaming (not supported for durable agents).
|
||||
|
||||
Args:
|
||||
messages: The message(s) to send to the agent
|
||||
thread: Optional agent thread for conversation context
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Streaming is not supported for durable agents
|
||||
"""
|
||||
raise NotImplementedError("Streaming is not supported for durable agents")
|
||||
|
||||
def get_new_thread(self, **kwargs: Any) -> DurableAgentThread:
|
||||
"""Create a new agent thread via the provider."""
|
||||
return self._executor.get_new_thread(self.name, **kwargs)
|
||||
|
||||
@@ -45,6 +45,7 @@ environments = [
|
||||
fallback-version = "0.0.0"
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = 'tests'
|
||||
pythonpath = ["tests/integration_tests"]
|
||||
addopts = "-ra -q -r fEX"
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
"""Pytest configuration and fixtures for durabletask integration tests."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
@@ -11,14 +13,15 @@ import uuid
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import pytest
|
||||
import redis.asyncio as aioredis
|
||||
from dotenv import load_dotenv
|
||||
from durabletask.azuremanaged.client import DurableTaskSchedulerClient
|
||||
from durabletask.client import OrchestrationStatus
|
||||
|
||||
# Add the integration_tests directory to the path so testutils can be imported
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
from agent_framework_durabletask import DurableAIAgentClient
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv(Path(__file__).parent / ".env")
|
||||
@@ -27,6 +30,11 @@ load_dotenv(Path(__file__).parent / ".env")
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Environment and Service Checks
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _get_dts_endpoint() -> str:
|
||||
"""Get the DTS endpoint from environment or use default."""
|
||||
return os.getenv("ENDPOINT", "http://localhost:8080")
|
||||
@@ -36,13 +44,13 @@ def _check_dts_available(endpoint: str | None = None) -> bool:
|
||||
"""Check if DTS emulator is available at the given endpoint."""
|
||||
try:
|
||||
resolved_endpoint: str = _get_dts_endpoint() if endpoint is None else endpoint
|
||||
DurableTaskSchedulerClient(
|
||||
host_address=resolved_endpoint,
|
||||
secure_channel=False,
|
||||
taskhub="test",
|
||||
token_credential=None,
|
||||
)
|
||||
return True
|
||||
parsed = urlparse(resolved_endpoint)
|
||||
host = parsed.hostname or "localhost"
|
||||
port = parsed.port or 8080
|
||||
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
sock.settimeout(2)
|
||||
return sock.connect_ex((host, port)) == 0
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@@ -66,6 +74,207 @@ def _check_redis_available() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Client Factory Functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def create_dts_client(endpoint: str, taskhub: str) -> DurableTaskSchedulerClient:
|
||||
"""Create a DurableTaskSchedulerClient with common configuration.
|
||||
|
||||
Args:
|
||||
endpoint: The DTS endpoint address
|
||||
taskhub: The task hub name
|
||||
|
||||
Returns:
|
||||
A configured DurableTaskSchedulerClient instance
|
||||
"""
|
||||
return DurableTaskSchedulerClient(
|
||||
host_address=endpoint,
|
||||
secure_channel=False,
|
||||
taskhub=taskhub,
|
||||
token_credential=None,
|
||||
)
|
||||
|
||||
|
||||
def create_agent_client(
|
||||
endpoint: str,
|
||||
taskhub: str,
|
||||
max_poll_retries: int = 90,
|
||||
) -> tuple[DurableTaskSchedulerClient, DurableAIAgentClient]:
|
||||
"""Create a DurableAIAgentClient with the underlying DTS client.
|
||||
|
||||
Args:
|
||||
endpoint: The DTS endpoint address
|
||||
taskhub: The task hub name
|
||||
max_poll_retries: Max poll retries for the agent client
|
||||
|
||||
Returns:
|
||||
A tuple of (DurableTaskSchedulerClient, DurableAIAgentClient)
|
||||
"""
|
||||
dts_client = create_dts_client(endpoint, taskhub)
|
||||
agent_client = DurableAIAgentClient(dts_client, max_poll_retries=max_poll_retries)
|
||||
return dts_client, agent_client
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Orchestration Helper Class
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class OrchestrationHelper:
|
||||
"""Helper class for orchestration-related test operations."""
|
||||
|
||||
def __init__(self, dts_client: DurableTaskSchedulerClient):
|
||||
"""Initialize the orchestration helper.
|
||||
|
||||
Args:
|
||||
dts_client: The DurableTaskSchedulerClient instance to use
|
||||
"""
|
||||
self.client = dts_client
|
||||
|
||||
def wait_for_orchestration(
|
||||
self,
|
||||
instance_id: str,
|
||||
timeout: float = 60.0,
|
||||
) -> Any:
|
||||
"""Wait for an orchestration to complete.
|
||||
|
||||
Args:
|
||||
instance_id: The orchestration instance ID
|
||||
timeout: Maximum time to wait in seconds
|
||||
|
||||
Returns:
|
||||
The final OrchestrationMetadata
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the orchestration doesn't complete within timeout
|
||||
RuntimeError: If the orchestration fails
|
||||
"""
|
||||
# Use the built-in wait_for_orchestration_completion method
|
||||
metadata = self.client.wait_for_orchestration_completion(
|
||||
instance_id=instance_id,
|
||||
timeout=int(timeout),
|
||||
)
|
||||
|
||||
if metadata is None:
|
||||
raise TimeoutError(f"Orchestration {instance_id} did not complete within {timeout} seconds")
|
||||
|
||||
# Check if failed or terminated
|
||||
if metadata.runtime_status == OrchestrationStatus.FAILED:
|
||||
raise RuntimeError(f"Orchestration {instance_id} failed: {metadata.serialized_custom_status}")
|
||||
if metadata.runtime_status == OrchestrationStatus.TERMINATED:
|
||||
raise RuntimeError(f"Orchestration {instance_id} was terminated")
|
||||
|
||||
return metadata
|
||||
|
||||
def wait_for_orchestration_with_output(
|
||||
self,
|
||||
instance_id: str,
|
||||
timeout: float = 60.0,
|
||||
) -> tuple[Any, Any]:
|
||||
"""Wait for an orchestration to complete and return its output.
|
||||
|
||||
Args:
|
||||
instance_id: The orchestration instance ID
|
||||
timeout: Maximum time to wait in seconds
|
||||
|
||||
Returns:
|
||||
A tuple of (OrchestrationMetadata, output)
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the orchestration doesn't complete within timeout
|
||||
RuntimeError: If the orchestration fails
|
||||
"""
|
||||
metadata = self.wait_for_orchestration(instance_id, timeout)
|
||||
|
||||
# The output should be available in the metadata
|
||||
return metadata, metadata.serialized_output
|
||||
|
||||
def get_orchestration_status(self, instance_id: str) -> Any | None:
|
||||
"""Get the current status of an orchestration.
|
||||
|
||||
Args:
|
||||
instance_id: The orchestration instance ID
|
||||
|
||||
Returns:
|
||||
The OrchestrationMetadata or None if not found
|
||||
"""
|
||||
try:
|
||||
# Try to wait with a short timeout to get current status
|
||||
return self.client.wait_for_orchestration_completion(
|
||||
instance_id=instance_id,
|
||||
timeout=1, # Very short timeout, just checking status
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def raise_event(
|
||||
self,
|
||||
instance_id: str,
|
||||
event_name: str,
|
||||
event_data: Any = None,
|
||||
) -> None:
|
||||
"""Raise an external event to an orchestration.
|
||||
|
||||
Args:
|
||||
instance_id: The orchestration instance ID
|
||||
event_name: The name of the event
|
||||
event_data: The event data payload
|
||||
"""
|
||||
self.client.raise_orchestration_event(instance_id, event_name, data=event_data)
|
||||
|
||||
def wait_for_notification(self, instance_id: str, timeout_seconds: int = 30) -> bool:
|
||||
"""Wait for the orchestration to reach a notification point.
|
||||
|
||||
Polls the orchestration status until it appears to be waiting for approval.
|
||||
|
||||
Args:
|
||||
instance_id: The orchestration instance ID
|
||||
timeout_seconds: Maximum time to wait
|
||||
|
||||
Returns:
|
||||
True if notification detected, False if timeout
|
||||
"""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout_seconds:
|
||||
try:
|
||||
metadata = self.client.get_orchestration_state(
|
||||
instance_id=instance_id,
|
||||
)
|
||||
|
||||
if metadata:
|
||||
# Check if we're waiting for approval by examining custom status
|
||||
if metadata.serialized_custom_status:
|
||||
try:
|
||||
custom_status = json.loads(metadata.serialized_custom_status)
|
||||
# Handle both string and dict custom status
|
||||
status_str = custom_status if isinstance(custom_status, str) else str(custom_status)
|
||||
if status_str.lower().startswith("requesting human feedback"):
|
||||
return True
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
# If it's not JSON, treat as plain string
|
||||
if metadata.serialized_custom_status.lower().startswith("requesting human feedback"):
|
||||
return True
|
||||
|
||||
# Check for terminal states
|
||||
if metadata.runtime_status.name == "COMPLETED" or metadata.runtime_status.name == "FAILED":
|
||||
return False
|
||||
except Exception:
|
||||
# Silently ignore transient errors during polling (e.g., network issues, service unavailable).
|
||||
# The loop will retry until timeout, allowing the service to recover.
|
||||
pass
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Pytest Configuration
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def pytest_configure(config: pytest.Config) -> None:
|
||||
"""Register custom markers."""
|
||||
config.addinivalue_line("markers", "integration_test: mark test as integration test")
|
||||
@@ -109,6 +318,11 @@ def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item
|
||||
item.add_marker(skip_redis)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Pytest Fixtures
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def dts_endpoint() -> str:
|
||||
"""Get the DTS endpoint from environment or use default."""
|
||||
@@ -149,8 +363,7 @@ def worker_process(
|
||||
unique_taskhub: str,
|
||||
request: pytest.FixtureRequest,
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
"""
|
||||
Start a worker process for the current test module by running the sample worker.py.
|
||||
"""Start a worker process for the current test module by running the sample worker.py.
|
||||
|
||||
This fixture:
|
||||
1. Determines which sample to run from @pytest.mark.sample()
|
||||
@@ -205,7 +418,15 @@ def worker_process(
|
||||
pytest.fail(f"Failed to start worker subprocess: {e}")
|
||||
|
||||
# Wait for worker to initialize
|
||||
time.sleep(2)
|
||||
# The worker needs time to:
|
||||
# 1. Start Python and import modules
|
||||
# 2. Create Azure OpenAI clients
|
||||
# 3. Register agents with the DTS worker
|
||||
# 4. Connect to DTS and be ready to receive signals
|
||||
#
|
||||
# We use a generous wait time because CI environments can be slow,
|
||||
# and the first test that runs depends on the worker being fully ready.
|
||||
time.sleep(8)
|
||||
|
||||
# Check if process is still running
|
||||
if process.poll() is not None:
|
||||
@@ -232,3 +453,33 @@ def worker_process(
|
||||
process.wait()
|
||||
except Exception as e:
|
||||
logging.warning(f"Error during worker process cleanup: {e}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def orchestration_helper(worker_process: dict[str, Any]) -> OrchestrationHelper:
|
||||
"""Create an OrchestrationHelper for the current test module."""
|
||||
dts_client = create_dts_client(worker_process["endpoint"], worker_process["taskhub"])
|
||||
return OrchestrationHelper(dts_client)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def agent_client_factory(worker_process: dict[str, Any]) -> type:
|
||||
"""Return a factory class for creating agent clients.
|
||||
|
||||
Usage in tests:
|
||||
def test_example(self, agent_client_factory):
|
||||
dts_client, agent_client = agent_client_factory.create(max_poll_retries=90)
|
||||
"""
|
||||
|
||||
class AgentClientFactory:
|
||||
"""Factory for creating DTS and Agent client pairs."""
|
||||
|
||||
endpoint = worker_process["endpoint"]
|
||||
taskhub = worker_process["taskhub"]
|
||||
|
||||
@classmethod
|
||||
def create(cls, max_poll_retries: int = 90) -> tuple[DurableTaskSchedulerClient, DurableAIAgentClient]:
|
||||
"""Create a DTS client and Agent client pair."""
|
||||
return create_agent_client(cls.endpoint, cls.taskhub, max_poll_retries)
|
||||
|
||||
return AgentClientFactory
|
||||
|
||||
@@ -1,205 +0,0 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Test utilities for durabletask integration tests."""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from durabletask.azuremanaged.client import DurableTaskSchedulerClient
|
||||
from durabletask.client import OrchestrationStatus
|
||||
|
||||
from agent_framework_durabletask import DurableAIAgentClient
|
||||
|
||||
|
||||
def create_dts_client(endpoint: str, taskhub: str) -> DurableTaskSchedulerClient:
|
||||
"""
|
||||
Create a DurableTaskSchedulerClient with common configuration.
|
||||
|
||||
Args:
|
||||
endpoint: The DTS endpoint address
|
||||
taskhub: The task hub name
|
||||
|
||||
Returns:
|
||||
A configured DurableTaskSchedulerClient instance
|
||||
"""
|
||||
return DurableTaskSchedulerClient(
|
||||
host_address=endpoint,
|
||||
secure_channel=False,
|
||||
taskhub=taskhub,
|
||||
token_credential=None,
|
||||
)
|
||||
|
||||
|
||||
def create_agent_client(
|
||||
endpoint: str,
|
||||
taskhub: str,
|
||||
max_poll_retries: int = 90,
|
||||
) -> tuple[DurableTaskSchedulerClient, DurableAIAgentClient]:
|
||||
"""
|
||||
Create a DurableAIAgentClient with the underlying DTS client.
|
||||
|
||||
Args:
|
||||
endpoint: The DTS endpoint address
|
||||
taskhub: The task hub name
|
||||
max_poll_retries: Max poll retries for the agent client
|
||||
|
||||
Returns:
|
||||
A tuple of (DurableTaskSchedulerClient, DurableAIAgentClient)
|
||||
"""
|
||||
dts_client = create_dts_client(endpoint, taskhub)
|
||||
agent_client = DurableAIAgentClient(dts_client, max_poll_retries=max_poll_retries)
|
||||
return dts_client, agent_client
|
||||
|
||||
|
||||
class OrchestrationHelper:
|
||||
"""Helper class for orchestration-related test operations."""
|
||||
|
||||
def __init__(self, dts_client: DurableTaskSchedulerClient):
|
||||
"""
|
||||
Initialize the orchestration helper.
|
||||
|
||||
Args:
|
||||
dts_client: The DurableTaskSchedulerClient instance to use
|
||||
"""
|
||||
self.client = dts_client
|
||||
|
||||
def wait_for_orchestration(
|
||||
self,
|
||||
instance_id: str,
|
||||
timeout: float = 60.0,
|
||||
) -> Any:
|
||||
"""
|
||||
Wait for an orchestration to complete.
|
||||
|
||||
Args:
|
||||
instance_id: The orchestration instance ID
|
||||
timeout: Maximum time to wait in seconds
|
||||
|
||||
Returns:
|
||||
The final OrchestrationMetadata
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the orchestration doesn't complete within timeout
|
||||
RuntimeError: If the orchestration fails
|
||||
"""
|
||||
# Use the built-in wait_for_orchestration_completion method
|
||||
metadata = self.client.wait_for_orchestration_completion(
|
||||
instance_id=instance_id,
|
||||
timeout=int(timeout),
|
||||
)
|
||||
|
||||
if metadata is None:
|
||||
raise TimeoutError(f"Orchestration {instance_id} did not complete within {timeout} seconds")
|
||||
|
||||
# Check if failed or terminated
|
||||
if metadata.runtime_status == OrchestrationStatus.FAILED:
|
||||
raise RuntimeError(f"Orchestration {instance_id} failed: {metadata.serialized_custom_status}")
|
||||
if metadata.runtime_status == OrchestrationStatus.TERMINATED:
|
||||
raise RuntimeError(f"Orchestration {instance_id} was terminated")
|
||||
|
||||
return metadata
|
||||
|
||||
def wait_for_orchestration_with_output(
|
||||
self,
|
||||
instance_id: str,
|
||||
timeout: float = 60.0,
|
||||
) -> tuple[Any, Any]:
|
||||
"""
|
||||
Wait for an orchestration to complete and return its output.
|
||||
|
||||
Args:
|
||||
instance_id: The orchestration instance ID
|
||||
timeout: Maximum time to wait in seconds
|
||||
|
||||
Returns:
|
||||
A tuple of (OrchestrationMetadata, output)
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the orchestration doesn't complete within timeout
|
||||
RuntimeError: If the orchestration fails
|
||||
"""
|
||||
metadata = self.wait_for_orchestration(instance_id, timeout)
|
||||
|
||||
# The output should be available in the metadata
|
||||
return metadata, metadata.serialized_output
|
||||
|
||||
def get_orchestration_status(self, instance_id: str) -> Any | None:
|
||||
"""
|
||||
Get the current status of an orchestration.
|
||||
|
||||
Args:
|
||||
instance_id: The orchestration instance ID
|
||||
|
||||
Returns:
|
||||
The OrchestrationMetadata or None if not found
|
||||
"""
|
||||
try:
|
||||
# Try to wait with a short timeout to get current status
|
||||
return self.client.wait_for_orchestration_completion(
|
||||
instance_id=instance_id,
|
||||
timeout=1, # Very short timeout, just checking status
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def raise_event(
|
||||
self,
|
||||
instance_id: str,
|
||||
event_name: str,
|
||||
event_data: Any = None,
|
||||
) -> None:
|
||||
"""
|
||||
Raise an external event to an orchestration.
|
||||
|
||||
Args:
|
||||
instance_id: The orchestration instance ID
|
||||
event_name: The name of the event
|
||||
event_data: The event data payload
|
||||
"""
|
||||
self.client.raise_orchestration_event(instance_id, event_name, data=event_data)
|
||||
|
||||
def wait_for_notification(self, instance_id: str, timeout_seconds: int = 30) -> bool:
|
||||
"""Wait for the orchestration to reach a notification point.
|
||||
|
||||
Polls the orchestration status until it appears to be waiting for approval.
|
||||
|
||||
Args:
|
||||
instance_id: The orchestration instance ID
|
||||
timeout_seconds: Maximum time to wait
|
||||
|
||||
Returns:
|
||||
True if notification detected, False if timeout
|
||||
"""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout_seconds:
|
||||
try:
|
||||
metadata = self.client.get_orchestration_state(
|
||||
instance_id=instance_id,
|
||||
)
|
||||
|
||||
if metadata:
|
||||
# Check if we're waiting for approval by examining custom status
|
||||
if metadata.serialized_custom_status:
|
||||
try:
|
||||
custom_status = json.loads(metadata.serialized_custom_status)
|
||||
# Handle both string and dict custom status
|
||||
status_str = custom_status if isinstance(custom_status, str) else str(custom_status)
|
||||
if status_str.lower().startswith("requesting human feedback"):
|
||||
return True
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
# If it's not JSON, treat as plain string
|
||||
if metadata.serialized_custom_status.lower().startswith("requesting human feedback"):
|
||||
return True
|
||||
|
||||
# Check for terminal states
|
||||
if metadata.runtime_status.name == "COMPLETED" or metadata.runtime_status.name == "FAILED":
|
||||
return False
|
||||
except Exception:
|
||||
# Silently ignore transient errors during polling (e.g., network issues, service unavailable).
|
||||
# The loop will retry until timeout, allowing the service to recover.
|
||||
pass
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
return False
|
||||
@@ -10,10 +10,7 @@ Tests basic agent operations including:
|
||||
- Empty thread ID handling
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from dt_testutils import create_agent_client
|
||||
|
||||
# Module-level markers - applied to all tests in this module
|
||||
pytestmark = [
|
||||
@@ -28,13 +25,10 @@ class TestSingleAgent:
|
||||
"""Test suite for single agent functionality."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, worker_process: dict[str, Any], dts_endpoint: str) -> None:
|
||||
def setup(self, agent_client_factory: type) -> None:
|
||||
"""Setup test fixtures."""
|
||||
self.endpoint: str = dts_endpoint
|
||||
self.taskhub: str = str(worker_process["taskhub"])
|
||||
|
||||
# Create agent client
|
||||
_, self.agent_client = create_agent_client(self.endpoint, self.taskhub)
|
||||
# Create agent client using the factory fixture
|
||||
_, self.agent_client = agent_client_factory.create()
|
||||
|
||||
def test_agent_registration(self) -> None:
|
||||
"""Test that the Joker agent is registered and accessible."""
|
||||
|
||||
@@ -10,10 +10,7 @@ Tests operations with multiple specialized agents:
|
||||
- Agent isolation and tool routing
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from dt_testutils import create_agent_client
|
||||
|
||||
# Agent names from the 02_multi_agent sample
|
||||
WEATHER_AGENT_NAME: str = "WeatherAgent"
|
||||
@@ -32,13 +29,10 @@ class TestMultiAgent:
|
||||
"""Test suite for multi-agent functionality."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, worker_process: dict[str, Any], dts_endpoint: str) -> None:
|
||||
def setup(self, agent_client_factory: type) -> None:
|
||||
"""Setup test fixtures."""
|
||||
self.endpoint: str = dts_endpoint
|
||||
self.taskhub: str = str(worker_process["taskhub"])
|
||||
|
||||
# Create agent client
|
||||
_, self.agent_client = create_agent_client(self.endpoint, self.taskhub)
|
||||
# Create agent client using the factory fixture
|
||||
_, self.agent_client = agent_client_factory.create()
|
||||
|
||||
def test_multiple_agents_registered(self) -> None:
|
||||
"""Test that both agents are registered and accessible."""
|
||||
|
||||
+4
-9
@@ -22,11 +22,9 @@ import sys
|
||||
import time
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import redis.asyncio as aioredis
|
||||
from dt_testutils import OrchestrationHelper, create_agent_client
|
||||
|
||||
# Add sample directory to path to import RedisStreamResponseHandler
|
||||
SAMPLE_DIR = Path(__file__).parents[4] / "samples" / "getting_started" / "durabletask" / "03_single_agent_streaming"
|
||||
@@ -48,14 +46,11 @@ class TestSampleReliableStreaming:
|
||||
"""Tests for 03_single_agent_streaming sample."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, worker_process: dict[str, Any], dts_endpoint: str) -> None:
|
||||
def setup(self, agent_client_factory: type, orchestration_helper) -> None:
|
||||
"""Setup test fixtures."""
|
||||
self.endpoint: str = dts_endpoint
|
||||
self.taskhub: str = str(worker_process["taskhub"])
|
||||
|
||||
# Create agent client
|
||||
dts_client, self.agent_client = create_agent_client(self.endpoint, self.taskhub)
|
||||
self.helper = OrchestrationHelper(dts_client)
|
||||
# Create agent client using the factory fixture
|
||||
_, self.agent_client = agent_client_factory.create()
|
||||
self.helper = orchestration_helper
|
||||
|
||||
# Redis configuration
|
||||
self.redis_connection_string = os.environ.get("REDIS_CONNECTION_STRING", "redis://localhost:6379")
|
||||
|
||||
+4
-11
@@ -11,10 +11,8 @@ Tests orchestration patterns with sequential agent calls:
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from dt_testutils import OrchestrationHelper, create_agent_client
|
||||
from durabletask.client import OrchestrationStatus
|
||||
|
||||
# Agent name from the 04_single_agent_orchestration_chaining sample
|
||||
@@ -36,16 +34,11 @@ class TestSingleAgentOrchestrationChaining:
|
||||
"""Test suite for single agent orchestration with chaining."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, worker_process: dict[str, Any], dts_endpoint: str) -> None:
|
||||
def setup(self, agent_client_factory: type, orchestration_helper) -> None:
|
||||
"""Setup test fixtures."""
|
||||
self.endpoint: str = dts_endpoint
|
||||
self.taskhub: str = str(worker_process["taskhub"])
|
||||
|
||||
# Create agent client and DTS client
|
||||
self.dts_client, self.agent_client = create_agent_client(self.endpoint, self.taskhub)
|
||||
|
||||
# Create orchestration helper
|
||||
self.orch_helper = OrchestrationHelper(self.dts_client)
|
||||
# Create agent client using the factory fixture
|
||||
self.dts_client, self.agent_client = agent_client_factory.create()
|
||||
self.orch_helper = orchestration_helper
|
||||
|
||||
def test_agent_registered(self):
|
||||
"""Test that the Writer agent is registered."""
|
||||
|
||||
+4
-11
@@ -11,10 +11,8 @@ Tests concurrent execution patterns:
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from dt_testutils import OrchestrationHelper, create_agent_client
|
||||
from durabletask.client import OrchestrationStatus
|
||||
|
||||
# Agent names from the 05_multi_agent_orchestration_concurrency sample
|
||||
@@ -36,16 +34,11 @@ class TestMultiAgentOrchestrationConcurrency:
|
||||
"""Test suite for multi-agent orchestration with concurrency."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, worker_process: dict[str, Any], dts_endpoint: str) -> None:
|
||||
def setup(self, agent_client_factory: type, orchestration_helper) -> None:
|
||||
"""Setup test fixtures."""
|
||||
self.endpoint = dts_endpoint
|
||||
self.taskhub = worker_process["taskhub"]
|
||||
|
||||
# Create agent client and DTS client
|
||||
self.dts_client, self.agent_client = create_agent_client(self.endpoint, self.taskhub)
|
||||
|
||||
# Create orchestration helper
|
||||
self.orch_helper = OrchestrationHelper(self.dts_client)
|
||||
# Create agent client using the factory fixture
|
||||
self.dts_client, self.agent_client = agent_client_factory.create()
|
||||
self.orch_helper = orchestration_helper
|
||||
|
||||
def test_agents_registered(self):
|
||||
"""Test that both agents are registered."""
|
||||
|
||||
+4
-11
@@ -11,10 +11,8 @@ Tests conditional orchestration patterns:
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from dt_testutils import OrchestrationHelper, create_agent_client
|
||||
from durabletask.client import OrchestrationStatus
|
||||
|
||||
# Agent names from the 06_multi_agent_orchestration_conditionals sample
|
||||
@@ -36,16 +34,11 @@ class TestMultiAgentOrchestrationConditionals:
|
||||
"""Test suite for multi-agent orchestration with conditionals."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, worker_process: dict[str, Any], dts_endpoint: str) -> None:
|
||||
def setup(self, agent_client_factory: type, orchestration_helper) -> None:
|
||||
"""Setup test fixtures."""
|
||||
self.endpoint: str = dts_endpoint
|
||||
self.taskhub: str = str(worker_process["taskhub"])
|
||||
|
||||
# Create agent client and DTS client
|
||||
self.dts_client, self.agent_client = create_agent_client(self.endpoint, self.taskhub)
|
||||
|
||||
# Create orchestration helper
|
||||
self.orch_helper = OrchestrationHelper(self.dts_client)
|
||||
# Create agent client using the factory fixture
|
||||
self.dts_client, self.agent_client = agent_client_factory.create()
|
||||
self.orch_helper = orchestration_helper
|
||||
|
||||
def test_agents_registered(self):
|
||||
"""Test that both agents are registered."""
|
||||
|
||||
+4
-13
@@ -11,10 +11,8 @@ Tests human-in-the-loop (HITL) patterns:
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from dt_testutils import OrchestrationHelper, create_agent_client
|
||||
from durabletask.client import OrchestrationStatus
|
||||
|
||||
# Constants from the 07_single_agent_orchestration_hitl sample
|
||||
@@ -36,18 +34,11 @@ class TestSingleAgentOrchestrationHITL:
|
||||
"""Test suite for single agent orchestration with human-in-the-loop."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, worker_process: dict[str, Any], dts_endpoint: str) -> None:
|
||||
def setup(self, agent_client_factory: type, orchestration_helper) -> None:
|
||||
"""Setup test fixtures."""
|
||||
self.endpoint: str = str(worker_process["endpoint"])
|
||||
self.taskhub: str = str(worker_process["taskhub"])
|
||||
|
||||
logging.info(f"Using taskhub: {self.taskhub} at endpoint: {self.endpoint}")
|
||||
|
||||
# Create agent client and DTS client
|
||||
self.dts_client, self.agent_client = create_agent_client(self.endpoint, self.taskhub)
|
||||
|
||||
# Create orchestration helper
|
||||
self.orch_helper = OrchestrationHelper(self.dts_client)
|
||||
# Create agent client using the factory fixture
|
||||
self.dts_client, self.agent_client = agent_client_factory.create()
|
||||
self.orch_helper = orchestration_helper
|
||||
|
||||
def test_agent_registered(self):
|
||||
"""Test that the Writer agent is registered."""
|
||||
|
||||
@@ -11,7 +11,7 @@ from typing import Any, TypeVar
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
from agent_framework import AgentResponse, AgentResponseUpdate, ChatMessage, Content
|
||||
from agent_framework import AgentResponse, AgentResponseUpdate, ChatMessage, Content, ResponseStream
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agent_framework_durabletask import (
|
||||
@@ -81,8 +81,27 @@ def _role_value(chat_message: DurableAgentStateMessage) -> str:
|
||||
|
||||
def _agent_response(text: str | None) -> AgentResponse:
|
||||
"""Create an AgentResponse with a single assistant message."""
|
||||
message = ChatMessage("assistant", [text]) if text is not None else ChatMessage("assistant", [])
|
||||
return AgentResponse(messages=[message])
|
||||
message = ChatMessage(role="assistant", text=text) if text is not None else ChatMessage(role="assistant", text="")
|
||||
return AgentResponse(messages=[message], created_at="2024-01-01T00:00:00Z")
|
||||
|
||||
|
||||
def _create_mock_run(response: AgentResponse | None = None, side_effect: Exception | None = None):
|
||||
"""Create a mock run function that handles stream parameter correctly.
|
||||
|
||||
The durabletask entity code tries run(stream=True) first, then falls back to run(stream=False).
|
||||
This helper creates a mock that raises TypeError for streaming (to trigger fallback) and
|
||||
returns the response or raises the side_effect for non-streaming.
|
||||
"""
|
||||
|
||||
async def mock_run(*args, stream=False, **kwargs):
|
||||
if stream:
|
||||
# Simulate "streaming not supported" to trigger fallback
|
||||
raise TypeError("streaming not supported")
|
||||
if side_effect:
|
||||
raise side_effect
|
||||
return response
|
||||
|
||||
return mock_run
|
||||
|
||||
|
||||
class RecordingCallback:
|
||||
@@ -194,7 +213,14 @@ class TestAgentEntityRunAgent:
|
||||
"""Test that run executes the agent."""
|
||||
mock_agent = Mock()
|
||||
mock_response = _agent_response("Test response")
|
||||
mock_agent.run = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Mock run() to return response for non-streaming, raise for streaming (to test fallback)
|
||||
async def mock_run(*args, stream=False, **kwargs):
|
||||
if stream:
|
||||
raise TypeError("streaming not supported")
|
||||
return mock_response
|
||||
|
||||
mock_agent.run = mock_run
|
||||
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
@@ -203,22 +229,12 @@ class TestAgentEntityRunAgent:
|
||||
"correlationId": "corr-entity-1",
|
||||
})
|
||||
|
||||
# Verify agent.run was called
|
||||
mock_agent.run.assert_called_once()
|
||||
_, kwargs = mock_agent.run.call_args
|
||||
sent_messages: list[Any] = kwargs.get("messages")
|
||||
assert len(sent_messages) == 1
|
||||
sent_message = sent_messages[0]
|
||||
assert isinstance(sent_message, ChatMessage)
|
||||
assert getattr(sent_message, "text", None) == "Test message"
|
||||
assert getattr(sent_message.role, "value", sent_message.role) == "user"
|
||||
|
||||
# Verify result
|
||||
assert isinstance(result, AgentResponse)
|
||||
assert result.text == "Test response"
|
||||
|
||||
async def test_run_agent_streaming_callbacks_invoked(self) -> None:
|
||||
"""Ensure streaming updates trigger callbacks and run() is not used."""
|
||||
"""Ensure streaming updates trigger callbacks when using run(stream=True)."""
|
||||
updates = [
|
||||
AgentResponseUpdate(contents=[Content.from_text(text="Hello")]),
|
||||
AgentResponseUpdate(contents=[Content.from_text(text=" world")]),
|
||||
@@ -230,8 +246,17 @@ class TestAgentEntityRunAgent:
|
||||
|
||||
mock_agent = Mock()
|
||||
mock_agent.name = "StreamingAgent"
|
||||
mock_agent.run_stream = Mock(return_value=update_generator())
|
||||
mock_agent.run = AsyncMock(side_effect=AssertionError("run() should not be called when streaming succeeds"))
|
||||
|
||||
# Mock run() to return ResponseStream when stream=True
|
||||
def mock_run(*args, stream=False, **kwargs):
|
||||
if stream:
|
||||
return ResponseStream(
|
||||
update_generator(),
|
||||
finalizer=AgentResponse.from_updates,
|
||||
)
|
||||
raise AssertionError("run(stream=False) should not be called when streaming succeeds")
|
||||
|
||||
mock_agent.run = mock_run
|
||||
|
||||
callback = RecordingCallback()
|
||||
entity = _make_entity(mock_agent, callback=callback, thread_id="session-1")
|
||||
@@ -247,7 +272,6 @@ class TestAgentEntityRunAgent:
|
||||
assert "Hello" in result.text
|
||||
assert callback.stream_mock.await_count == len(updates)
|
||||
assert callback.response_mock.await_count == 1
|
||||
mock_agent.run.assert_not_called()
|
||||
|
||||
# Validate callback arguments
|
||||
stream_calls = callback.stream_mock.await_args_list
|
||||
@@ -272,9 +296,8 @@ class TestAgentEntityRunAgent:
|
||||
"""Ensure the final callback fires even when streaming is unavailable."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.name = "NonStreamingAgent"
|
||||
mock_agent.run_stream = None
|
||||
agent_response = _agent_response("Final response")
|
||||
mock_agent.run = AsyncMock(return_value=agent_response)
|
||||
mock_agent.run = _create_mock_run(response=agent_response)
|
||||
|
||||
callback = RecordingCallback()
|
||||
entity = _make_entity(mock_agent, callback=callback, thread_id="session-2")
|
||||
@@ -304,7 +327,7 @@ class TestAgentEntityRunAgent:
|
||||
"""Test that run_agent updates the conversation history."""
|
||||
mock_agent = Mock()
|
||||
mock_response = _agent_response("Agent response")
|
||||
mock_agent.run = AsyncMock(return_value=mock_response)
|
||||
mock_agent.run = _create_mock_run(response=mock_response)
|
||||
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
@@ -327,7 +350,7 @@ class TestAgentEntityRunAgent:
|
||||
async def test_run_agent_increments_message_count(self) -> None:
|
||||
"""Test that run_agent increments the message count."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
|
||||
mock_agent.run = _create_mock_run(response=_agent_response("Response"))
|
||||
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
@@ -345,7 +368,7 @@ class TestAgentEntityRunAgent:
|
||||
async def test_run_requires_entity_thread_id(self) -> None:
|
||||
"""Test that AgentEntity.run rejects missing entity thread identifiers."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
|
||||
mock_agent.run = _create_mock_run(response=_agent_response("Response"))
|
||||
|
||||
entity = _make_entity(mock_agent, thread_id="")
|
||||
|
||||
@@ -355,7 +378,7 @@ class TestAgentEntityRunAgent:
|
||||
async def test_run_agent_multiple_conversations(self) -> None:
|
||||
"""Test that run_agent maintains history across multiple messages."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
|
||||
mock_agent.run = _create_mock_run(response=_agent_response("Response"))
|
||||
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
@@ -419,7 +442,7 @@ class TestAgentEntityReset:
|
||||
async def test_reset_after_conversation(self) -> None:
|
||||
"""Test reset after a full conversation."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
|
||||
mock_agent.run = _create_mock_run(response=_agent_response("Response"))
|
||||
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
@@ -445,7 +468,7 @@ class TestErrorHandling:
|
||||
async def test_run_agent_handles_agent_exception(self) -> None:
|
||||
"""Test that run_agent handles agent exceptions."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(side_effect=Exception("Agent failed"))
|
||||
mock_agent.run = _create_mock_run(side_effect=Exception("Agent failed"))
|
||||
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
@@ -461,7 +484,7 @@ class TestErrorHandling:
|
||||
async def test_run_agent_handles_value_error(self) -> None:
|
||||
"""Test that run_agent handles ValueError instances."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(side_effect=ValueError("Invalid input"))
|
||||
mock_agent.run = _create_mock_run(side_effect=ValueError("Invalid input"))
|
||||
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
@@ -477,7 +500,7 @@ class TestErrorHandling:
|
||||
async def test_run_agent_handles_timeout_error(self) -> None:
|
||||
"""Test that run_agent handles TimeoutError instances."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(side_effect=TimeoutError("Request timeout"))
|
||||
mock_agent.run = _create_mock_run(side_effect=TimeoutError("Request timeout"))
|
||||
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
@@ -492,7 +515,7 @@ class TestErrorHandling:
|
||||
async def test_run_agent_preserves_message_on_error(self) -> None:
|
||||
"""Test that run_agent preserves message information on error."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(side_effect=Exception("Error"))
|
||||
mock_agent.run = _create_mock_run(side_effect=Exception("Error"))
|
||||
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
@@ -513,7 +536,7 @@ class TestConversationHistory:
|
||||
async def test_conversation_history_has_timestamps(self) -> None:
|
||||
"""Test that conversation history entries include timestamps."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
|
||||
mock_agent.run = _create_mock_run(response=_agent_response("Response"))
|
||||
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
@@ -533,17 +556,17 @@ class TestConversationHistory:
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
# Send multiple messages with different responses
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response 1"))
|
||||
mock_agent.run = _create_mock_run(response=_agent_response("Response 1"))
|
||||
await entity.run(
|
||||
{"message": "Message 1", "correlationId": "corr-entity-history-2a"},
|
||||
)
|
||||
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response 2"))
|
||||
mock_agent.run = _create_mock_run(response=_agent_response("Response 2"))
|
||||
await entity.run(
|
||||
{"message": "Message 2", "correlationId": "corr-entity-history-2b"},
|
||||
)
|
||||
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response 3"))
|
||||
mock_agent.run = _create_mock_run(response=_agent_response("Response 3"))
|
||||
await entity.run(
|
||||
{"message": "Message 3", "correlationId": "corr-entity-history-2c"},
|
||||
)
|
||||
@@ -561,7 +584,7 @@ class TestConversationHistory:
|
||||
async def test_conversation_history_role_alternation(self) -> None:
|
||||
"""Test that conversation history alternates between user and assistant roles."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
|
||||
mock_agent.run = _create_mock_run(response=_agent_response("Response"))
|
||||
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
@@ -587,7 +610,7 @@ class TestRunRequestSupport:
|
||||
async def test_run_agent_with_run_request_object(self) -> None:
|
||||
"""Test run_agent with a RunRequest object."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
|
||||
mock_agent.run = _create_mock_run(response=_agent_response("Response"))
|
||||
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
@@ -606,7 +629,7 @@ class TestRunRequestSupport:
|
||||
async def test_run_agent_with_dict_request(self) -> None:
|
||||
"""Test run_agent with a dictionary request."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
|
||||
mock_agent.run = _create_mock_run(response=_agent_response("Response"))
|
||||
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
@@ -625,7 +648,7 @@ class TestRunRequestSupport:
|
||||
async def test_run_agent_with_string_raises_without_correlation(self) -> None:
|
||||
"""Test that run_agent rejects legacy string input without correlation ID."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
|
||||
mock_agent.run = _create_mock_run(response=_agent_response("Response"))
|
||||
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
@@ -635,7 +658,7 @@ class TestRunRequestSupport:
|
||||
async def test_run_agent_stores_role_in_history(self) -> None:
|
||||
"""Test that run_agent stores the role in conversation history."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
|
||||
mock_agent.run = _create_mock_run(response=_agent_response("Response"))
|
||||
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
@@ -657,7 +680,7 @@ class TestRunRequestSupport:
|
||||
"""Test run_agent with a JSON response format."""
|
||||
mock_agent = Mock()
|
||||
# Return JSON response
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response('{"answer": 42}'))
|
||||
mock_agent.run = _create_mock_run(response=_agent_response('{"answer": 42}'))
|
||||
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
@@ -676,7 +699,7 @@ class TestRunRequestSupport:
|
||||
async def test_run_agent_disable_tool_calls(self) -> None:
|
||||
"""Test run_agent with tool calls disabled."""
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = AsyncMock(return_value=_agent_response("Response"))
|
||||
mock_agent.run = _create_mock_run(response=_agent_response("Response"))
|
||||
|
||||
entity = _make_entity(mock_agent)
|
||||
|
||||
@@ -686,7 +709,7 @@ class TestRunRequestSupport:
|
||||
|
||||
assert isinstance(result, AgentResponse)
|
||||
# Agent should have been called (tool disabling is framework-dependent)
|
||||
mock_agent.run.assert_called_once()
|
||||
assert result.text == "Response"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -77,7 +77,7 @@ class TestDurableAIAgentMessageNormalization:
|
||||
|
||||
def test_run_accepts_chat_message(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None:
|
||||
"""Verify run accepts and normalizes ChatMessage objects."""
|
||||
chat_msg = ChatMessage("user", ["Test message"])
|
||||
chat_msg = ChatMessage(role="user", text="Test message")
|
||||
test_agent.run(chat_msg)
|
||||
|
||||
mock_executor.run_durable_agent.assert_called_once()
|
||||
@@ -95,8 +95,8 @@ class TestDurableAIAgentMessageNormalization:
|
||||
def test_run_accepts_list_of_chat_messages(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None:
|
||||
"""Verify run accepts and joins list of ChatMessage objects."""
|
||||
messages = [
|
||||
ChatMessage("user", ["Message 1"]),
|
||||
ChatMessage("assistant", ["Message 2"]),
|
||||
ChatMessage(role="user", text="Message 1"),
|
||||
ChatMessage(role="assistant", text="Message 2"),
|
||||
]
|
||||
test_agent.run(messages)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user