mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Fix declarative package powerfx import crash and response_format kwarg error (#3841)
* Fix declarative package powerfx import crash and response_format kwarg error * Address PR feedback. Propagate kwargs for declarative workflows * move tests * Fix options merge logic
This commit is contained in:
committed by
GitHub
Unverified
parent
692fcd1888
commit
ff91473912
@@ -605,10 +605,11 @@ class AgentFactory:
|
||||
# Parse tools
|
||||
tools = self._parse_tools(prompt_agent.tools) if prompt_agent.tools else None
|
||||
|
||||
# Parse response format
|
||||
response_format = None
|
||||
# Parse response format into default_options
|
||||
default_options: dict[str, Any] | None = None
|
||||
if prompt_agent.outputSchema:
|
||||
response_format = _create_model_from_json_schema("agent", prompt_agent.outputSchema.to_json_schema())
|
||||
default_options = {"response_format": response_format}
|
||||
|
||||
# Create the agent using the provider
|
||||
# The provider's create_agent returns a Agent directly
|
||||
@@ -620,7 +621,7 @@ class AgentFactory:
|
||||
instructions=prompt_agent.instructions,
|
||||
description=prompt_agent.description,
|
||||
tools=tools,
|
||||
response_format=response_format,
|
||||
default_options=default_options,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
+24
-4
@@ -327,6 +327,16 @@ async def handle_invoke_azure_agent(ctx: ActionContext) -> AsyncGenerator[Workfl
|
||||
max_iterations = 100 # Safety limit
|
||||
|
||||
# Start external loop if configured
|
||||
# Build options for kwargs propagation to agent tools
|
||||
run_kwargs = ctx.run_kwargs
|
||||
options: dict[str, Any] | None = None
|
||||
if run_kwargs:
|
||||
# Merge caller-provided options to avoid duplicate keyword argument
|
||||
options = dict(run_kwargs.get("options") or {})
|
||||
options["additional_function_arguments"] = run_kwargs
|
||||
# Exclude 'options' from splat to avoid TypeError on duplicate keyword
|
||||
run_kwargs = {k: v for k, v in run_kwargs.items() if k != "options"}
|
||||
|
||||
while True:
|
||||
# Invoke the agent
|
||||
try:
|
||||
@@ -337,7 +347,7 @@ async def handle_invoke_azure_agent(ctx: ActionContext) -> AsyncGenerator[Workfl
|
||||
updates: list[Any] = []
|
||||
tool_calls: list[Any] = []
|
||||
|
||||
async for chunk in agent.run(messages, stream=True):
|
||||
async for chunk in agent.run(messages, stream=True, options=options, **run_kwargs):
|
||||
updates.append(chunk)
|
||||
|
||||
# Yield streaming events for text chunks
|
||||
@@ -403,7 +413,7 @@ async def handle_invoke_azure_agent(ctx: ActionContext) -> AsyncGenerator[Workfl
|
||||
|
||||
except TypeError:
|
||||
# Agent doesn't support streaming, fall back to non-streaming
|
||||
response = await agent.run(messages)
|
||||
response = await agent.run(messages, options=options, **run_kwargs)
|
||||
|
||||
text = response.text
|
||||
response_messages = response.messages
|
||||
@@ -570,6 +580,16 @@ async def handle_invoke_prompt_agent(ctx: ActionContext) -> AsyncGenerator[Workf
|
||||
|
||||
logger.debug(f"InvokePromptAgent: calling '{agent_name}' with {len(messages)} messages")
|
||||
|
||||
# Build options for kwargs propagation to agent tools
|
||||
prompt_run_kwargs = ctx.run_kwargs
|
||||
prompt_options: dict[str, Any] | None = None
|
||||
if prompt_run_kwargs:
|
||||
# Merge caller-provided options to avoid duplicate keyword argument
|
||||
prompt_options = dict(prompt_run_kwargs.get("options") or {})
|
||||
prompt_options["additional_function_arguments"] = prompt_run_kwargs
|
||||
# Exclude 'options' from splat to avoid TypeError on duplicate keyword
|
||||
prompt_run_kwargs = {k: v for k, v in prompt_run_kwargs.items() if k != "options"}
|
||||
|
||||
# Invoke the agent
|
||||
try:
|
||||
if hasattr(agent, "run"):
|
||||
@@ -577,7 +597,7 @@ async def handle_invoke_prompt_agent(ctx: ActionContext) -> AsyncGenerator[Workf
|
||||
try:
|
||||
updates: list[Any] = []
|
||||
|
||||
async for chunk in agent.run(messages, stream=True):
|
||||
async for chunk in agent.run(messages, stream=True, options=prompt_options, **prompt_run_kwargs):
|
||||
updates.append(chunk)
|
||||
|
||||
if hasattr(chunk, "text") and chunk.text:
|
||||
@@ -607,7 +627,7 @@ async def handle_invoke_prompt_agent(ctx: ActionContext) -> AsyncGenerator[Workf
|
||||
|
||||
except TypeError:
|
||||
# Agent doesn't support streaming, fall back to non-streaming
|
||||
response = await agent.run(messages)
|
||||
response = await agent.run(messages, options=prompt_options, **prompt_run_kwargs)
|
||||
text = response.text
|
||||
response_messages = response.messages
|
||||
|
||||
|
||||
+16
-2
@@ -37,7 +37,13 @@ from agent_framework._workflows import (
|
||||
WorkflowContext,
|
||||
)
|
||||
from agent_framework._workflows._state import State
|
||||
from powerfx import Engine
|
||||
|
||||
try:
|
||||
from powerfx import Engine
|
||||
except (ImportError, RuntimeError):
|
||||
# ImportError: powerfx package not installed
|
||||
# RuntimeError: .NET runtime not available or misconfigured
|
||||
Engine = None # type: ignore[assignment, misc]
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import TypedDict # type: ignore # pragma: no cover
|
||||
@@ -339,7 +345,8 @@ class DeclarativeWorkflowState:
|
||||
undefined variables (matching legacy fallback parser behavior).
|
||||
|
||||
Raises:
|
||||
ImportError: If the powerfx package is not installed.
|
||||
RuntimeError: If the powerfx package is not installed and the
|
||||
expression requires PowerFx evaluation.
|
||||
"""
|
||||
if not expression:
|
||||
return expression
|
||||
@@ -363,6 +370,13 @@ class DeclarativeWorkflowState:
|
||||
# Replace them with their evaluated results before sending to PowerFx
|
||||
formula = self._preprocess_custom_functions(formula)
|
||||
|
||||
if Engine is None:
|
||||
raise RuntimeError(
|
||||
f"PowerFx is not available (dotnet runtime not installed). "
|
||||
f"Expression '={formula[:80]}' cannot be evaluated. "
|
||||
f"Install dotnet and the powerfx package for full PowerFx support."
|
||||
)
|
||||
|
||||
engine = Engine()
|
||||
symbols = self._to_powerfx_symbols()
|
||||
try:
|
||||
|
||||
+13
-1
@@ -656,10 +656,22 @@ class InvokeAzureAgentExecutor(DeclarativeActionExecutor):
|
||||
if isinstance(messages_for_agent, list) and messages_for_agent:
|
||||
_validate_conversation_history(messages_for_agent, agent_name)
|
||||
|
||||
# Retrieve kwargs passed to workflow.run() so they propagate to agent tools
|
||||
from agent_framework._workflows._const import WORKFLOW_RUN_KWARGS_KEY
|
||||
|
||||
run_kwargs: dict[str, Any] = ctx.get_state(WORKFLOW_RUN_KWARGS_KEY, {})
|
||||
options: dict[str, Any] | None = None
|
||||
if run_kwargs:
|
||||
# Merge caller-provided options to avoid duplicate keyword argument
|
||||
options = dict(run_kwargs.get("options") or {})
|
||||
options["additional_function_arguments"] = run_kwargs
|
||||
# Exclude 'options' from splat to avoid TypeError on duplicate keyword
|
||||
run_kwargs = {k: v for k, v in run_kwargs.items() if k != "options"}
|
||||
|
||||
# Use run() method to get properly structured messages (including tool calls and results)
|
||||
# This is critical for multi-turn conversations where tool calls must be followed
|
||||
# by their results in the message history
|
||||
result: Any = await agent.run(messages_for_agent)
|
||||
result: Any = await agent.run(messages_for_agent, options=options, **run_kwargs)
|
||||
if hasattr(result, "text") and result.text:
|
||||
accumulated_response = str(result.text)
|
||||
if auto_send:
|
||||
|
||||
@@ -10,7 +10,7 @@ has a corresponding handler registered via the @action_handler decorator.
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
|
||||
|
||||
from agent_framework import get_logger
|
||||
@@ -44,6 +44,9 @@ class ActionContext:
|
||||
bindings: dict[str, Any]
|
||||
"""Function bindings for tool calls."""
|
||||
|
||||
run_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
"""Kwargs from workflow.run() to forward to agent invocations."""
|
||||
|
||||
@property
|
||||
def action_id(self) -> str | None:
|
||||
"""Get the action's unique identifier."""
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import builtins
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
@@ -905,3 +907,105 @@ tools:
|
||||
|
||||
# Verify project_connection_id is set from connection name
|
||||
assert mcp_tool.get("project_connection_id") == "my-oauth-connection"
|
||||
|
||||
|
||||
class TestProviderResponseFormat:
|
||||
"""response_format from outputSchema must be passed inside default_options."""
|
||||
|
||||
@staticmethod
|
||||
def _make_mock_prompt_agent(*, with_output_schema: bool = False) -> MagicMock:
|
||||
"""Create a mock PromptAgent to avoid serialization complexity."""
|
||||
mock_model = MagicMock()
|
||||
mock_model.id = "gpt-4"
|
||||
mock_model.connection = None
|
||||
|
||||
agent = MagicMock()
|
||||
agent.name = "test-agent"
|
||||
agent.description = "test"
|
||||
agent.instructions = "be helpful"
|
||||
agent.model = mock_model
|
||||
agent.tools = None
|
||||
|
||||
if with_output_schema:
|
||||
mock_schema = MagicMock()
|
||||
mock_schema.to_json_schema.return_value = {
|
||||
"type": "object",
|
||||
"properties": {"answer": {"type": "string"}},
|
||||
}
|
||||
agent.outputSchema = mock_schema
|
||||
else:
|
||||
agent.outputSchema = None
|
||||
|
||||
return agent
|
||||
|
||||
@staticmethod
|
||||
def _make_mock_provider() -> tuple[MagicMock, AsyncMock]:
|
||||
"""Create a mock provider class and its instance."""
|
||||
mock_agent = MagicMock()
|
||||
mock_provider_instance = AsyncMock()
|
||||
mock_provider_instance.create_agent = AsyncMock(return_value=mock_agent)
|
||||
mock_provider_class = MagicMock(return_value=mock_provider_instance)
|
||||
return mock_provider_class, mock_provider_instance
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_format_in_default_options(self):
|
||||
"""Provider.create_agent() should receive response_format inside default_options."""
|
||||
from agent_framework_declarative._loader import AgentFactory
|
||||
|
||||
prompt_agent = self._make_mock_prompt_agent(with_output_schema=True)
|
||||
mock_provider_class, mock_provider_instance = self._make_mock_provider()
|
||||
|
||||
mapping = {"package": "some_module", "name": "SomeProvider"}
|
||||
factory = AgentFactory()
|
||||
|
||||
original_import = builtins.__import__
|
||||
|
||||
def mock_import(name, *args, **kwargs):
|
||||
if name == "some_module":
|
||||
mod = MagicMock()
|
||||
mod.SomeProvider = mock_provider_class
|
||||
return mod
|
||||
return original_import(name, *args, **kwargs)
|
||||
|
||||
with (
|
||||
patch.object(builtins, "__import__", side_effect=mock_import),
|
||||
patch.object(factory, "_parse_tools", return_value=None),
|
||||
):
|
||||
await factory._create_agent_with_provider(prompt_agent, mapping)
|
||||
|
||||
mock_provider_instance.create_agent.assert_called_once()
|
||||
call_kwargs = mock_provider_instance.create_agent.call_args.kwargs
|
||||
|
||||
assert "response_format" not in call_kwargs
|
||||
default_options = call_kwargs.get("default_options")
|
||||
assert default_options is not None
|
||||
assert "response_format" in default_options
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_default_options_without_output_schema(self):
|
||||
"""When there's no outputSchema, default_options should be None."""
|
||||
from agent_framework_declarative._loader import AgentFactory
|
||||
|
||||
prompt_agent = self._make_mock_prompt_agent(with_output_schema=False)
|
||||
mock_provider_class, mock_provider_instance = self._make_mock_provider()
|
||||
|
||||
mapping = {"package": "some_module", "name": "SomeProvider"}
|
||||
factory = AgentFactory()
|
||||
|
||||
original_import = builtins.__import__
|
||||
|
||||
def mock_import(name, *args, **kwargs):
|
||||
if name == "some_module":
|
||||
mod = MagicMock()
|
||||
mod.SomeProvider = mock_provider_class
|
||||
return mod
|
||||
return original_import(name, *args, **kwargs)
|
||||
|
||||
with (
|
||||
patch.object(builtins, "__import__", side_effect=mock_import),
|
||||
patch.object(factory, "_parse_tools", return_value=None),
|
||||
):
|
||||
await factory._create_agent_with_provider(prompt_agent, mapping)
|
||||
|
||||
call_kwargs = mock_provider_instance.create_agent.call_args.kwargs
|
||||
assert call_kwargs.get("default_options") is None
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
"""Tests for the graph-based declarative workflow executors."""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
@@ -1295,3 +1296,203 @@ class TestExtractJsonFromResponse:
|
||||
text = 'First: {"status": "pending"} then later: {"status": "complete", "id": 42}'
|
||||
result = _extract_json_from_response(text)
|
||||
assert result == {"status": "complete", "id": 42}
|
||||
|
||||
|
||||
class TestPowerFxConditionalImport:
|
||||
"""The _declarative_base module should be importable without dotnet/powerfx."""
|
||||
|
||||
def test_import_guard_exists(self):
|
||||
"""The powerfx import must be wrapped in try/except."""
|
||||
import agent_framework_declarative._workflows._declarative_base as base_mod
|
||||
|
||||
assert hasattr(base_mod, "DeclarativeWorkflowState")
|
||||
assert hasattr(base_mod, "Engine")
|
||||
|
||||
# Engine should either be the real class or None — never an ImportError
|
||||
engine = base_mod.Engine
|
||||
assert engine is None or callable(engine)
|
||||
|
||||
def test_eval_raises_when_engine_unavailable(self):
|
||||
"""eval() should raise RuntimeError when Engine is None."""
|
||||
import agent_framework_declarative._workflows._declarative_base as base_mod
|
||||
|
||||
mock_state = MagicMock()
|
||||
mock_state._data: dict[str, Any] = {}
|
||||
mock_state.get = MagicMock(side_effect=lambda k, d=None: mock_state._data.get(k, d))
|
||||
mock_state.set = MagicMock(side_effect=lambda k, v: mock_state._data.__setitem__(k, v))
|
||||
|
||||
state = DeclarativeWorkflowState(mock_state)
|
||||
state.initialize({"name": "test"})
|
||||
|
||||
original_engine = base_mod.Engine
|
||||
try:
|
||||
base_mod.Engine = None
|
||||
with pytest.raises(RuntimeError, match="PowerFx is not available"):
|
||||
state.eval("=Local.counter + 1")
|
||||
finally:
|
||||
base_mod.Engine = original_engine
|
||||
|
||||
def test_eval_passes_through_plain_strings_without_engine(self):
|
||||
"""Non-PowerFx strings (no leading '=') should work without Engine."""
|
||||
import agent_framework_declarative._workflows._declarative_base as base_mod
|
||||
|
||||
mock_state = MagicMock()
|
||||
mock_state._data: dict[str, Any] = {}
|
||||
mock_state.get = MagicMock(side_effect=lambda k, d=None: mock_state._data.get(k, d))
|
||||
mock_state.set = MagicMock(side_effect=lambda k, v: mock_state._data.__setitem__(k, v))
|
||||
|
||||
state = DeclarativeWorkflowState(mock_state)
|
||||
state.initialize()
|
||||
|
||||
original_engine = base_mod.Engine
|
||||
try:
|
||||
base_mod.Engine = None
|
||||
assert state.eval("hello world") == "hello world"
|
||||
assert state.eval("") == ""
|
||||
assert state.eval(42) == 42
|
||||
finally:
|
||||
base_mod.Engine = original_engine
|
||||
|
||||
|
||||
class TestExecutorKwargsForwarding:
|
||||
"""Workflow run kwargs should be forwarded through executor agent invocations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_agent_forwards_kwargs(self):
|
||||
"""InvokeAzureAgentExecutor should forward run_kwargs to agent.run()."""
|
||||
from agent_framework._workflows._const import WORKFLOW_RUN_KWARGS_KEY
|
||||
from agent_framework._workflows._state import State
|
||||
|
||||
from agent_framework_declarative._workflows._executors_agents import (
|
||||
InvokeAzureAgentExecutor,
|
||||
)
|
||||
|
||||
# Create a mock State with kwargs stored
|
||||
mock_state = MagicMock(spec=State)
|
||||
state_data: dict[str, Any] = {}
|
||||
|
||||
def mock_get(key, default=None):
|
||||
return state_data.get(key, default)
|
||||
|
||||
def mock_set(key, value):
|
||||
state_data[key] = value
|
||||
|
||||
mock_state.get = MagicMock(side_effect=mock_get)
|
||||
mock_state.set = MagicMock(side_effect=mock_set)
|
||||
|
||||
# Store kwargs in state like Workflow.run() does
|
||||
test_kwargs = {"user_token": "abc123", "service_config": {"endpoint": "http://test"}}
|
||||
state_data[WORKFLOW_RUN_KWARGS_KEY] = test_kwargs
|
||||
|
||||
# Initialize declarative state
|
||||
dws = DeclarativeWorkflowState(mock_state)
|
||||
dws.initialize({"input": "hello"})
|
||||
|
||||
# Create a mock agent
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "response text"
|
||||
mock_response.messages = []
|
||||
mock_response.tool_calls = []
|
||||
mock_agent = AsyncMock()
|
||||
mock_agent.run = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Create a mock workflow context
|
||||
mock_ctx = MagicMock()
|
||||
mock_ctx.get_state = MagicMock(side_effect=mock_get)
|
||||
mock_ctx.yield_output = AsyncMock()
|
||||
|
||||
executor = InvokeAzureAgentExecutor.__new__(InvokeAzureAgentExecutor)
|
||||
executor._agents = {"test_agent": mock_agent}
|
||||
|
||||
await executor._invoke_agent_and_store_results(
|
||||
agent=mock_agent,
|
||||
agent_name="test_agent",
|
||||
input_text="hello",
|
||||
state=dws,
|
||||
ctx=mock_ctx,
|
||||
messages_var=None,
|
||||
response_obj_var=None,
|
||||
result_property=None,
|
||||
auto_send=True,
|
||||
)
|
||||
|
||||
# Verify agent.run was called with kwargs
|
||||
mock_agent.run.assert_called_once()
|
||||
call_kwargs = mock_agent.run.call_args
|
||||
|
||||
# Check options contains additional_function_arguments
|
||||
assert "options" in call_kwargs.kwargs
|
||||
assert call_kwargs.kwargs["options"]["additional_function_arguments"] == test_kwargs
|
||||
|
||||
# Check direct kwargs were passed
|
||||
assert call_kwargs.kwargs.get("user_token") == "abc123"
|
||||
assert call_kwargs.kwargs.get("service_config") == {"endpoint": "http://test"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_agent_merges_caller_options(self):
|
||||
"""Caller-provided options in run_kwargs should be merged, not cause TypeError."""
|
||||
from agent_framework._workflows._const import WORKFLOW_RUN_KWARGS_KEY
|
||||
from agent_framework._workflows._state import State
|
||||
|
||||
from agent_framework_declarative._workflows._executors_agents import (
|
||||
InvokeAzureAgentExecutor,
|
||||
)
|
||||
|
||||
mock_state = MagicMock(spec=State)
|
||||
state_data: dict[str, Any] = {}
|
||||
|
||||
def mock_get(key, default=None):
|
||||
return state_data.get(key, default)
|
||||
|
||||
def mock_set(key, value):
|
||||
state_data[key] = value
|
||||
|
||||
mock_state.get = MagicMock(side_effect=mock_get)
|
||||
mock_state.set = MagicMock(side_effect=mock_set)
|
||||
|
||||
# Include 'options' in run_kwargs to test merge behavior
|
||||
test_kwargs = {
|
||||
"user_token": "abc123",
|
||||
"options": {"temperature": 0.5},
|
||||
}
|
||||
state_data[WORKFLOW_RUN_KWARGS_KEY] = test_kwargs
|
||||
|
||||
dws = DeclarativeWorkflowState(mock_state)
|
||||
dws.initialize({"input": "hello"})
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "response text"
|
||||
mock_response.messages = []
|
||||
mock_response.tool_calls = []
|
||||
mock_agent = AsyncMock()
|
||||
mock_agent.run = AsyncMock(return_value=mock_response)
|
||||
|
||||
mock_ctx = MagicMock()
|
||||
mock_ctx.get_state = MagicMock(side_effect=mock_get)
|
||||
mock_ctx.yield_output = AsyncMock()
|
||||
|
||||
executor = InvokeAzureAgentExecutor.__new__(InvokeAzureAgentExecutor)
|
||||
executor._agents = {"test_agent": mock_agent}
|
||||
|
||||
await executor._invoke_agent_and_store_results(
|
||||
agent=mock_agent,
|
||||
agent_name="test_agent",
|
||||
input_text="hello",
|
||||
state=dws,
|
||||
ctx=mock_ctx,
|
||||
messages_var=None,
|
||||
response_obj_var=None,
|
||||
result_property=None,
|
||||
auto_send=True,
|
||||
)
|
||||
|
||||
mock_agent.run.assert_called_once()
|
||||
call_kwargs = mock_agent.run.call_args
|
||||
|
||||
# Caller options should be merged with additional_function_arguments
|
||||
merged_options = call_kwargs.kwargs["options"]
|
||||
assert merged_options["temperature"] == 0.5
|
||||
assert "additional_function_arguments" in merged_options
|
||||
|
||||
# Direct kwargs should be passed without 'options' (no duplicate keyword)
|
||||
assert call_kwargs.kwargs.get("user_token") == "abc123"
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -29,6 +30,7 @@ def create_action_context(
|
||||
inputs: dict[str, Any] | None = None,
|
||||
agents: dict[str, Any] | None = None,
|
||||
bindings: dict[str, Any] | None = None,
|
||||
run_kwargs: dict[str, Any] | None = None,
|
||||
) -> ActionContext:
|
||||
"""Helper to create an ActionContext for testing."""
|
||||
state = WorkflowState(inputs=inputs or {})
|
||||
@@ -47,6 +49,7 @@ def create_action_context(
|
||||
execute_actions=execute_actions,
|
||||
agents=agents or {},
|
||||
bindings=bindings or {},
|
||||
run_kwargs=run_kwargs or {},
|
||||
)
|
||||
async for event in handler(ctx):
|
||||
yield event
|
||||
@@ -57,6 +60,7 @@ def create_action_context(
|
||||
execute_actions=execute_actions,
|
||||
agents=agents or {},
|
||||
bindings=bindings or {},
|
||||
run_kwargs=run_kwargs or {},
|
||||
)
|
||||
|
||||
|
||||
@@ -422,3 +426,128 @@ class TestTryCatchHandler:
|
||||
|
||||
assert ctx.state.get("Local.try") == "ran"
|
||||
assert ctx.state.get("Local.finally") == "ran"
|
||||
|
||||
|
||||
class TestActionContextKwargs:
|
||||
"""ActionContext should carry and forward run_kwargs to agent invocations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_context_carries_run_kwargs(self):
|
||||
"""ActionContext should store and expose run_kwargs."""
|
||||
ctx = create_action_context(
|
||||
{"kind": "SetValue", "path": "Local.x", "value": "1"},
|
||||
run_kwargs={"user_token": "test123"},
|
||||
)
|
||||
assert ctx.run_kwargs == {"user_token": "test123"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_context_defaults_to_empty_kwargs(self):
|
||||
"""ActionContext.run_kwargs should default to empty dict."""
|
||||
ctx = create_action_context(
|
||||
{"kind": "SetValue", "path": "Local.x", "value": "1"},
|
||||
)
|
||||
assert ctx.run_kwargs == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_agent_handler_forwards_kwargs(self):
|
||||
"""handle_invoke_azure_agent should forward ctx.run_kwargs to agent.run()."""
|
||||
import agent_framework_declarative._workflows._actions_agents # noqa: F401
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "response"
|
||||
mock_response.messages = []
|
||||
mock_response.tool_calls = []
|
||||
|
||||
async def non_streaming_run(*args, **kwargs):
|
||||
if kwargs.get("stream"):
|
||||
raise TypeError("no streaming")
|
||||
return mock_response
|
||||
|
||||
mock_agent = AsyncMock()
|
||||
mock_agent.run = AsyncMock(side_effect=non_streaming_run)
|
||||
|
||||
test_kwargs = {"user_token": "secret", "api_key": "key123"}
|
||||
|
||||
state = WorkflowState()
|
||||
state.add_conversation_message(MagicMock(role="user", text="hello"))
|
||||
|
||||
ctx = create_action_context(
|
||||
action={
|
||||
"kind": "InvokeAzureAgent",
|
||||
"agent": "my_agent",
|
||||
},
|
||||
agents={"my_agent": mock_agent},
|
||||
run_kwargs=test_kwargs,
|
||||
)
|
||||
|
||||
handler = get_action_handler("InvokeAzureAgent")
|
||||
_ = [e async for e in handler(ctx)]
|
||||
|
||||
assert mock_agent.run.call_count >= 1
|
||||
|
||||
# Find the non-streaming fallback call
|
||||
for call in mock_agent.run.call_args_list:
|
||||
call_kw = call.kwargs
|
||||
if not call_kw.get("stream"):
|
||||
assert call_kw.get("user_token") == "secret"
|
||||
assert call_kw.get("api_key") == "key123"
|
||||
assert call_kw.get("options") == {"additional_function_arguments": test_kwargs}
|
||||
break
|
||||
else:
|
||||
# All calls were streaming — check the streaming call
|
||||
call_kw = mock_agent.run.call_args_list[0].kwargs
|
||||
assert call_kw.get("user_token") == "secret"
|
||||
assert call_kw.get("api_key") == "key123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_agent_handler_merges_caller_options(self):
|
||||
"""Caller-provided options in run_kwargs should be merged, not cause TypeError."""
|
||||
import agent_framework_declarative._workflows._actions_agents # noqa: F401
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "response"
|
||||
mock_response.messages = []
|
||||
mock_response.tool_calls = []
|
||||
|
||||
async def non_streaming_run(*args, **kwargs):
|
||||
if kwargs.get("stream"):
|
||||
raise TypeError("no streaming")
|
||||
return mock_response
|
||||
|
||||
mock_agent = AsyncMock()
|
||||
mock_agent.run = AsyncMock(side_effect=non_streaming_run)
|
||||
|
||||
# Include 'options' in run_kwargs to test merge behavior
|
||||
test_kwargs = {"user_token": "secret", "options": {"temperature": 0.7}}
|
||||
|
||||
state = WorkflowState()
|
||||
state.add_conversation_message(MagicMock(role="user", text="hello"))
|
||||
|
||||
ctx = create_action_context(
|
||||
action={
|
||||
"kind": "InvokeAzureAgent",
|
||||
"agent": "my_agent",
|
||||
},
|
||||
agents={"my_agent": mock_agent},
|
||||
run_kwargs=test_kwargs,
|
||||
)
|
||||
|
||||
handler = get_action_handler("InvokeAzureAgent")
|
||||
_ = [e async for e in handler(ctx)]
|
||||
|
||||
assert mock_agent.run.call_count >= 1
|
||||
|
||||
# Find the non-streaming fallback call
|
||||
for call in mock_agent.run.call_args_list:
|
||||
call_kw = call.kwargs
|
||||
if not call_kw.get("stream"):
|
||||
# Caller options should be merged with additional_function_arguments
|
||||
assert call_kw["options"]["temperature"] == 0.7
|
||||
assert "additional_function_arguments" in call_kw["options"]
|
||||
# Direct kwargs should not include 'options' (no duplicate keyword)
|
||||
assert call_kw.get("user_token") == "secret"
|
||||
break
|
||||
else:
|
||||
call_kw = mock_agent.run.call_args_list[0].kwargs
|
||||
assert call_kw["options"]["temperature"] == 0.7
|
||||
assert "additional_function_arguments" in call_kw["options"]
|
||||
|
||||
Reference in New Issue
Block a user