From b9926c095bc56c4c248da14015a048f03d9bbc88 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Tue, 28 Apr 2026 21:37:50 -0700 Subject: [PATCH] Address comments --- .../packages/core/agent_framework/_types.py | 24 ++- .../core/agent_framework/observability.py | 32 +-- .../core/tests/core/test_observability.py | 194 ++++++++++++++++++ 3 files changed, 234 insertions(+), 16 deletions(-) diff --git a/python/packages/core/agent_framework/_types.py b/python/packages/core/agent_framework/_types.py index a5f20a4029..5fbdcd1079 100644 --- a/python/packages/core/agent_framework/_types.py +++ b/python/packages/core/agent_framework/_types.py @@ -3047,9 +3047,25 @@ class ResponseStream(AsyncIterable[UpdateT], Generic[UpdateT, FinalT]): update = hooked return update + async def _resolve_stream_with_pull_contexts(self) -> AsyncIterable[UpdateT]: + """Resolve the underlying stream while activating any registered pull context managers. + + Used by ``__await__`` and ``get_final_response`` so that any spans/contexts created + during stream resolution (e.g. when the source is an Awaitable that internally + creates child telemetry spans) inherit the same active context as iterator pulls. + ``__anext__`` resolves the stream inside its own ExitStack and so calls ``_get_stream`` + directly. + """ + if self._stream is not None: + return await self._get_stream() + with contextlib.ExitStack() as stack: + for factory in self._pull_context_manager_factories: + stack.enter_context(factory()) + return await self._get_stream() + def __await__(self) -> Any: async def _wrap() -> ResponseStream[UpdateT, FinalT]: - await self._get_stream() + await self._resolve_stream_with_pull_contexts() return self return _wrap().__await__() @@ -3073,10 +3089,12 @@ class ResponseStream(AsyncIterable[UpdateT], Generic[UpdateT, FinalT]): """ if self._wrap_inner: if self._inner_stream is None: - # Use _get_stream() to resolve the awaitable - this properly handles + # Use _resolve_stream_with_pull_contexts() so that any spans/contexts + # created while resolving the awaitable (e.g. inner telemetry spans) + # inherit the same active context as iterator pulls. This also handles # the case where _stream_source and _inner_stream_source are the same # coroutine (e.g., from from_awaitable), avoiding double-await errors. - await self._get_stream() + await self._resolve_stream_with_pull_contexts() if self._inner_stream is None: raise RuntimeError("Inner stream not available") if not self._finalized and not self._consumed: diff --git a/python/packages/core/agent_framework/observability.py b/python/packages/core/agent_framework/observability.py index 2c71187d7f..051319926f 100644 --- a/python/packages/core/agent_framework/observability.py +++ b/python/packages/core/agent_framework/observability.py @@ -1301,18 +1301,23 @@ class ChatTelemetryLayer(Generic[OptionsCoT]): def _record_duration() -> None: duration_state["duration"] = perf_counter() - start_time - result_stream = cast( - ResponseStream[ChatResponseUpdate, ChatResponse[Any]], - super_get_response( - messages=messages, - stream=True, - options=opts, - compaction_strategy=compaction_strategy, - tokenizer=tokenizer, - function_invocation_kwargs=function_invocation_kwargs, - client_kwargs=merged_client_kwargs, - ), - ) + try: + result_stream = cast( + ResponseStream[ChatResponseUpdate, ChatResponse[Any]], + super_get_response( + messages=messages, + stream=True, + options=opts, + compaction_strategy=compaction_strategy, + tokenizer=tokenizer, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=merged_client_kwargs, + ), + ) + except Exception as exception: + capture_exception(span=span, exception=exception, timestamp=time_ns()) + _close_span() + raise async def _finalize_stream() -> None: from ._types import ChatResponse @@ -1576,7 +1581,8 @@ class AgentTelemetryLayer: result_stream = ResponseStream.from_awaitable(run_result) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] else: raise RuntimeError("Streaming telemetry requires a ResponseStream result.") - except Exception: + except Exception as exception: + capture_exception(span=span, exception=exception, timestamp=time_ns()) INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS.reset(inner_response_telemetry_captured_fields_token) INNER_ACCUMULATED_USAGE.reset(inner_accumulated_usage_token) _close_span() diff --git a/python/packages/core/tests/core/test_observability.py b/python/packages/core/tests/core/test_observability.py index ce7222ecc1..71b59a351b 100644 --- a/python/packages/core/tests/core/test_observability.py +++ b/python/packages/core/tests/core/test_observability.py @@ -3603,3 +3603,197 @@ async def test_http_span_nested_under_chat_span(span_exporter: InMemorySpanExpor assert http_span.parent is not None assert http_span.parent.span_id == chat_span.context.span_id assert http_span.context.trace_id == chat_span.context.trace_id + + +# region Test ResponseStream.with_pull_context_manager + + +async def test_with_pull_context_manager_enters_and_exits_per_pull(): + """The registered factory is entered and exited symmetrically around each iterator pull.""" + import contextlib + + events: list[str] = [] + + @contextlib.contextmanager + def cm(): + events.append("enter") + try: + yield + finally: + events.append("exit") + + async def src() -> AsyncIterable[int]: + yield 1 + yield 2 + + stream: ResponseStream[int, list[int]] = ResponseStream(src(), finalizer=lambda updates: list(updates)) + stream.with_pull_context_manager(cm) + + pulled = [u async for u in stream] + + assert pulled == [1, 2] + # Enter/exit must be balanced and there must be at least one pair per yielded update. + assert events.count("enter") == events.count("exit") + assert events.count("enter") >= 2 + # Verify symmetric ordering (no overlapping pairs). + for i in range(0, len(events), 2): + assert events[i] == "enter" + assert events[i + 1] == "exit" + + +async def test_with_pull_context_manager_exits_on_iteration_error(): + """The pull context is exited even when the underlying stream raises mid-iteration.""" + import contextlib + + events: list[str] = [] + + @contextlib.contextmanager + def cm(): + events.append("enter") + try: + yield + finally: + events.append("exit") + + async def src() -> AsyncIterable[int]: + yield 1 + raise RuntimeError("boom") + + stream: ResponseStream[int, list[int]] = ResponseStream(src(), finalizer=lambda updates: list(updates)) + stream.with_pull_context_manager(cm) + + with pytest.raises(RuntimeError, match="boom"): + async for _ in stream: + pass + + # Enter/exit balanced even on the failing pull. + assert events.count("enter") == events.count("exit") + assert events.count("enter") >= 2 + + +async def test_with_pull_context_manager_wraps_stream_resolution_via_await(): + """Awaiting a ``from_awaitable`` stream resolves the inner stream under the pull contexts.""" + import contextlib + + events: list[str] = [] + + @contextlib.contextmanager + def cm(): + events.append("enter") + try: + yield + finally: + events.append("exit") + + async def inner() -> AsyncIterable[int]: + yield 1 + + async def make_stream() -> ResponseStream[int, list[int]]: + # Record that we resolve while a pull context is active. + events.append("resolving") + return ResponseStream(inner(), finalizer=lambda updates: list(updates)) + + stream: ResponseStream[int, list[int]] = ResponseStream.from_awaitable(make_stream()) + stream.with_pull_context_manager(cm) + + await stream # Triggers _resolve_stream_with_pull_contexts via __await__ + + assert "resolving" in events + resolve_index = events.index("resolving") + assert events[resolve_index - 1] == "enter" # Pull context active during resolution + + +# region Test streaming telemetry error paths + + +@pytest.mark.parametrize("enable_sensitive_data", [True], indirect=True) +async def test_chat_streaming_super_failure_closes_span(span_exporter: InMemorySpanExporter, enable_sensitive_data): + """If the underlying client raises synchronously when constructing the stream, the chat + span is ended and the exception is recorded (no span leak).""" + + class FailingClient(ChatTelemetryLayer, BaseChatClient[Any]): + def service_url(self): + return "https://test.example.com" + + def _inner_get_response( + self, *, messages: MutableSequence[Message], stream: bool, options: dict[str, Any], **kwargs: Any + ) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]: + raise RuntimeError("inner failed") + + span_exporter.clear() + client = FailingClient() + with pytest.raises(RuntimeError, match="inner failed"): + client.get_response(stream=True, messages=[Message(role="user", contents=["Test"])], options={"model": "Test"}) + + spans = span_exporter.get_finished_spans() + chat_spans = [s for s in spans if s.attributes.get(OtelAttr.OPERATION.value) == OtelAttr.CHAT_COMPLETION_OPERATION] + assert len(chat_spans) == 1 + assert chat_spans[0].status.status_code == StatusCode.ERROR + + +@pytest.mark.parametrize("enable_sensitive_data", [True], indirect=True) +async def test_agent_streaming_execute_failure_closes_span_and_resets_contextvars( + span_exporter: InMemorySpanExporter, enable_sensitive_data +): + """If ``execute()`` raises synchronously during streaming agent invocation, the agent span is + ended, the exception is recorded, and the telemetry contextvars are reset.""" + from agent_framework.observability import ( + INNER_ACCUMULATED_USAGE, + INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS, + ) + + class _FailingExecuteAgent: + AGENT_PROVIDER_NAME = "test_provider" + + def __init__(self): + self._id = "failing_execute" + self._name = "Failing Execute" + self._description = "Agent whose stream call raises synchronously" + self._default_options: dict[str, Any] = {} + + @property + def id(self): + return self._id + + @property + def name(self): + return self._name + + @property + def description(self): + return self._description + + @property + def default_options(self): + return self._default_options + + def run(self, messages=None, *, stream: bool = False, session=None, **kwargs): + if stream: + raise RuntimeError("execute failed") + raise NotImplementedError + + class FailingExecuteAgent(AgentTelemetryLayer, _FailingExecuteAgent): + pass + + # Sentinel values to detect that contextvars were reset to their pre-call state. + sentinel_fields: set[str] = set() + sentinel_usage: dict[str, Any] = {} + fields_token = INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS.set(sentinel_fields) + usage_token = INNER_ACCUMULATED_USAGE.set(sentinel_usage) + try: + agent = FailingExecuteAgent() + span_exporter.clear() + with pytest.raises(RuntimeError, match="execute failed"): + agent.run(messages="Hello", stream=True) + + # Contextvars must be back to the sentinel values registered before the call. + assert INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS.get() is sentinel_fields + assert INNER_ACCUMULATED_USAGE.get() is sentinel_usage + finally: + INNER_ACCUMULATED_USAGE.reset(usage_token) + INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS.reset(fields_token) + + spans = span_exporter.get_finished_spans() + agent_spans = [s for s in spans if s.attributes.get(OtelAttr.OPERATION.value) == OtelAttr.AGENT_INVOKE_OPERATION] + assert len(agent_spans) == 1 + assert agent_spans[0].status.status_code == StatusCode.ERROR