diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py b/python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py index 639a3f89b3..330a66dc10 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py @@ -69,19 +69,23 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) # Keys that are internal to AG-UI orchestration and should not be passed to chat clients -AG_UI_INTERNAL_METADATA_KEYS = {"ag_ui_thread_id", "ag_ui_run_id", "current_state"} +AG_UI_INTERNAL_METADATA_KEYS = {"ag_ui_thread_id", "ag_ui_run_id", "current_state", "forwarded_props"} def _build_safe_metadata(thread_metadata: dict[str, Any] | None) -> dict[str, Any]: - """Build metadata dict with truncated string values for Azure compatibility. + """Build metadata dict with string values for Azure compatibility. - Azure has a 512 character limit per metadata value. + Azure has a 512 character limit per metadata value. String values that + already fit are kept as-is. Non-string values are JSON-serialized. If the + resulting string exceeds 512 characters the key is **dropped** (with a + warning) instead of truncated, because truncation can produce invalid JSON + that downstream consumers cannot decode. Args: thread_metadata: Raw metadata dict Returns: - Metadata with string values truncated to 512 chars + Metadata with safe string values (each <= 512 chars) """ if not thread_metadata: return {} @@ -89,7 +93,12 @@ def _build_safe_metadata(thread_metadata: dict[str, Any] | None) -> dict[str, An for key, value in thread_metadata.items(): value_str = value if isinstance(value, str) else json.dumps(value) if len(value_str) > 512: - value_str = value_str[:512] + logger.warning( + "Dropping metadata key %r: serialized value is %d chars (limit 512)", + key, + len(value_str), + ) + continue safe_metadata[key] = value_str return safe_metadata @@ -790,6 +799,10 @@ async def run_agent_stream( "ag_ui_thread_id": thread_id, "ag_ui_run_id": run_id, } + if "forwarded_props" in input_data: + base_metadata["forwarded_props"] = input_data["forwarded_props"] + elif "forwardedProps" in input_data: + base_metadata["forwarded_props"] = input_data["forwardedProps"] if flow.current_state: base_metadata["current_state"] = flow.current_state session.metadata = _build_safe_metadata(base_metadata) # type: ignore[attr-defined] diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_workflow_run.py b/python/packages/ag-ui/agent_framework_ag_ui/_workflow_run.py index d34cb7db61..211657e688 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_workflow_run.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_workflow_run.py @@ -4,6 +4,7 @@ from __future__ import annotations +import inspect import json import logging import uuid @@ -581,11 +582,33 @@ async def run_workflow_stream( flow.accumulated_text = "" return [TextMessageEndEvent(message_id=current_message_id)] + fwd_kwargs: dict[str, Any] = {} + if "forwarded_props" in input_data: + forwarded_props = input_data["forwarded_props"] + fwd_kwargs["function_invocation_kwargs"] = {"forwarded_props": forwarded_props} + elif "forwardedProps" in input_data: + forwarded_props = input_data["forwardedProps"] + fwd_kwargs["function_invocation_kwargs"] = {"forwarded_props": forwarded_props} + + # Only pass function_invocation_kwargs if the workflow.run signature accepts it + if fwd_kwargs: + try: + sig = inspect.signature(workflow.run) + params = sig.parameters + accepts_fwd = "function_invocation_kwargs" in params or any( + p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values() + ) + except (ValueError, TypeError): + accepts_fwd = False + if not accepts_fwd: + logger.debug("workflow.run() does not accept function_invocation_kwargs; dropping forwarded_props") + fwd_kwargs = {} + try: if responses: - event_stream = workflow.run(responses=responses, stream=True) + event_stream = workflow.run(responses=responses, stream=True, **fwd_kwargs) else: - event_stream = workflow.run(message=messages, stream=True) + event_stream = workflow.run(message=messages, stream=True, **fwd_kwargs) async for event in event_stream: event_type = getattr(event, "type", None) diff --git a/python/packages/ag-ui/tests/ag_ui/test_forwarded_props_in_metadata.py b/python/packages/ag-ui/tests/ag_ui/test_forwarded_props_in_metadata.py new file mode 100644 index 0000000000..fba70db17e --- /dev/null +++ b/python/packages/ag-ui/tests/ag_ui/test_forwarded_props_in_metadata.py @@ -0,0 +1,83 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Tests for forwarded_props inclusion in AG-UI session metadata.""" + +import json +from typing import Any + +from agent_framework_ag_ui._agent_run import AG_UI_INTERNAL_METADATA_KEYS, _build_safe_metadata + + +class TestForwardedPropsInSessionMetadata: + """Verify that forwarded_props is surfaced in session metadata and filtered from LLM metadata.""" + + def test_forwarded_props_in_internal_metadata_keys(self): + """forwarded_props is listed in AG_UI_INTERNAL_METADATA_KEYS to prevent LLM leakage.""" + assert "forwarded_props" in AG_UI_INTERNAL_METADATA_KEYS + + def test_forwarded_props_filtered_from_client_metadata(self): + """forwarded_props is filtered out when building LLM-bound client metadata.""" + session_metadata: dict[str, Any] = { + "ag_ui_thread_id": "t1", + "ag_ui_run_id": "r1", + "forwarded_props": '{"custom_flag": true}', + } + + client_metadata = {k: v for k, v in session_metadata.items() if k not in AG_UI_INTERNAL_METADATA_KEYS} + + assert "forwarded_props" not in client_metadata + assert "ag_ui_thread_id" not in client_metadata + + +class TestBuildSafeMetadata: + """Verify _build_safe_metadata handles various value types correctly.""" + + def test_string_value_unchanged(self): + result = _build_safe_metadata({"key": "hello"}) + assert result == {"key": "hello"} + + def test_dict_value_serialized_to_json(self): + result = _build_safe_metadata({"fp": {"flag": True, "source": "frontend"}}) + assert "fp" in result + assert isinstance(result["fp"], str) + # Must be valid, decodable JSON + decoded = json.loads(result["fp"]) + assert decoded == {"flag": True, "source": "frontend"} + + def test_empty_dict_serialized_to_json(self): + result = _build_safe_metadata({"fp": {}}) + assert result["fp"] == "{}" + assert json.loads(result["fp"]) == {} + + def test_value_within_limit_kept(self): + value = "x" * 512 + result = _build_safe_metadata({"key": value}) + assert result["key"] == value + + def test_value_exceeding_limit_dropped(self): + """Values exceeding 512 chars are dropped entirely (not truncated).""" + value = "x" * 513 + result = _build_safe_metadata({"key": value}) + assert "key" not in result + + def test_json_value_exceeding_limit_dropped(self): + """JSON-serialized dict exceeding 512 chars is dropped, not truncated into invalid JSON.""" + big_dict = {f"key_{i}": "v" * 100 for i in range(50)} + result = _build_safe_metadata({"forwarded_props": big_dict}) + assert "forwarded_props" not in result + + def test_other_keys_preserved_when_one_dropped(self): + """Dropping one oversized key does not affect other keys.""" + result = _build_safe_metadata( + { + "small": "ok", + "big": "x" * 600, + } + ) + assert result == {"small": "ok"} + + def test_none_input_returns_empty(self): + assert _build_safe_metadata(None) == {} + + def test_empty_input_returns_empty(self): + assert _build_safe_metadata({}) == {} diff --git a/python/packages/ag-ui/tests/ag_ui/test_run.py b/python/packages/ag-ui/tests/ag_ui/test_run.py index 18b0d0d7e4..392f0cd723 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_run.py +++ b/python/packages/ag-ui/tests/ag_ui/test_run.py @@ -63,12 +63,12 @@ class TestBuildSafeMetadata: result = _build_safe_metadata(metadata) assert result == metadata - def test_truncates_long_strings(self): - """Truncates strings over 512 chars.""" + def test_drops_long_strings(self): + """Drops strings over 512 chars instead of truncating.""" long_value = "x" * 1000 metadata = {"key": long_value} result = _build_safe_metadata(metadata) - assert len(result["key"]) == 512 + assert "key" not in result def test_serializes_non_strings(self): """Serializes non-string values to JSON.""" @@ -77,12 +77,12 @@ class TestBuildSafeMetadata: assert result["count"] == "42" assert result["items"] == "[1, 2, 3]" - def test_truncates_serialized_values(self): - """Truncates serialized values over 512 chars.""" + def test_drops_oversized_serialized_values(self): + """Drops serialized values over 512 chars instead of truncating.""" long_list = list(range(200)) metadata = {"data": long_list} result = _build_safe_metadata(metadata) - assert len(result["data"]) == 512 + assert "data" not in result class TestHasOnlyToolCalls: diff --git a/python/packages/ag-ui/tests/ag_ui/test_workflow_run.py b/python/packages/ag-ui/tests/ag_ui/test_workflow_run.py index 26b44b03ba..a52cc4dd2c 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_workflow_run.py +++ b/python/packages/ag-ui/tests/ag_ui/test_workflow_run.py @@ -1672,3 +1672,210 @@ async def test_workflow_run_non_terminal_status_emits_custom(): custom = [e for e in events if e.type == "CUSTOM" and e.name == "status"] assert len(custom) == 1 assert custom[0].value == {"state": "running"} + + +async def test_workflow_run_passes_forwarded_props_as_function_invocation_kwargs() -> None: + """forwarded_props from input_data is forwarded to workflow.run() via function_invocation_kwargs.""" + + class CapturingWorkflow: + def __init__(self) -> None: + self.captured_kwargs: dict[str, Any] = {} + + def run(self, **kwargs: Any): + self.captured_kwargs = dict(kwargs) + + async def _stream(): + yield SimpleNamespace(type="started") + + return _stream() + + workflow = CapturingWorkflow() + events = [ + event + async for event in run_workflow_stream( + { + "messages": [{"role": "user", "content": "hello"}], + "forwarded_props": {"custom_flag": True, "source": "copilotkit"}, + }, + cast(Any, workflow), + ) + ] + + event_types = [event.type for event in events] + assert "RUN_STARTED" in event_types + assert "RUN_FINISHED" in event_types + + assert workflow.captured_kwargs["stream"] is True + assert "function_invocation_kwargs" in workflow.captured_kwargs + assert workflow.captured_kwargs["function_invocation_kwargs"] == { + "forwarded_props": {"custom_flag": True, "source": "copilotkit"}, + } + + +async def test_workflow_run_omits_function_invocation_kwargs_when_no_forwarded_props() -> None: + """function_invocation_kwargs is not passed when forwarded_props is absent.""" + + class CapturingWorkflow: + def __init__(self) -> None: + self.captured_kwargs: dict[str, Any] = {} + + def run(self, **kwargs: Any): + self.captured_kwargs = dict(kwargs) + + async def _stream(): + yield SimpleNamespace(type="started") + + return _stream() + + workflow = CapturingWorkflow() + events = [ + event + async for event in run_workflow_stream( + {"messages": [{"role": "user", "content": "hello"}]}, + cast(Any, workflow), + ) + ] + + event_types = [event.type for event in events] + assert "RUN_STARTED" in event_types + assert workflow.captured_kwargs["stream"] is True + assert "function_invocation_kwargs" not in workflow.captured_kwargs + + +async def test_workflow_run_accepts_camel_case_forwarded_props() -> None: + """forwardedProps (camelCase) is accepted as an alternative key.""" + + class CapturingWorkflow: + def __init__(self) -> None: + self.captured_kwargs: dict[str, Any] = {} + + def run(self, **kwargs: Any): + self.captured_kwargs = dict(kwargs) + + async def _stream(): + yield SimpleNamespace(type="started") + + return _stream() + + workflow = CapturingWorkflow() + events = [ + event + async for event in run_workflow_stream( + { + "messages": [{"role": "user", "content": "hello"}], + "forwardedProps": {"source": "frontend"}, + }, + cast(Any, workflow), + ) + ] + + event_types = [event.type for event in events] + assert "RUN_STARTED" in event_types + + assert workflow.captured_kwargs["stream"] is True + assert "function_invocation_kwargs" in workflow.captured_kwargs + assert workflow.captured_kwargs["function_invocation_kwargs"] == { + "forwarded_props": {"source": "frontend"}, + } + + +async def test_workflow_run_passes_empty_dict_forwarded_props() -> None: + """An empty dict forwarded_props={} should still be forwarded (not dropped by truthiness).""" + + class CapturingWorkflow: + def __init__(self) -> None: + self.captured_kwargs: dict[str, Any] = {} + + def run(self, **kwargs: Any): + self.captured_kwargs = dict(kwargs) + + async def _stream(): + yield SimpleNamespace(type="started") + + return _stream() + + workflow = CapturingWorkflow() + events = [ + event + async for event in run_workflow_stream( + { + "messages": [{"role": "user", "content": "hello"}], + "forwarded_props": {}, + }, + cast(Any, workflow), + ) + ] + + event_types = [event.type for event in events] + assert "RUN_STARTED" in event_types + assert "RUN_FINISHED" in event_types + + assert workflow.captured_kwargs["stream"] is True + assert "function_invocation_kwargs" in workflow.captured_kwargs + assert workflow.captured_kwargs["function_invocation_kwargs"] == { + "forwarded_props": {}, + } + + +async def test_workflow_run_stream_true_always_passed() -> None: + """stream=True is always passed to workflow.run().""" + + class CapturingWorkflow: + def __init__(self) -> None: + self.captured_kwargs: dict[str, Any] = {} + + def run(self, **kwargs: Any): + self.captured_kwargs = dict(kwargs) + + async def _stream(): + yield SimpleNamespace(type="started") + + return _stream() + + workflow = CapturingWorkflow() + _ = [ + event + async for event in run_workflow_stream( + { + "messages": [{"role": "user", "content": "hello"}], + "forwarded_props": {"key": "val"}, + }, + cast(Any, workflow), + ) + ] + + assert workflow.captured_kwargs["stream"] is True + + +async def test_workflow_run_drops_fwd_kwargs_when_run_lacks_param() -> None: + """function_invocation_kwargs is silently dropped if workflow.run() does not accept it.""" + + class StrictWorkflow: + def __init__(self) -> None: + self.captured_kwargs: dict[str, Any] = {} + + def run(self, *, message: Any = None, responses: Any = None, stream: bool = False): + self.captured_kwargs = {"message": message, "responses": responses, "stream": stream} + + async def _stream(): + yield SimpleNamespace(type="started") + + return _stream() + + workflow = StrictWorkflow() + events = [ + event + async for event in run_workflow_stream( + { + "messages": [{"role": "user", "content": "hello"}], + "forwarded_props": {"custom": True}, + }, + cast(Any, workflow), + ) + ] + + event_types = [event.type for event in events] + assert "RUN_STARTED" in event_types + assert "RUN_FINISHED" in event_types + # No TypeError raised, and function_invocation_kwargs was not passed + assert "function_invocation_kwargs" not in workflow.captured_kwargs