Address comments

This commit is contained in:
Tao Chen
2026-04-28 21:37:50 -07:00
Unverified
parent b62ce6dc15
commit b9926c095b
3 changed files with 234 additions and 16 deletions
+21 -3
View File
@@ -3047,9 +3047,25 @@ class ResponseStream(AsyncIterable[UpdateT], Generic[UpdateT, FinalT]):
update = hooked
return update
async def _resolve_stream_with_pull_contexts(self) -> AsyncIterable[UpdateT]:
"""Resolve the underlying stream while activating any registered pull context managers.
Used by ``__await__`` and ``get_final_response`` so that any spans/contexts created
during stream resolution (e.g. when the source is an Awaitable that internally
creates child telemetry spans) inherit the same active context as iterator pulls.
``__anext__`` resolves the stream inside its own ExitStack and so calls ``_get_stream``
directly.
"""
if self._stream is not None:
return await self._get_stream()
with contextlib.ExitStack() as stack:
for factory in self._pull_context_manager_factories:
stack.enter_context(factory())
return await self._get_stream()
def __await__(self) -> Any:
async def _wrap() -> ResponseStream[UpdateT, FinalT]:
await self._get_stream()
await self._resolve_stream_with_pull_contexts()
return self
return _wrap().__await__()
@@ -3073,10 +3089,12 @@ class ResponseStream(AsyncIterable[UpdateT], Generic[UpdateT, FinalT]):
"""
if self._wrap_inner:
if self._inner_stream is None:
# Use _get_stream() to resolve the awaitable - this properly handles
# Use _resolve_stream_with_pull_contexts() so that any spans/contexts
# created while resolving the awaitable (e.g. inner telemetry spans)
# inherit the same active context as iterator pulls. This also handles
# the case where _stream_source and _inner_stream_source are the same
# coroutine (e.g., from from_awaitable), avoiding double-await errors.
await self._get_stream()
await self._resolve_stream_with_pull_contexts()
if self._inner_stream is None:
raise RuntimeError("Inner stream not available")
if not self._finalized and not self._consumed:
@@ -1301,18 +1301,23 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
def _record_duration() -> None:
duration_state["duration"] = perf_counter() - start_time
result_stream = cast(
ResponseStream[ChatResponseUpdate, ChatResponse[Any]],
super_get_response(
messages=messages,
stream=True,
options=opts,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=merged_client_kwargs,
),
)
try:
result_stream = cast(
ResponseStream[ChatResponseUpdate, ChatResponse[Any]],
super_get_response(
messages=messages,
stream=True,
options=opts,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=merged_client_kwargs,
),
)
except Exception as exception:
capture_exception(span=span, exception=exception, timestamp=time_ns())
_close_span()
raise
async def _finalize_stream() -> None:
from ._types import ChatResponse
@@ -1576,7 +1581,8 @@ class AgentTelemetryLayer:
result_stream = ResponseStream.from_awaitable(run_result) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
else:
raise RuntimeError("Streaming telemetry requires a ResponseStream result.")
except Exception:
except Exception as exception:
capture_exception(span=span, exception=exception, timestamp=time_ns())
INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS.reset(inner_response_telemetry_captured_fields_token)
INNER_ACCUMULATED_USAGE.reset(inner_accumulated_usage_token)
_close_span()
@@ -3603,3 +3603,197 @@ async def test_http_span_nested_under_chat_span(span_exporter: InMemorySpanExpor
assert http_span.parent is not None
assert http_span.parent.span_id == chat_span.context.span_id
assert http_span.context.trace_id == chat_span.context.trace_id
# region Test ResponseStream.with_pull_context_manager
async def test_with_pull_context_manager_enters_and_exits_per_pull():
"""The registered factory is entered and exited symmetrically around each iterator pull."""
import contextlib
events: list[str] = []
@contextlib.contextmanager
def cm():
events.append("enter")
try:
yield
finally:
events.append("exit")
async def src() -> AsyncIterable[int]:
yield 1
yield 2
stream: ResponseStream[int, list[int]] = ResponseStream(src(), finalizer=lambda updates: list(updates))
stream.with_pull_context_manager(cm)
pulled = [u async for u in stream]
assert pulled == [1, 2]
# Enter/exit must be balanced and there must be at least one pair per yielded update.
assert events.count("enter") == events.count("exit")
assert events.count("enter") >= 2
# Verify symmetric ordering (no overlapping pairs).
for i in range(0, len(events), 2):
assert events[i] == "enter"
assert events[i + 1] == "exit"
async def test_with_pull_context_manager_exits_on_iteration_error():
"""The pull context is exited even when the underlying stream raises mid-iteration."""
import contextlib
events: list[str] = []
@contextlib.contextmanager
def cm():
events.append("enter")
try:
yield
finally:
events.append("exit")
async def src() -> AsyncIterable[int]:
yield 1
raise RuntimeError("boom")
stream: ResponseStream[int, list[int]] = ResponseStream(src(), finalizer=lambda updates: list(updates))
stream.with_pull_context_manager(cm)
with pytest.raises(RuntimeError, match="boom"):
async for _ in stream:
pass
# Enter/exit balanced even on the failing pull.
assert events.count("enter") == events.count("exit")
assert events.count("enter") >= 2
async def test_with_pull_context_manager_wraps_stream_resolution_via_await():
"""Awaiting a ``from_awaitable`` stream resolves the inner stream under the pull contexts."""
import contextlib
events: list[str] = []
@contextlib.contextmanager
def cm():
events.append("enter")
try:
yield
finally:
events.append("exit")
async def inner() -> AsyncIterable[int]:
yield 1
async def make_stream() -> ResponseStream[int, list[int]]:
# Record that we resolve while a pull context is active.
events.append("resolving")
return ResponseStream(inner(), finalizer=lambda updates: list(updates))
stream: ResponseStream[int, list[int]] = ResponseStream.from_awaitable(make_stream())
stream.with_pull_context_manager(cm)
await stream # Triggers _resolve_stream_with_pull_contexts via __await__
assert "resolving" in events
resolve_index = events.index("resolving")
assert events[resolve_index - 1] == "enter" # Pull context active during resolution
# region Test streaming telemetry error paths
@pytest.mark.parametrize("enable_sensitive_data", [True], indirect=True)
async def test_chat_streaming_super_failure_closes_span(span_exporter: InMemorySpanExporter, enable_sensitive_data):
"""If the underlying client raises synchronously when constructing the stream, the chat
span is ended and the exception is recorded (no span leak)."""
class FailingClient(ChatTelemetryLayer, BaseChatClient[Any]):
def service_url(self):
return "https://test.example.com"
def _inner_get_response(
self, *, messages: MutableSequence[Message], stream: bool, options: dict[str, Any], **kwargs: Any
) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]:
raise RuntimeError("inner failed")
span_exporter.clear()
client = FailingClient()
with pytest.raises(RuntimeError, match="inner failed"):
client.get_response(stream=True, messages=[Message(role="user", contents=["Test"])], options={"model": "Test"})
spans = span_exporter.get_finished_spans()
chat_spans = [s for s in spans if s.attributes.get(OtelAttr.OPERATION.value) == OtelAttr.CHAT_COMPLETION_OPERATION]
assert len(chat_spans) == 1
assert chat_spans[0].status.status_code == StatusCode.ERROR
@pytest.mark.parametrize("enable_sensitive_data", [True], indirect=True)
async def test_agent_streaming_execute_failure_closes_span_and_resets_contextvars(
span_exporter: InMemorySpanExporter, enable_sensitive_data
):
"""If ``execute()`` raises synchronously during streaming agent invocation, the agent span is
ended, the exception is recorded, and the telemetry contextvars are reset."""
from agent_framework.observability import (
INNER_ACCUMULATED_USAGE,
INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS,
)
class _FailingExecuteAgent:
AGENT_PROVIDER_NAME = "test_provider"
def __init__(self):
self._id = "failing_execute"
self._name = "Failing Execute"
self._description = "Agent whose stream call raises synchronously"
self._default_options: dict[str, Any] = {}
@property
def id(self):
return self._id
@property
def name(self):
return self._name
@property
def description(self):
return self._description
@property
def default_options(self):
return self._default_options
def run(self, messages=None, *, stream: bool = False, session=None, **kwargs):
if stream:
raise RuntimeError("execute failed")
raise NotImplementedError
class FailingExecuteAgent(AgentTelemetryLayer, _FailingExecuteAgent):
pass
# Sentinel values to detect that contextvars were reset to their pre-call state.
sentinel_fields: set[str] = set()
sentinel_usage: dict[str, Any] = {}
fields_token = INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS.set(sentinel_fields)
usage_token = INNER_ACCUMULATED_USAGE.set(sentinel_usage)
try:
agent = FailingExecuteAgent()
span_exporter.clear()
with pytest.raises(RuntimeError, match="execute failed"):
agent.run(messages="Hello", stream=True)
# Contextvars must be back to the sentinel values registered before the call.
assert INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS.get() is sentinel_fields
assert INNER_ACCUMULATED_USAGE.get() is sentinel_usage
finally:
INNER_ACCUMULATED_USAGE.reset(usage_token)
INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS.reset(fields_token)
spans = span_exporter.get_finished_spans()
agent_spans = [s for s in spans if s.attributes.get(OtelAttr.OPERATION.value) == OtelAttr.AGENT_INVOKE_OPERATION]
assert len(agent_spans) == 1
assert agent_spans[0].status.status_code == StatusCode.ERROR