mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Address comments
This commit is contained in:
@@ -3047,9 +3047,25 @@ class ResponseStream(AsyncIterable[UpdateT], Generic[UpdateT, FinalT]):
|
||||
update = hooked
|
||||
return update
|
||||
|
||||
async def _resolve_stream_with_pull_contexts(self) -> AsyncIterable[UpdateT]:
|
||||
"""Resolve the underlying stream while activating any registered pull context managers.
|
||||
|
||||
Used by ``__await__`` and ``get_final_response`` so that any spans/contexts created
|
||||
during stream resolution (e.g. when the source is an Awaitable that internally
|
||||
creates child telemetry spans) inherit the same active context as iterator pulls.
|
||||
``__anext__`` resolves the stream inside its own ExitStack and so calls ``_get_stream``
|
||||
directly.
|
||||
"""
|
||||
if self._stream is not None:
|
||||
return await self._get_stream()
|
||||
with contextlib.ExitStack() as stack:
|
||||
for factory in self._pull_context_manager_factories:
|
||||
stack.enter_context(factory())
|
||||
return await self._get_stream()
|
||||
|
||||
def __await__(self) -> Any:
|
||||
async def _wrap() -> ResponseStream[UpdateT, FinalT]:
|
||||
await self._get_stream()
|
||||
await self._resolve_stream_with_pull_contexts()
|
||||
return self
|
||||
|
||||
return _wrap().__await__()
|
||||
@@ -3073,10 +3089,12 @@ class ResponseStream(AsyncIterable[UpdateT], Generic[UpdateT, FinalT]):
|
||||
"""
|
||||
if self._wrap_inner:
|
||||
if self._inner_stream is None:
|
||||
# Use _get_stream() to resolve the awaitable - this properly handles
|
||||
# Use _resolve_stream_with_pull_contexts() so that any spans/contexts
|
||||
# created while resolving the awaitable (e.g. inner telemetry spans)
|
||||
# inherit the same active context as iterator pulls. This also handles
|
||||
# the case where _stream_source and _inner_stream_source are the same
|
||||
# coroutine (e.g., from from_awaitable), avoiding double-await errors.
|
||||
await self._get_stream()
|
||||
await self._resolve_stream_with_pull_contexts()
|
||||
if self._inner_stream is None:
|
||||
raise RuntimeError("Inner stream not available")
|
||||
if not self._finalized and not self._consumed:
|
||||
|
||||
@@ -1301,18 +1301,23 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
|
||||
def _record_duration() -> None:
|
||||
duration_state["duration"] = perf_counter() - start_time
|
||||
|
||||
result_stream = cast(
|
||||
ResponseStream[ChatResponseUpdate, ChatResponse[Any]],
|
||||
super_get_response(
|
||||
messages=messages,
|
||||
stream=True,
|
||||
options=opts,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=merged_client_kwargs,
|
||||
),
|
||||
)
|
||||
try:
|
||||
result_stream = cast(
|
||||
ResponseStream[ChatResponseUpdate, ChatResponse[Any]],
|
||||
super_get_response(
|
||||
messages=messages,
|
||||
stream=True,
|
||||
options=opts,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=merged_client_kwargs,
|
||||
),
|
||||
)
|
||||
except Exception as exception:
|
||||
capture_exception(span=span, exception=exception, timestamp=time_ns())
|
||||
_close_span()
|
||||
raise
|
||||
|
||||
async def _finalize_stream() -> None:
|
||||
from ._types import ChatResponse
|
||||
@@ -1576,7 +1581,8 @@ class AgentTelemetryLayer:
|
||||
result_stream = ResponseStream.from_awaitable(run_result) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
|
||||
else:
|
||||
raise RuntimeError("Streaming telemetry requires a ResponseStream result.")
|
||||
except Exception:
|
||||
except Exception as exception:
|
||||
capture_exception(span=span, exception=exception, timestamp=time_ns())
|
||||
INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS.reset(inner_response_telemetry_captured_fields_token)
|
||||
INNER_ACCUMULATED_USAGE.reset(inner_accumulated_usage_token)
|
||||
_close_span()
|
||||
|
||||
@@ -3603,3 +3603,197 @@ async def test_http_span_nested_under_chat_span(span_exporter: InMemorySpanExpor
|
||||
assert http_span.parent is not None
|
||||
assert http_span.parent.span_id == chat_span.context.span_id
|
||||
assert http_span.context.trace_id == chat_span.context.trace_id
|
||||
|
||||
|
||||
# region Test ResponseStream.with_pull_context_manager
|
||||
|
||||
|
||||
async def test_with_pull_context_manager_enters_and_exits_per_pull():
|
||||
"""The registered factory is entered and exited symmetrically around each iterator pull."""
|
||||
import contextlib
|
||||
|
||||
events: list[str] = []
|
||||
|
||||
@contextlib.contextmanager
|
||||
def cm():
|
||||
events.append("enter")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
events.append("exit")
|
||||
|
||||
async def src() -> AsyncIterable[int]:
|
||||
yield 1
|
||||
yield 2
|
||||
|
||||
stream: ResponseStream[int, list[int]] = ResponseStream(src(), finalizer=lambda updates: list(updates))
|
||||
stream.with_pull_context_manager(cm)
|
||||
|
||||
pulled = [u async for u in stream]
|
||||
|
||||
assert pulled == [1, 2]
|
||||
# Enter/exit must be balanced and there must be at least one pair per yielded update.
|
||||
assert events.count("enter") == events.count("exit")
|
||||
assert events.count("enter") >= 2
|
||||
# Verify symmetric ordering (no overlapping pairs).
|
||||
for i in range(0, len(events), 2):
|
||||
assert events[i] == "enter"
|
||||
assert events[i + 1] == "exit"
|
||||
|
||||
|
||||
async def test_with_pull_context_manager_exits_on_iteration_error():
|
||||
"""The pull context is exited even when the underlying stream raises mid-iteration."""
|
||||
import contextlib
|
||||
|
||||
events: list[str] = []
|
||||
|
||||
@contextlib.contextmanager
|
||||
def cm():
|
||||
events.append("enter")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
events.append("exit")
|
||||
|
||||
async def src() -> AsyncIterable[int]:
|
||||
yield 1
|
||||
raise RuntimeError("boom")
|
||||
|
||||
stream: ResponseStream[int, list[int]] = ResponseStream(src(), finalizer=lambda updates: list(updates))
|
||||
stream.with_pull_context_manager(cm)
|
||||
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
async for _ in stream:
|
||||
pass
|
||||
|
||||
# Enter/exit balanced even on the failing pull.
|
||||
assert events.count("enter") == events.count("exit")
|
||||
assert events.count("enter") >= 2
|
||||
|
||||
|
||||
async def test_with_pull_context_manager_wraps_stream_resolution_via_await():
|
||||
"""Awaiting a ``from_awaitable`` stream resolves the inner stream under the pull contexts."""
|
||||
import contextlib
|
||||
|
||||
events: list[str] = []
|
||||
|
||||
@contextlib.contextmanager
|
||||
def cm():
|
||||
events.append("enter")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
events.append("exit")
|
||||
|
||||
async def inner() -> AsyncIterable[int]:
|
||||
yield 1
|
||||
|
||||
async def make_stream() -> ResponseStream[int, list[int]]:
|
||||
# Record that we resolve while a pull context is active.
|
||||
events.append("resolving")
|
||||
return ResponseStream(inner(), finalizer=lambda updates: list(updates))
|
||||
|
||||
stream: ResponseStream[int, list[int]] = ResponseStream.from_awaitable(make_stream())
|
||||
stream.with_pull_context_manager(cm)
|
||||
|
||||
await stream # Triggers _resolve_stream_with_pull_contexts via __await__
|
||||
|
||||
assert "resolving" in events
|
||||
resolve_index = events.index("resolving")
|
||||
assert events[resolve_index - 1] == "enter" # Pull context active during resolution
|
||||
|
||||
|
||||
# region Test streaming telemetry error paths
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enable_sensitive_data", [True], indirect=True)
|
||||
async def test_chat_streaming_super_failure_closes_span(span_exporter: InMemorySpanExporter, enable_sensitive_data):
|
||||
"""If the underlying client raises synchronously when constructing the stream, the chat
|
||||
span is ended and the exception is recorded (no span leak)."""
|
||||
|
||||
class FailingClient(ChatTelemetryLayer, BaseChatClient[Any]):
|
||||
def service_url(self):
|
||||
return "https://test.example.com"
|
||||
|
||||
def _inner_get_response(
|
||||
self, *, messages: MutableSequence[Message], stream: bool, options: dict[str, Any], **kwargs: Any
|
||||
) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]:
|
||||
raise RuntimeError("inner failed")
|
||||
|
||||
span_exporter.clear()
|
||||
client = FailingClient()
|
||||
with pytest.raises(RuntimeError, match="inner failed"):
|
||||
client.get_response(stream=True, messages=[Message(role="user", contents=["Test"])], options={"model": "Test"})
|
||||
|
||||
spans = span_exporter.get_finished_spans()
|
||||
chat_spans = [s for s in spans if s.attributes.get(OtelAttr.OPERATION.value) == OtelAttr.CHAT_COMPLETION_OPERATION]
|
||||
assert len(chat_spans) == 1
|
||||
assert chat_spans[0].status.status_code == StatusCode.ERROR
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enable_sensitive_data", [True], indirect=True)
|
||||
async def test_agent_streaming_execute_failure_closes_span_and_resets_contextvars(
|
||||
span_exporter: InMemorySpanExporter, enable_sensitive_data
|
||||
):
|
||||
"""If ``execute()`` raises synchronously during streaming agent invocation, the agent span is
|
||||
ended, the exception is recorded, and the telemetry contextvars are reset."""
|
||||
from agent_framework.observability import (
|
||||
INNER_ACCUMULATED_USAGE,
|
||||
INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS,
|
||||
)
|
||||
|
||||
class _FailingExecuteAgent:
|
||||
AGENT_PROVIDER_NAME = "test_provider"
|
||||
|
||||
def __init__(self):
|
||||
self._id = "failing_execute"
|
||||
self._name = "Failing Execute"
|
||||
self._description = "Agent whose stream call raises synchronously"
|
||||
self._default_options: dict[str, Any] = {}
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return self._description
|
||||
|
||||
@property
|
||||
def default_options(self):
|
||||
return self._default_options
|
||||
|
||||
def run(self, messages=None, *, stream: bool = False, session=None, **kwargs):
|
||||
if stream:
|
||||
raise RuntimeError("execute failed")
|
||||
raise NotImplementedError
|
||||
|
||||
class FailingExecuteAgent(AgentTelemetryLayer, _FailingExecuteAgent):
|
||||
pass
|
||||
|
||||
# Sentinel values to detect that contextvars were reset to their pre-call state.
|
||||
sentinel_fields: set[str] = set()
|
||||
sentinel_usage: dict[str, Any] = {}
|
||||
fields_token = INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS.set(sentinel_fields)
|
||||
usage_token = INNER_ACCUMULATED_USAGE.set(sentinel_usage)
|
||||
try:
|
||||
agent = FailingExecuteAgent()
|
||||
span_exporter.clear()
|
||||
with pytest.raises(RuntimeError, match="execute failed"):
|
||||
agent.run(messages="Hello", stream=True)
|
||||
|
||||
# Contextvars must be back to the sentinel values registered before the call.
|
||||
assert INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS.get() is sentinel_fields
|
||||
assert INNER_ACCUMULATED_USAGE.get() is sentinel_usage
|
||||
finally:
|
||||
INNER_ACCUMULATED_USAGE.reset(usage_token)
|
||||
INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS.reset(fields_token)
|
||||
|
||||
spans = span_exporter.get_finished_spans()
|
||||
agent_spans = [s for s in spans if s.attributes.get(OtelAttr.OPERATION.value) == OtelAttr.AGENT_INVOKE_OPERATION]
|
||||
assert len(agent_spans) == 1
|
||||
assert agent_spans[0].status.status_code == StatusCode.ERROR
|
||||
|
||||
Reference in New Issue
Block a user