mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
WIP: debugging empty message bug
This commit is contained in:
@@ -29,7 +29,7 @@ from .._types import (
|
||||
UsageDetails,
|
||||
add_usage_details,
|
||||
)
|
||||
from ..exceptions import AgentInvalidRequestException, AgentInvalidResponseException
|
||||
from ..exceptions import AgentException, AgentInvalidRequestException, AgentInvalidResponseException
|
||||
from ._checkpoint import CheckpointStorage
|
||||
from ._events import (
|
||||
AGENT_FORWARDED_EVENT_TYPES,
|
||||
@@ -177,7 +177,7 @@ class WorkflowAgent(BaseAgent):
|
||||
|
||||
Args:
|
||||
messages: The message(s) to send to the workflow. Required for new runs,
|
||||
should be None when resuming from checkpoint.
|
||||
could be None if only restoring the underlying workflow from a checkpoint.
|
||||
|
||||
Keyword Args:
|
||||
stream: If True, returns an async iterable of updates. If False (default),
|
||||
@@ -411,34 +411,29 @@ class WorkflowAgent(BaseAgent):
|
||||
Yields:
|
||||
WorkflowEvent objects from the workflow execution.
|
||||
"""
|
||||
final_state: WorkflowRunState | None = None
|
||||
# Restore the workflow state if a checkpoint is provided
|
||||
if checkpoint_id is not None:
|
||||
if checkpoint_storage is None:
|
||||
raise AgentInvalidRequestException("checkpoint_storage must be provided when checkpoint_id is provided")
|
||||
logger.debug(f"Restoring workflow from checkpoint {checkpoint_id}")
|
||||
# Restore the workflow from checkpoint
|
||||
if streaming:
|
||||
async for event in self.workflow.run(
|
||||
async for _ in self.workflow.run(
|
||||
stream=True,
|
||||
checkpoint_id=checkpoint_id,
|
||||
checkpoint_storage=checkpoint_storage,
|
||||
):
|
||||
if event.type == "status":
|
||||
final_state = event.state
|
||||
pass
|
||||
else:
|
||||
run_result = await self.workflow.run(
|
||||
_ = await self.workflow.run(
|
||||
checkpoint_id=checkpoint_id,
|
||||
checkpoint_storage=checkpoint_storage,
|
||||
)
|
||||
final_state = run_result.get_final_state()
|
||||
if not input_messages:
|
||||
logger.info("No input messages provided; the workflow has been restored to the checkpoint state.")
|
||||
return
|
||||
|
||||
if final_state is None:
|
||||
raise AgentInvalidRequestException(
|
||||
"Workflow did not emit a final state event. Unable to determine workflow completion."
|
||||
)
|
||||
|
||||
# Set the default final state to IDLE if checkpoint was not provided
|
||||
final_state = final_state or WorkflowRunState.IDLE
|
||||
final_state = self._workflow.status
|
||||
logger.debug(f"Workflow state: {final_state}")
|
||||
|
||||
if final_state == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS:
|
||||
@@ -453,6 +448,7 @@ class WorkflowAgent(BaseAgent):
|
||||
async for event in self.workflow.run(
|
||||
responses=function_responses,
|
||||
stream=True,
|
||||
checkpoint_storage=checkpoint_storage,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
):
|
||||
@@ -460,11 +456,12 @@ class WorkflowAgent(BaseAgent):
|
||||
else:
|
||||
for event in await self.workflow.run(
|
||||
responses=function_responses,
|
||||
checkpoint_storage=checkpoint_storage,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
):
|
||||
yield event
|
||||
else:
|
||||
elif final_state == WorkflowRunState.IDLE:
|
||||
if streaming:
|
||||
async for event in self.workflow.run(
|
||||
message=input_messages,
|
||||
@@ -482,6 +479,8 @@ class WorkflowAgent(BaseAgent):
|
||||
client_kwargs=client_kwargs,
|
||||
):
|
||||
yield event
|
||||
else:
|
||||
raise AgentException(f"The underlying workflow is in an invalid state to restart: {final_state}.")
|
||||
|
||||
# endregion Run Methods
|
||||
|
||||
@@ -725,6 +724,11 @@ class WorkflowAgent(BaseAgent):
|
||||
for message in input_messages:
|
||||
for content in message.contents:
|
||||
if content.type == "function_approval_response":
|
||||
request_id = content.additional_properties.get("request_id")
|
||||
if not request_id:
|
||||
raise AgentInvalidResponseException(
|
||||
"FunctionApprovalResponseContent must have a request_id in additional_properties."
|
||||
)
|
||||
# Parse the function arguments to recover request payload
|
||||
arguments_payload = content.function_call.arguments # type: ignore[attr-defined, union-attr]
|
||||
if isinstance(arguments_payload, str):
|
||||
@@ -740,8 +744,7 @@ class WorkflowAgent(BaseAgent):
|
||||
raise AgentInvalidResponseException(
|
||||
"FunctionApprovalResponseContent arguments must be a mapping or JSON string."
|
||||
)
|
||||
|
||||
function_responses[parsed_args.request_id] = parsed_args.data
|
||||
function_responses[request_id] = parsed_args.data
|
||||
elif content.type == "function_result":
|
||||
response_data = content.result if hasattr(content, "result") else str(content) # type: ignore[attr-defined]
|
||||
function_responses[content.call_id] = response_data # pyright: ignore[reportArgumentType]
|
||||
|
||||
@@ -360,6 +360,22 @@ class Workflow(DictConvertible):
|
||||
# Flag to prevent concurrent workflow executions
|
||||
self._is_running = False
|
||||
|
||||
# Current run-level status of this workflow instance. Updated in lockstep with
|
||||
# the status events emitted from `_run_workflow_with_tracing`. Defaults to IDLE
|
||||
# for a freshly built workflow that has not yet been run.
|
||||
self._status: WorkflowRunState = WorkflowRunState.IDLE
|
||||
|
||||
@property
|
||||
def status(self) -> WorkflowRunState:
|
||||
"""Return the current run-level status of this workflow instance.
|
||||
|
||||
Mirrors the most recent status event emitted by the workflow. Safe to read at
|
||||
any time: workflows run on a single asyncio event loop, and the underlying
|
||||
attribute is a single enum reference whose assignment is atomic under the
|
||||
CPython GIL, so no locking is required.
|
||||
"""
|
||||
return self._status
|
||||
|
||||
def _ensure_not_running(self) -> None:
|
||||
"""Ensure the workflow is not already running."""
|
||||
if self._is_running:
|
||||
@@ -513,8 +529,9 @@ class Workflow(DictConvertible):
|
||||
with _framework_event_origin():
|
||||
started = WorkflowEvent.started()
|
||||
yield started # noqa: RUF070
|
||||
self._status = WorkflowRunState.IN_PROGRESS
|
||||
with _framework_event_origin():
|
||||
in_progress = WorkflowEvent.status(WorkflowRunState.IN_PROGRESS)
|
||||
in_progress = WorkflowEvent.status(self._status)
|
||||
yield in_progress # noqa: RUF070
|
||||
|
||||
# Per-run reset for fresh-message runs only. We deliberately
|
||||
@@ -569,17 +586,20 @@ class Workflow(DictConvertible):
|
||||
|
||||
if event.type == "request_info" and not emitted_in_progress_pending:
|
||||
emitted_in_progress_pending = True
|
||||
self._status = WorkflowRunState.IN_PROGRESS_PENDING_REQUESTS
|
||||
with _framework_event_origin():
|
||||
pending_status = WorkflowEvent.status(WorkflowRunState.IN_PROGRESS_PENDING_REQUESTS)
|
||||
pending_status = WorkflowEvent.status(self._status)
|
||||
yield pending_status # noqa: RUF070
|
||||
# Workflow runs until idle - emit final status based on whether requests are pending
|
||||
if saw_request:
|
||||
self._status = WorkflowRunState.IDLE_WITH_PENDING_REQUESTS
|
||||
with _framework_event_origin():
|
||||
terminal_status = WorkflowEvent.status(WorkflowRunState.IDLE_WITH_PENDING_REQUESTS)
|
||||
terminal_status = WorkflowEvent.status(self._status)
|
||||
yield terminal_status
|
||||
else:
|
||||
self._status = WorkflowRunState.IDLE
|
||||
with _framework_event_origin():
|
||||
terminal_status = WorkflowEvent.status(WorkflowRunState.IDLE)
|
||||
terminal_status = WorkflowEvent.status(self._status)
|
||||
yield terminal_status
|
||||
|
||||
span.add_event(OtelAttr.WORKFLOW_COMPLETED)
|
||||
@@ -593,6 +613,7 @@ class Workflow(DictConvertible):
|
||||
with _framework_event_origin():
|
||||
failed_event = WorkflowEvent.failed(details)
|
||||
yield failed_event # noqa: RUF070
|
||||
self._status = WorkflowRunState.FAILED
|
||||
with _framework_event_origin():
|
||||
failed_status = WorkflowEvent.status(WorkflowRunState.FAILED)
|
||||
yield failed_status # noqa: RUF070
|
||||
|
||||
@@ -263,8 +263,9 @@ class TestWorkflowAgent:
|
||||
assert approval_request.function_call.name == function_call.name
|
||||
|
||||
# Verify the request is tracked in pending_requests
|
||||
assert len(agent.pending_requests) == 1
|
||||
assert function_call.call_id in agent.pending_requests
|
||||
pending_requests = await workflow._runner_context.get_pending_request_info_events()
|
||||
assert len(pending_requests) == 1
|
||||
assert function_call.call_id in pending_requests
|
||||
|
||||
# Now provide an approval response with updated arguments to test continuation
|
||||
response_args = WorkflowAgent.RequestInfoFunctionArgs(
|
||||
@@ -291,7 +292,8 @@ class TestWorkflowAgent:
|
||||
assert isinstance(continuation_result, AgentResponse)
|
||||
|
||||
# Verify cleanup - pending requests should be cleared after function response handling
|
||||
assert len(agent.pending_requests) == 0
|
||||
pending_requests = await workflow._runner_context.get_pending_request_info_events()
|
||||
assert len(pending_requests) == 0
|
||||
|
||||
def test_workflow_as_agent_method(self) -> None:
|
||||
"""Test that Workflow.as_agent() creates a properly configured WorkflowAgent."""
|
||||
|
||||
@@ -0,0 +1,149 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Tests for the ``Workflow.status`` property."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from agent_framework import (
|
||||
Executor,
|
||||
Workflow,
|
||||
WorkflowBuilder,
|
||||
WorkflowContext,
|
||||
WorkflowEvent,
|
||||
WorkflowRunState,
|
||||
handler,
|
||||
response_handler,
|
||||
)
|
||||
from agent_framework._workflows._executor import Executor as _Executor
|
||||
from agent_framework._workflows._request_info_mixin import RequestInfoMixin
|
||||
|
||||
|
||||
class PassThroughExecutor(Executor):
|
||||
"""Executor that yields its input as a workflow output and stops."""
|
||||
|
||||
@handler
|
||||
async def passthrough(self, msg: str, ctx: WorkflowContext[str, str]) -> None:
|
||||
await ctx.yield_output(msg)
|
||||
|
||||
|
||||
class FailingExecutor(Executor):
|
||||
"""Executor that raises at runtime to drive the FAILED status."""
|
||||
|
||||
@handler
|
||||
async def fail(self, msg: int, ctx: WorkflowContext) -> None: # pragma: no cover - invoked via workflow
|
||||
raise RuntimeError("boom")
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ApprovalRequest:
|
||||
prompt: str
|
||||
request_id: str = ""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.request_id:
|
||||
import uuid
|
||||
|
||||
self.request_id = str(uuid.uuid4())
|
||||
|
||||
|
||||
class ApprovalExecutor(_Executor, RequestInfoMixin):
|
||||
"""Executor that issues a single request_info call and finalizes on response."""
|
||||
|
||||
def __init__(self, id: str = "approval"):
|
||||
super().__init__(id=id)
|
||||
|
||||
@handler
|
||||
async def start(self, message: str, ctx: WorkflowContext[str, str]) -> None:
|
||||
await ctx.request_info(_ApprovalRequest(prompt=message), bool)
|
||||
|
||||
@response_handler
|
||||
async def on_response(
|
||||
self, original_request: _ApprovalRequest, approved: bool, ctx: WorkflowContext[str, str]
|
||||
) -> None:
|
||||
await ctx.yield_output(f"approved={approved}")
|
||||
|
||||
|
||||
def _build_passthrough_workflow() -> Workflow:
|
||||
executor = PassThroughExecutor(id="p")
|
||||
return WorkflowBuilder(start_executor=executor, output_from=[executor]).build()
|
||||
|
||||
|
||||
def _build_failing_workflow() -> Workflow:
|
||||
# FailingExecutor has no workflow_output_types, so we leave designation
|
||||
# implicit; the deprecation warning is filtered at call sites that need it.
|
||||
return WorkflowBuilder(start_executor=FailingExecutor(id="f")).build()
|
||||
|
||||
|
||||
def _build_approval_workflow() -> Workflow:
|
||||
executor = ApprovalExecutor(id="approval")
|
||||
return WorkflowBuilder(start_executor=executor, output_from=[executor]).build()
|
||||
|
||||
|
||||
async def test_status_default_is_idle_before_first_run():
|
||||
wf = _build_passthrough_workflow()
|
||||
assert wf.status is WorkflowRunState.IDLE
|
||||
|
||||
|
||||
async def test_status_is_idle_after_successful_run():
|
||||
wf = _build_passthrough_workflow()
|
||||
await wf.run("hello")
|
||||
assert wf.status is WorkflowRunState.IDLE
|
||||
|
||||
|
||||
async def test_status_is_failed_after_failure():
|
||||
wf = _build_failing_workflow()
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
await wf.run(0)
|
||||
assert wf.status is WorkflowRunState.FAILED
|
||||
|
||||
|
||||
async def test_status_transitions_during_streaming_run():
|
||||
"""Workflow.status mirrors the most recent emitted status event."""
|
||||
wf = _build_passthrough_workflow()
|
||||
observed: list[WorkflowRunState] = []
|
||||
|
||||
async for event in wf.run("hi", stream=True):
|
||||
if isinstance(event, WorkflowEvent) and event.type == "status":
|
||||
# By the time a status event surfaces to the consumer, the property
|
||||
# must already reflect that state (updated in lockstep with emission).
|
||||
assert wf.status == event.state
|
||||
observed.append(event.state) # type: ignore
|
||||
|
||||
# IN_PROGRESS must precede IDLE; both must appear.
|
||||
assert WorkflowRunState.IN_PROGRESS in observed
|
||||
assert observed[-1] is WorkflowRunState.IDLE
|
||||
assert wf.status is WorkflowRunState.IDLE
|
||||
|
||||
|
||||
async def test_status_idle_with_pending_requests_then_resolves_to_idle():
|
||||
wf = _build_approval_workflow()
|
||||
|
||||
request_event: WorkflowEvent | None = None
|
||||
async for event in wf.run("please approve", stream=True):
|
||||
if isinstance(event, WorkflowEvent) and event.type == "request_info":
|
||||
request_event = event
|
||||
|
||||
assert request_event is not None
|
||||
assert wf.status is WorkflowRunState.IDLE_WITH_PENDING_REQUESTS
|
||||
|
||||
async for _ in wf.run(stream=True, responses={request_event.request_id: True}):
|
||||
pass
|
||||
|
||||
assert wf.status is WorkflowRunState.IDLE
|
||||
|
||||
|
||||
async def test_status_in_progress_pending_requests_observed_mid_run():
|
||||
"""While streaming, status reaches IN_PROGRESS_PENDING_REQUESTS after a request_info event."""
|
||||
wf = _build_approval_workflow()
|
||||
seen_states: list[WorkflowRunState] = []
|
||||
|
||||
async for event in wf.run("please approve", stream=True):
|
||||
if isinstance(event, WorkflowEvent) and event.type == "status":
|
||||
seen_states.append(event.state) # type: ignore
|
||||
|
||||
assert WorkflowRunState.IN_PROGRESS in seen_states
|
||||
assert WorkflowRunState.IN_PROGRESS_PENDING_REQUESTS in seen_states
|
||||
assert seen_states[-1] is WorkflowRunState.IDLE_WITH_PENDING_REQUESTS
|
||||
assert wf.status is WorkflowRunState.IDLE_WITH_PENDING_REQUESTS
|
||||
@@ -155,7 +155,7 @@ async def main() -> None:
|
||||
print("=== Demonstration of a tool with approvals ===\n")
|
||||
|
||||
await run_weather_agent_with_approval(stream=False)
|
||||
await run_weather_agent_with_approval(stream=True)
|
||||
# await run_weather_agent_with_approval(stream=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Generated
+4069
-4082
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user