mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Expose forwardedProps to agents and tools via session metadata (#5264)
* Expose forwarded_props to agents and tools via session metadata (#5239) Include forwarded_props from AG-UI request input_data in session.metadata (agent runner) and function_invocation_kwargs (workflow runner) so that agents, tools, and workflow executors can access request-level metadata such as invocation source flags from CopilotKit. - Add forwarded_props to base_metadata in _agent_run.py when present - Add 'forwarded_props' to AG_UI_INTERNAL_METADATA_KEYS to filter it from LLM-bound client metadata - Extract forwarded_props in _workflow_run.py and pass via function_invocation_kwargs to workflow.run() - Accept both snake_case and camelCase keys (forwarded_props/forwardedProps) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix(ag-ui): pass stream=True as literal to satisfy pyright overload resolution (#5239) The previous fix passed stream=True via **kwargs dict, which prevented pyright from resolving the Workflow.run() overload to the streaming variant. Pass stream=True as an explicit keyword argument so pyright can correctly infer the ResponseStream return type. Also remove unused pytest import in test file. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix: address PR review feedback for forwarded_props (#5239) - Use key-presence checks instead of truthiness for forwarded_props so empty dict {} is forwarded correctly - Gate function_invocation_kwargs on workflow.run() signature inspection to avoid TypeError for workflows without **kwargs - Change _build_safe_metadata to drop (with warning) keys whose serialized values exceed 512 chars instead of truncating into invalid JSON - Rewrite metadata tests to exercise _build_safe_metadata directly with JSON-decodability and truncation assertions - Add workflow tests for empty dict forwarded_props, stream=True assertion, and signature-gated kwarg dropping Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * test: add stream=True assertions to CapturingWorkflow tests (#5239) Guard against accidental removal of the explicit stream=True kwarg in all forwarded_props CapturingWorkflow test cases. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Address review feedback for #5239: Python: Expose forwardedProps to agents and tools via session metadata --------- Co-authored-by: Copilot <copilot@github.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
04aaf0c1fe
commit
07f4c8a8d6
@@ -69,19 +69,23 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Keys that are internal to AG-UI orchestration and should not be passed to chat clients
|
||||
AG_UI_INTERNAL_METADATA_KEYS = {"ag_ui_thread_id", "ag_ui_run_id", "current_state"}
|
||||
AG_UI_INTERNAL_METADATA_KEYS = {"ag_ui_thread_id", "ag_ui_run_id", "current_state", "forwarded_props"}
|
||||
|
||||
|
||||
def _build_safe_metadata(thread_metadata: dict[str, Any] | None) -> dict[str, Any]:
|
||||
"""Build metadata dict with truncated string values for Azure compatibility.
|
||||
"""Build metadata dict with string values for Azure compatibility.
|
||||
|
||||
Azure has a 512 character limit per metadata value.
|
||||
Azure has a 512 character limit per metadata value. String values that
|
||||
already fit are kept as-is. Non-string values are JSON-serialized. If the
|
||||
resulting string exceeds 512 characters the key is **dropped** (with a
|
||||
warning) instead of truncated, because truncation can produce invalid JSON
|
||||
that downstream consumers cannot decode.
|
||||
|
||||
Args:
|
||||
thread_metadata: Raw metadata dict
|
||||
|
||||
Returns:
|
||||
Metadata with string values truncated to 512 chars
|
||||
Metadata with safe string values (each <= 512 chars)
|
||||
"""
|
||||
if not thread_metadata:
|
||||
return {}
|
||||
@@ -89,7 +93,12 @@ def _build_safe_metadata(thread_metadata: dict[str, Any] | None) -> dict[str, An
|
||||
for key, value in thread_metadata.items():
|
||||
value_str = value if isinstance(value, str) else json.dumps(value)
|
||||
if len(value_str) > 512:
|
||||
value_str = value_str[:512]
|
||||
logger.warning(
|
||||
"Dropping metadata key %r: serialized value is %d chars (limit 512)",
|
||||
key,
|
||||
len(value_str),
|
||||
)
|
||||
continue
|
||||
safe_metadata[key] = value_str
|
||||
return safe_metadata
|
||||
|
||||
@@ -790,6 +799,10 @@ async def run_agent_stream(
|
||||
"ag_ui_thread_id": thread_id,
|
||||
"ag_ui_run_id": run_id,
|
||||
}
|
||||
if "forwarded_props" in input_data:
|
||||
base_metadata["forwarded_props"] = input_data["forwarded_props"]
|
||||
elif "forwardedProps" in input_data:
|
||||
base_metadata["forwarded_props"] = input_data["forwardedProps"]
|
||||
if flow.current_state:
|
||||
base_metadata["current_state"] = flow.current_state
|
||||
session.metadata = _build_safe_metadata(base_metadata) # type: ignore[attr-defined]
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
@@ -581,11 +582,33 @@ async def run_workflow_stream(
|
||||
flow.accumulated_text = ""
|
||||
return [TextMessageEndEvent(message_id=current_message_id)]
|
||||
|
||||
fwd_kwargs: dict[str, Any] = {}
|
||||
if "forwarded_props" in input_data:
|
||||
forwarded_props = input_data["forwarded_props"]
|
||||
fwd_kwargs["function_invocation_kwargs"] = {"forwarded_props": forwarded_props}
|
||||
elif "forwardedProps" in input_data:
|
||||
forwarded_props = input_data["forwardedProps"]
|
||||
fwd_kwargs["function_invocation_kwargs"] = {"forwarded_props": forwarded_props}
|
||||
|
||||
# Only pass function_invocation_kwargs if the workflow.run signature accepts it
|
||||
if fwd_kwargs:
|
||||
try:
|
||||
sig = inspect.signature(workflow.run)
|
||||
params = sig.parameters
|
||||
accepts_fwd = "function_invocation_kwargs" in params or any(
|
||||
p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
accepts_fwd = False
|
||||
if not accepts_fwd:
|
||||
logger.debug("workflow.run() does not accept function_invocation_kwargs; dropping forwarded_props")
|
||||
fwd_kwargs = {}
|
||||
|
||||
try:
|
||||
if responses:
|
||||
event_stream = workflow.run(responses=responses, stream=True)
|
||||
event_stream = workflow.run(responses=responses, stream=True, **fwd_kwargs)
|
||||
else:
|
||||
event_stream = workflow.run(message=messages, stream=True)
|
||||
event_stream = workflow.run(message=messages, stream=True, **fwd_kwargs)
|
||||
|
||||
async for event in event_stream:
|
||||
event_type = getattr(event, "type", None)
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Tests for forwarded_props inclusion in AG-UI session metadata."""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from agent_framework_ag_ui._agent_run import AG_UI_INTERNAL_METADATA_KEYS, _build_safe_metadata
|
||||
|
||||
|
||||
class TestForwardedPropsInSessionMetadata:
|
||||
"""Verify that forwarded_props is surfaced in session metadata and filtered from LLM metadata."""
|
||||
|
||||
def test_forwarded_props_in_internal_metadata_keys(self):
|
||||
"""forwarded_props is listed in AG_UI_INTERNAL_METADATA_KEYS to prevent LLM leakage."""
|
||||
assert "forwarded_props" in AG_UI_INTERNAL_METADATA_KEYS
|
||||
|
||||
def test_forwarded_props_filtered_from_client_metadata(self):
|
||||
"""forwarded_props is filtered out when building LLM-bound client metadata."""
|
||||
session_metadata: dict[str, Any] = {
|
||||
"ag_ui_thread_id": "t1",
|
||||
"ag_ui_run_id": "r1",
|
||||
"forwarded_props": '{"custom_flag": true}',
|
||||
}
|
||||
|
||||
client_metadata = {k: v for k, v in session_metadata.items() if k not in AG_UI_INTERNAL_METADATA_KEYS}
|
||||
|
||||
assert "forwarded_props" not in client_metadata
|
||||
assert "ag_ui_thread_id" not in client_metadata
|
||||
|
||||
|
||||
class TestBuildSafeMetadata:
|
||||
"""Verify _build_safe_metadata handles various value types correctly."""
|
||||
|
||||
def test_string_value_unchanged(self):
|
||||
result = _build_safe_metadata({"key": "hello"})
|
||||
assert result == {"key": "hello"}
|
||||
|
||||
def test_dict_value_serialized_to_json(self):
|
||||
result = _build_safe_metadata({"fp": {"flag": True, "source": "frontend"}})
|
||||
assert "fp" in result
|
||||
assert isinstance(result["fp"], str)
|
||||
# Must be valid, decodable JSON
|
||||
decoded = json.loads(result["fp"])
|
||||
assert decoded == {"flag": True, "source": "frontend"}
|
||||
|
||||
def test_empty_dict_serialized_to_json(self):
|
||||
result = _build_safe_metadata({"fp": {}})
|
||||
assert result["fp"] == "{}"
|
||||
assert json.loads(result["fp"]) == {}
|
||||
|
||||
def test_value_within_limit_kept(self):
|
||||
value = "x" * 512
|
||||
result = _build_safe_metadata({"key": value})
|
||||
assert result["key"] == value
|
||||
|
||||
def test_value_exceeding_limit_dropped(self):
|
||||
"""Values exceeding 512 chars are dropped entirely (not truncated)."""
|
||||
value = "x" * 513
|
||||
result = _build_safe_metadata({"key": value})
|
||||
assert "key" not in result
|
||||
|
||||
def test_json_value_exceeding_limit_dropped(self):
|
||||
"""JSON-serialized dict exceeding 512 chars is dropped, not truncated into invalid JSON."""
|
||||
big_dict = {f"key_{i}": "v" * 100 for i in range(50)}
|
||||
result = _build_safe_metadata({"forwarded_props": big_dict})
|
||||
assert "forwarded_props" not in result
|
||||
|
||||
def test_other_keys_preserved_when_one_dropped(self):
|
||||
"""Dropping one oversized key does not affect other keys."""
|
||||
result = _build_safe_metadata(
|
||||
{
|
||||
"small": "ok",
|
||||
"big": "x" * 600,
|
||||
}
|
||||
)
|
||||
assert result == {"small": "ok"}
|
||||
|
||||
def test_none_input_returns_empty(self):
|
||||
assert _build_safe_metadata(None) == {}
|
||||
|
||||
def test_empty_input_returns_empty(self):
|
||||
assert _build_safe_metadata({}) == {}
|
||||
@@ -63,12 +63,12 @@ class TestBuildSafeMetadata:
|
||||
result = _build_safe_metadata(metadata)
|
||||
assert result == metadata
|
||||
|
||||
def test_truncates_long_strings(self):
|
||||
"""Truncates strings over 512 chars."""
|
||||
def test_drops_long_strings(self):
|
||||
"""Drops strings over 512 chars instead of truncating."""
|
||||
long_value = "x" * 1000
|
||||
metadata = {"key": long_value}
|
||||
result = _build_safe_metadata(metadata)
|
||||
assert len(result["key"]) == 512
|
||||
assert "key" not in result
|
||||
|
||||
def test_serializes_non_strings(self):
|
||||
"""Serializes non-string values to JSON."""
|
||||
@@ -77,12 +77,12 @@ class TestBuildSafeMetadata:
|
||||
assert result["count"] == "42"
|
||||
assert result["items"] == "[1, 2, 3]"
|
||||
|
||||
def test_truncates_serialized_values(self):
|
||||
"""Truncates serialized values over 512 chars."""
|
||||
def test_drops_oversized_serialized_values(self):
|
||||
"""Drops serialized values over 512 chars instead of truncating."""
|
||||
long_list = list(range(200))
|
||||
metadata = {"data": long_list}
|
||||
result = _build_safe_metadata(metadata)
|
||||
assert len(result["data"]) == 512
|
||||
assert "data" not in result
|
||||
|
||||
|
||||
class TestHasOnlyToolCalls:
|
||||
|
||||
@@ -1672,3 +1672,210 @@ async def test_workflow_run_non_terminal_status_emits_custom():
|
||||
custom = [e for e in events if e.type == "CUSTOM" and e.name == "status"]
|
||||
assert len(custom) == 1
|
||||
assert custom[0].value == {"state": "running"}
|
||||
|
||||
|
||||
async def test_workflow_run_passes_forwarded_props_as_function_invocation_kwargs() -> None:
|
||||
"""forwarded_props from input_data is forwarded to workflow.run() via function_invocation_kwargs."""
|
||||
|
||||
class CapturingWorkflow:
|
||||
def __init__(self) -> None:
|
||||
self.captured_kwargs: dict[str, Any] = {}
|
||||
|
||||
def run(self, **kwargs: Any):
|
||||
self.captured_kwargs = dict(kwargs)
|
||||
|
||||
async def _stream():
|
||||
yield SimpleNamespace(type="started")
|
||||
|
||||
return _stream()
|
||||
|
||||
workflow = CapturingWorkflow()
|
||||
events = [
|
||||
event
|
||||
async for event in run_workflow_stream(
|
||||
{
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"forwarded_props": {"custom_flag": True, "source": "copilotkit"},
|
||||
},
|
||||
cast(Any, workflow),
|
||||
)
|
||||
]
|
||||
|
||||
event_types = [event.type for event in events]
|
||||
assert "RUN_STARTED" in event_types
|
||||
assert "RUN_FINISHED" in event_types
|
||||
|
||||
assert workflow.captured_kwargs["stream"] is True
|
||||
assert "function_invocation_kwargs" in workflow.captured_kwargs
|
||||
assert workflow.captured_kwargs["function_invocation_kwargs"] == {
|
||||
"forwarded_props": {"custom_flag": True, "source": "copilotkit"},
|
||||
}
|
||||
|
||||
|
||||
async def test_workflow_run_omits_function_invocation_kwargs_when_no_forwarded_props() -> None:
|
||||
"""function_invocation_kwargs is not passed when forwarded_props is absent."""
|
||||
|
||||
class CapturingWorkflow:
|
||||
def __init__(self) -> None:
|
||||
self.captured_kwargs: dict[str, Any] = {}
|
||||
|
||||
def run(self, **kwargs: Any):
|
||||
self.captured_kwargs = dict(kwargs)
|
||||
|
||||
async def _stream():
|
||||
yield SimpleNamespace(type="started")
|
||||
|
||||
return _stream()
|
||||
|
||||
workflow = CapturingWorkflow()
|
||||
events = [
|
||||
event
|
||||
async for event in run_workflow_stream(
|
||||
{"messages": [{"role": "user", "content": "hello"}]},
|
||||
cast(Any, workflow),
|
||||
)
|
||||
]
|
||||
|
||||
event_types = [event.type for event in events]
|
||||
assert "RUN_STARTED" in event_types
|
||||
assert workflow.captured_kwargs["stream"] is True
|
||||
assert "function_invocation_kwargs" not in workflow.captured_kwargs
|
||||
|
||||
|
||||
async def test_workflow_run_accepts_camel_case_forwarded_props() -> None:
|
||||
"""forwardedProps (camelCase) is accepted as an alternative key."""
|
||||
|
||||
class CapturingWorkflow:
|
||||
def __init__(self) -> None:
|
||||
self.captured_kwargs: dict[str, Any] = {}
|
||||
|
||||
def run(self, **kwargs: Any):
|
||||
self.captured_kwargs = dict(kwargs)
|
||||
|
||||
async def _stream():
|
||||
yield SimpleNamespace(type="started")
|
||||
|
||||
return _stream()
|
||||
|
||||
workflow = CapturingWorkflow()
|
||||
events = [
|
||||
event
|
||||
async for event in run_workflow_stream(
|
||||
{
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"forwardedProps": {"source": "frontend"},
|
||||
},
|
||||
cast(Any, workflow),
|
||||
)
|
||||
]
|
||||
|
||||
event_types = [event.type for event in events]
|
||||
assert "RUN_STARTED" in event_types
|
||||
|
||||
assert workflow.captured_kwargs["stream"] is True
|
||||
assert "function_invocation_kwargs" in workflow.captured_kwargs
|
||||
assert workflow.captured_kwargs["function_invocation_kwargs"] == {
|
||||
"forwarded_props": {"source": "frontend"},
|
||||
}
|
||||
|
||||
|
||||
async def test_workflow_run_passes_empty_dict_forwarded_props() -> None:
|
||||
"""An empty dict forwarded_props={} should still be forwarded (not dropped by truthiness)."""
|
||||
|
||||
class CapturingWorkflow:
|
||||
def __init__(self) -> None:
|
||||
self.captured_kwargs: dict[str, Any] = {}
|
||||
|
||||
def run(self, **kwargs: Any):
|
||||
self.captured_kwargs = dict(kwargs)
|
||||
|
||||
async def _stream():
|
||||
yield SimpleNamespace(type="started")
|
||||
|
||||
return _stream()
|
||||
|
||||
workflow = CapturingWorkflow()
|
||||
events = [
|
||||
event
|
||||
async for event in run_workflow_stream(
|
||||
{
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"forwarded_props": {},
|
||||
},
|
||||
cast(Any, workflow),
|
||||
)
|
||||
]
|
||||
|
||||
event_types = [event.type for event in events]
|
||||
assert "RUN_STARTED" in event_types
|
||||
assert "RUN_FINISHED" in event_types
|
||||
|
||||
assert workflow.captured_kwargs["stream"] is True
|
||||
assert "function_invocation_kwargs" in workflow.captured_kwargs
|
||||
assert workflow.captured_kwargs["function_invocation_kwargs"] == {
|
||||
"forwarded_props": {},
|
||||
}
|
||||
|
||||
|
||||
async def test_workflow_run_stream_true_always_passed() -> None:
|
||||
"""stream=True is always passed to workflow.run()."""
|
||||
|
||||
class CapturingWorkflow:
|
||||
def __init__(self) -> None:
|
||||
self.captured_kwargs: dict[str, Any] = {}
|
||||
|
||||
def run(self, **kwargs: Any):
|
||||
self.captured_kwargs = dict(kwargs)
|
||||
|
||||
async def _stream():
|
||||
yield SimpleNamespace(type="started")
|
||||
|
||||
return _stream()
|
||||
|
||||
workflow = CapturingWorkflow()
|
||||
_ = [
|
||||
event
|
||||
async for event in run_workflow_stream(
|
||||
{
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"forwarded_props": {"key": "val"},
|
||||
},
|
||||
cast(Any, workflow),
|
||||
)
|
||||
]
|
||||
|
||||
assert workflow.captured_kwargs["stream"] is True
|
||||
|
||||
|
||||
async def test_workflow_run_drops_fwd_kwargs_when_run_lacks_param() -> None:
|
||||
"""function_invocation_kwargs is silently dropped if workflow.run() does not accept it."""
|
||||
|
||||
class StrictWorkflow:
|
||||
def __init__(self) -> None:
|
||||
self.captured_kwargs: dict[str, Any] = {}
|
||||
|
||||
def run(self, *, message: Any = None, responses: Any = None, stream: bool = False):
|
||||
self.captured_kwargs = {"message": message, "responses": responses, "stream": stream}
|
||||
|
||||
async def _stream():
|
||||
yield SimpleNamespace(type="started")
|
||||
|
||||
return _stream()
|
||||
|
||||
workflow = StrictWorkflow()
|
||||
events = [
|
||||
event
|
||||
async for event in run_workflow_stream(
|
||||
{
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"forwarded_props": {"custom": True},
|
||||
},
|
||||
cast(Any, workflow),
|
||||
)
|
||||
]
|
||||
|
||||
event_types = [event.type for event in events]
|
||||
assert "RUN_STARTED" in event_types
|
||||
assert "RUN_FINISHED" in event_types
|
||||
# No TypeError raised, and function_invocation_kwargs was not passed
|
||||
assert "function_invocation_kwargs" not in workflow.captured_kwargs
|
||||
|
||||
Reference in New Issue
Block a user