mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
3dc59c83b5
* WIP * big update to new ResponseStream model * fixed tests and typing * fixed tests and typing * fixed tools typevar import * fix * mypy fix * mypy fixes and some cleanup * fix missing quoted names * and client * fix imports agui * fix anthropic override * fix agui * fix ag ui * fix import * fix anthropic types * fix mypy * refactoring * updated typing * fix 3.11 * fixes * redid layering of chat clients and agents * redid layering of chat clients and agents * Fix lint, type, and test issues after rebase - Add @overload decorators to AgentProtocol.run() for type compatibility - Add missing docstring params (middleware, function_invocation_configuration) - Fix TODO format (TD002) by adding author tags - Fix broken observability tests from upstream: - Replace non-existent use_instrumentation with direct instantiation - Replace non-existent use_agent_instrumentation with AgentTelemetryLayer mixin - Fix get_streaming_response to use get_response(stream=True) - Add AgentInitializationError import - Update streaming exception tests to match actual behavior * Fix AgentExecutionException import error in test_agents.py - Replace non-existent AgentExecutionException with AgentRunException * Fix test import and asyncio deprecation issues - Add 'tests' to pythonpath in ag-ui pyproject.toml for utils_test_ag_ui import - Replace deprecated asyncio.get_event_loop().run_until_complete with asyncio.run * Fix azure-ai test failures - Update _prepare_options patching to use correct class path - Fix test_to_azure_ai_agent_tools_web_search_missing_connection to clear env vars * Convert ag-ui utils_test_ag_ui.py to conftest.py - Move test utilities to conftest.py for proper pytest discovery - Update all test imports to use conftest instead of utils_test_ag_ui - Remove old utils_test_ag_ui.py file - Revert pythonpath change in pyproject.toml * fix: use relative imports for ag-ui test utilities * fix agui * Rename Bare*Client to Raw*Client and BaseChatClient - Renamed BareChatClient to BaseChatClient (abstract base class) - Renamed BareOpenAIChatClient to RawOpenAIChatClient - Renamed BareOpenAIResponsesClient to RawOpenAIResponsesClient - Renamed BareAzureAIClient to RawAzureAIClient - Added warning docstrings to Raw* classes about layer ordering - Updated README in samples/getting_started/agents/custom with layer docs - Added test for span ordering with function calling * Fix layer ordering: FunctionInvocationLayer before ChatTelemetryLayer This ensures each inner LLM call gets its own telemetry span, resulting in the correct span sequence: chat -> execute_tool -> chat Updated all production clients and test mocks to use correct ordering: - ChatMiddlewareLayer (first) - FunctionInvocationLayer (second) - ChatTelemetryLayer (third) - BaseChatClient/Raw...Client (fourth) * Remove run_stream usage * Fix conversation_id propagation * Python: Add BaseAgent implementation for Claude Agent SDK (#3509) * Added ClaudeAgent implementation * Updated streaming logic * Small updates * Small update * Fixes * Small fix * Naming improvements * Updated imports * Addressed comments * Updated package versions * Update Claude agent connector layering * fix test and plugin * Store function middleware in invocation layer * Fix telemetry streaming and ag-ui tests * Remove legacy ag-ui tests folder * updates * Remove terminate flag from FunctionInvocationContext, use MiddlewareTermination instead - Remove terminate attribute from FunctionInvocationContext - Add result attribute to MiddlewareTermination to carry function results - FunctionMiddlewarePipeline.execute() now lets MiddlewareTermination propagate - _auto_invoke_function captures context.result in exception before re-raising - _try_execute_function_calls catches MiddlewareTermination and sets should_terminate - Fix handoff middleware to append to chat_client.function_middleware directly - Update tests to use raise MiddlewareTermination instead of context.terminate - Add middleware flow documentation in samples/concepts/tools/README.md - Fix ag-ui to use FunctionMiddlewarePipeline instead of removed create_function_middleware_pipeline * fix: remove references to removed terminate flag in purview tests, add type ignore * fix: move _test_utils.py from package to test folder * fix: call get_final_response() to trigger context provider notification in streaming test * fix: correct broken links in tools README * docs: clarify default middleware behavior in summary table * fix: ensure inner stream result hooks are called when using map()/from_awaitable() * Fix mypy type errors * Address PR review comments on observability.py - Remove TODO comment about unconsumed streams, add explanatory note instead - Remove redundant _close_span cleanup hook (already called in _finalize_stream) - Clarify behavior: cleanup hooks run after stream iteration, if stream is not consumed the span remains open until garbage collected * Remove gen_ai.client.operation.duration from span attributes Duration is a metrics-only attribute per OpenTelemetry semantic conventions. It should be recorded to the histogram but not set as a span attribute. * Remove duration from _get_response_attributes, pass directly to _capture_response Duration is a metrics-only attribute. It's now passed directly to _capture_response instead of being included in the attributes dict that gets set on the span. * Remove redundant _close_span cleanup hook in AgentTelemetryLayer _finalize_stream already calls _close_span() in its finally block, so adding it as a separate cleanup hook is redundant. * Use weakref.finalize to close span when stream is garbage collected If a user creates a streaming response but never consumes it, the cleanup hooks won't run. Now we register a weak reference finalizer that will close the span when the stream object is garbage collected, ensuring spans don't leak in this scenario. * Fix _get_finalizers_from_stream to use _result_hooks attribute Renamed function to _get_result_hooks_from_stream and fixed it to look for the _result_hooks attribute which is the correct name in ResponseStream class. * Add missing asyncio import in test_request_info_mixin.py * Fix leftover merge conflict marker in image_generation sample * Update integration tests * Fix integration tests: increase max_iterations from 1 to 2 Tests with tool_choice options require at least 2 iterations: 1. First iteration to get function call and execute the tool 2. Second iteration to get the final text response With max_iterations=1, streaming tests would return early with only the function call/result but no final text content. * Fix duplicate function call error in conversation-based APIs When using conversation_id (for Responses/Assistants APIs), the server already has the function call message from the previous response. We should only send the new function result message, not all messages including the function call which would cause a duplicate ID error. Fix: When conversation_id is set, only send the last message (the tool result) instead of all response.messages. * Add regression test for conversation_id propagation between tool iterations Port test from PR #3664 with updates for new streaming API pattern. Tests that conversation_id is properly updated in options dict during function invocation loop iterations. * Fix tool_choice=required to return after tool execution When tool_choice is 'required', the user's intent is to force exactly one tool call. After the tool executes, return immediately with the function call and result - don't continue to call the model again. This fixes integration tests that were failing with empty text responses because with tool_choice=required, the model would keep returning function calls instead of text. Also adds regression tests for: - conversation_id propagation between tool iterations (from PR #3664) - tool_choice=required returns after tool execution * Document tool_choice behavior in tools README - Add table explaining tool_choice values (auto, none, required) - Explain why tool_choice=required returns immediately after tool execution - Add code example showing the difference between required and auto - Update flow diagram to show the early return path for tool_choice=required * Fix tool_choice=None behavior - don't default to 'auto' Remove the hardcoded default of 'auto' for tool_choice in ChatAgent init. When tool_choice is not specified (None), it will now not be sent to the API, allowing the API's default behavior to be used. Users who want tool_choice='auto' can still explicitly set it either in default_options or at runtime. Fixes #3585 * Fix tool_choice=none should not remove tools In OpenAI Assistants client, tools were not being sent when tool_choice='none'. This was incorrect - tool_choice='none' means the model won't call tools, but tools should still be available in the request (they may be used later in the conversation). Fixes #3585 * Add test for tool_choice=none preserving tools Adds a regression test to ensure that when tool_choice='none' is set but tools are provided, the tools are still sent to the API. This verifies the fix for #3585. * Fix tool_choice=none should not remove tools in all clients Apply the same fix to OpenAI Responses client and Azure AI client: - OpenAI Responses: Remove else block that popped tool_choice/parallel_tool_calls - Azure AI: Remove tool_choice != 'none' check when adding tools When tool_choice='none', the model won't call tools, but tools should still be sent to the API so they're available for future turns. Also update README to clarify tool_choice=required supports multiple tools. Fixes #3585 * Keep tool_choice even when tools is None Move tool_choice processing outside of the 'if tools' block in OpenAI Responses client so tool_choice is sent to the API even when no tools are provided. * Update test to match new parallel_tool_calls behavior Changed test_prepare_options_removes_parallel_tool_calls_when_no_tools to test_prepare_options_preserves_parallel_tool_calls_when_no_tools to reflect that parallel_tool_calls is now preserved even when no tools are present, consistent with the tool_choice behavior. * Fix ChatMessage API and Role enum usage after rebase - Update ChatMessage instantiation to use keyword args (role=, text=, contents=) - Fix Role enum comparisons to use .value for string comparison - Add created_at to AgentResponse in error handling - Fix AgentResponse.from_updates -> from_agent_run_response_updates - Fix DurableAgentStateMessage.from_chat_message to convert Role enum to string - Add Role import where needed * Fix additional ChatMessage API and method name changes - Fix ChatMessage usage in workflow files (use text= instead of contents= for strings) - Fix AgentResponse.from_updates -> from_agent_run_response_updates in workflow files - Fix test files for ChatMessage and Role enum usage * Fix remaining ChatMessage API usage in test files * Fix more ChatMessage and Role API changes in source and test files - Fix ChatMessage in _magentic.py replan method - Fix Role enum comparison in test assertions - Fix remaining test files with old ChatMessage syntax * Fix ChatMessage and Role API changes across packages - Add Role import where missing - Fix ChatMessage signature: positional args to keyword args (role=, text=, contents=) - Fix Role enum comparisons: .role.value instead of .role string - Fix FinishReason enum usage in ag-ui event converters - Rename AgentResponse.from_updates to from_agent_run_response_updates in ag-ui Fixes API compatibility after Types API Review improvements merge * Fix ChatMessage and Role API changes in github_copilot tests * Fix ChatMessage and Role API changes in redis and github_copilot packages - Fix redis provider: Role enum comparison using .value - Fix redis tests: ChatMessage signature and Role comparisons - Fix github_copilot tests: ChatMessage signature and Role comparisons - Update docstring examples in redis chat message store * Fix ChatMessage and Role API changes in devui package - Fix executor: ChatMessage signature change - Fix conversations: Role enum to string conversion in two places - Fix tests: ChatMessage signatures and Role comparisons * Fix ChatMessage and Role API changes in a2a and lab packages - Fix a2a tests: Role comparisons and ChatMessage signatures - Fix lab tau2 source: Role enum comparison in flip_messages, log_messages, sliding_window - Fix lab tau2 tests: ChatMessage signatures and Role comparisons * Remove duplicate test files from ag-ui/tests (tests are in ag_ui_tests) * Fix ChatMessage and Role API changes across packages After rebasing on upstream/main which merged PR #3647 (Types API Review improvements), fix all packages to use the new API: - ChatMessage: Use keyword args (role=, text=, contents=) instead of positional args - Role: Compare using .value attribute since it's now an enum Packages fixed: - ag-ui: Fixed Role value extraction bugs in _message_adapters.py - anthropic: Fixed ChatMessage and Role comparisons in tests - azure-ai: Fixed Role comparison in _client.py - azure-ai-search: Fixed ChatMessage and Role in source/tests - bedrock: Fixed ChatMessage signatures in tests - chatkit: Fixed ChatMessage and Role in source/tests - copilotstudio: Fixed ChatMessage and Role in tests - declarative: Fixed ChatMessage in _executors_agents.py - mem0: Fixed ChatMessage and Role in source/tests - purview: Fixed ChatMessage in source/tests * Fix mypy errors for ChatMessage and Role API changes - durabletask: Use str() fallback in role value extraction - core: Fix ChatMessage in _orchestrator_helpers.py to use keyword args - core: Add type ignore for _conversation_state.py contents deserialization - ag-ui: Fix type ignore comments (call-overload instead of arg-type) - azure-ai-search: Fix get_role_value type hint to accept Any - lab: Move get_role_value to module level with Any type hint * Improve CI test timeout configuration - Increase job timeout from 10 to 15 minutes - Reduce per-test timeout to 60s (was 900s/300s) - Add --timeout_method thread for better timeout handling - Add --timeout-verbose to see which tests are slow - Reduce retries from 3 to 2 and delay from 10s to 5s This ensures individual test timeouts are shorter than the job timeout, providing better visibility when tests hang. With 60s timeout and 2 retries, worst case per test is ~180s. * Fix ChatMessage API usage in docstrings and source - Fix ChatMessage positional args in docstrings: _serialization.py, _threads.py, _middleware.py - Fix ChatMessage in tau2 runner.py - Fix role comparison in _orchestrator_helpers.py to use .value - Fix role comparison in _group_chat.py docstring example - Fix role assertions in test_durable_entities.py to use .value * Revert tool_choice/parallel_tool_calls changes - must be removed when no tools OpenAI API requires tool_choice and parallel_tool_calls to only be present when tools are specified. Restored the logic that removes these options when there are no tools. - Restored check in _chat_client.py to remove tool_choice and parallel_tool_calls when no tools present - Restored same logic in _responses_client.py - Reverted test to expect the correct behavior * fixed issue in tests * fix: resolve merge conflict markers in ag-ui tests * fix: restructure ag-ui tests and fix Role/FinishReason to use string types * fix: streaming function invocation and middleware termination - Refactor streaming function invocation to use get_final_response() on inner streams - Fix MiddlewareTermination to accept result parameter for passing results - Fix _AutoHandoffMiddleware to use MiddlewareTermination instead of context.terminate - Fix AgentMiddlewareLayer.run() to properly forward function/chat middleware - Remove duplicate middleware registration in AgentMiddlewareLayer.__init__ - Fix exception handling in _auto_invoke_function to properly capture termination - Fix mypy errors in core package - Update tests to use stream=True parameter for unified run API * fix all tests command * Refactor integration tests to use pytest fixtures - Merge testutils.py into conftest.py for azurefunctions integration tests - Merge dt_testutils.py into conftest.py for durabletask integration tests - Convert all integration tests to use fixtures instead of direct imports (fixes ModuleNotFoundError with --import-mode=importlib) - Add sample_helper fixture for azurefunctions tests - Add agent_client_factory and orchestration_helper fixtures for durabletask - Integration tests now skip with descriptive messages when services unavailable - Restructure devui tests into tests/devui/ with proper conftest.py - Add test organization guidelines to CODING_STANDARD.md - Remove __init__.py from test directories per pytest best practices * Fix pytest_collection_modifyitems to only skip integration tests The hook was skipping all tests in the test session, not just integration tests. Now it only skips items in the integration_tests directory. * Fix mem0 tests failing on Python 3.13 Use patch.object on the imported module instead of @patch with string path to ensure the mock takes effect regardless of import timing. * fix mem0 * another attempt for mem0 * fix for mem0 * fix mem0 * Increase worker initialization wait time in durabletask tests Increase from 2 to 8 seconds to allow time for: - Python startup and module imports - Azure OpenAI client creation - Agent registration with DTS worker - Worker connection to DTS This helps prevent test failures in CI where the first tests may run before the worker is fully ready to process requests. * Fix streaming test to use ResponseStream with finalizer The _consume_stream method now expects a ResponseStream that can provide a final AgentResponse via get_final_response(). Update the test to use ResponseStream with AgentResponse.from_updates as the finalizer. * Fix MockToolCallingAgent to use new ResponseStream API and update samples * small updates to run_stream to run * fix sub workflow * temp fix for az func test --------- Co-authored-by: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com>
1177 lines
44 KiB
Python
1177 lines
44 KiB
Python
# Copyright (c) Microsoft. All rights reserved.
|
|
|
|
"""Unit tests for AgentFunctionApp."""
|
|
|
|
# pyright: reportPrivateUsage=false
|
|
|
|
import json
|
|
from collections.abc import Awaitable, Callable
|
|
from typing import Any, TypeVar
|
|
from unittest.mock import ANY, AsyncMock, Mock, patch
|
|
|
|
import azure.durable_functions as df
|
|
import azure.functions as func
|
|
import pytest
|
|
from agent_framework import AgentResponse, ChatMessage
|
|
from agent_framework_durabletask import (
|
|
MIMETYPE_APPLICATION_JSON,
|
|
MIMETYPE_TEXT_PLAIN,
|
|
THREAD_ID_HEADER,
|
|
WAIT_FOR_RESPONSE_FIELD,
|
|
WAIT_FOR_RESPONSE_HEADER,
|
|
AgentEntity,
|
|
AgentEntityStateProviderMixin,
|
|
DurableAgentState,
|
|
)
|
|
|
|
from agent_framework_azurefunctions import AgentFunctionApp
|
|
from agent_framework_azurefunctions._entities import create_agent_entity
|
|
|
|
TFunc = TypeVar("TFunc", bound=Callable[..., Any])
|
|
|
|
|
|
def _identity_decorator(func: TFunc) -> TFunc:
|
|
return func
|
|
|
|
|
|
class _InMemoryStateProvider(AgentEntityStateProviderMixin):
|
|
def __init__(self, *, thread_id: str = "test-thread", initial_state: dict[str, Any] | None = None) -> None:
|
|
self._thread_id = thread_id
|
|
self._state_dict: dict[str, Any] = initial_state or {}
|
|
|
|
def _get_state_dict(self) -> dict[str, Any]:
|
|
return self._state_dict
|
|
|
|
def _set_state_dict(self, state: dict[str, Any]) -> None:
|
|
self._state_dict = state
|
|
|
|
def _get_thread_id_from_entity(self) -> str:
|
|
return self._thread_id
|
|
|
|
|
|
class TestAgentFunctionAppInit:
|
|
"""Test suite for AgentFunctionApp initialization."""
|
|
|
|
def test_init_with_defaults(self) -> None:
|
|
"""Test initialization with default parameters."""
|
|
mock_agent = Mock()
|
|
mock_agent.name = "TestAgent"
|
|
|
|
app = AgentFunctionApp(agents=[mock_agent])
|
|
|
|
assert len(app.agents) == 1
|
|
assert "TestAgent" in app.agents
|
|
assert app.enable_health_check is True
|
|
|
|
def test_init_with_custom_auth_level(self) -> None:
|
|
"""Test initialization with custom auth level."""
|
|
mock_agent = Mock()
|
|
mock_agent.name = "TestAgent"
|
|
|
|
app = AgentFunctionApp(agents=[mock_agent], http_auth_level=func.AuthLevel.FUNCTION)
|
|
|
|
# App should be created successfully
|
|
assert "TestAgent" in app.agents
|
|
|
|
def test_init_with_health_check_disabled(self) -> None:
|
|
"""Test initialization with health check disabled."""
|
|
mock_agent = Mock()
|
|
mock_agent.name = "TestAgent"
|
|
|
|
app = AgentFunctionApp(agents=[mock_agent], enable_health_check=False)
|
|
|
|
assert app.enable_health_check is False
|
|
|
|
def test_init_with_http_endpoints_disabled(self) -> None:
|
|
"""Test initialization with HTTP endpoints disabled."""
|
|
mock_agent = Mock()
|
|
mock_agent.name = "TestAgent"
|
|
|
|
app = AgentFunctionApp(agents=[mock_agent], enable_http_endpoints=False)
|
|
|
|
assert app.enable_http_endpoints is False
|
|
|
|
def test_init_stores_agent_reference(self) -> None:
|
|
"""Test that agent reference is stored correctly."""
|
|
mock_agent = Mock()
|
|
mock_agent.name = "TestAgent"
|
|
|
|
app = AgentFunctionApp(agents=[mock_agent])
|
|
|
|
assert app.agents["TestAgent"].name == "TestAgent"
|
|
|
|
def test_add_agent_uses_specific_callback(self) -> None:
|
|
"""Verify that a per-agent callback overrides the default."""
|
|
|
|
mock_agent = Mock()
|
|
mock_agent.name = "CallbackAgent"
|
|
specific_callback = Mock()
|
|
|
|
with patch.object(AgentFunctionApp, "_setup_agent_functions") as setup_mock:
|
|
app = AgentFunctionApp(default_callback=Mock())
|
|
app.add_agent(mock_agent, callback=specific_callback)
|
|
|
|
setup_mock.assert_called_once()
|
|
_, _, passed_callback, enable_http_endpoint, _enable_mcp_tool_trigger = setup_mock.call_args[0]
|
|
assert passed_callback is specific_callback
|
|
assert enable_http_endpoint is True
|
|
|
|
def test_default_callback_applied_when_no_specific(self) -> None:
|
|
"""Ensure the default callback is supplied when add_agent lacks override."""
|
|
|
|
mock_agent = Mock()
|
|
mock_agent.name = "DefaultAgent"
|
|
default_callback = Mock()
|
|
|
|
with patch.object(AgentFunctionApp, "_setup_agent_functions") as setup_mock:
|
|
app = AgentFunctionApp(default_callback=default_callback)
|
|
app.add_agent(mock_agent)
|
|
|
|
setup_mock.assert_called_once()
|
|
_, _, passed_callback, enable_http_endpoint, _enable_mcp_tool_trigger = setup_mock.call_args[0]
|
|
assert passed_callback is default_callback
|
|
assert enable_http_endpoint is True
|
|
|
|
def test_init_with_agents_uses_default_callback(self) -> None:
|
|
"""Agents provided in __init__ should receive the default callback."""
|
|
|
|
mock_agent = Mock()
|
|
mock_agent.name = "InitAgent"
|
|
default_callback = Mock()
|
|
|
|
with patch.object(AgentFunctionApp, "_setup_agent_functions") as setup_mock:
|
|
AgentFunctionApp(agents=[mock_agent], default_callback=default_callback)
|
|
|
|
setup_mock.assert_called_once()
|
|
_, _, passed_callback, enable_http_endpoint, _enable_mcp_tool_trigger = setup_mock.call_args[0]
|
|
assert passed_callback is default_callback
|
|
assert enable_http_endpoint is True
|
|
|
|
|
|
class TestAgentFunctionAppSetup:
|
|
"""Test suite for AgentFunctionApp setup and configuration."""
|
|
|
|
def test_app_is_dfapp_instance(self) -> None:
|
|
"""Test that AgentFunctionApp is a DFApp instance."""
|
|
mock_agent = Mock()
|
|
mock_agent.name = "TestAgent"
|
|
|
|
app = AgentFunctionApp(agents=[mock_agent])
|
|
|
|
assert isinstance(app, df.DFApp)
|
|
|
|
def test_setup_creates_http_trigger(self) -> None:
|
|
"""Test that setup creates an HTTP trigger."""
|
|
mock_agent = Mock()
|
|
mock_agent.name = "TestAgent"
|
|
|
|
def passthrough_decorator(*args: Any, **kwargs: Any) -> Callable[[TFunc], TFunc]:
|
|
def decorator(func: TFunc) -> TFunc:
|
|
return func
|
|
|
|
return decorator
|
|
|
|
with (
|
|
patch.object(AgentFunctionApp, "route", new=passthrough_decorator),
|
|
patch.object(AgentFunctionApp, "durable_client_input", new=passthrough_decorator),
|
|
patch.object(AgentFunctionApp, "entity_trigger", new=passthrough_decorator),
|
|
):
|
|
app = AgentFunctionApp(agents=[mock_agent])
|
|
|
|
# Verify agent is registered
|
|
assert "TestAgent" in app.agents
|
|
|
|
def test_http_function_name_uses_prefix_format(self) -> None:
|
|
"""Ensure function names follow the prefix-agent naming convention."""
|
|
mock_agent = Mock()
|
|
mock_agent.name = "Agent 42"
|
|
|
|
captured_names: list[str] = []
|
|
|
|
def capture_function_name(
|
|
self: AgentFunctionApp, name: str, *args: Any, **kwargs: Any
|
|
) -> Callable[[TFunc], TFunc]:
|
|
def decorator(func: TFunc) -> TFunc:
|
|
captured_names.append(name)
|
|
return func
|
|
|
|
return decorator
|
|
|
|
def passthrough_decorator(*args: Any, **kwargs: Any) -> Callable[[TFunc], TFunc]:
|
|
def decorator(func: TFunc) -> TFunc:
|
|
return func
|
|
|
|
return decorator
|
|
|
|
with (
|
|
patch.object(AgentFunctionApp, "function_name", new=capture_function_name),
|
|
patch.object(AgentFunctionApp, "route", new=passthrough_decorator),
|
|
patch.object(AgentFunctionApp, "durable_client_input", new=passthrough_decorator),
|
|
patch.object(AgentFunctionApp, "entity_trigger", new=passthrough_decorator),
|
|
):
|
|
AgentFunctionApp(agents=[mock_agent])
|
|
|
|
assert captured_names == ["http-Agent_42"]
|
|
|
|
def test_setup_skips_http_trigger_when_disabled(self) -> None:
|
|
"""Test that HTTP trigger is not created when disabled."""
|
|
mock_agent = Mock()
|
|
mock_agent.name = "TestAgent"
|
|
|
|
captured_routes: list[str | None] = []
|
|
|
|
def capture_route(*args: Any, **kwargs: Any) -> Callable[[TFunc], TFunc]:
|
|
def decorator(func: TFunc) -> TFunc:
|
|
route_key = kwargs.get("route") if kwargs else None
|
|
captured_routes.append(route_key)
|
|
return func
|
|
|
|
return decorator
|
|
|
|
def passthrough_decorator(*args: Any, **kwargs: Any) -> Callable[[TFunc], TFunc]:
|
|
def decorator(func: TFunc) -> TFunc:
|
|
return func
|
|
|
|
return decorator
|
|
|
|
with (
|
|
patch.object(AgentFunctionApp, "function_name", new=passthrough_decorator),
|
|
patch.object(AgentFunctionApp, "route", new=capture_route),
|
|
patch.object(AgentFunctionApp, "durable_client_input", new=passthrough_decorator),
|
|
patch.object(AgentFunctionApp, "entity_trigger", new=passthrough_decorator),
|
|
):
|
|
app = AgentFunctionApp(agents=[mock_agent], enable_http_endpoints=False)
|
|
|
|
# Verify agent is registered
|
|
assert "TestAgent" in app.agents
|
|
|
|
# Verify that no HTTP run route was created
|
|
run_route = f"agents/{mock_agent.name}/run"
|
|
assert run_route not in captured_routes
|
|
|
|
def test_agent_override_enables_http_route_when_app_disabled(self) -> None:
|
|
"""Agent-level override should enable HTTP route even when app disables it."""
|
|
|
|
mock_agent = Mock()
|
|
mock_agent.name = "OverrideAgent"
|
|
|
|
with (
|
|
patch.object(AgentFunctionApp, "_setup_http_run_route") as http_route_mock,
|
|
patch.object(AgentFunctionApp, "_setup_agent_entity") as agent_entity_mock,
|
|
):
|
|
app = AgentFunctionApp(enable_health_check=False, enable_http_endpoints=False)
|
|
app.add_agent(mock_agent, enable_http_endpoint=True)
|
|
|
|
http_route_mock.assert_called_once_with("OverrideAgent")
|
|
agent_entity_mock.assert_called_once_with(mock_agent, "OverrideAgent", ANY)
|
|
assert app._agent_metadata["OverrideAgent"].http_endpoint_enabled is True
|
|
|
|
def test_agent_override_disables_http_route_when_app_enabled(self) -> None:
|
|
"""Agent-level override should disable HTTP route even when app enables it."""
|
|
|
|
mock_agent = Mock()
|
|
mock_agent.name = "DisabledOverride"
|
|
|
|
with (
|
|
patch.object(AgentFunctionApp, "_setup_http_run_route") as http_route_mock,
|
|
patch.object(AgentFunctionApp, "_setup_agent_entity") as agent_entity_mock,
|
|
):
|
|
app = AgentFunctionApp(enable_health_check=False, enable_http_endpoints=True)
|
|
app.add_agent(mock_agent, enable_http_endpoint=False)
|
|
|
|
http_route_mock.assert_not_called()
|
|
agent_entity_mock.assert_called_once_with(mock_agent, "DisabledOverride", ANY)
|
|
assert app._agent_metadata["DisabledOverride"].http_endpoint_enabled is False
|
|
|
|
def test_multiple_apps_independent(self) -> None:
|
|
"""Test that multiple AgentFunctionApp instances are independent."""
|
|
agent1 = Mock()
|
|
agent1.name = "Agent1"
|
|
agent2 = Mock()
|
|
agent2.name = "Agent2"
|
|
|
|
app1 = AgentFunctionApp(agents=[agent1])
|
|
app2 = AgentFunctionApp(agents=[agent2])
|
|
|
|
assert app1.agents["Agent1"].name == "Agent1"
|
|
assert app2.agents["Agent2"].name == "Agent2"
|
|
assert "Agent1" in app1.agents
|
|
assert "Agent2" in app2.agents
|
|
|
|
|
|
class TestWaitForResponseAndCorrelationId:
|
|
"""Tests for wait_for_response flag and correlation ID handling."""
|
|
|
|
def _create_app(self) -> AgentFunctionApp:
|
|
mock_agent = Mock()
|
|
mock_agent.__class__.__name__ = "MockAgent"
|
|
mock_agent.name = "MockAgent"
|
|
return AgentFunctionApp(agents=[mock_agent], enable_health_check=False)
|
|
|
|
def _make_request(
|
|
self,
|
|
headers: dict[str, str] | None = None,
|
|
params: dict[str, str] | None = None,
|
|
) -> Mock:
|
|
request = Mock()
|
|
request.headers = headers or {}
|
|
request.params = params or {}
|
|
return request
|
|
|
|
def test_wait_for_response_header_true(self) -> None:
|
|
"""Test that the wait-for-response header is honored."""
|
|
app = self._create_app()
|
|
request = self._make_request(headers={WAIT_FOR_RESPONSE_HEADER: "true"})
|
|
|
|
assert app._should_wait_for_response(request, {}) is True
|
|
|
|
def test_wait_for_response_body_snake_case(self) -> None:
|
|
"""Test that payload controls wait_for_response."""
|
|
app = self._create_app()
|
|
request = self._make_request()
|
|
|
|
assert app._should_wait_for_response(request, {WAIT_FOR_RESPONSE_FIELD: "true"}) is True
|
|
assert app._should_wait_for_response(request, {WAIT_FOR_RESPONSE_FIELD: "false"}) is False
|
|
assert app._should_wait_for_response(request, {WAIT_FOR_RESPONSE_FIELD: "0"}) is False
|
|
|
|
def test_wait_for_response_query_parameter(self) -> None:
|
|
"""Test that query parameter controls wait_for_response."""
|
|
app = self._create_app()
|
|
request = self._make_request(params={WAIT_FOR_RESPONSE_FIELD: "true"})
|
|
|
|
assert app._should_wait_for_response(request, {}) is True
|
|
|
|
def test_wait_for_response_query_precedence(self) -> None:
|
|
"""Test that query parameter overrides body value."""
|
|
app = self._create_app()
|
|
request = self._make_request(params={WAIT_FOR_RESPONSE_FIELD: "false"})
|
|
|
|
assert app._should_wait_for_response(request, {WAIT_FOR_RESPONSE_FIELD: "true"}) is False
|
|
|
|
|
|
class TestAgentEntityOperations:
|
|
"""Test suite for entity operations."""
|
|
|
|
async def test_entity_run_agent_operation(self) -> None:
|
|
"""Test that entity can run agent operation."""
|
|
mock_agent = Mock()
|
|
mock_agent.run = AsyncMock(
|
|
return_value=AgentResponse(messages=[ChatMessage(role="assistant", text="Test response")])
|
|
)
|
|
|
|
entity = AgentEntity(mock_agent, state_provider=_InMemoryStateProvider(thread_id="test-conv-123"))
|
|
|
|
result = await entity.run({
|
|
"message": "Test message",
|
|
"correlationId": "corr-app-entity-1",
|
|
})
|
|
|
|
assert isinstance(result, AgentResponse)
|
|
assert result.text == "Test response"
|
|
assert entity.state.message_count == 2
|
|
|
|
async def test_entity_stores_conversation_history(self) -> None:
|
|
"""Test that the entity stores conversation history."""
|
|
mock_agent = Mock()
|
|
mock_agent.run = AsyncMock(
|
|
return_value=AgentResponse(messages=[ChatMessage(role="assistant", text="Response 1")])
|
|
)
|
|
|
|
entity = AgentEntity(mock_agent, state_provider=_InMemoryStateProvider(thread_id="conv-1"))
|
|
|
|
# Send first message
|
|
await entity.run({"message": "Message 1", "correlationId": "corr-app-entity-2"})
|
|
|
|
# Each conversation turn creates 2 entries: request and response
|
|
history = entity.state.data.conversation_history[0].messages # Request entry
|
|
assert len(history) == 1 # Just the user message
|
|
|
|
# Send second message
|
|
await entity.run({"message": "Message 2", "correlationId": "corr-app-entity-2b"})
|
|
|
|
# Now we have 4 entries total (2 requests + 2 responses)
|
|
# Access the first request entry
|
|
history2 = entity.state.data.conversation_history[2].messages # Second request entry
|
|
assert len(history2) == 1 # Just the user message
|
|
|
|
user_msg = history[0]
|
|
user_role = getattr(user_msg.role, "value", user_msg.role)
|
|
assert user_role == "user"
|
|
assert user_msg.text == "Message 1"
|
|
|
|
assistant_msg = entity.state.data.conversation_history[1].messages[0]
|
|
assistant_role = getattr(assistant_msg.role, "value", assistant_msg.role)
|
|
assert assistant_role == "assistant"
|
|
assert assistant_msg.text == "Response 1"
|
|
|
|
async def test_entity_increments_message_count(self) -> None:
|
|
"""Test that the entity increments the message count."""
|
|
mock_agent = Mock()
|
|
mock_agent.run = AsyncMock(
|
|
return_value=AgentResponse(messages=[ChatMessage(role="assistant", text="Response")])
|
|
)
|
|
|
|
entity = AgentEntity(mock_agent, state_provider=_InMemoryStateProvider(thread_id="conv-1"))
|
|
|
|
assert len(entity.state.data.conversation_history) == 0
|
|
|
|
await entity.run({"message": "Message 1", "correlationId": "corr-app-entity-3a"})
|
|
assert len(entity.state.data.conversation_history) == 2
|
|
|
|
await entity.run({"message": "Message 2", "correlationId": "corr-app-entity-3b"})
|
|
assert len(entity.state.data.conversation_history) == 4
|
|
|
|
def test_entity_reset(self) -> None:
|
|
"""Test that entity reset clears state."""
|
|
mock_agent = Mock()
|
|
entity = AgentEntity(mock_agent, state_provider=_InMemoryStateProvider())
|
|
|
|
# Set some state
|
|
entity.state = DurableAgentState()
|
|
|
|
# Reset
|
|
entity.reset()
|
|
|
|
assert len(entity.state.data.conversation_history) == 0
|
|
|
|
|
|
class TestAgentEntityFactory:
|
|
"""Test suite for the entity factory function."""
|
|
|
|
def test_create_agent_entity_returns_function(self) -> None:
|
|
"""Test that create_agent_entity returns a function."""
|
|
mock_agent = Mock()
|
|
entity_function = create_agent_entity(mock_agent)
|
|
|
|
assert callable(entity_function)
|
|
|
|
def test_entity_function_handles_run_operation(self) -> None:
|
|
"""Test that the entity function handles the run operation."""
|
|
mock_agent = Mock()
|
|
mock_agent.run = AsyncMock(
|
|
return_value=AgentResponse(messages=[ChatMessage(role="assistant", text="Response")])
|
|
)
|
|
|
|
entity_function = create_agent_entity(mock_agent)
|
|
|
|
# Mock context
|
|
mock_context = Mock()
|
|
mock_context.operation_name = "run"
|
|
mock_context.get_input.return_value = {
|
|
"message": "Test message",
|
|
"correlationId": "corr-app-factory-1",
|
|
}
|
|
mock_context.get_state.return_value = None
|
|
|
|
# Execute entity function
|
|
entity_function(mock_context)
|
|
|
|
# Verify result was set
|
|
assert mock_context.set_result.called
|
|
assert mock_context.set_state.called
|
|
result_call = mock_context.set_result.call_args[0][0]
|
|
assert "error" not in result_call
|
|
|
|
def test_entity_function_handles_run_agent_operation(self) -> None:
|
|
"""Test that the entity function handles the deprecated run_agent operation for backward compatibility."""
|
|
mock_agent = Mock()
|
|
mock_agent.run = AsyncMock(
|
|
return_value=AgentResponse(messages=[ChatMessage(role="assistant", text="Response")])
|
|
)
|
|
|
|
entity_function = create_agent_entity(mock_agent)
|
|
|
|
# Mock context
|
|
mock_context = Mock()
|
|
mock_context.operation_name = "run_agent"
|
|
mock_context.get_input.return_value = {
|
|
"message": "Test message",
|
|
"correlationId": "corr-app-factory-1",
|
|
}
|
|
mock_context.get_state.return_value = None
|
|
|
|
# Execute entity function
|
|
entity_function(mock_context)
|
|
|
|
# Verify result was set
|
|
assert mock_context.set_result.called
|
|
assert mock_context.set_state.called
|
|
result_call = mock_context.set_result.call_args[0][0]
|
|
assert "error" not in result_call
|
|
|
|
def test_entity_function_handles_reset_operation(self) -> None:
|
|
"""Test that the entity function handles the reset operation."""
|
|
mock_agent = Mock()
|
|
entity_function = create_agent_entity(mock_agent)
|
|
|
|
# Mock context
|
|
mock_context = Mock()
|
|
mock_context.operation_name = "reset"
|
|
mock_context.get_state.return_value = {
|
|
"schemaVersion": "1.0.0",
|
|
"data": {
|
|
"conversationHistory": [
|
|
{
|
|
"$type": "request",
|
|
"correlationId": "corr-reset-test",
|
|
"createdAt": "2024-01-01T00:00:00Z",
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"contents": [
|
|
{
|
|
"$type": "text",
|
|
"text": "test",
|
|
}
|
|
],
|
|
}
|
|
],
|
|
}
|
|
],
|
|
},
|
|
}
|
|
|
|
# Execute entity function
|
|
entity_function(mock_context)
|
|
|
|
# Verify result was set
|
|
assert mock_context.set_result.called
|
|
result_call = mock_context.set_result.call_args[0][0]
|
|
assert result_call["status"] == "reset"
|
|
|
|
def test_entity_function_handles_unknown_operation(self) -> None:
|
|
"""Test that the entity function handles an unknown operation."""
|
|
mock_agent = Mock()
|
|
entity_function = create_agent_entity(mock_agent)
|
|
|
|
# Mock context with unknown operation
|
|
mock_context = Mock()
|
|
mock_context.operation_name = "unknown_operation"
|
|
mock_context.get_state.return_value = None
|
|
|
|
# Execute entity function
|
|
entity_function(mock_context)
|
|
|
|
# Verify error result was set
|
|
assert mock_context.set_result.called
|
|
result_call = mock_context.set_result.call_args[0][0]
|
|
assert "error" in result_call
|
|
assert "unknown_operation" in result_call["error"]
|
|
|
|
def test_entity_function_restores_state(self) -> None:
|
|
"""Test that the entity function restores state from the context."""
|
|
mock_agent = Mock()
|
|
entity_function = create_agent_entity(mock_agent)
|
|
|
|
# Mock context with existing state
|
|
existing_state = {
|
|
"schemaVersion": "1.0.0",
|
|
"data": {
|
|
"conversationHistory": [
|
|
{
|
|
"$type": "request",
|
|
"correlationId": "corr-existing-1",
|
|
"createdAt": "2024-01-01T00:00:00Z",
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"contents": [
|
|
{
|
|
"$type": "text",
|
|
"text": "msg1",
|
|
}
|
|
],
|
|
}
|
|
],
|
|
},
|
|
{
|
|
"$type": "response",
|
|
"correlationId": "corr-existing-1",
|
|
"createdAt": "2024-01-01T00:05:00Z",
|
|
"messages": [
|
|
{
|
|
"role": "assistant",
|
|
"contents": [
|
|
{
|
|
"$type": "text",
|
|
"text": "resp1",
|
|
}
|
|
],
|
|
}
|
|
],
|
|
},
|
|
],
|
|
},
|
|
}
|
|
|
|
mock_context = Mock()
|
|
mock_context.operation_name = "run"
|
|
mock_context.get_input.return_value = {
|
|
"message": "Test message",
|
|
"correlationId": "corr-restore-1",
|
|
}
|
|
mock_context.get_state.return_value = existing_state
|
|
|
|
with patch.object(DurableAgentState, "from_dict", wraps=DurableAgentState.from_dict) as from_dict_mock:
|
|
entity_function(mock_context)
|
|
|
|
from_dict_mock.assert_called_once_with(existing_state)
|
|
|
|
|
|
class TestErrorHandling:
|
|
"""Test suite for error handling."""
|
|
|
|
async def test_entity_handles_agent_error(self) -> None:
|
|
"""Test that the entity handles agent execution errors."""
|
|
mock_agent = Mock()
|
|
mock_agent.run = AsyncMock(side_effect=Exception("Agent error"))
|
|
|
|
entity = AgentEntity(mock_agent, state_provider=_InMemoryStateProvider(thread_id="conv-1"))
|
|
|
|
result = await entity.run({
|
|
"message": "Test message",
|
|
"correlationId": "corr-app-error-1",
|
|
})
|
|
|
|
assert isinstance(result, AgentResponse)
|
|
assert len(result.messages) == 1
|
|
content = result.messages[0].contents[0]
|
|
assert content.type == "error"
|
|
assert "Agent error" in (content.message or "")
|
|
assert content.error_code == "Exception"
|
|
|
|
def test_entity_function_handles_exception(self) -> None:
|
|
"""Test that the entity function handles exceptions gracefully."""
|
|
mock_agent = Mock()
|
|
# Force an exception by making get_input fail
|
|
mock_agent.run = AsyncMock(side_effect=Exception("Test error"))
|
|
|
|
entity_function = create_agent_entity(mock_agent)
|
|
|
|
mock_context = Mock()
|
|
mock_context.operation_name = "run"
|
|
mock_context.get_input.side_effect = Exception("Input error")
|
|
mock_context.get_state.return_value = None
|
|
|
|
# Execute entity function - should not raise
|
|
entity_function(mock_context)
|
|
|
|
# Verify error result was set
|
|
assert mock_context.set_result.called
|
|
result_call = mock_context.set_result.call_args[0][0]
|
|
assert "error" in result_call
|
|
|
|
|
|
class TestIncomingRequestParsing:
|
|
"""Tests for parsing run requests with JSON and plain text bodies."""
|
|
|
|
def _create_app(self) -> AgentFunctionApp:
|
|
mock_agent = Mock()
|
|
mock_agent.name = "ParserAgent"
|
|
return AgentFunctionApp(agents=[mock_agent], enable_health_check=False)
|
|
|
|
def test_parse_plain_text_body(self) -> None:
|
|
"""Test parsing a plain-text request body."""
|
|
app = self._create_app()
|
|
|
|
request = Mock()
|
|
request.headers = {}
|
|
request.params = {}
|
|
request.get_json.side_effect = ValueError("Invalid JSON")
|
|
request.get_body.return_value = b"Plain text message"
|
|
|
|
req_body, message, response_format = app._parse_incoming_request(request)
|
|
|
|
assert req_body == {}
|
|
assert message == "Plain text message"
|
|
|
|
assert response_format == "text"
|
|
|
|
def test_parse_plain_text_trims_whitespace(self) -> None:
|
|
"""Plain-text parser returns an empty string when the body contains only whitespace."""
|
|
app = self._create_app()
|
|
|
|
request = Mock()
|
|
request.headers = {}
|
|
request.params = {}
|
|
request.get_json.side_effect = ValueError("Invalid JSON")
|
|
request.get_body.return_value = b" "
|
|
|
|
req_body, message, response_format = app._parse_incoming_request(request)
|
|
|
|
assert req_body == {}
|
|
assert message == ""
|
|
assert response_format == "text"
|
|
|
|
def test_accept_header_prefers_json(self) -> None:
|
|
"""Test that the Accept header can force JSON responses for plain-text bodies."""
|
|
app = self._create_app()
|
|
|
|
request = Mock()
|
|
request.headers = {"accept": MIMETYPE_APPLICATION_JSON}
|
|
request.params = {}
|
|
request.get_json.side_effect = ValueError("Invalid JSON")
|
|
request.get_body.return_value = b"Plain text message"
|
|
|
|
_, message, response_format = app._parse_incoming_request(request)
|
|
|
|
assert message == "Plain text message"
|
|
assert response_format == "json"
|
|
|
|
def test_extract_thread_id_from_query_params(self) -> None:
|
|
"""Test thread identifier extraction from query parameters."""
|
|
app = self._create_app()
|
|
|
|
request = Mock()
|
|
request.params = {"thread_id": "query-thread"}
|
|
req_body: dict[str, Any] = {}
|
|
|
|
thread_id = app._resolve_thread_id(request, req_body)
|
|
|
|
assert thread_id == "query-thread"
|
|
|
|
|
|
class TestHttpRunRoute:
|
|
"""Tests for the HTTP run route behavior."""
|
|
|
|
@staticmethod
|
|
def _get_run_handler(agent: Mock) -> Callable[[func.HttpRequest, Any], Awaitable[func.HttpResponse]]:
|
|
captured_handlers: dict[str | None, Callable[..., Awaitable[func.HttpResponse]]] = {}
|
|
|
|
def capture_decorator(*args: Any, **kwargs: Any) -> Callable[[TFunc], TFunc]:
|
|
def decorator(func: TFunc) -> TFunc:
|
|
return func
|
|
|
|
return decorator
|
|
|
|
def capture_route(*args: Any, **kwargs: Any) -> Callable[[TFunc], TFunc]:
|
|
def decorator(func: TFunc) -> TFunc:
|
|
route_key = kwargs.get("route") if kwargs else None
|
|
captured_handlers[route_key] = func
|
|
return func
|
|
|
|
return decorator
|
|
|
|
with (
|
|
patch.object(AgentFunctionApp, "function_name", new=capture_decorator),
|
|
patch.object(AgentFunctionApp, "route", new=capture_route),
|
|
patch.object(AgentFunctionApp, "durable_client_input", new=capture_decorator),
|
|
patch.object(AgentFunctionApp, "entity_trigger", new=capture_decorator),
|
|
):
|
|
AgentFunctionApp(agents=[agent], enable_health_check=False)
|
|
|
|
run_route = f"agents/{agent.name}/run"
|
|
return captured_handlers[run_route]
|
|
|
|
async def test_http_run_accepts_plain_text(self) -> None:
|
|
"""Test that the HTTP handler accepts plain-text requests."""
|
|
mock_agent = Mock()
|
|
mock_agent.name = "HttpAgent"
|
|
|
|
handler = self._get_run_handler(mock_agent)
|
|
|
|
request = Mock()
|
|
request.headers = {WAIT_FOR_RESPONSE_HEADER: "false"}
|
|
request.params = {}
|
|
request.route_params = {}
|
|
request.get_json.side_effect = ValueError("Invalid JSON")
|
|
request.get_body.return_value = b"Plain text via HTTP"
|
|
|
|
client = AsyncMock()
|
|
|
|
response = await handler(request, client)
|
|
|
|
assert response.status_code == 202
|
|
assert response.mimetype == MIMETYPE_TEXT_PLAIN
|
|
assert response.headers.get(THREAD_ID_HEADER) is not None
|
|
assert response.get_body().decode("utf-8") == "Agent request accepted"
|
|
|
|
signal_args = client.signal_entity.call_args[0]
|
|
run_request = signal_args[2]
|
|
|
|
assert run_request["message"] == "Plain text via HTTP"
|
|
assert run_request["role"] == "user"
|
|
assert "thread_id" not in run_request
|
|
|
|
async def test_http_run_accept_header_returns_json(self) -> None:
|
|
"""Test that Accept header requesting JSON results in JSON response."""
|
|
mock_agent = Mock()
|
|
mock_agent.name = "HttpAgentJson"
|
|
|
|
handler = self._get_run_handler(mock_agent)
|
|
|
|
request = Mock()
|
|
request.headers = {WAIT_FOR_RESPONSE_HEADER: "false", "Accept": MIMETYPE_APPLICATION_JSON}
|
|
request.params = {}
|
|
request.route_params = {}
|
|
request.get_json.side_effect = ValueError("Invalid JSON")
|
|
request.get_body.return_value = b"Plain text via HTTP"
|
|
|
|
client = AsyncMock()
|
|
|
|
response = await handler(request, client)
|
|
|
|
assert response.status_code == 202
|
|
assert response.mimetype == MIMETYPE_APPLICATION_JSON
|
|
assert response.headers.get(THREAD_ID_HEADER) is None
|
|
body = response.get_body().decode("utf-8")
|
|
assert '"status": "accepted"' in body
|
|
|
|
async def test_http_run_rejects_empty_message(self) -> None:
|
|
"""Test that the HTTP handler rejects empty messages with a 400 response."""
|
|
mock_agent = Mock()
|
|
mock_agent.name = "HttpAgentEmpty"
|
|
|
|
handler = self._get_run_handler(mock_agent)
|
|
|
|
request = Mock()
|
|
request.headers = {WAIT_FOR_RESPONSE_HEADER: "false"}
|
|
request.params = {}
|
|
request.route_params = {}
|
|
request.get_json.side_effect = ValueError("Invalid JSON")
|
|
request.get_body.return_value = b" "
|
|
|
|
client = AsyncMock()
|
|
|
|
response = await handler(request, client)
|
|
|
|
assert response.status_code == 400
|
|
assert response.mimetype == MIMETYPE_TEXT_PLAIN
|
|
assert response.headers.get(THREAD_ID_HEADER) is not None
|
|
assert response.get_body().decode("utf-8") == "Message is required"
|
|
client.signal_entity.assert_not_called()
|
|
|
|
|
|
class TestMCPToolEndpoint:
|
|
"""Test suite for MCP tool endpoint functionality."""
|
|
|
|
def test_init_with_mcp_tool_endpoint_enabled(self) -> None:
|
|
"""Test initialization with MCP tool endpoint enabled."""
|
|
mock_agent = Mock()
|
|
mock_agent.name = "TestAgent"
|
|
|
|
app = AgentFunctionApp(agents=[mock_agent], enable_mcp_tool_trigger=True)
|
|
|
|
assert app.enable_mcp_tool_trigger is True
|
|
|
|
def test_init_with_mcp_tool_endpoint_disabled(self) -> None:
|
|
"""Test initialization with MCP tool endpoint disabled (default)."""
|
|
mock_agent = Mock()
|
|
mock_agent.name = "TestAgent"
|
|
|
|
app = AgentFunctionApp(agents=[mock_agent])
|
|
|
|
assert app.enable_mcp_tool_trigger is False
|
|
|
|
def test_add_agent_with_mcp_tool_trigger_enabled(self) -> None:
|
|
"""Test adding an agent with MCP tool trigger explicitly enabled."""
|
|
mock_agent = Mock()
|
|
mock_agent.name = "MCPAgent"
|
|
mock_agent.description = "Test MCP Agent"
|
|
|
|
with patch.object(AgentFunctionApp, "_setup_agent_functions") as setup_mock:
|
|
app = AgentFunctionApp()
|
|
app.add_agent(mock_agent, enable_mcp_tool_trigger=True)
|
|
|
|
setup_mock.assert_called_once()
|
|
_, _, _, _, enable_mcp = setup_mock.call_args[0]
|
|
assert enable_mcp is True
|
|
|
|
def test_add_agent_with_mcp_tool_trigger_disabled(self) -> None:
|
|
"""Test adding an agent with MCP tool trigger explicitly disabled."""
|
|
mock_agent = Mock()
|
|
mock_agent.name = "NoMCPAgent"
|
|
|
|
with patch.object(AgentFunctionApp, "_setup_agent_functions") as setup_mock:
|
|
app = AgentFunctionApp(enable_mcp_tool_trigger=True)
|
|
app.add_agent(mock_agent, enable_mcp_tool_trigger=False)
|
|
|
|
setup_mock.assert_called_once()
|
|
_, _, _, _, enable_mcp = setup_mock.call_args[0]
|
|
assert enable_mcp is False
|
|
|
|
def test_agent_override_enables_mcp_when_app_disabled(self) -> None:
|
|
"""Test that per-agent override can enable MCP when app-level is disabled."""
|
|
mock_agent = Mock()
|
|
mock_agent.name = "OverrideAgent"
|
|
|
|
with patch.object(AgentFunctionApp, "_setup_mcp_tool_trigger") as mcp_setup_mock:
|
|
app = AgentFunctionApp(enable_mcp_tool_trigger=False)
|
|
app.add_agent(mock_agent, enable_mcp_tool_trigger=True)
|
|
|
|
mcp_setup_mock.assert_called_once()
|
|
|
|
def test_agent_override_disables_mcp_when_app_enabled(self) -> None:
|
|
"""Test that per-agent override can disable MCP when app-level is enabled."""
|
|
mock_agent = Mock()
|
|
mock_agent.name = "NoOverrideAgent"
|
|
|
|
with patch.object(AgentFunctionApp, "_setup_mcp_tool_trigger") as mcp_setup_mock:
|
|
app = AgentFunctionApp(enable_mcp_tool_trigger=True)
|
|
app.add_agent(mock_agent, enable_mcp_tool_trigger=False)
|
|
|
|
mcp_setup_mock.assert_not_called()
|
|
|
|
def test_setup_mcp_tool_trigger_registers_decorators(self) -> None:
|
|
"""Test that _setup_mcp_tool_trigger registers the correct decorators."""
|
|
mock_agent = Mock()
|
|
mock_agent.name = "MCPToolAgent"
|
|
mock_agent.description = "Test MCP Tool"
|
|
|
|
app = AgentFunctionApp()
|
|
|
|
# Mock the decorators
|
|
with (
|
|
patch.object(app, "function_name") as func_name_mock,
|
|
patch.object(app, "mcp_tool_trigger") as mcp_trigger_mock,
|
|
patch.object(app, "durable_client_input") as client_mock,
|
|
):
|
|
# Setup mock decorator chain
|
|
func_name_mock.return_value = _identity_decorator
|
|
mcp_trigger_mock.return_value = _identity_decorator
|
|
client_mock.return_value = _identity_decorator
|
|
|
|
app._setup_mcp_tool_trigger(mock_agent.name, mock_agent.description)
|
|
|
|
# Verify decorators were called with correct parameters
|
|
func_name_mock.assert_called_once()
|
|
mcp_trigger_mock.assert_called_once_with(
|
|
arg_name="context",
|
|
tool_name=mock_agent.name,
|
|
description=mock_agent.description,
|
|
tool_properties=ANY,
|
|
data_type=func.DataType.UNDEFINED,
|
|
)
|
|
client_mock.assert_called_once_with(client_name="client")
|
|
|
|
def test_setup_mcp_tool_trigger_uses_default_description(self) -> None:
|
|
"""Test that _setup_mcp_tool_trigger uses default description when none provided."""
|
|
mock_agent = Mock()
|
|
mock_agent.name = "NoDescAgent"
|
|
|
|
app = AgentFunctionApp()
|
|
|
|
with (
|
|
patch.object(app, "function_name", return_value=_identity_decorator),
|
|
patch.object(app, "mcp_tool_trigger") as mcp_trigger_mock,
|
|
patch.object(app, "durable_client_input", return_value=_identity_decorator),
|
|
):
|
|
mcp_trigger_mock.return_value = _identity_decorator
|
|
|
|
app._setup_mcp_tool_trigger(mock_agent.name, None)
|
|
|
|
# Verify default description was used
|
|
call_args = mcp_trigger_mock.call_args
|
|
assert call_args[1]["description"] == f"Interact with {mock_agent.name} agent"
|
|
|
|
async def test_handle_mcp_tool_invocation_with_json_string(self) -> None:
|
|
"""Test _handle_mcp_tool_invocation with JSON string context."""
|
|
mock_agent = Mock()
|
|
mock_agent.name = "TestAgent"
|
|
|
|
app = AgentFunctionApp(agents=[mock_agent])
|
|
client = AsyncMock()
|
|
|
|
# Mock the entity response
|
|
mock_state = Mock()
|
|
mock_state.entity_state = {
|
|
"schemaVersion": "1.0.0",
|
|
"data": {"conversationHistory": []},
|
|
}
|
|
client.read_entity_state.return_value = mock_state
|
|
|
|
# Create JSON string context
|
|
context = '{"arguments": {"query": "test query", "threadId": "test-thread"}}'
|
|
|
|
with patch.object(app, "_get_response_from_entity") as get_response_mock:
|
|
get_response_mock.return_value = {"status": "success", "response": "Test response"}
|
|
|
|
result = await app._handle_mcp_tool_invocation("TestAgent", context, client)
|
|
|
|
assert result == "Test response"
|
|
get_response_mock.assert_called_once()
|
|
|
|
async def test_handle_mcp_tool_invocation_with_json_context(self) -> None:
|
|
"""Test _handle_mcp_tool_invocation with JSON string context."""
|
|
mock_agent = Mock()
|
|
mock_agent.name = "TestAgent"
|
|
|
|
app = AgentFunctionApp(agents=[mock_agent])
|
|
client = AsyncMock()
|
|
|
|
# Mock the entity response
|
|
mock_state = Mock()
|
|
mock_state.entity_state = {
|
|
"schemaVersion": "1.0.0",
|
|
"data": {"conversationHistory": []},
|
|
}
|
|
client.read_entity_state.return_value = mock_state
|
|
|
|
# Create JSON string context
|
|
context = json.dumps({"arguments": {"query": "test query", "threadId": "test-thread"}})
|
|
|
|
with patch.object(app, "_get_response_from_entity") as get_response_mock:
|
|
get_response_mock.return_value = {"status": "success", "response": "Test response"}
|
|
|
|
result = await app._handle_mcp_tool_invocation("TestAgent", context, client)
|
|
|
|
assert result == "Test response"
|
|
get_response_mock.assert_called_once()
|
|
|
|
async def test_handle_mcp_tool_invocation_missing_query(self) -> None:
|
|
"""Test _handle_mcp_tool_invocation raises ValueError when query is missing."""
|
|
mock_agent = Mock()
|
|
mock_agent.name = "TestAgent"
|
|
|
|
app = AgentFunctionApp(agents=[mock_agent])
|
|
client = AsyncMock()
|
|
|
|
# Context missing query (as JSON string)
|
|
context = json.dumps({"arguments": {}})
|
|
|
|
with pytest.raises(ValueError, match="missing required 'query' argument"):
|
|
await app._handle_mcp_tool_invocation("TestAgent", context, client)
|
|
|
|
async def test_handle_mcp_tool_invocation_invalid_json(self) -> None:
|
|
"""Test _handle_mcp_tool_invocation raises ValueError for invalid JSON."""
|
|
mock_agent = Mock()
|
|
mock_agent.name = "TestAgent"
|
|
|
|
app = AgentFunctionApp(agents=[mock_agent])
|
|
client = AsyncMock()
|
|
|
|
# Invalid JSON string
|
|
context = "not valid json"
|
|
|
|
with pytest.raises(ValueError, match="Invalid MCP context format"):
|
|
await app._handle_mcp_tool_invocation("TestAgent", context, client)
|
|
|
|
async def test_handle_mcp_tool_invocation_runtime_error(self) -> None:
|
|
"""Test _handle_mcp_tool_invocation raises RuntimeError when agent fails."""
|
|
mock_agent = Mock()
|
|
mock_agent.name = "TestAgent"
|
|
|
|
app = AgentFunctionApp(agents=[mock_agent])
|
|
client = AsyncMock()
|
|
|
|
# Mock the entity response
|
|
mock_state = Mock()
|
|
mock_state.entity_state = {
|
|
"schemaVersion": "1.0.0",
|
|
"data": {"conversationHistory": []},
|
|
}
|
|
client.read_entity_state.return_value = mock_state
|
|
|
|
context = '{"arguments": {"query": "test query"}}'
|
|
|
|
with patch.object(app, "_get_response_from_entity") as get_response_mock:
|
|
get_response_mock.return_value = {"status": "failed", "error": "Agent error"}
|
|
|
|
with pytest.raises(RuntimeError, match="Agent execution failed"):
|
|
await app._handle_mcp_tool_invocation("TestAgent", context, client)
|
|
|
|
async def test_handle_mcp_tool_invocation_ignores_agent_name_in_thread_id(self) -> None:
|
|
"""Test that MCP tool invocation uses the agent_name parameter, not the name from thread_id."""
|
|
mock_agent = Mock()
|
|
mock_agent.name = "PlantAdvisor"
|
|
|
|
app = AgentFunctionApp(agents=[mock_agent])
|
|
client = AsyncMock()
|
|
|
|
# Mock the entity response
|
|
mock_state = Mock()
|
|
mock_state.entity_state = {
|
|
"schemaVersion": "1.0.0",
|
|
"data": {"conversationHistory": []},
|
|
}
|
|
client.read_entity_state.return_value = mock_state
|
|
|
|
# Thread ID contains a different agent name (@StockAdvisor@poc123)
|
|
# but we're invoking PlantAdvisor - it should use PlantAdvisor's entity
|
|
context = json.dumps({"arguments": {"query": "test query", "threadId": "@StockAdvisor@test123"}})
|
|
|
|
with patch.object(app, "_get_response_from_entity") as get_response_mock:
|
|
get_response_mock.return_value = {"status": "success", "response": "Test response"}
|
|
|
|
await app._handle_mcp_tool_invocation("PlantAdvisor", context, client)
|
|
|
|
# Verify signal_entity was called with PlantAdvisor's entity, not StockAdvisor's
|
|
client.signal_entity.assert_called_once()
|
|
call_args = client.signal_entity.call_args
|
|
entity_id = call_args[0][0]
|
|
|
|
# Entity name should be dafx-PlantAdvisor, not dafx-StockAdvisor
|
|
assert entity_id.name == "dafx-PlantAdvisor"
|
|
assert entity_id.key == "test123"
|
|
|
|
async def test_handle_mcp_tool_invocation_uses_plain_thread_id_as_key(self) -> None:
|
|
"""Test that a plain thread_id (not in @name@key format) is used as-is for the key."""
|
|
mock_agent = Mock()
|
|
mock_agent.name = "TestAgent"
|
|
|
|
app = AgentFunctionApp(agents=[mock_agent])
|
|
client = AsyncMock()
|
|
|
|
mock_state = Mock()
|
|
mock_state.entity_state = {
|
|
"schemaVersion": "1.0.0",
|
|
"data": {"conversationHistory": []},
|
|
}
|
|
client.read_entity_state.return_value = mock_state
|
|
|
|
# Plain thread_id without @name@key format
|
|
context = json.dumps({"arguments": {"query": "test query", "threadId": "simple-thread-123"}})
|
|
|
|
with patch.object(app, "_get_response_from_entity") as get_response_mock:
|
|
get_response_mock.return_value = {"status": "success", "response": "Test response"}
|
|
|
|
await app._handle_mcp_tool_invocation("TestAgent", context, client)
|
|
|
|
client.signal_entity.assert_called_once()
|
|
call_args = client.signal_entity.call_args
|
|
entity_id = call_args[0][0]
|
|
|
|
assert entity_id.name == "dafx-TestAgent"
|
|
assert entity_id.key == "simple-thread-123"
|
|
|
|
def test_health_check_includes_mcp_tool_enabled(self) -> None:
|
|
"""Test that health check endpoint includes mcp_tool_enabled field."""
|
|
mock_agent = Mock()
|
|
mock_agent.name = "HealthAgent"
|
|
|
|
app = AgentFunctionApp(agents=[mock_agent], enable_mcp_tool_trigger=True)
|
|
|
|
# Capture the health check handler function
|
|
captured_handler: Callable[[func.HttpRequest], func.HttpResponse] | None = None
|
|
|
|
def capture_decorator(*args: Any, **kwargs: Any) -> Callable[[TFunc], TFunc]:
|
|
def decorator(func: TFunc) -> TFunc:
|
|
nonlocal captured_handler
|
|
captured_handler = func
|
|
return func
|
|
|
|
return decorator
|
|
|
|
with patch.object(app, "route", side_effect=capture_decorator):
|
|
app._setup_health_route()
|
|
|
|
# Verify we captured the handler
|
|
assert captured_handler is not None
|
|
|
|
# Call the health handler
|
|
request = Mock()
|
|
response = captured_handler(request)
|
|
|
|
# Verify response includes mcp_tool_enabled
|
|
import json
|
|
|
|
body = json.loads(response.get_body().decode("utf-8"))
|
|
assert "agents" in body
|
|
assert len(body["agents"]) == 1
|
|
assert "mcp_tool_enabled" in body["agents"][0]
|
|
assert body["agents"][0]["mcp_tool_enabled"] is True
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v", "--tb=short"])
|