Python: Expose forwardedProps to agents and tools via session metadata (#5264)

* Expose forwarded_props to agents and tools via session metadata (#5239)

Include forwarded_props from AG-UI request input_data in session.metadata
(agent runner) and function_invocation_kwargs (workflow runner) so that
agents, tools, and workflow executors can access request-level metadata
such as invocation source flags from CopilotKit.

- Add forwarded_props to base_metadata in _agent_run.py when present
- Add 'forwarded_props' to AG_UI_INTERNAL_METADATA_KEYS to filter it
  from LLM-bound client metadata
- Extract forwarded_props in _workflow_run.py and pass via
  function_invocation_kwargs to workflow.run()
- Accept both snake_case and camelCase keys (forwarded_props/forwardedProps)

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* fix(ag-ui): pass stream=True as literal to satisfy pyright overload resolution (#5239)

The previous fix passed stream=True via **kwargs dict, which prevented
pyright from resolving the Workflow.run() overload to the streaming
variant. Pass stream=True as an explicit keyword argument so pyright
can correctly infer the ResponseStream return type.

Also remove unused pytest import in test file.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* fix: address PR review feedback for forwarded_props (#5239)

- Use key-presence checks instead of truthiness for forwarded_props so
  empty dict {} is forwarded correctly
- Gate function_invocation_kwargs on workflow.run() signature inspection
  to avoid TypeError for workflows without **kwargs
- Change _build_safe_metadata to drop (with warning) keys whose
  serialized values exceed 512 chars instead of truncating into invalid
  JSON
- Rewrite metadata tests to exercise _build_safe_metadata directly with
  JSON-decodability and truncation assertions
- Add workflow tests for empty dict forwarded_props, stream=True
  assertion, and signature-gated kwarg dropping

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* test: add stream=True assertions to CapturingWorkflow tests (#5239)

Guard against accidental removal of the explicit stream=True kwarg
in all forwarded_props CapturingWorkflow test cases.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Address review feedback for #5239: Python: Expose forwardedProps to agents and tools via session metadata

---------

Co-authored-by: Copilot <copilot@github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Evan Mattson
2026-04-21 13:25:45 +09:00
committed by GitHub
Unverified
parent 04aaf0c1fe
commit 07f4c8a8d6
5 changed files with 339 additions and 13 deletions
@@ -69,19 +69,23 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# Keys that are internal to AG-UI orchestration and should not be passed to chat clients
AG_UI_INTERNAL_METADATA_KEYS = {"ag_ui_thread_id", "ag_ui_run_id", "current_state"}
AG_UI_INTERNAL_METADATA_KEYS = {"ag_ui_thread_id", "ag_ui_run_id", "current_state", "forwarded_props"}
def _build_safe_metadata(thread_metadata: dict[str, Any] | None) -> dict[str, Any]:
"""Build metadata dict with truncated string values for Azure compatibility.
"""Build metadata dict with string values for Azure compatibility.
Azure has a 512 character limit per metadata value.
Azure has a 512 character limit per metadata value. String values that
already fit are kept as-is. Non-string values are JSON-serialized. If the
resulting string exceeds 512 characters the key is **dropped** (with a
warning) instead of truncated, because truncation can produce invalid JSON
that downstream consumers cannot decode.
Args:
thread_metadata: Raw metadata dict
Returns:
Metadata with string values truncated to 512 chars
Metadata with safe string values (each <= 512 chars)
"""
if not thread_metadata:
return {}
@@ -89,7 +93,12 @@ def _build_safe_metadata(thread_metadata: dict[str, Any] | None) -> dict[str, An
for key, value in thread_metadata.items():
value_str = value if isinstance(value, str) else json.dumps(value)
if len(value_str) > 512:
value_str = value_str[:512]
logger.warning(
"Dropping metadata key %r: serialized value is %d chars (limit 512)",
key,
len(value_str),
)
continue
safe_metadata[key] = value_str
return safe_metadata
@@ -790,6 +799,10 @@ async def run_agent_stream(
"ag_ui_thread_id": thread_id,
"ag_ui_run_id": run_id,
}
if "forwarded_props" in input_data:
base_metadata["forwarded_props"] = input_data["forwarded_props"]
elif "forwardedProps" in input_data:
base_metadata["forwarded_props"] = input_data["forwardedProps"]
if flow.current_state:
base_metadata["current_state"] = flow.current_state
session.metadata = _build_safe_metadata(base_metadata) # type: ignore[attr-defined]
@@ -4,6 +4,7 @@
from __future__ import annotations
import inspect
import json
import logging
import uuid
@@ -581,11 +582,33 @@ async def run_workflow_stream(
flow.accumulated_text = ""
return [TextMessageEndEvent(message_id=current_message_id)]
fwd_kwargs: dict[str, Any] = {}
if "forwarded_props" in input_data:
forwarded_props = input_data["forwarded_props"]
fwd_kwargs["function_invocation_kwargs"] = {"forwarded_props": forwarded_props}
elif "forwardedProps" in input_data:
forwarded_props = input_data["forwardedProps"]
fwd_kwargs["function_invocation_kwargs"] = {"forwarded_props": forwarded_props}
# Only pass function_invocation_kwargs if the workflow.run signature accepts it
if fwd_kwargs:
try:
sig = inspect.signature(workflow.run)
params = sig.parameters
accepts_fwd = "function_invocation_kwargs" in params or any(
p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()
)
except (ValueError, TypeError):
accepts_fwd = False
if not accepts_fwd:
logger.debug("workflow.run() does not accept function_invocation_kwargs; dropping forwarded_props")
fwd_kwargs = {}
try:
if responses:
event_stream = workflow.run(responses=responses, stream=True)
event_stream = workflow.run(responses=responses, stream=True, **fwd_kwargs)
else:
event_stream = workflow.run(message=messages, stream=True)
event_stream = workflow.run(message=messages, stream=True, **fwd_kwargs)
async for event in event_stream:
event_type = getattr(event, "type", None)
@@ -0,0 +1,83 @@
# Copyright (c) Microsoft. All rights reserved.
"""Tests for forwarded_props inclusion in AG-UI session metadata."""
import json
from typing import Any
from agent_framework_ag_ui._agent_run import AG_UI_INTERNAL_METADATA_KEYS, _build_safe_metadata
class TestForwardedPropsInSessionMetadata:
"""Verify that forwarded_props is surfaced in session metadata and filtered from LLM metadata."""
def test_forwarded_props_in_internal_metadata_keys(self):
"""forwarded_props is listed in AG_UI_INTERNAL_METADATA_KEYS to prevent LLM leakage."""
assert "forwarded_props" in AG_UI_INTERNAL_METADATA_KEYS
def test_forwarded_props_filtered_from_client_metadata(self):
"""forwarded_props is filtered out when building LLM-bound client metadata."""
session_metadata: dict[str, Any] = {
"ag_ui_thread_id": "t1",
"ag_ui_run_id": "r1",
"forwarded_props": '{"custom_flag": true}',
}
client_metadata = {k: v for k, v in session_metadata.items() if k not in AG_UI_INTERNAL_METADATA_KEYS}
assert "forwarded_props" not in client_metadata
assert "ag_ui_thread_id" not in client_metadata
class TestBuildSafeMetadata:
"""Verify _build_safe_metadata handles various value types correctly."""
def test_string_value_unchanged(self):
result = _build_safe_metadata({"key": "hello"})
assert result == {"key": "hello"}
def test_dict_value_serialized_to_json(self):
result = _build_safe_metadata({"fp": {"flag": True, "source": "frontend"}})
assert "fp" in result
assert isinstance(result["fp"], str)
# Must be valid, decodable JSON
decoded = json.loads(result["fp"])
assert decoded == {"flag": True, "source": "frontend"}
def test_empty_dict_serialized_to_json(self):
result = _build_safe_metadata({"fp": {}})
assert result["fp"] == "{}"
assert json.loads(result["fp"]) == {}
def test_value_within_limit_kept(self):
value = "x" * 512
result = _build_safe_metadata({"key": value})
assert result["key"] == value
def test_value_exceeding_limit_dropped(self):
"""Values exceeding 512 chars are dropped entirely (not truncated)."""
value = "x" * 513
result = _build_safe_metadata({"key": value})
assert "key" not in result
def test_json_value_exceeding_limit_dropped(self):
"""JSON-serialized dict exceeding 512 chars is dropped, not truncated into invalid JSON."""
big_dict = {f"key_{i}": "v" * 100 for i in range(50)}
result = _build_safe_metadata({"forwarded_props": big_dict})
assert "forwarded_props" not in result
def test_other_keys_preserved_when_one_dropped(self):
"""Dropping one oversized key does not affect other keys."""
result = _build_safe_metadata(
{
"small": "ok",
"big": "x" * 600,
}
)
assert result == {"small": "ok"}
def test_none_input_returns_empty(self):
assert _build_safe_metadata(None) == {}
def test_empty_input_returns_empty(self):
assert _build_safe_metadata({}) == {}
@@ -63,12 +63,12 @@ class TestBuildSafeMetadata:
result = _build_safe_metadata(metadata)
assert result == metadata
def test_truncates_long_strings(self):
"""Truncates strings over 512 chars."""
def test_drops_long_strings(self):
"""Drops strings over 512 chars instead of truncating."""
long_value = "x" * 1000
metadata = {"key": long_value}
result = _build_safe_metadata(metadata)
assert len(result["key"]) == 512
assert "key" not in result
def test_serializes_non_strings(self):
"""Serializes non-string values to JSON."""
@@ -77,12 +77,12 @@ class TestBuildSafeMetadata:
assert result["count"] == "42"
assert result["items"] == "[1, 2, 3]"
def test_truncates_serialized_values(self):
"""Truncates serialized values over 512 chars."""
def test_drops_oversized_serialized_values(self):
"""Drops serialized values over 512 chars instead of truncating."""
long_list = list(range(200))
metadata = {"data": long_list}
result = _build_safe_metadata(metadata)
assert len(result["data"]) == 512
assert "data" not in result
class TestHasOnlyToolCalls:
@@ -1672,3 +1672,210 @@ async def test_workflow_run_non_terminal_status_emits_custom():
custom = [e for e in events if e.type == "CUSTOM" and e.name == "status"]
assert len(custom) == 1
assert custom[0].value == {"state": "running"}
async def test_workflow_run_passes_forwarded_props_as_function_invocation_kwargs() -> None:
"""forwarded_props from input_data is forwarded to workflow.run() via function_invocation_kwargs."""
class CapturingWorkflow:
def __init__(self) -> None:
self.captured_kwargs: dict[str, Any] = {}
def run(self, **kwargs: Any):
self.captured_kwargs = dict(kwargs)
async def _stream():
yield SimpleNamespace(type="started")
return _stream()
workflow = CapturingWorkflow()
events = [
event
async for event in run_workflow_stream(
{
"messages": [{"role": "user", "content": "hello"}],
"forwarded_props": {"custom_flag": True, "source": "copilotkit"},
},
cast(Any, workflow),
)
]
event_types = [event.type for event in events]
assert "RUN_STARTED" in event_types
assert "RUN_FINISHED" in event_types
assert workflow.captured_kwargs["stream"] is True
assert "function_invocation_kwargs" in workflow.captured_kwargs
assert workflow.captured_kwargs["function_invocation_kwargs"] == {
"forwarded_props": {"custom_flag": True, "source": "copilotkit"},
}
async def test_workflow_run_omits_function_invocation_kwargs_when_no_forwarded_props() -> None:
"""function_invocation_kwargs is not passed when forwarded_props is absent."""
class CapturingWorkflow:
def __init__(self) -> None:
self.captured_kwargs: dict[str, Any] = {}
def run(self, **kwargs: Any):
self.captured_kwargs = dict(kwargs)
async def _stream():
yield SimpleNamespace(type="started")
return _stream()
workflow = CapturingWorkflow()
events = [
event
async for event in run_workflow_stream(
{"messages": [{"role": "user", "content": "hello"}]},
cast(Any, workflow),
)
]
event_types = [event.type for event in events]
assert "RUN_STARTED" in event_types
assert workflow.captured_kwargs["stream"] is True
assert "function_invocation_kwargs" not in workflow.captured_kwargs
async def test_workflow_run_accepts_camel_case_forwarded_props() -> None:
"""forwardedProps (camelCase) is accepted as an alternative key."""
class CapturingWorkflow:
def __init__(self) -> None:
self.captured_kwargs: dict[str, Any] = {}
def run(self, **kwargs: Any):
self.captured_kwargs = dict(kwargs)
async def _stream():
yield SimpleNamespace(type="started")
return _stream()
workflow = CapturingWorkflow()
events = [
event
async for event in run_workflow_stream(
{
"messages": [{"role": "user", "content": "hello"}],
"forwardedProps": {"source": "frontend"},
},
cast(Any, workflow),
)
]
event_types = [event.type for event in events]
assert "RUN_STARTED" in event_types
assert workflow.captured_kwargs["stream"] is True
assert "function_invocation_kwargs" in workflow.captured_kwargs
assert workflow.captured_kwargs["function_invocation_kwargs"] == {
"forwarded_props": {"source": "frontend"},
}
async def test_workflow_run_passes_empty_dict_forwarded_props() -> None:
"""An empty dict forwarded_props={} should still be forwarded (not dropped by truthiness)."""
class CapturingWorkflow:
def __init__(self) -> None:
self.captured_kwargs: dict[str, Any] = {}
def run(self, **kwargs: Any):
self.captured_kwargs = dict(kwargs)
async def _stream():
yield SimpleNamespace(type="started")
return _stream()
workflow = CapturingWorkflow()
events = [
event
async for event in run_workflow_stream(
{
"messages": [{"role": "user", "content": "hello"}],
"forwarded_props": {},
},
cast(Any, workflow),
)
]
event_types = [event.type for event in events]
assert "RUN_STARTED" in event_types
assert "RUN_FINISHED" in event_types
assert workflow.captured_kwargs["stream"] is True
assert "function_invocation_kwargs" in workflow.captured_kwargs
assert workflow.captured_kwargs["function_invocation_kwargs"] == {
"forwarded_props": {},
}
async def test_workflow_run_stream_true_always_passed() -> None:
"""stream=True is always passed to workflow.run()."""
class CapturingWorkflow:
def __init__(self) -> None:
self.captured_kwargs: dict[str, Any] = {}
def run(self, **kwargs: Any):
self.captured_kwargs = dict(kwargs)
async def _stream():
yield SimpleNamespace(type="started")
return _stream()
workflow = CapturingWorkflow()
_ = [
event
async for event in run_workflow_stream(
{
"messages": [{"role": "user", "content": "hello"}],
"forwarded_props": {"key": "val"},
},
cast(Any, workflow),
)
]
assert workflow.captured_kwargs["stream"] is True
async def test_workflow_run_drops_fwd_kwargs_when_run_lacks_param() -> None:
"""function_invocation_kwargs is silently dropped if workflow.run() does not accept it."""
class StrictWorkflow:
def __init__(self) -> None:
self.captured_kwargs: dict[str, Any] = {}
def run(self, *, message: Any = None, responses: Any = None, stream: bool = False):
self.captured_kwargs = {"message": message, "responses": responses, "stream": stream}
async def _stream():
yield SimpleNamespace(type="started")
return _stream()
workflow = StrictWorkflow()
events = [
event
async for event in run_workflow_stream(
{
"messages": [{"role": "user", "content": "hello"}],
"forwarded_props": {"custom": True},
},
cast(Any, workflow),
)
]
event_types = [event.type for event in events]
assert "RUN_STARTED" in event_types
assert "RUN_FINISHED" in event_types
# No TypeError raised, and function_invocation_kwargs was not passed
assert "function_invocation_kwargs" not in workflow.captured_kwargs