mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: (samples): adopt AzureOpenAIResponsesClient, reorganize orchestration examples, and fix workflow/orchestration bugs (#3873)
* adopt AzureOpenAIResponsesClient, reorganize orchestration examples, and fix workflow/orchestration bugs * Updates * add comment
This commit is contained in:
committed by
GitHub
Unverified
parent
8457533c69
commit
1b10b051fd
@@ -2,6 +2,7 @@
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
|
||||
@@ -292,10 +293,10 @@ class AgentExecutor(Executor):
|
||||
# Non-streaming mode: use run() and emit single event
|
||||
response = await self._run_agent(cast(WorkflowContext[Never, AgentResponse], ctx))
|
||||
|
||||
# Always extend full conversation with cached messages plus agent outputs
|
||||
# (agent_response.messages) after each run. This is to avoid losing context
|
||||
# when agent did not complete and the cache is cleared when responses come back.
|
||||
self._full_conversation.extend(list(self._cache) + (list(response.messages) if response else []))
|
||||
# Snapshot current conversation as cache + latest agent outputs.
|
||||
# Do not append to prior snapshots: callers may provide full-history messages
|
||||
# in request.messages, and extending would duplicate prior turns.
|
||||
self._full_conversation = list(self._cache) + (list(response.messages) if response else [])
|
||||
|
||||
if response is None:
|
||||
# Agent did not complete (e.g., waiting for user input); do not emit response
|
||||
@@ -315,12 +316,7 @@ class AgentExecutor(Executor):
|
||||
Returns:
|
||||
The complete AgentResponse, or None if waiting for user input.
|
||||
"""
|
||||
run_kwargs: dict[str, Any] = ctx.get_state(WORKFLOW_RUN_KWARGS_KEY, {})
|
||||
|
||||
# Build options dict with additional_function_arguments for tool kwargs propagation
|
||||
options: dict[str, Any] | None = None
|
||||
if run_kwargs:
|
||||
options = {"additional_function_arguments": run_kwargs}
|
||||
run_kwargs, options = self._prepare_agent_run_args(ctx.get_state(WORKFLOW_RUN_KWARGS_KEY, {}))
|
||||
|
||||
response = await self._agent.run(
|
||||
self._cache,
|
||||
@@ -349,12 +345,7 @@ class AgentExecutor(Executor):
|
||||
Returns:
|
||||
The complete AgentResponse, or None if waiting for user input.
|
||||
"""
|
||||
run_kwargs: dict[str, Any] = ctx.get_state(WORKFLOW_RUN_KWARGS_KEY) or {}
|
||||
|
||||
# Build options dict with additional_function_arguments for tool kwargs propagation
|
||||
options: dict[str, Any] | None = None
|
||||
if run_kwargs:
|
||||
options = {"additional_function_arguments": run_kwargs}
|
||||
run_kwargs, options = self._prepare_agent_run_args(ctx.get_state(WORKFLOW_RUN_KWARGS_KEY) or {})
|
||||
|
||||
updates: list[AgentResponseUpdate] = []
|
||||
user_input_requests: list[Content] = []
|
||||
@@ -389,3 +380,55 @@ class AgentExecutor(Executor):
|
||||
return None
|
||||
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def _prepare_agent_run_args(raw_run_kwargs: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any] | None]:
|
||||
"""Prepare kwargs and options for agent.run(), avoiding duplicate option passing.
|
||||
|
||||
Workflow-level kwargs are propagated to tool calls through
|
||||
`options.additional_function_arguments`. If workflow kwargs include an
|
||||
`options` key, merge it into the final options object and remove it from
|
||||
kwargs before spreading `**run_kwargs`.
|
||||
"""
|
||||
run_kwargs = dict(raw_run_kwargs)
|
||||
options_from_workflow = run_kwargs.pop("options", None)
|
||||
workflow_additional_args = run_kwargs.pop("additional_function_arguments", None)
|
||||
|
||||
options: dict[str, Any] = {}
|
||||
if options_from_workflow is not None:
|
||||
if isinstance(options_from_workflow, Mapping):
|
||||
for key, value in options_from_workflow.items():
|
||||
if isinstance(key, str):
|
||||
options[key] = value
|
||||
else:
|
||||
logger.warning(
|
||||
"Ignoring non-mapping workflow 'options' kwarg of type %s for AgentExecutor %s.",
|
||||
type(options_from_workflow).__name__,
|
||||
AgentExecutor.__name__,
|
||||
)
|
||||
|
||||
existing_additional_args = options.get("additional_function_arguments")
|
||||
if isinstance(existing_additional_args, Mapping):
|
||||
additional_args = {key: value for key, value in existing_additional_args.items() if isinstance(key, str)}
|
||||
else:
|
||||
additional_args = {}
|
||||
|
||||
if workflow_additional_args is not None:
|
||||
if isinstance(workflow_additional_args, Mapping):
|
||||
additional_args.update({
|
||||
key: value for key, value in workflow_additional_args.items() if isinstance(key, str)
|
||||
})
|
||||
else:
|
||||
logger.warning(
|
||||
"Ignoring non-mapping workflow 'additional_function_arguments' kwarg of type %s for AgentExecutor %s.", # noqa: E501
|
||||
type(workflow_additional_args).__name__,
|
||||
AgentExecutor.__name__,
|
||||
)
|
||||
|
||||
if run_kwargs:
|
||||
additional_args.update(run_kwargs)
|
||||
|
||||
if additional_args:
|
||||
options["additional_function_arguments"] = additional_args
|
||||
|
||||
return run_kwargs, options or None
|
||||
|
||||
@@ -190,6 +190,10 @@ class Runner:
|
||||
# Save executor states into the shared state before creating the checkpoint,
|
||||
# so that they are included in the checkpoint payload.
|
||||
await self._save_executor_states()
|
||||
# `on_checkpoint_save()` writes via State.set(), which stages values in the
|
||||
# pending buffer. Checkpoints serialize committed state only, so commit here
|
||||
# to ensure executor snapshots are captured in this checkpoint.
|
||||
self._state.commit()
|
||||
|
||||
checkpoint_id = await self._ctx.create_checkpoint(
|
||||
self._workflow_name,
|
||||
|
||||
@@ -942,8 +942,16 @@ class RawOpenAIResponsesClient( # type: ignore[misc]
|
||||
"""Prepare content for the OpenAI Responses API format."""
|
||||
match content.type:
|
||||
case "text":
|
||||
if role == "assistant":
|
||||
# Assistant history is represented as output text items; Azure validation
|
||||
# requires `annotations` to be present for this type.
|
||||
return {
|
||||
"type": "output_text",
|
||||
"text": content.text,
|
||||
"annotations": [],
|
||||
}
|
||||
return {
|
||||
"type": "output_text" if role == "assistant" else "input_text",
|
||||
"type": "input_text",
|
||||
"text": content.text,
|
||||
}
|
||||
case "text_reasoning":
|
||||
|
||||
@@ -677,6 +677,40 @@ def test_prepare_content_for_openai_hosted_vector_store_content() -> None:
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_prepare_content_for_openai_text_uses_role_specific_type() -> None:
|
||||
"""Text content should use input_text for user and output_text for assistant."""
|
||||
client = OpenAIResponsesClient(model_id="test-model", api_key="test-key")
|
||||
|
||||
text_content = Content.from_text(text="hello")
|
||||
|
||||
user_result = client._prepare_content_for_openai("user", text_content, {})
|
||||
assistant_result = client._prepare_content_for_openai("assistant", text_content, {})
|
||||
|
||||
assert user_result["type"] == "input_text"
|
||||
assert assistant_result["type"] == "output_text"
|
||||
assert assistant_result["annotations"] == []
|
||||
assert user_result["text"] == "hello"
|
||||
assert assistant_result["text"] == "hello"
|
||||
|
||||
|
||||
def test_prepare_messages_for_openai_assistant_history_uses_output_text_with_annotations() -> None:
|
||||
"""Assistant history should be output_text and include required annotations."""
|
||||
client = OpenAIResponsesClient(model_id="test-model", api_key="test-key")
|
||||
|
||||
messages = [
|
||||
Message(role="user", text="What is async/await?"),
|
||||
Message(role="assistant", text="Async/await enables non-blocking concurrency."),
|
||||
]
|
||||
|
||||
prepared = client._prepare_messages_for_openai(messages)
|
||||
|
||||
assert prepared[0]["role"] == "user"
|
||||
assert prepared[0]["content"][0]["type"] == "input_text"
|
||||
assert prepared[1]["role"] == "assistant"
|
||||
assert prepared[1]["content"][0]["type"] == "output_text"
|
||||
assert prepared[1]["content"][0]["annotations"] == []
|
||||
|
||||
|
||||
def test_parse_response_from_openai_with_mcp_server_tool_result() -> None:
|
||||
"""Test _parse_response_from_openai with MCP server tool result."""
|
||||
client = OpenAIResponsesClient(model_id="test-model", api_key="test-key")
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing_extensions import Never
|
||||
|
||||
from agent_framework import (
|
||||
AgentExecutor,
|
||||
AgentExecutorRequest,
|
||||
AgentExecutorResponse,
|
||||
AgentResponse,
|
||||
AgentResponseUpdate,
|
||||
@@ -150,3 +151,64 @@ async def test_sequential_adapter_uses_full_conversation() -> None:
|
||||
assert len(seen) == 2
|
||||
assert seen[0].role == "user" and "hello seq" in (seen[0].text or "")
|
||||
assert seen[1].role == "assistant" and "A1 reply" in (seen[1].text or "")
|
||||
|
||||
|
||||
class _RoundTripCoordinator(Executor):
|
||||
"""Loops once back to the same agent with full conversation + feedback."""
|
||||
|
||||
def __init__(self, *, target_agent_id: str, id: str = "round_trip_coordinator") -> None:
|
||||
super().__init__(id=id)
|
||||
self._target_agent_id = target_agent_id
|
||||
self._seen = 0
|
||||
|
||||
@handler
|
||||
async def handle_response(
|
||||
self,
|
||||
response: AgentExecutorResponse,
|
||||
ctx: WorkflowContext[Never, dict[str, Any]],
|
||||
) -> None:
|
||||
self._seen += 1
|
||||
if self._seen == 1:
|
||||
assert response.full_conversation is not None
|
||||
await ctx.send_message(
|
||||
AgentExecutorRequest(
|
||||
messages=list(response.full_conversation) + [Message(role="user", text="apply feedback")],
|
||||
should_respond=True,
|
||||
),
|
||||
target_id=self._target_agent_id,
|
||||
)
|
||||
return
|
||||
|
||||
assert response.full_conversation is not None
|
||||
await ctx.yield_output({
|
||||
"roles": [m.role for m in response.full_conversation],
|
||||
"texts": [m.text for m in response.full_conversation],
|
||||
})
|
||||
|
||||
|
||||
async def test_agent_executor_full_conversation_round_trip_does_not_duplicate_history() -> None:
|
||||
"""When full history is replayed, AgentExecutor should not duplicate prior turns."""
|
||||
agent = _SimpleAgent(id="writer_agent", name="Writer", reply_text="draft reply")
|
||||
agent_exec = AgentExecutor(agent, id="writer_agent")
|
||||
coordinator = _RoundTripCoordinator(target_agent_id="writer_agent")
|
||||
|
||||
wf = (
|
||||
WorkflowBuilder(start_executor=agent_exec, output_executors=[coordinator])
|
||||
.add_edge(agent_exec, coordinator)
|
||||
.add_edge(coordinator, agent_exec)
|
||||
.build()
|
||||
)
|
||||
|
||||
result = await wf.run("initial prompt")
|
||||
outputs = result.get_outputs()
|
||||
assert len(outputs) == 1
|
||||
payload = outputs[0]
|
||||
assert isinstance(payload, dict)
|
||||
|
||||
# Expected conversation after one loop:
|
||||
# user(initial), assistant(first reply), user(feedback), assistant(second reply)
|
||||
assert payload["roles"] == ["user", "assistant", "user", "assistant"]
|
||||
assert payload["texts"][0] == "initial prompt"
|
||||
assert payload["texts"][1] == "draft reply"
|
||||
assert payload["texts"][2] == "apply feedback"
|
||||
assert payload["texts"][3] == "draft reply"
|
||||
|
||||
@@ -72,6 +72,41 @@ class _KwargsCapturingAgent(BaseAgent):
|
||||
return _run()
|
||||
|
||||
|
||||
class _OptionsAwareAgent(BaseAgent):
|
||||
"""Test agent that captures explicit `options` and kwargs passed to run()."""
|
||||
|
||||
captured_options: list[dict[str, Any] | None]
|
||||
captured_kwargs: list[dict[str, Any]]
|
||||
|
||||
def __init__(self, name: str = "options_agent") -> None:
|
||||
super().__init__(name=name, description="Test agent for options capture")
|
||||
self.captured_options = []
|
||||
self.captured_kwargs = []
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
thread: AgentThread | None = None,
|
||||
options: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
|
||||
self.captured_options.append(dict(options) if options is not None else None)
|
||||
self.captured_kwargs.append(dict(kwargs))
|
||||
if stream:
|
||||
|
||||
async def _stream() -> AsyncIterable[AgentResponseUpdate]:
|
||||
yield AgentResponseUpdate(contents=[Content.from_text(text=f"{self.name} response")])
|
||||
|
||||
return ResponseStream(_stream(), finalizer=AgentResponse.from_updates)
|
||||
|
||||
async def _run() -> AgentResponse:
|
||||
return AgentResponse(messages=[Message("assistant", [f"{self.name} response"])])
|
||||
|
||||
return _run()
|
||||
|
||||
|
||||
# region Sequential Builder Tests
|
||||
|
||||
|
||||
@@ -131,6 +166,106 @@ async def test_sequential_run_kwargs_flow() -> None:
|
||||
assert agent.captured_kwargs[0].get("custom_data") == {"test": True}
|
||||
|
||||
|
||||
async def test_sequential_run_options_does_not_conflict_with_agent_options() -> None:
|
||||
"""Test workflow.run(options=...) does not conflict with Agent.run(options=...)."""
|
||||
agent = _OptionsAwareAgent(name="options_agent")
|
||||
workflow = SequentialBuilder(participants=[agent]).build()
|
||||
|
||||
custom_data = {"session_id": "abc123"}
|
||||
user_token = {"user_name": "alice"}
|
||||
provided_options = {
|
||||
"store": False,
|
||||
"additional_function_arguments": {"source": "workflow-options"},
|
||||
}
|
||||
|
||||
async for event in workflow.run(
|
||||
"test message",
|
||||
stream=True,
|
||||
options=provided_options,
|
||||
custom_data=custom_data,
|
||||
user_token=user_token,
|
||||
):
|
||||
if event.type == "status" and event.state == WorkflowRunState.IDLE:
|
||||
break
|
||||
|
||||
assert len(agent.captured_options) >= 1
|
||||
captured_options = agent.captured_options[0]
|
||||
assert captured_options is not None
|
||||
assert captured_options.get("store") is False
|
||||
|
||||
additional_args = captured_options.get("additional_function_arguments")
|
||||
assert isinstance(additional_args, dict)
|
||||
assert additional_args.get("source") == "workflow-options"
|
||||
assert additional_args.get("custom_data") == custom_data
|
||||
assert additional_args.get("user_token") == user_token
|
||||
|
||||
# "options" should be passed once via the dedicated options parameter,
|
||||
# not duplicated in **kwargs.
|
||||
assert len(agent.captured_kwargs) >= 1
|
||||
captured_kwargs = agent.captured_kwargs[0]
|
||||
assert "options" not in captured_kwargs
|
||||
assert captured_kwargs.get("custom_data") == custom_data
|
||||
assert captured_kwargs.get("user_token") == user_token
|
||||
|
||||
|
||||
async def test_sequential_run_additional_function_arguments_flattened() -> None:
|
||||
"""Test workflow.run(additional_function_arguments=...) maps directly to tool kwargs."""
|
||||
agent = _OptionsAwareAgent(name="options_agent")
|
||||
workflow = SequentialBuilder(participants=[agent]).build()
|
||||
|
||||
custom_data = {"session_id": "abc123"}
|
||||
user_token = {"user_name": "alice"}
|
||||
|
||||
async for event in workflow.run(
|
||||
"test message",
|
||||
stream=True,
|
||||
additional_function_arguments={"custom_data": custom_data, "user_token": user_token},
|
||||
):
|
||||
if event.type == "status" and event.state == WorkflowRunState.IDLE:
|
||||
break
|
||||
|
||||
assert len(agent.captured_options) >= 1
|
||||
captured_options = agent.captured_options[0]
|
||||
assert captured_options is not None
|
||||
|
||||
additional_args = captured_options.get("additional_function_arguments")
|
||||
assert isinstance(additional_args, dict)
|
||||
assert additional_args.get("custom_data") == custom_data
|
||||
assert additional_args.get("user_token") == user_token
|
||||
assert "additional_function_arguments" not in additional_args
|
||||
|
||||
assert len(agent.captured_kwargs) >= 1
|
||||
captured_kwargs = agent.captured_kwargs[0]
|
||||
assert "additional_function_arguments" not in captured_kwargs
|
||||
|
||||
|
||||
async def test_sequential_run_additional_function_arguments_merges_with_options() -> None:
|
||||
"""Test workflow additional_function_arguments merges with workflow options."""
|
||||
agent = _OptionsAwareAgent(name="options_agent")
|
||||
workflow = SequentialBuilder(participants=[agent]).build()
|
||||
|
||||
async for event in workflow.run(
|
||||
"test message",
|
||||
stream=True,
|
||||
options={"additional_function_arguments": {"source": "workflow-options"}},
|
||||
additional_function_arguments={"custom_data": {"session_id": "abc123"}},
|
||||
user_token={"user_name": "alice"},
|
||||
):
|
||||
if event.type == "status" and event.state == WorkflowRunState.IDLE:
|
||||
break
|
||||
|
||||
assert len(agent.captured_options) >= 1
|
||||
captured_options = agent.captured_options[0]
|
||||
assert captured_options is not None
|
||||
|
||||
additional_args = captured_options.get("additional_function_arguments")
|
||||
assert isinstance(additional_args, dict)
|
||||
assert additional_args.get("source") == "workflow-options"
|
||||
assert additional_args.get("custom_data") == {"session_id": "abc123"}
|
||||
assert additional_args.get("user_token") == {"user_name": "alice"}
|
||||
assert "additional_function_arguments" not in additional_args
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user