mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Fix declarative package powerfx import crash and response_format kwarg error (#3841)
* Fix declarative package powerfx import crash and response_format kwarg error * Address PR feedback. Propagate kwargs for declarative workflows * move tests * Fix options merge logic
This commit is contained in:
committed by
GitHub
Unverified
parent
692fcd1888
commit
ff91473912
@@ -1,8 +1,10 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import builtins
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
@@ -905,3 +907,105 @@ tools:
|
||||
|
||||
# Verify project_connection_id is set from connection name
|
||||
assert mcp_tool.get("project_connection_id") == "my-oauth-connection"
|
||||
|
||||
|
||||
class TestProviderResponseFormat:
|
||||
"""response_format from outputSchema must be passed inside default_options."""
|
||||
|
||||
@staticmethod
|
||||
def _make_mock_prompt_agent(*, with_output_schema: bool = False) -> MagicMock:
|
||||
"""Create a mock PromptAgent to avoid serialization complexity."""
|
||||
mock_model = MagicMock()
|
||||
mock_model.id = "gpt-4"
|
||||
mock_model.connection = None
|
||||
|
||||
agent = MagicMock()
|
||||
agent.name = "test-agent"
|
||||
agent.description = "test"
|
||||
agent.instructions = "be helpful"
|
||||
agent.model = mock_model
|
||||
agent.tools = None
|
||||
|
||||
if with_output_schema:
|
||||
mock_schema = MagicMock()
|
||||
mock_schema.to_json_schema.return_value = {
|
||||
"type": "object",
|
||||
"properties": {"answer": {"type": "string"}},
|
||||
}
|
||||
agent.outputSchema = mock_schema
|
||||
else:
|
||||
agent.outputSchema = None
|
||||
|
||||
return agent
|
||||
|
||||
@staticmethod
|
||||
def _make_mock_provider() -> tuple[MagicMock, AsyncMock]:
|
||||
"""Create a mock provider class and its instance."""
|
||||
mock_agent = MagicMock()
|
||||
mock_provider_instance = AsyncMock()
|
||||
mock_provider_instance.create_agent = AsyncMock(return_value=mock_agent)
|
||||
mock_provider_class = MagicMock(return_value=mock_provider_instance)
|
||||
return mock_provider_class, mock_provider_instance
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_format_in_default_options(self):
|
||||
"""Provider.create_agent() should receive response_format inside default_options."""
|
||||
from agent_framework_declarative._loader import AgentFactory
|
||||
|
||||
prompt_agent = self._make_mock_prompt_agent(with_output_schema=True)
|
||||
mock_provider_class, mock_provider_instance = self._make_mock_provider()
|
||||
|
||||
mapping = {"package": "some_module", "name": "SomeProvider"}
|
||||
factory = AgentFactory()
|
||||
|
||||
original_import = builtins.__import__
|
||||
|
||||
def mock_import(name, *args, **kwargs):
|
||||
if name == "some_module":
|
||||
mod = MagicMock()
|
||||
mod.SomeProvider = mock_provider_class
|
||||
return mod
|
||||
return original_import(name, *args, **kwargs)
|
||||
|
||||
with (
|
||||
patch.object(builtins, "__import__", side_effect=mock_import),
|
||||
patch.object(factory, "_parse_tools", return_value=None),
|
||||
):
|
||||
await factory._create_agent_with_provider(prompt_agent, mapping)
|
||||
|
||||
mock_provider_instance.create_agent.assert_called_once()
|
||||
call_kwargs = mock_provider_instance.create_agent.call_args.kwargs
|
||||
|
||||
assert "response_format" not in call_kwargs
|
||||
default_options = call_kwargs.get("default_options")
|
||||
assert default_options is not None
|
||||
assert "response_format" in default_options
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_default_options_without_output_schema(self):
|
||||
"""When there's no outputSchema, default_options should be None."""
|
||||
from agent_framework_declarative._loader import AgentFactory
|
||||
|
||||
prompt_agent = self._make_mock_prompt_agent(with_output_schema=False)
|
||||
mock_provider_class, mock_provider_instance = self._make_mock_provider()
|
||||
|
||||
mapping = {"package": "some_module", "name": "SomeProvider"}
|
||||
factory = AgentFactory()
|
||||
|
||||
original_import = builtins.__import__
|
||||
|
||||
def mock_import(name, *args, **kwargs):
|
||||
if name == "some_module":
|
||||
mod = MagicMock()
|
||||
mod.SomeProvider = mock_provider_class
|
||||
return mod
|
||||
return original_import(name, *args, **kwargs)
|
||||
|
||||
with (
|
||||
patch.object(builtins, "__import__", side_effect=mock_import),
|
||||
patch.object(factory, "_parse_tools", return_value=None),
|
||||
):
|
||||
await factory._create_agent_with_provider(prompt_agent, mapping)
|
||||
|
||||
call_kwargs = mock_provider_instance.create_agent.call_args.kwargs
|
||||
assert call_kwargs.get("default_options") is None
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
"""Tests for the graph-based declarative workflow executors."""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
@@ -1295,3 +1296,203 @@ class TestExtractJsonFromResponse:
|
||||
text = 'First: {"status": "pending"} then later: {"status": "complete", "id": 42}'
|
||||
result = _extract_json_from_response(text)
|
||||
assert result == {"status": "complete", "id": 42}
|
||||
|
||||
|
||||
class TestPowerFxConditionalImport:
|
||||
"""The _declarative_base module should be importable without dotnet/powerfx."""
|
||||
|
||||
def test_import_guard_exists(self):
|
||||
"""The powerfx import must be wrapped in try/except."""
|
||||
import agent_framework_declarative._workflows._declarative_base as base_mod
|
||||
|
||||
assert hasattr(base_mod, "DeclarativeWorkflowState")
|
||||
assert hasattr(base_mod, "Engine")
|
||||
|
||||
# Engine should either be the real class or None — never an ImportError
|
||||
engine = base_mod.Engine
|
||||
assert engine is None or callable(engine)
|
||||
|
||||
def test_eval_raises_when_engine_unavailable(self):
|
||||
"""eval() should raise RuntimeError when Engine is None."""
|
||||
import agent_framework_declarative._workflows._declarative_base as base_mod
|
||||
|
||||
mock_state = MagicMock()
|
||||
mock_state._data: dict[str, Any] = {}
|
||||
mock_state.get = MagicMock(side_effect=lambda k, d=None: mock_state._data.get(k, d))
|
||||
mock_state.set = MagicMock(side_effect=lambda k, v: mock_state._data.__setitem__(k, v))
|
||||
|
||||
state = DeclarativeWorkflowState(mock_state)
|
||||
state.initialize({"name": "test"})
|
||||
|
||||
original_engine = base_mod.Engine
|
||||
try:
|
||||
base_mod.Engine = None
|
||||
with pytest.raises(RuntimeError, match="PowerFx is not available"):
|
||||
state.eval("=Local.counter + 1")
|
||||
finally:
|
||||
base_mod.Engine = original_engine
|
||||
|
||||
def test_eval_passes_through_plain_strings_without_engine(self):
|
||||
"""Non-PowerFx strings (no leading '=') should work without Engine."""
|
||||
import agent_framework_declarative._workflows._declarative_base as base_mod
|
||||
|
||||
mock_state = MagicMock()
|
||||
mock_state._data: dict[str, Any] = {}
|
||||
mock_state.get = MagicMock(side_effect=lambda k, d=None: mock_state._data.get(k, d))
|
||||
mock_state.set = MagicMock(side_effect=lambda k, v: mock_state._data.__setitem__(k, v))
|
||||
|
||||
state = DeclarativeWorkflowState(mock_state)
|
||||
state.initialize()
|
||||
|
||||
original_engine = base_mod.Engine
|
||||
try:
|
||||
base_mod.Engine = None
|
||||
assert state.eval("hello world") == "hello world"
|
||||
assert state.eval("") == ""
|
||||
assert state.eval(42) == 42
|
||||
finally:
|
||||
base_mod.Engine = original_engine
|
||||
|
||||
|
||||
class TestExecutorKwargsForwarding:
|
||||
"""Workflow run kwargs should be forwarded through executor agent invocations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_agent_forwards_kwargs(self):
|
||||
"""InvokeAzureAgentExecutor should forward run_kwargs to agent.run()."""
|
||||
from agent_framework._workflows._const import WORKFLOW_RUN_KWARGS_KEY
|
||||
from agent_framework._workflows._state import State
|
||||
|
||||
from agent_framework_declarative._workflows._executors_agents import (
|
||||
InvokeAzureAgentExecutor,
|
||||
)
|
||||
|
||||
# Create a mock State with kwargs stored
|
||||
mock_state = MagicMock(spec=State)
|
||||
state_data: dict[str, Any] = {}
|
||||
|
||||
def mock_get(key, default=None):
|
||||
return state_data.get(key, default)
|
||||
|
||||
def mock_set(key, value):
|
||||
state_data[key] = value
|
||||
|
||||
mock_state.get = MagicMock(side_effect=mock_get)
|
||||
mock_state.set = MagicMock(side_effect=mock_set)
|
||||
|
||||
# Store kwargs in state like Workflow.run() does
|
||||
test_kwargs = {"user_token": "abc123", "service_config": {"endpoint": "http://test"}}
|
||||
state_data[WORKFLOW_RUN_KWARGS_KEY] = test_kwargs
|
||||
|
||||
# Initialize declarative state
|
||||
dws = DeclarativeWorkflowState(mock_state)
|
||||
dws.initialize({"input": "hello"})
|
||||
|
||||
# Create a mock agent
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "response text"
|
||||
mock_response.messages = []
|
||||
mock_response.tool_calls = []
|
||||
mock_agent = AsyncMock()
|
||||
mock_agent.run = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Create a mock workflow context
|
||||
mock_ctx = MagicMock()
|
||||
mock_ctx.get_state = MagicMock(side_effect=mock_get)
|
||||
mock_ctx.yield_output = AsyncMock()
|
||||
|
||||
executor = InvokeAzureAgentExecutor.__new__(InvokeAzureAgentExecutor)
|
||||
executor._agents = {"test_agent": mock_agent}
|
||||
|
||||
await executor._invoke_agent_and_store_results(
|
||||
agent=mock_agent,
|
||||
agent_name="test_agent",
|
||||
input_text="hello",
|
||||
state=dws,
|
||||
ctx=mock_ctx,
|
||||
messages_var=None,
|
||||
response_obj_var=None,
|
||||
result_property=None,
|
||||
auto_send=True,
|
||||
)
|
||||
|
||||
# Verify agent.run was called with kwargs
|
||||
mock_agent.run.assert_called_once()
|
||||
call_kwargs = mock_agent.run.call_args
|
||||
|
||||
# Check options contains additional_function_arguments
|
||||
assert "options" in call_kwargs.kwargs
|
||||
assert call_kwargs.kwargs["options"]["additional_function_arguments"] == test_kwargs
|
||||
|
||||
# Check direct kwargs were passed
|
||||
assert call_kwargs.kwargs.get("user_token") == "abc123"
|
||||
assert call_kwargs.kwargs.get("service_config") == {"endpoint": "http://test"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_agent_merges_caller_options(self):
|
||||
"""Caller-provided options in run_kwargs should be merged, not cause TypeError."""
|
||||
from agent_framework._workflows._const import WORKFLOW_RUN_KWARGS_KEY
|
||||
from agent_framework._workflows._state import State
|
||||
|
||||
from agent_framework_declarative._workflows._executors_agents import (
|
||||
InvokeAzureAgentExecutor,
|
||||
)
|
||||
|
||||
mock_state = MagicMock(spec=State)
|
||||
state_data: dict[str, Any] = {}
|
||||
|
||||
def mock_get(key, default=None):
|
||||
return state_data.get(key, default)
|
||||
|
||||
def mock_set(key, value):
|
||||
state_data[key] = value
|
||||
|
||||
mock_state.get = MagicMock(side_effect=mock_get)
|
||||
mock_state.set = MagicMock(side_effect=mock_set)
|
||||
|
||||
# Include 'options' in run_kwargs to test merge behavior
|
||||
test_kwargs = {
|
||||
"user_token": "abc123",
|
||||
"options": {"temperature": 0.5},
|
||||
}
|
||||
state_data[WORKFLOW_RUN_KWARGS_KEY] = test_kwargs
|
||||
|
||||
dws = DeclarativeWorkflowState(mock_state)
|
||||
dws.initialize({"input": "hello"})
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "response text"
|
||||
mock_response.messages = []
|
||||
mock_response.tool_calls = []
|
||||
mock_agent = AsyncMock()
|
||||
mock_agent.run = AsyncMock(return_value=mock_response)
|
||||
|
||||
mock_ctx = MagicMock()
|
||||
mock_ctx.get_state = MagicMock(side_effect=mock_get)
|
||||
mock_ctx.yield_output = AsyncMock()
|
||||
|
||||
executor = InvokeAzureAgentExecutor.__new__(InvokeAzureAgentExecutor)
|
||||
executor._agents = {"test_agent": mock_agent}
|
||||
|
||||
await executor._invoke_agent_and_store_results(
|
||||
agent=mock_agent,
|
||||
agent_name="test_agent",
|
||||
input_text="hello",
|
||||
state=dws,
|
||||
ctx=mock_ctx,
|
||||
messages_var=None,
|
||||
response_obj_var=None,
|
||||
result_property=None,
|
||||
auto_send=True,
|
||||
)
|
||||
|
||||
mock_agent.run.assert_called_once()
|
||||
call_kwargs = mock_agent.run.call_args
|
||||
|
||||
# Caller options should be merged with additional_function_arguments
|
||||
merged_options = call_kwargs.kwargs["options"]
|
||||
assert merged_options["temperature"] == 0.5
|
||||
assert "additional_function_arguments" in merged_options
|
||||
|
||||
# Direct kwargs should be passed without 'options' (no duplicate keyword)
|
||||
assert call_kwargs.kwargs.get("user_token") == "abc123"
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -29,6 +30,7 @@ def create_action_context(
|
||||
inputs: dict[str, Any] | None = None,
|
||||
agents: dict[str, Any] | None = None,
|
||||
bindings: dict[str, Any] | None = None,
|
||||
run_kwargs: dict[str, Any] | None = None,
|
||||
) -> ActionContext:
|
||||
"""Helper to create an ActionContext for testing."""
|
||||
state = WorkflowState(inputs=inputs or {})
|
||||
@@ -47,6 +49,7 @@ def create_action_context(
|
||||
execute_actions=execute_actions,
|
||||
agents=agents or {},
|
||||
bindings=bindings or {},
|
||||
run_kwargs=run_kwargs or {},
|
||||
)
|
||||
async for event in handler(ctx):
|
||||
yield event
|
||||
@@ -57,6 +60,7 @@ def create_action_context(
|
||||
execute_actions=execute_actions,
|
||||
agents=agents or {},
|
||||
bindings=bindings or {},
|
||||
run_kwargs=run_kwargs or {},
|
||||
)
|
||||
|
||||
|
||||
@@ -422,3 +426,128 @@ class TestTryCatchHandler:
|
||||
|
||||
assert ctx.state.get("Local.try") == "ran"
|
||||
assert ctx.state.get("Local.finally") == "ran"
|
||||
|
||||
|
||||
class TestActionContextKwargs:
|
||||
"""ActionContext should carry and forward run_kwargs to agent invocations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_context_carries_run_kwargs(self):
|
||||
"""ActionContext should store and expose run_kwargs."""
|
||||
ctx = create_action_context(
|
||||
{"kind": "SetValue", "path": "Local.x", "value": "1"},
|
||||
run_kwargs={"user_token": "test123"},
|
||||
)
|
||||
assert ctx.run_kwargs == {"user_token": "test123"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_context_defaults_to_empty_kwargs(self):
|
||||
"""ActionContext.run_kwargs should default to empty dict."""
|
||||
ctx = create_action_context(
|
||||
{"kind": "SetValue", "path": "Local.x", "value": "1"},
|
||||
)
|
||||
assert ctx.run_kwargs == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_agent_handler_forwards_kwargs(self):
|
||||
"""handle_invoke_azure_agent should forward ctx.run_kwargs to agent.run()."""
|
||||
import agent_framework_declarative._workflows._actions_agents # noqa: F401
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "response"
|
||||
mock_response.messages = []
|
||||
mock_response.tool_calls = []
|
||||
|
||||
async def non_streaming_run(*args, **kwargs):
|
||||
if kwargs.get("stream"):
|
||||
raise TypeError("no streaming")
|
||||
return mock_response
|
||||
|
||||
mock_agent = AsyncMock()
|
||||
mock_agent.run = AsyncMock(side_effect=non_streaming_run)
|
||||
|
||||
test_kwargs = {"user_token": "secret", "api_key": "key123"}
|
||||
|
||||
state = WorkflowState()
|
||||
state.add_conversation_message(MagicMock(role="user", text="hello"))
|
||||
|
||||
ctx = create_action_context(
|
||||
action={
|
||||
"kind": "InvokeAzureAgent",
|
||||
"agent": "my_agent",
|
||||
},
|
||||
agents={"my_agent": mock_agent},
|
||||
run_kwargs=test_kwargs,
|
||||
)
|
||||
|
||||
handler = get_action_handler("InvokeAzureAgent")
|
||||
_ = [e async for e in handler(ctx)]
|
||||
|
||||
assert mock_agent.run.call_count >= 1
|
||||
|
||||
# Find the non-streaming fallback call
|
||||
for call in mock_agent.run.call_args_list:
|
||||
call_kw = call.kwargs
|
||||
if not call_kw.get("stream"):
|
||||
assert call_kw.get("user_token") == "secret"
|
||||
assert call_kw.get("api_key") == "key123"
|
||||
assert call_kw.get("options") == {"additional_function_arguments": test_kwargs}
|
||||
break
|
||||
else:
|
||||
# All calls were streaming — check the streaming call
|
||||
call_kw = mock_agent.run.call_args_list[0].kwargs
|
||||
assert call_kw.get("user_token") == "secret"
|
||||
assert call_kw.get("api_key") == "key123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_agent_handler_merges_caller_options(self):
|
||||
"""Caller-provided options in run_kwargs should be merged, not cause TypeError."""
|
||||
import agent_framework_declarative._workflows._actions_agents # noqa: F401
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "response"
|
||||
mock_response.messages = []
|
||||
mock_response.tool_calls = []
|
||||
|
||||
async def non_streaming_run(*args, **kwargs):
|
||||
if kwargs.get("stream"):
|
||||
raise TypeError("no streaming")
|
||||
return mock_response
|
||||
|
||||
mock_agent = AsyncMock()
|
||||
mock_agent.run = AsyncMock(side_effect=non_streaming_run)
|
||||
|
||||
# Include 'options' in run_kwargs to test merge behavior
|
||||
test_kwargs = {"user_token": "secret", "options": {"temperature": 0.7}}
|
||||
|
||||
state = WorkflowState()
|
||||
state.add_conversation_message(MagicMock(role="user", text="hello"))
|
||||
|
||||
ctx = create_action_context(
|
||||
action={
|
||||
"kind": "InvokeAzureAgent",
|
||||
"agent": "my_agent",
|
||||
},
|
||||
agents={"my_agent": mock_agent},
|
||||
run_kwargs=test_kwargs,
|
||||
)
|
||||
|
||||
handler = get_action_handler("InvokeAzureAgent")
|
||||
_ = [e async for e in handler(ctx)]
|
||||
|
||||
assert mock_agent.run.call_count >= 1
|
||||
|
||||
# Find the non-streaming fallback call
|
||||
for call in mock_agent.run.call_args_list:
|
||||
call_kw = call.kwargs
|
||||
if not call_kw.get("stream"):
|
||||
# Caller options should be merged with additional_function_arguments
|
||||
assert call_kw["options"]["temperature"] == 0.7
|
||||
assert "additional_function_arguments" in call_kw["options"]
|
||||
# Direct kwargs should not include 'options' (no duplicate keyword)
|
||||
assert call_kw.get("user_token") == "secret"
|
||||
break
|
||||
else:
|
||||
call_kw = mock_agent.run.call_args_list[0].kwargs
|
||||
assert call_kw["options"]["temperature"] == 0.7
|
||||
assert "additional_function_arguments" in call_kw["options"]
|
||||
|
||||
Reference in New Issue
Block a user