Python: [Feature Branch] Merge from main to Azure AI branch (#2111)

* Do not build DevUI assets during .NET project build (#2010)

* .NET: Add unit tests for declarative executor SetMultipleVariables (#2016)

* Add unit tests for create conversation executor

* Update indentation and comment typo.

* Added unit tests for declarative executor SetMultipleVariablesExecutor

* Updated comments and syntactic sugar

* Python: DevUI: Use metadata.entity_id instead of model field (#1984)

* DevUI: Use metadata.entity_id for agent/workflow name instead of model field

* OpenAI Responses: add explicit request validation

* Review feedback

* .NET: DevUI - Do not automatically add/map OpenAI services/endpoints (#2014)

* Don't add OpenAIResponses as part of Dev UI

You should be able to add and remove Dev UI without impacting your other production endpoints.

* Remove `AddDevUI()` and do not map OpenAI endpoints from `MapDevUI()`

* Fix comment wording

* Revise documentation

---------

Co-authored-by: Daniel Roth <daroth@microsoft.com>

* Python: DevUI: Add OpenAI Responses API proxy support  + HIL for Workflows (#1737)

* DevUI: Add OpenAI Responses API proxy support with enhanced UI features

This commit adds support for proxying requests to OpenAI's Responses API,
allowing DevUI to route conversations to OpenAI models when configured to enable testing.

Backend changes:
- Add OpenAI proxy executor with conversation routing logic
- Enhance event mapper to support OpenAI Responses API format
- Extend server endpoints to handle OpenAI proxy mode
- Update models with OpenAI-specific response types
- Remove emojis from logging and CLI output for cleaner text

Frontend changes:
- Add settings modal with OpenAI proxy configuration UI
- Enhance agent and workflow views with improved state management
- Add new UI components (separator, switch) for settings
- Update debug panel with better event filtering
- Improve message renderers for OpenAI content types
- Update types and API client for OpenAI integration

* update ui, settings modal and workflow input form, add register cleanup hooks.

* add workflow HIL support, user mode, other fixes

* feat(devui): add human-in-the-loop (HIL) support with dynamic response schemas

Implement  HIL workflow support allowing workflows to pause for user input
with dynamically generated JSON schemas based on response handler type hints.

Key Features:
- Automatic response schema extraction from @response_handler decorators
- Dynamic form generation in UI based on Pydantic/dataclass response types
- Checkpoint-based conversation storage for HIL requests/responses
- Resume workflow execution after user provides HIL response

Backend Changes:
- Add extract_response_type_from_executor() to introspect response handlers
- Enrich RequestInfoEvent with response_schema via _enrich_request_info_event_with_response_schema()
- Map RequestInfoEvent to response.input.requested OpenAI event format
- Store HIL responses in conversation history and restore checkpoints

Frontend Changes:
- Add HILInputModal component with SchemaFormRenderer for dynamic forms
- Support Pydantic BaseModel and dataclass response types
- Render enum fields as dropdowns, strings as text/textarea, numbers, booleans, arrays, objects
- Display original request context alongside response form

Testing:
- Add  tests for checkpoint storage (test_checkpoints.py)
- Add schema generation tests for all input types (test_schema_generation.py)
- Validate end-to-end HIL flow with spam workflow sample

This enables workflows to seamlessly pause execution and request structured user input
with type-safe, validated forms generated automatically from response type annotations.

* improve HIL support, improve workflow execution view

* ui updates

* ui updates

* improve HIL for workflows, add auth and view modes

* update workflow

* security improvements , ui fixes

* fix mypy error

* update loading spinner in ui

---------

Co-authored-by: Mark Wallace <127216156+markwallace-microsoft@users.noreply.github.com>

* .NET: Remove launchSettings.json from .gitignore in dotnet/samples (#2006)

* Remove launchSettings.json from .gitignore in dotnet/samples

* Update dotnet/samples/GettingStarted/DevUI/DevUI_Step01_BasicUsage/Properties/launchSettings.json

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update dotnet/samples/AGUIClientServer/AGUIServer/Properties/launchSettings.json

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* DevUI: Serialize workflow input as string to maintain conformance with OpenAI Responses format (#2021)

Co-authored-by: Victor Dibia <chuvidi2003@gmail.com>

* Add Microsoft Agent Framework logo to assets (#2007)

* Updated package versions (#2027)

* DevUI: Prevent line breaks within words in the agent view (#2024)

Co-authored-by: Victor Dibia <chuvidi2003@gmail.com>

* .NET [AG-UI]: Adds support for shared state. (#1996)

* Product changes

* Tests

* Dojo project

* Cleanups

* Python: Fix underlying tool choice bug and all for return to previous Handoff subagent (#2037)

* Fix tool_choice override bug and add enable_return_to_previous support

* Add unit test for handoff checkpointing

* Handle tools when we have them

* added missing chatAgent params (#2044)

* .NET: fix ChatCompletions Tools serialization (#2043)

* fix serialization in chat completions on tools

* nit

* .NET: assign AgentCard's URL to mapped-endpoint if not defined explicitly (#2047)

* fix serialization in chat completions on tools

* nit

* write e2e test for agent card resolve + adjust behavior

* nit

* Version 1.0.0-preview.251110.1 (#2048)

* .NET: Remove moved OpenAPI sample and point to SK one. (#1997)

* Remove moved OpenAPI sample and point to SK one.

* Update dotnet/samples/GettingStarted/Agents/README.md

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Bump AWSSDK.Extensions.Bedrock.MEAI from 4.0.4.2 to 4.0.4.6 (#2031)

---
updated-dependencies:
- dependency-name: AWSSDK.Extensions.Bedrock.MEAI
  dependency-version: 4.0.4.6
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* .NET: Separate all memory and rag samples into their own folders (#2000)

* Separate all memory and rag samples into their own folders

* Fix broken link.

* Python: .Net: Dotnet devui compatibility fixes (#2026)

* DevUI: Add OpenAI Responses API proxy support with enhanced UI features

This commit adds support for proxying requests to OpenAI's Responses API,
allowing DevUI to route conversations to OpenAI models when configured to enable testing.

Backend changes:
- Add OpenAI proxy executor with conversation routing logic
- Enhance event mapper to support OpenAI Responses API format
- Extend server endpoints to handle OpenAI proxy mode
- Update models with OpenAI-specific response types
- Remove emojis from logging and CLI output for cleaner text

Frontend changes:
- Add settings modal with OpenAI proxy configuration UI
- Enhance agent and workflow views with improved state management
- Add new UI components (separator, switch) for settings
- Update debug panel with better event filtering
- Improve message renderers for OpenAI content types
- Update types and API client for OpenAI integration

* update ui, settings modal and workflow input form, add register cleanup hooks.

* add workflow HIL support, user mode, other fixes

* feat(devui): add human-in-the-loop (HIL) support with dynamic response schemas

Implement  HIL workflow support allowing workflows to pause for user input
with dynamically generated JSON schemas based on response handler type hints.

Key Features:
- Automatic response schema extraction from @response_handler decorators
- Dynamic form generation in UI based on Pydantic/dataclass response types
- Checkpoint-based conversation storage for HIL requests/responses
- Resume workflow execution after user provides HIL response

Backend Changes:
- Add extract_response_type_from_executor() to introspect response handlers
- Enrich RequestInfoEvent with response_schema via _enrich_request_info_event_with_response_schema()
- Map RequestInfoEvent to response.input.requested OpenAI event format
- Store HIL responses in conversation history and restore checkpoints

Frontend Changes:
- Add HILInputModal component with SchemaFormRenderer for dynamic forms
- Support Pydantic BaseModel and dataclass response types
- Render enum fields as dropdowns, strings as text/textarea, numbers, booleans, arrays, objects
- Display original request context alongside response form

Testing:
- Add  tests for checkpoint storage (test_checkpoints.py)
- Add schema generation tests for all input types (test_schema_generation.py)
- Validate end-to-end HIL flow with spam workflow sample

This enables workflows to seamlessly pause execution and request structured user input
with type-safe, validated forms generated automatically from response type annotations.

* improve HIL support, improve workflow execution view

* ui updates

* ui updates

* improve HIL for workflows, add auth and view modes

* update workflow

* security improvements , ui fixes

* fix mypy error

* update loading spinner in ui

* DevUI: Serialize workflow input as string to maintain conformance with OpenAI Responses format

* Phase 1: Add /meta endpoint and fix workflow event naming for .NET DevUI compatibility

* additional fixes for .NET DevUI workflow visualization item ID tracking

**Problem:**
.NET DevUI was generating different item IDs for ExecutorInvokedEvent and
ExecutorCompletedEvent, causing only the first executor to highlight in the
workflow graph. Long executor names and error messages also broke UI layout.

**Changes:**
- Add ExecutorActionItemResource to match Python DevUI implementation
- Track item IDs per executor using dictionary in AgentRunResponseUpdateExtensions
- Reuse same item ID across invoked/completed/failed events for proper pairing
- Add truncateText() utility to workflow-utils.ts
- Truncate executor names to 35 chars in execution timeline
- Truncate error messages to 150 chars in workflow graph nodes

** Details:**
- ExecutorActionItemResource registered with JSON source generation context
- Dictionary cleaned up after executor completion/failure to prevent memory leaks
- Frontend item tracking by unique item.id supports multiple executor runs
- All changes follow existing codebase patterns and conventions

Tested with review-workflow showing correct executor highlighting and state
transitions for sequential and concurrent executors.

* format fixes, remove cors tests

* remove unecessary attributes

---------

Co-authored-by: Mark Wallace <127216156+markwallace-microsoft@users.noreply.github.com>
Co-authored-by: Reuben Bond <reuben.bond@gmail.com>

* DevUI: support having both an agent and a workflow with the same id in discovery (#2023)

* Python: Fix Model ID attribute not showing up in `invoke_agent` span (#2061)

* Best effort to surface the model id to invoke agent span

* Fix tests

* Fix tests

* Version 1.0.0-preview.251107.2 (#2065)

* Version 1.0.0-preview.251110.2 (#2067)

* Update README.md to change Grafana links to Azure portal links for dashboard access (#1983)

* .NET - Enable build & test on branch `feature-foundry-agents` (#2068)

* Tests good, mkay

* Update .github/workflows/dotnet-build-and-test.yml

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Enable feature build pipelines

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Roger Barreto <19890735+rogerbarreto@users.noreply.github.com>

* Python: Add concrete AGUIChatClient (#2072)

* Add concrete AGUIChatClient

* Update logging docstrings and conventions

* PR feedback

* Updates to support client-side tool calls

* .NET: Move catalog samples to the HostedAgents folder (#2090)

* move catalog samples to the HostedAgents folder

* move the catalog samples' projects to the HostedAgents folder

* Bump OpenTelemetry.Instrumentation.Runtime from 1.12.0 to 1.13.0 (#1856)

---
updated-dependencies:
- dependency-name: OpenTelemetry.Instrumentation.Runtime
  dependency-version: 1.13.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* .NET: Bump Microsoft.SemanticKernel.Agents.Abstractions from 1.66.0 to 1.67.0 (#1962)

* Bump Microsoft.SemanticKernel.Agents.Abstractions from 1.66.0 to 1.67.0

---
updated-dependencies:
- dependency-name: Microsoft.SemanticKernel.Agents.Abstractions
  dependency-version: 1.67.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>

* .NET: Bump all Microsoft.SemanticKernel packages from 1.66.* to 1.67.* (#1969)

* Initial plan

* Update all Microsoft.SemanticKernel packages to 1.67.*

Co-authored-by: rogerbarreto <19890735+rogerbarreto@users.noreply.github.com>

* Remove unrelated changes to package-lock.json and yarn.lock

Co-authored-by: markwallace-microsoft <127216156+markwallace-microsoft@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: rogerbarreto <19890735+rogerbarreto@users.noreply.github.com>
Co-authored-by: markwallace-microsoft <127216156+markwallace-microsoft@users.noreply.github.com>

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
Co-authored-by: rogerbarreto <19890735+rogerbarreto@users.noreply.github.com>
Co-authored-by: markwallace-microsoft <127216156+markwallace-microsoft@users.noreply.github.com>

* .NET: fix: WorkflowAsAgent Sample (#1787)

* fix: WorkflowAsAgent Sample

* Also makes ChatForwardingExecutor public

* feat: Expand ChatForwardingExecutor handled types

Make ChatForwardingExecutor match the input types of ChatProtocolExecutor.

* fix: Update for the new AgentRunResponseUpdate merge logic

AIAgent always sends out List<ChatMessage> now.

* Updated (#2076)

* Bump vite in /python/samples/demos/chatkit-integration/frontend (#1918)

Bumps [vite](https://github.com/vitejs/vite/tree/HEAD/packages/vite) from 7.1.9 to 7.1.12.
- [Release notes](https://github.com/vitejs/vite/releases)
- [Changelog](https://github.com/vitejs/vite/blob/v7.1.12/packages/vite/CHANGELOG.md)
- [Commits](https://github.com/vitejs/vite/commits/v7.1.12/packages/vite)

---
updated-dependencies:
- dependency-name: vite
  dependency-version: 7.1.12
  dependency-type: direct:development
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Bump Roslynator.Analyzers from 4.14.0 to 4.14.1 (#1857)

---
updated-dependencies:
- dependency-name: Roslynator.Analyzers
  dependency-version: 4.14.1
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Bump MishaKav/pytest-coverage-comment from 1.1.57 to 1.1.59 (#2034)

Bumps [MishaKav/pytest-coverage-comment](https://github.com/mishakav/pytest-coverage-comment) from 1.1.57 to 1.1.59.
- [Release notes](https://github.com/mishakav/pytest-coverage-comment/releases)
- [Changelog](https://github.com/MishaKav/pytest-coverage-comment/blob/main/CHANGELOG.md)
- [Commits](https://github.com/mishakav/pytest-coverage-comment/compare/v1.1.57...v1.1.59)

---
updated-dependencies:
- dependency-name: MishaKav/pytest-coverage-comment
  dependency-version: 1.1.59
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Chris <66376200+crickman@users.noreply.github.com>

* Python: Handle agent user input request in AgentExecutor (#2022)

* Handle agent user input request in AgentExecutor

* fix test

* Address comments

* Fix tests

* Fix tests

* Address comments

* Address comments

* Python: OpenAI Responses Image Generation Stream Support, Sample and Unit Tests (#1853)

* support for image gen streaming

* small fixes

* fixes

* added comment

* Python: Fix MCP Tool Parameter Descriptions Not Propagated to LLMs (#1978)

* mcp tool description fix

* small fix

* .NET: Allow extending agent run options via additional properties (#1872)

* Allow extending agent run options via additional properties

This mirrors the M.E.AI model in ChatOptions.AdditionalProperties which is very useful when building functionality pipelines.

Fixes https://github.com/microsoft/agent-framework/issues/1815

* Expand XML documentation

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Add AdditionalProperties tests to AgentRunOptions

Co-authored-by: kzu <169707+kzu@users.noreply.github.com>

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: kzu <169707+kzu@users.noreply.github.com>

* Python: Use the last entry in the task history to avoid empty responses (#2101)

* Use the last entry in the task history to avoid empty responses

* History only contains Messages

* Updated package versions (#2104)

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: Reuben Bond <203839+ReubenBond@users.noreply.github.com>
Co-authored-by: Peter Ibekwe <109177538+peibekwe@users.noreply.github.com>
Co-authored-by: Jeff Handley <jeffhandley@users.noreply.github.com>
Co-authored-by: Daniel Roth <daroth@microsoft.com>
Co-authored-by: Victor Dibia <chuvidi2003@gmail.com>
Co-authored-by: Mark Wallace <127216156+markwallace-microsoft@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Shawn Henry <sphenry@gmail.com>
Co-authored-by: Javier Calvarro Nelson <jacalvar@microsoft.com>
Co-authored-by: Evan Mattson <35585003+moonbox3@users.noreply.github.com>
Co-authored-by: Eduard van Valkenburg <eavanvalkenburg@users.noreply.github.com>
Co-authored-by: Korolev Dmitry <deagle.gross@gmail.com>
Co-authored-by: westey <164392973+westey-m@users.noreply.github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Reuben Bond <reuben.bond@gmail.com>
Co-authored-by: Tao Chen <taochen@microsoft.com>
Co-authored-by: wuweng <wuweng@microsoft.com>
Co-authored-by: Chris <66376200+crickman@users.noreply.github.com>
Co-authored-by: Roger Barreto <19890735+rogerbarreto@users.noreply.github.com>
Co-authored-by: SergeyMenshykh <68852919+SergeyMenshykh@users.noreply.github.com>
Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
Co-authored-by: Jacob Alber <jaalber@microsoft.com>
Co-authored-by: Giles Odigwe <79032838+giles17@users.noreply.github.com>
Co-authored-by: Daniel Cazzulino <daniel@cazzulino.com>
Co-authored-by: kzu <169707+kzu@users.noreply.github.com>
This commit is contained in:
Dmytro Struk
2025-11-11 23:12:09 -08:00
committed by GitHub
Unverified
parent 85fcd230bf
commit 361c47f30f
231 changed files with 19659 additions and 4143 deletions
@@ -587,9 +587,11 @@ class ChatAgent(BaseAgent):
name: str | None = None,
description: str | None = None,
chat_message_store_factory: Callable[[], ChatMessageStoreProtocol] | None = None,
conversation_id: str | None = None,
context_providers: ContextProvider | list[ContextProvider] | AggregateContextProvider | None = None,
middleware: Middleware | list[Middleware] | None = None,
# chat option params
allow_multiple_tool_calls: bool | None = None,
conversation_id: str | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str | int, float] | None = None,
max_tokens: int | None = None,
@@ -630,15 +632,17 @@ class ChatAgent(BaseAgent):
description: A brief description of the agent's purpose.
chat_message_store_factory: Factory function to create an instance of ChatMessageStoreProtocol.
If not provided, the default in-memory store will be used.
conversation_id: The conversation ID for service-managed threads.
Cannot be used together with chat_message_store_factory.
context_providers: The collection of multiple context providers to include during agent invocation.
middleware: List of middleware to intercept agent and function invocations.
allow_multiple_tool_calls: Whether to allow multiple tool calls in a single response.
conversation_id: The conversation ID for service-managed threads.
Cannot be used together with chat_message_store_factory.
frequency_penalty: The frequency penalty to use.
logit_bias: The logit bias to use.
max_tokens: The maximum number of tokens to generate.
metadata: Additional metadata to include in the request.
model_id: The model_id to use for the agent.
This overrides the model_id set in the chat client if it contains one.
presence_penalty: The presence penalty to use.
response_format: The format of the response.
seed: The random seed to use.
@@ -687,7 +691,8 @@ class ChatAgent(BaseAgent):
self._local_mcp_tools = [tool for tool in normalized_tools if isinstance(tool, MCPTool)]
agent_tools = [tool for tool in normalized_tools if not isinstance(tool, MCPTool)]
self.chat_options = ChatOptions(
model_id=model_id,
model_id=model_id or (str(chat_client.model_id) if hasattr(chat_client, "model_id") else None),
allow_multiple_tool_calls=allow_multiple_tool_calls,
conversation_id=conversation_id,
frequency_penalty=frequency_penalty,
instructions=instructions,
@@ -758,6 +763,7 @@ class ChatAgent(BaseAgent):
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
*,
thread: AgentThread | None = None,
allow_multiple_tool_calls: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str | int, float] | None = None,
max_tokens: int | None = None,
@@ -793,6 +799,7 @@ class ChatAgent(BaseAgent):
Keyword Args:
thread: The thread to use for the agent.
allow_multiple_tool_calls: Whether to allow multiple tool calls in a single response.
frequency_penalty: The frequency penalty to use.
logit_bias: The logit bias to use.
max_tokens: The maximum number of tokens to generate.
@@ -844,6 +851,7 @@ class ChatAgent(BaseAgent):
co = run_chat_options & ChatOptions(
model_id=model_id,
conversation_id=thread.service_thread_id,
allow_multiple_tool_calls=allow_multiple_tool_calls,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
max_tokens=max_tokens,
@@ -887,6 +895,7 @@ class ChatAgent(BaseAgent):
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
*,
thread: AgentThread | None = None,
allow_multiple_tool_calls: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str | int, float] | None = None,
max_tokens: int | None = None,
@@ -922,6 +931,7 @@ class ChatAgent(BaseAgent):
Keyword Args:
thread: The thread to use for the agent.
allow_multiple_tool_calls: Whether to allow multiple tool calls in a single response.
frequency_penalty: The frequency penalty to use.
logit_bias: The logit bias to use.
max_tokens: The maximum number of tokens to generate.
@@ -971,6 +981,7 @@ class ChatAgent(BaseAgent):
co = run_chat_options & ChatOptions(
conversation_id=thread.service_thread_id,
allow_multiple_tool_calls=allow_multiple_tool_calls,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
max_tokens=max_tokens,
@@ -224,7 +224,7 @@ def _merge_chat_options(
stop: str | Sequence[str] | None = None,
store: bool | None = None,
temperature: float | None = None,
tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = None,
tools: list[ToolProtocol | dict[str, Any] | Callable[..., Any]] | None = None,
top_p: float | None = None,
user: str | None = None,
@@ -496,7 +496,7 @@ class BaseChatClient(SerializationMixin, ABC):
stop: str | Sequence[str] | None = None,
store: bool | None = None,
temperature: float | None = None,
tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = None,
tools: ToolProtocol
| Callable[..., Any]
| MutableMapping[str, Any]
@@ -591,7 +591,7 @@ class BaseChatClient(SerializationMixin, ABC):
stop: str | Sequence[str] | None = None,
store: bool | None = None,
temperature: float | None = None,
tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = None,
tools: ToolProtocol
| Callable[..., Any]
| MutableMapping[str, Any]
@@ -714,6 +714,8 @@ class BaseChatClient(SerializationMixin, ABC):
chat_message_store_factory: Callable[[], ChatMessageStoreProtocol] | None = None,
context_providers: ContextProvider | list[ContextProvider] | AggregateContextProvider | None = None,
middleware: Middleware | list[Middleware] | None = None,
allow_multiple_tool_calls: bool | None = None,
conversation_id: str | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str | int, float] | None = None,
max_tokens: int | None = None,
@@ -751,6 +753,8 @@ class BaseChatClient(SerializationMixin, ABC):
If not provided, the default in-memory store will be used.
context_providers: Context providers to include during agent invocation.
middleware: List of middleware to intercept agent and function invocations.
allow_multiple_tool_calls: Whether to allow multiple tool calls per agent turn.
conversation_id: The conversation ID to associate with the agent's messages.
frequency_penalty: The frequency penalty to use.
logit_bias: The logit bias to use.
max_tokens: The maximum number of tokens to generate.
@@ -801,6 +805,8 @@ class BaseChatClient(SerializationMixin, ABC):
chat_message_store_factory=chat_message_store_factory,
context_providers=context_providers,
middleware=middleware,
allow_multiple_tool_calls=allow_multiple_tool_calls,
conversation_id=conversation_id,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
max_tokens=max_tokens,
+10 -3
View File
@@ -19,7 +19,7 @@ from mcp.client.websocket import websocket_client
from mcp.shared.context import RequestContext
from mcp.shared.exceptions import McpError
from mcp.shared.session import RequestResponder
from pydantic import BaseModel, create_model
from pydantic import BaseModel, Field, create_model
from ._tools import AIFunction, HostedMCPSpecificApproval
from ._types import ChatMessage, Contents, DataContent, Role, TextContent, UriContent
@@ -224,13 +224,20 @@ def _get_input_model_from_mcp_tool(tool: types.Tool) -> type[BaseModel]:
prop_details = json.loads(prop_details) if isinstance(prop_details, str) else prop_details
python_type = resolve_type(prop_details)
description = prop_details.get("description", "")
# Create field definition for create_model
if prop_name in required:
field_definitions[prop_name] = (python_type, ...)
field_definitions[prop_name] = (
(python_type, Field(description=description)) if description else (python_type, ...)
)
else:
default_value = prop_details.get("default", None)
field_definitions[prop_name] = (python_type, default_value)
field_definitions[prop_name] = (
(python_type, Field(default=default_value, description=description))
if description
else (python_type, default_value)
)
return create_model(f"{tool.name}_input", **field_definitions)
@@ -1525,6 +1525,12 @@ def _handle_function_calls_response(
prepped_messages = prepare_messages(messages)
response: "ChatResponse | None" = None
fcc_messages: "list[ChatMessage]" = []
# If tools are provided but tool_choice is not set, default to "auto" for function invocation
tools = _extract_tools(kwargs)
if tools and kwargs.get("tool_choice") is None:
kwargs["tool_choice"] = "auto"
for attempt_idx in range(config.max_iterations if config.enabled else 0):
fcc_todo = _collect_approval_responses(prepped_messages)
if fcc_todo:
@@ -1050,6 +1050,50 @@ class DataContent(BaseContent):
def has_top_level_media_type(self, top_level_media_type: Literal["application", "audio", "image", "text"]) -> bool:
return _has_top_level_media_type(self.media_type, top_level_media_type)
@staticmethod
def detect_image_format_from_base64(image_base64: str) -> str:
"""Detect image format from base64 data by examining the binary header.
Args:
image_base64: Base64 encoded image data
Returns:
Image format as string (png, jpeg, webp, gif) with png as fallback
"""
try:
# Constants for image format detection
# ~75 bytes of binary data should be enough to detect most image formats
FORMAT_DETECTION_BASE64_CHARS = 100
# Decode a small portion to detect format
decoded_data = base64.b64decode(image_base64[:FORMAT_DETECTION_BASE64_CHARS])
if decoded_data.startswith(b"\x89PNG"):
return "png"
if decoded_data.startswith(b"\xff\xd8\xff"):
return "jpeg"
if decoded_data.startswith(b"RIFF") and b"WEBP" in decoded_data[:12]:
return "webp"
if decoded_data.startswith(b"GIF87a") or decoded_data.startswith(b"GIF89a"):
return "gif"
return "png" # Default fallback
except Exception:
return "png" # Fallback if decoding fails
@classmethod
def create_data_uri_from_base64(cls, image_base64: str) -> tuple[str, str]:
"""Create a data URI and media type from base64 image data.
Args:
image_base64: Base64 encoded image data
Returns:
Tuple of (data_uri, media_type)
"""
format_type = cls.detect_image_format_from_base64(image_base64)
uri = f"data:image/{format_type};base64,{image_base64}"
media_type = f"image/{format_type}"
return uri, media_type
class UriContent(BaseContent):
"""Represents a URI content.
@@ -2,11 +2,14 @@
import logging
from dataclasses import dataclass
from typing import Any
from typing import Any, cast
from agent_framework import FunctionApprovalRequestContent, FunctionApprovalResponseContent
from .._agents import AgentProtocol, ChatAgent
from .._threads import AgentThread
from .._types import AgentRunResponse, AgentRunResponseUpdate, ChatMessage
from ._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value
from ._conversation_state import encode_chat_messages
from ._events import (
AgentRunEvent,
@@ -14,6 +17,7 @@ from ._events import (
)
from ._executor import Executor, handler
from ._message_utils import normalize_messages_input
from ._request_info_mixin import response_handler
from ._workflow_context import WorkflowContext
logger = logging.getLogger(__name__)
@@ -83,6 +87,8 @@ class AgentExecutor(Executor):
super().__init__(exec_id)
self._agent = agent
self._agent_thread = agent_thread or self._agent.get_new_thread()
self._pending_agent_requests: dict[str, FunctionApprovalRequestContent] = {}
self._pending_responses_to_agent: list[FunctionApprovalResponseContent] = []
self._output_response = output_response
self._cache: list[ChatMessage] = []
@@ -93,50 +99,6 @@ class AgentExecutor(Executor):
return [AgentRunResponse]
return []
async def _run_agent_and_emit(self, ctx: WorkflowContext[AgentExecutorResponse, AgentRunResponse]) -> None:
"""Execute the underlying agent, emit events, and enqueue response.
Checks ctx.is_streaming() to determine whether to emit incremental AgentRunUpdateEvent
events (streaming mode) or a single AgentRunEvent (non-streaming mode).
"""
if ctx.is_streaming():
# Streaming mode: emit incremental updates
updates: list[AgentRunResponseUpdate] = []
async for update in self._agent.run_stream(
self._cache,
thread=self._agent_thread,
):
updates.append(update)
await ctx.add_event(AgentRunUpdateEvent(self.id, update))
if isinstance(self._agent, ChatAgent):
response_format = self._agent.chat_options.response_format
response = AgentRunResponse.from_agent_run_response_updates(
updates,
output_format_type=response_format,
)
else:
response = AgentRunResponse.from_agent_run_response_updates(updates)
else:
# Non-streaming mode: use run() and emit single event
response = await self._agent.run(
self._cache,
thread=self._agent_thread,
)
await ctx.add_event(AgentRunEvent(self.id, response))
if self._output_response:
await ctx.yield_output(response)
# Always construct a full conversation snapshot from inputs (cache)
# plus agent outputs (agent_run_response.messages). Do not mutate
# response.messages so AgentRunEvent remains faithful to the raw output.
full_conversation: list[ChatMessage] = list(self._cache) + list(response.messages)
agent_response = AgentExecutorResponse(self.id, response, full_conversation=full_conversation)
await ctx.send_message(agent_response)
self._cache.clear()
@handler
async def run(
self, request: AgentExecutorRequest, ctx: WorkflowContext[AgentExecutorResponse, AgentRunResponse]
@@ -192,6 +154,31 @@ class AgentExecutor(Executor):
self._cache = normalize_messages_input(messages)
await self._run_agent_and_emit(ctx)
@response_handler
async def handle_user_input_response(
self,
original_request: FunctionApprovalRequestContent,
response: FunctionApprovalResponseContent,
ctx: WorkflowContext[AgentExecutorResponse, AgentRunResponse],
) -> None:
"""Handle user input responses for function approvals during agent execution.
This will hold the executor's execution until all pending user input requests are resolved.
Args:
original_request: The original function approval request sent by the agent.
response: The user's response to the function approval request.
ctx: The workflow context for emitting events and outputs.
"""
self._pending_responses_to_agent.append(response)
self._pending_agent_requests.pop(original_request.id, None)
if not self._pending_agent_requests:
# All pending requests have been resolved; resume agent execution
self._cache = normalize_messages_input(ChatMessage(role="user", contents=self._pending_responses_to_agent))
self._pending_responses_to_agent.clear()
await self._run_agent_and_emit(ctx)
async def snapshot_state(self) -> dict[str, Any]:
"""Capture current executor state for checkpointing.
@@ -226,6 +213,8 @@ class AgentExecutor(Executor):
return {
"cache": encode_chat_messages(self._cache),
"agent_thread": serialized_thread,
"pending_agent_requests": encode_checkpoint_value(self._pending_agent_requests),
"pending_responses_to_agent": encode_checkpoint_value(self._pending_responses_to_agent),
}
async def restore_state(self, state: dict[str, Any]) -> None:
@@ -258,7 +247,109 @@ class AgentExecutor(Executor):
else:
self._agent_thread = self._agent.get_new_thread()
pending_requests_payload = state.get("pending_agent_requests")
if pending_requests_payload:
self._pending_agent_requests = decode_checkpoint_value(pending_requests_payload)
pending_responses_payload = state.get("pending_responses_to_agent")
if pending_responses_payload:
self._pending_responses_to_agent = decode_checkpoint_value(pending_responses_payload)
def reset(self) -> None:
"""Reset the internal cache of the executor."""
logger.debug("AgentExecutor %s: Resetting cache", self.id)
self._cache.clear()
async def _run_agent_and_emit(self, ctx: WorkflowContext[AgentExecutorResponse, AgentRunResponse]) -> None:
"""Execute the underlying agent, emit events, and enqueue response.
Checks ctx.is_streaming() to determine whether to emit incremental AgentRunUpdateEvent
events (streaming mode) or a single AgentRunEvent (non-streaming mode).
"""
if ctx.is_streaming():
# Streaming mode: emit incremental updates
response = await self._run_agent_streaming(cast(WorkflowContext, ctx))
else:
# Non-streaming mode: use run() and emit single event
response = await self._run_agent(cast(WorkflowContext, ctx))
if response is None:
# Agent did not complete (e.g., waiting for user input); do not emit response
logger.info("AgentExecutor %s: Agent did not complete, awaiting user input", self.id)
return
if self._output_response:
await ctx.yield_output(response)
# Always construct a full conversation snapshot from inputs (cache)
# plus agent outputs (agent_run_response.messages). Do not mutate
# response.messages so AgentRunEvent remains faithful to the raw output.
full_conversation: list[ChatMessage] = list(self._cache) + list(response.messages)
agent_response = AgentExecutorResponse(self.id, response, full_conversation=full_conversation)
await ctx.send_message(agent_response)
self._cache.clear()
async def _run_agent(self, ctx: WorkflowContext) -> AgentRunResponse | None:
"""Execute the underlying agent in non-streaming mode.
Args:
ctx: The workflow context for emitting events.
Returns:
The complete AgentRunResponse, or None if waiting for user input.
"""
response = await self._agent.run(
self._cache,
thread=self._agent_thread,
)
await ctx.add_event(AgentRunEvent(self.id, response))
# Handle any user input requests
if response.user_input_requests:
for user_input_request in response.user_input_requests:
self._pending_agent_requests[user_input_request.id] = user_input_request
await ctx.request_info(user_input_request, FunctionApprovalResponseContent)
return None
return response
async def _run_agent_streaming(self, ctx: WorkflowContext) -> AgentRunResponse | None:
"""Execute the underlying agent in streaming mode and collect the full response.
Args:
ctx: The workflow context for emitting events.
Returns:
The complete AgentRunResponse, or None if waiting for user input.
"""
updates: list[AgentRunResponseUpdate] = []
user_input_requests: list[FunctionApprovalRequestContent] = []
async for update in self._agent.run_stream(
self._cache,
thread=self._agent_thread,
):
updates.append(update)
await ctx.add_event(AgentRunUpdateEvent(self.id, update))
if update.user_input_requests:
user_input_requests.extend(update.user_input_requests)
# Build the final AgentRunResponse from the collected updates
if isinstance(self._agent, ChatAgent):
response_format = self._agent.chat_options.response_format
response = AgentRunResponse.from_agent_run_response_updates(
updates,
output_format_type=response_format,
)
else:
response = AgentRunResponse.from_agent_run_response_updates(updates)
# Handle any user input requests after the streaming completes
if user_input_requests:
for user_input_request in user_input_requests:
self._pending_agent_requests[user_input_request.id] = user_input_request
await ctx.request_info(user_input_request, FunctionApprovalResponseContent)
return None
return response
@@ -85,8 +85,8 @@ def _clone_chat_agent(agent: ChatAgent) -> ChatAgent:
# so we need to recombine them here to pass the complete tools list to the constructor.
# This makes sure MCP tools are preserved when cloning agents for handoff workflows.
all_tools = list(options.tools) if options.tools else []
if agent._local_mcp_tools:
all_tools.extend(agent._local_mcp_tools)
if agent._local_mcp_tools: # type: ignore
all_tools.extend(agent._local_mcp_tools) # type: ignore
return ChatAgent(
chat_client=agent.chat_client,
@@ -133,6 +133,14 @@ class _ConversationWithUserInput:
full_conversation: list[ChatMessage] = field(default_factory=lambda: []) # type: ignore[misc]
@dataclass
class _ConversationForUserInput:
"""Internal message from coordinator to gateway specifying which agent will receive the response."""
conversation: list[ChatMessage]
next_agent_id: str
class _AutoHandoffMiddleware(FunctionMiddleware):
"""Intercept handoff tool invocations and short-circuit execution with synthetic results."""
@@ -275,6 +283,7 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
termination_condition: Callable[[list[ChatMessage]], bool | Awaitable[bool]],
id: str,
handoff_tool_targets: Mapping[str, str] | None = None,
return_to_previous: bool = False,
) -> None:
"""Create a coordinator that manages routing between specialists and the user."""
super().__init__(id)
@@ -284,6 +293,8 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
self._input_gateway_id = input_gateway_id
self._termination_condition = termination_condition
self._handoff_tool_targets = {k.lower(): v for k, v in (handoff_tool_targets or {}).items()}
self._return_to_previous = return_to_previous
self._current_agent_id: str | None = None # Track the current agent handling conversation
def _get_author_name(self) -> str:
"""Get the coordinator name for orchestrator-generated messages."""
@@ -293,7 +304,7 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
async def handle_agent_response(
self,
response: AgentExecutorResponse,
ctx: WorkflowContext[AgentExecutorRequest | list[ChatMessage], list[ChatMessage]],
ctx: WorkflowContext[AgentExecutorRequest | list[ChatMessage], list[ChatMessage] | _ConversationForUserInput],
) -> None:
"""Process an agent's response and determine whether to route, request input, or terminate."""
# Hydrate coordinator state (and detect new run) using checkpointable executor state
@@ -329,6 +340,9 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
# Check for handoff from ANY agent (starting agent or specialist)
target = self._resolve_specialist(response.agent_run_response, conversation)
if target is not None:
# Update current agent when handoff occurs
self._current_agent_id = target
logger.info(f"Handoff detected: {source} -> {target}. Routing control to specialist '{target}'.")
await self._persist_state(ctx)
# Clean tool-related content before sending to next agent
cleaned = clean_conversation_for_handoff(conversation)
@@ -340,10 +354,15 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
if not is_starting_agent and source not in self._specialist_ids:
raise RuntimeError(f"HandoffCoordinator received response from unknown executor '{source}'.")
# Update current agent when they respond without handoff
self._current_agent_id = source
logger.info(
f"Agent '{source}' responded without handoff. "
f"Requesting user input. Return-to-previous: {self._return_to_previous}"
)
await self._persist_state(ctx)
if await self._check_termination():
logger.info("Handoff workflow termination condition met. Ending conversation.")
# Clean the output conversation for display
cleaned_output = clean_conversation_for_handoff(conversation)
await ctx.yield_output(cleaned_output)
@@ -352,7 +371,13 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
# Clean conversation before sending to gateway for user input request
# This removes tool messages that shouldn't be shown to users
cleaned_for_display = clean_conversation_for_handoff(conversation)
await ctx.send_message(cleaned_for_display, target_id=self._input_gateway_id)
# The awaiting_agent_id is the agent that just responded and is awaiting user input
# This is the source of the current response
next_agent_id = source
message_to_gateway = _ConversationForUserInput(conversation=cleaned_for_display, next_agent_id=next_agent_id)
await ctx.send_message(message_to_gateway, target_id=self._input_gateway_id) # type: ignore[arg-type]
@handler
async def handle_user_input(
@@ -367,14 +392,26 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
# Check termination before sending to agent
if await self._check_termination():
logger.info("Handoff workflow termination condition met. Ending conversation.")
await ctx.yield_output(list(self._conversation))
return
# Clean before sending to starting agent
# Determine routing target based on return-to-previous setting
target_agent_id = self._starting_agent_id
if self._return_to_previous and self._current_agent_id:
# Route back to the current agent that's handling the conversation
target_agent_id = self._current_agent_id
logger.info(
f"Return-to-previous enabled: routing user input to current agent '{target_agent_id}' "
f"(bypassing coordinator '{self._starting_agent_id}')"
)
else:
logger.info(f"Routing user input to coordinator '{target_agent_id}'")
# Note: Stack is only used for specialist-to-specialist handoffs, not user input routing
# Clean before sending to target agent
cleaned = clean_conversation_for_handoff(self._conversation)
request = AgentExecutorRequest(messages=cleaned, should_respond=True)
await ctx.send_message(request, target_id=self._starting_agent_id)
await ctx.send_message(request, target_id=target_agent_id)
def _resolve_specialist(self, agent_response: AgentRunResponse, conversation: list[ChatMessage]) -> str | None:
"""Resolve the specialist executor id requested by the agent response, if any."""
@@ -444,22 +481,27 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
def _snapshot_pattern_metadata(self) -> dict[str, Any]:
"""Serialize pattern-specific state.
Handoff has no additional metadata beyond base conversation state.
Includes the current agent for return-to-previous routing.
Returns:
Empty dict (no pattern-specific state)
Dict containing current agent if return-to-previous is enabled
"""
if self._return_to_previous:
return {
"current_agent_id": self._current_agent_id,
}
return {}
def _restore_pattern_metadata(self, metadata: dict[str, Any]) -> None:
"""Restore pattern-specific state.
Handoff has no additional metadata beyond base conversation state.
Restores the current agent for return-to-previous routing.
Args:
metadata: Pattern-specific state dict (ignored)
metadata: Pattern-specific state dict
"""
pass
if self._return_to_previous and "current_agent_id" in metadata:
self._current_agent_id = metadata["current_agent_id"]
def _restore_conversation_from_state(self, state: Mapping[str, Any]) -> list[ChatMessage]:
"""Rehydrate the coordinator's conversation history from checkpointed state.
@@ -507,8 +549,21 @@ class _UserInputGateway(Executor):
self._prompt = prompt or "Provide your next input for the conversation."
@handler
async def request_input(self, conversation: list[ChatMessage], ctx: WorkflowContext) -> None:
async def request_input(self, message: _ConversationForUserInput, ctx: WorkflowContext) -> None:
"""Emit a `HandoffUserInputRequest` capturing the conversation snapshot."""
if not message.conversation:
raise ValueError("Handoff workflow requires non-empty conversation before requesting user input.")
request = HandoffUserInputRequest(
conversation=list(message.conversation),
awaiting_agent_id=message.next_agent_id,
prompt=self._prompt,
source_executor_id=self.id,
)
await ctx.request_info(request, object)
@handler
async def request_input_legacy(self, conversation: list[ChatMessage], ctx: WorkflowContext) -> None:
"""Legacy handler for backward compatibility - emit user input request with starting agent."""
if not conversation:
raise ValueError("Handoff workflow requires non-empty conversation before requesting user input.")
request = HandoffUserInputRequest(
@@ -558,7 +613,7 @@ def _as_user_messages(payload: Any) -> list[ChatMessage]:
def _default_termination_condition(conversation: list[ChatMessage]) -> bool:
"""Default termination: stop after 10 user messages to prevent infinite loops."""
"""Default termination: stop after 10 user messages."""
user_message_count = sum(1 for msg in conversation if msg.role == Role.USER)
return user_message_count >= 10
@@ -743,6 +798,7 @@ class HandoffBuilder:
)
self._auto_register_handoff_tools: bool = True
self._handoff_config: dict[str, list[str]] = {} # Maps agent_id -> [target_agent_ids]
self._return_to_previous: bool = False
if participants:
self.participants(participants)
@@ -1198,6 +1254,77 @@ class HandoffBuilder:
self._termination_condition = condition
return self
def enable_return_to_previous(self, enabled: bool = True) -> "HandoffBuilder":
"""Enable direct return to the current agent after user input, bypassing the coordinator.
When enabled, after a specialist responds without requesting another handoff, user input
routes directly back to that same specialist instead of always routing back to the
coordinator agent for re-evaluation.
This is useful when a specialist needs multiple turns with the user to gather information
or resolve an issue, avoiding unnecessary coordinator involvement while maintaining context.
Flow Comparison:
**Default (disabled):**
User -> Coordinator -> Specialist -> User -> Coordinator -> Specialist -> ...
**With return_to_previous (enabled):**
User -> Coordinator -> Specialist -> User -> Specialist -> ...
Args:
enabled: Whether to enable return-to-previous routing. Default is True.
Returns:
Self for method chaining.
Example:
.. code-block:: python
workflow = (
HandoffBuilder(participants=[triage, technical_support, billing])
.set_coordinator("triage")
.add_handoff(triage, [technical_support, billing])
.enable_return_to_previous() # Enable direct return routing
.build()
)
# Flow: User asks question
# -> Triage routes to Technical Support
# -> Technical Support asks clarifying question
# -> User provides more info
# -> Routes back to Technical Support (not Triage)
# -> Technical Support continues helping
Multi-tier handoff example:
.. code-block:: python
workflow = (
HandoffBuilder(participants=[triage, specialist_a, specialist_b])
.set_coordinator("triage")
.add_handoff(triage, [specialist_a, specialist_b])
.add_handoff(specialist_a, specialist_b)
.enable_return_to_previous()
.build()
)
# Flow: User asks question
# -> Triage routes to Specialist A
# -> Specialist A hands off to Specialist B
# -> Specialist B asks clarifying question
# -> User provides more info
# -> Routes back to Specialist B (who is currently handling the conversation)
Note:
This feature routes to whichever agent most recently responded, whether that's
the coordinator or a specialist. The conversation continues with that agent until
they either hand off to another agent or the termination condition is met.
"""
self._return_to_previous = enabled
return self
def build(self) -> Workflow:
"""Construct the final Workflow instance from the configured builder.
@@ -1326,6 +1453,7 @@ class HandoffBuilder:
termination_condition=self._termination_condition,
id="handoff-coordinator",
handoff_tool_targets=handoff_tool_targets,
return_to_previous=self._return_to_previous,
)
wiring = _GroupChatConfig(
@@ -0,0 +1,35 @@
# Copyright (c) Microsoft. All rights reserved.
import importlib
from typing import Any
PACKAGE_NAME = "agent_framework_ag_ui"
PACKAGE_EXTRA = "ag-ui"
_IMPORTS = [
"__version__",
"AgentFrameworkAgent",
"add_agent_framework_fastapi_endpoint",
"AGUIChatClient",
"AGUIEventConverter",
"AGUIHttpService",
"ConfirmationStrategy",
"DefaultConfirmationStrategy",
"TaskPlannerConfirmationStrategy",
"RecipeConfirmationStrategy",
"DocumentWriterConfirmationStrategy",
]
def __getattr__(name: str) -> Any:
if name in _IMPORTS:
try:
return getattr(importlib.import_module(PACKAGE_NAME), name)
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
f"The '{PACKAGE_EXTRA}' extra is not installed, please do `pip install agent-framework-{PACKAGE_EXTRA}`"
) from exc
raise AttributeError(f"Module {PACKAGE_NAME} has no attribute {name}.")
def __dir__() -> list[str]:
return _IMPORTS
@@ -0,0 +1,29 @@
# Copyright (c) Microsoft. All rights reserved.
from agent_framework_ag_ui import (
AgentFrameworkAgent,
AGUIChatClient,
AGUIEventConverter,
AGUIHttpService,
ConfirmationStrategy,
DefaultConfirmationStrategy,
DocumentWriterConfirmationStrategy,
RecipeConfirmationStrategy,
TaskPlannerConfirmationStrategy,
__version__,
add_agent_framework_fastapi_endpoint,
)
__all__ = [
"AGUIChatClient",
"AGUIEventConverter",
"AGUIHttpService",
"AgentFrameworkAgent",
"ConfirmationStrategy",
"DefaultConfirmationStrategy",
"DocumentWriterConfirmationStrategy",
"RecipeConfirmationStrategy",
"TaskPlannerConfirmationStrategy",
"__version__",
"add_agent_framework_fastapi_endpoint",
]
@@ -846,6 +846,7 @@ def _trace_get_response(
kwargs.get("model_id")
or (chat_options.model_id if (chat_options := kwargs.get("chat_options")) else None)
or getattr(self, "model_id", None)
or "unknown"
)
service_url = str(
service_url_func()
@@ -933,6 +934,7 @@ def _trace_get_streaming_response(
kwargs.get("model_id")
or (chat_options.model_id if (chat_options := kwargs.get("chat_options")) else None)
or getattr(self, "model_id", None)
or "unknown"
)
service_url = str(
service_url_func()
@@ -1324,7 +1326,10 @@ def _get_span(
attributes: dict[str, Any],
span_name_attribute: str,
) -> Generator["trace.Span", Any, Any]:
"""Start a span for a agent run."""
"""Start a span for a agent run.
Note: `attributes` must contain the `span_name_attribute` key.
"""
span = get_tracer().start_span(f"{attributes[OtelAttr.OPERATION]} {attributes[span_name_attribute]}")
span.set_attributes(attributes)
with trace.use_span(
@@ -1353,7 +1358,8 @@ def _get_span_attributes(**kwargs: Any) -> dict[str, Any]:
attributes[SpanAttributes.LLM_SYSTEM] = system_name
if provider_name := kwargs.get("provider_name"):
attributes[OtelAttr.PROVIDER_NAME] = provider_name
attributes[SpanAttributes.LLM_REQUEST_MODEL] = kwargs.get("model", "unknown")
if model_id := kwargs.get("model", chat_options.model_id):
attributes[SpanAttributes.LLM_REQUEST_MODEL] = model_id
if service_url := kwargs.get("service_url"):
attributes[OtelAttr.ADDRESS] = service_url
if conversation_id := kwargs.get("conversation_id", chat_options.conversation_id):
@@ -276,6 +276,14 @@ class OpenAIBaseResponsesClient(OpenAIBase, BaseChatClient):
# Map the parameter name and remove the old one
mapped_tool[api_param] = mapped_tool.pop(user_param)
# Validate partial_images parameter for streaming image generation
# OpenAI API requires partial_images to be between 0-3 (inclusive) for image_generation tool
# Reference: https://platform.openai.com/docs/api-reference/responses/create#responses_create-tools-image_generation_tool-partial_images
if "partial_images" in mapped_tool:
partial_images = mapped_tool["partial_images"]
if not isinstance(partial_images, int) or partial_images < 0 or partial_images > 3:
raise ValueError("partial_images must be an integer between 0 and 3 (inclusive).")
response_tools.append(mapped_tool)
else:
response_tools.append(tool_dict)
@@ -707,29 +715,8 @@ class OpenAIBaseResponsesClient(OpenAIBase, BaseChatClient):
uri = item.result
media_type = None
if not uri.startswith("data:"):
# Raw base64 string - convert to proper data URI format
# Detect format from base64 data
import base64
try:
# Decode a small portion to detect format
decoded_data = base64.b64decode(uri[:100]) # First ~75 bytes should be enough
if decoded_data.startswith(b"\x89PNG"):
format_type = "png"
elif decoded_data.startswith(b"\xff\xd8\xff"):
format_type = "jpeg"
elif decoded_data.startswith(b"RIFF") and b"WEBP" in decoded_data[:12]:
format_type = "webp"
elif decoded_data.startswith(b"GIF87a") or decoded_data.startswith(b"GIF89a"):
format_type = "gif"
else:
# Default to png if format cannot be detected
format_type = "png"
except Exception:
# Fallback to png if decoding fails
format_type = "png"
uri = f"data:image/{format_type};base64,{uri}"
media_type = f"image/{format_type}"
# Raw base64 string - convert to proper data URI format using helper
uri, media_type = DataContent.create_data_uri_from_base64(uri)
else:
# Parse media type from existing data URI
try:
@@ -945,6 +932,25 @@ class OpenAIBaseResponsesClient(OpenAIBase, BaseChatClient):
raw_representation=event,
)
)
case "response.image_generation_call.partial_image":
# Handle streaming partial image generation
image_base64 = event.partial_image_b64
partial_index = event.partial_image_index
# Use helper function to create data URI from base64
uri, media_type = DataContent.create_data_uri_from_base64(image_base64)
contents.append(
DataContent(
uri=uri,
media_type=media_type,
additional_properties={
"partial_image_index": partial_index,
"is_partial_image": True,
},
raw_representation=event,
)
)
case _:
logger.debug("Unparsed event of type: %s: %s", event.type, event)
+5 -4
View File
@@ -4,7 +4,7 @@ description = "Microsoft Agent Framework for building AI Agents with Python. Thi
authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}]
readme = "README.md"
requires-python = ">=3.10"
version = "1.0.0b251105"
version = "1.0.0b251111"
license-files = ["LICENSE"]
urls.homepage = "https://aka.ms/agent-framework"
urls.source = "https://github.com/microsoft/agent-framework/tree/main/python"
@@ -42,13 +42,14 @@ dependencies = [
[project.optional-dependencies]
all = [
"agent-framework-a2a",
"agent-framework-ag-ui",
"agent-framework-anthropic",
"agent-framework-azure-ai",
"agent-framework-copilotstudio",
"agent-framework-mem0",
"agent-framework-redis",
"agent-framework-devui",
"agent-framework-mem0",
"agent-framework-purview",
"agent-framework-anthropic",
"agent-framework-redis",
]
[tool.uv]
@@ -279,6 +279,45 @@ async def test_chat_client_streaming_observability(
assert span.attributes[OtelAttr.OUTPUT_MESSAGES] is not None
async def test_chat_client_without_model_id_observability(mock_chat_client, span_exporter: InMemorySpanExporter):
"""Test telemetry shouldn't fail when the model_id is not provided for unknown reason."""
client = use_observability(mock_chat_client)()
messages = [ChatMessage(role=Role.USER, text="Test")]
span_exporter.clear()
response = await client.get_response(messages=messages)
assert response is not None
spans = span_exporter.get_finished_spans()
assert len(spans) == 1
span = spans[0]
assert span.name == "chat unknown"
assert span.attributes[OtelAttr.OPERATION.value] == OtelAttr.CHAT_COMPLETION_OPERATION
assert span.attributes[SpanAttributes.LLM_REQUEST_MODEL] == "unknown"
async def test_chat_client_streaming_without_model_id_observability(
mock_chat_client, span_exporter: InMemorySpanExporter
):
"""Test streaming telemetry shouldn't fail when the model_id is not provided for unknown reason."""
client = use_observability(mock_chat_client)()
messages = [ChatMessage(role=Role.USER, text="Test")]
span_exporter.clear()
# Collect all yielded updates
updates = []
async for update in client.get_streaming_response(messages=messages):
updates.append(update)
# Verify we got the expected updates, this shouldn't be dependent on otel
assert len(updates) == 2
spans = span_exporter.get_finished_spans()
assert len(spans) == 1
span = spans[0]
assert span.name == "chat unknown"
assert span.attributes[OtelAttr.OPERATION.value] == OtelAttr.CHAT_COMPLETION_OPERATION
assert span.attributes[SpanAttributes.LLM_REQUEST_MODEL] == "unknown"
def test_prepend_user_agent_with_none_value():
"""Test prepend user agent with None value in headers."""
headers = {"User-Agent": None}
@@ -368,6 +407,7 @@ def mock_chat_agent():
self.name = "test_agent"
self.display_name = "Test Agent"
self.description = "Test agent description"
self.chat_options = ChatOptions(model_id="TestModel")
async def run(self, messages=None, *, thread=None, **kwargs):
return AgentRunResponse(
@@ -405,7 +445,7 @@ async def test_agent_instrumentation_enabled(
assert span.attributes[OtelAttr.AGENT_ID] == "test_agent_id"
assert span.attributes[OtelAttr.AGENT_NAME] == "Test Agent"
assert span.attributes[OtelAttr.AGENT_DESCRIPTION] == "Test agent description"
assert span.attributes[SpanAttributes.LLM_REQUEST_MODEL] == "unknown"
assert span.attributes[SpanAttributes.LLM_REQUEST_MODEL] == "TestModel"
assert span.attributes[OtelAttr.INPUT_TOKENS] == 15
assert span.attributes[OtelAttr.OUTPUT_TOKENS] == 25
if enable_sensitive_data:
@@ -433,7 +473,7 @@ async def test_agent_streaming_response_with_diagnostics_enabled_via_decorator(
assert span.attributes[OtelAttr.AGENT_ID] == "test_agent_id"
assert span.attributes[OtelAttr.AGENT_NAME] == "Test Agent"
assert span.attributes[OtelAttr.AGENT_DESCRIPTION] == "Test agent description"
assert span.attributes[SpanAttributes.LLM_REQUEST_MODEL] == "unknown"
assert span.attributes[SpanAttributes.LLM_REQUEST_MODEL] == "TestModel"
if enable_sensitive_data:
assert span.attributes.get(OtelAttr.OUTPUT_MESSAGES) is not None # Streaming, so no usage yet
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft. All rights reserved.
import base64
from collections.abc import AsyncIterable
from typing import Any
@@ -166,6 +167,57 @@ def test_data_content_empty():
DataContent(uri="")
def test_data_content_detect_image_format_from_base64():
"""Test the detect_image_format_from_base64 static method."""
# Test each supported format
png_data = b"\x89PNG\r\n\x1a\n" + b"fake_data"
assert DataContent.detect_image_format_from_base64(base64.b64encode(png_data).decode()) == "png"
jpeg_data = b"\xff\xd8\xff\xe0" + b"fake_data"
assert DataContent.detect_image_format_from_base64(base64.b64encode(jpeg_data).decode()) == "jpeg"
webp_data = b"RIFF" + b"1234" + b"WEBP" + b"fake_data"
assert DataContent.detect_image_format_from_base64(base64.b64encode(webp_data).decode()) == "webp"
gif_data = b"GIF89a" + b"fake_data"
assert DataContent.detect_image_format_from_base64(base64.b64encode(gif_data).decode()) == "gif"
# Test fallback behavior
unknown_data = b"UNKNOWN_FORMAT"
assert DataContent.detect_image_format_from_base64(base64.b64encode(unknown_data).decode()) == "png"
# Test error handling
assert DataContent.detect_image_format_from_base64("invalid_base64!") == "png"
assert DataContent.detect_image_format_from_base64("") == "png"
def test_data_content_create_data_uri_from_base64():
"""Test the create_data_uri_from_base64 class method."""
# Test with PNG data
png_data = b"\x89PNG\r\n\x1a\n" + b"fake_data"
png_base64 = base64.b64encode(png_data).decode()
uri, media_type = DataContent.create_data_uri_from_base64(png_base64)
assert uri == f"data:image/png;base64,{png_base64}"
assert media_type == "image/png"
# Test with different format
jpeg_data = b"\xff\xd8\xff\xe0" + b"fake_data"
jpeg_base64 = base64.b64encode(jpeg_data).decode()
uri, media_type = DataContent.create_data_uri_from_base64(jpeg_base64)
assert uri == f"data:image/jpeg;base64,{jpeg_base64}"
assert media_type == "image/jpeg"
# Test fallback for unknown format
unknown_data = b"UNKNOWN_FORMAT"
unknown_base64 = base64.b64encode(unknown_data).decode()
uri, media_type = DataContent.create_data_uri_from_base64(unknown_base64)
assert uri == f"data:image/png;base64,{unknown_base64}"
assert media_type == "image/png"
# region UriContent
File diff suppressed because it is too large Load Diff
@@ -111,6 +111,10 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None:
chat_store_state = thread_state["chat_message_store_state"] # type: ignore[index]
assert "messages" in chat_store_state, "Message store state should include messages"
# Verify checkpoint contains pending requests from agents and responses to be sent
assert "pending_agent_requests" in executor_state
assert "pending_responses_to_agent" in executor_state
# Create a new agent and executor for restoration
# This simulates starting from a fresh state and restoring from checkpoint
restored_agent = _CountingAgent(id="test_agent", name="TestAgent")
@@ -5,19 +5,32 @@
from collections.abc import AsyncIterable
from typing import Any
from typing_extensions import Never
from agent_framework import (
AgentExecutor,
AgentExecutorResponse,
AgentRunResponse,
AgentRunResponseUpdate,
AgentRunUpdateEvent,
AgentThread,
BaseAgent,
ChatAgent,
ChatMessage,
ChatResponse,
ChatResponseUpdate,
FunctionApprovalRequestContent,
FunctionCallContent,
FunctionResultContent,
RequestInfoEvent,
Role,
TextContent,
WorkflowBuilder,
WorkflowContext,
WorkflowOutputEvent,
ai_function,
executor,
use_function_invocation,
)
@@ -120,3 +133,235 @@ async def test_agent_executor_emits_tool_calls_in_streaming_mode() -> None:
assert events[3].data is not None
assert isinstance(events[3].data.contents[0], TextContent)
assert "sunny" in events[3].data.contents[0].text
@ai_function(approval_mode="always_require")
def mock_tool_requiring_approval(query: str) -> str:
"""Mock tool that requires approval before execution."""
return f"Executed tool with query: {query}"
@use_function_invocation
class MockChatClient:
"""Simple implementation of a chat client."""
def __init__(self, parallel_request: bool = False) -> None:
self.additional_properties: dict[str, Any] = {}
self._iteration: int = 0
self._parallel_request: bool = parallel_request
async def get_response(
self,
messages: str | ChatMessage | list[str] | list[ChatMessage],
**kwargs: Any,
) -> ChatResponse:
if self._iteration == 0:
if self._parallel_request:
response = ChatResponse(
messages=ChatMessage(
role="assistant",
contents=[
FunctionCallContent(
call_id="1", name="mock_tool_requiring_approval", arguments='{"query": "test"}'
),
FunctionCallContent(
call_id="2", name="mock_tool_requiring_approval", arguments='{"query": "test"}'
),
],
)
)
else:
response = ChatResponse(
messages=ChatMessage(
role="assistant",
contents=[
FunctionCallContent(
call_id="1", name="mock_tool_requiring_approval", arguments='{"query": "test"}'
)
],
)
)
else:
response = ChatResponse(messages=ChatMessage(role="assistant", text="Tool executed successfully."))
self._iteration += 1
return response
async def get_streaming_response(
self,
messages: str | ChatMessage | list[str] | list[ChatMessage],
**kwargs: Any,
) -> AsyncIterable[ChatResponseUpdate]:
if self._iteration == 0:
if self._parallel_request:
yield ChatResponseUpdate(
contents=[
FunctionCallContent(
call_id="1", name="mock_tool_requiring_approval", arguments='{"query": "test"}'
),
FunctionCallContent(
call_id="2", name="mock_tool_requiring_approval", arguments='{"query": "test"}'
),
],
role="assistant",
)
else:
yield ChatResponseUpdate(
contents=[
FunctionCallContent(
call_id="1", name="mock_tool_requiring_approval", arguments='{"query": "test"}'
)
],
role="assistant",
)
else:
yield ChatResponseUpdate(text=TextContent(text="Tool executed "), role="assistant")
yield ChatResponseUpdate(contents=[TextContent(text="successfully.")], role="assistant")
self._iteration += 1
@executor(id="test_executor")
async def test_executor(agent_executor_response: AgentExecutorResponse, ctx: WorkflowContext[Never, str]) -> None:
await ctx.yield_output(agent_executor_response.agent_run_response.text)
async def test_agent_executor_tool_call_with_approval() -> None:
"""Test that AgentExecutor handles tool calls requiring approval."""
# Arrange
agent = ChatAgent(
chat_client=MockChatClient(),
name="ApprovalAgent",
tools=[mock_tool_requiring_approval],
)
workflow = WorkflowBuilder().set_start_executor(agent).add_edge(agent, test_executor).build()
# Act
events = await workflow.run("Invoke tool requiring approval")
# Assert
assert len(events.get_request_info_events()) == 1
approval_request = events.get_request_info_events()[0]
assert isinstance(approval_request.data, FunctionApprovalRequestContent)
assert approval_request.data.function_call.name == "mock_tool_requiring_approval"
assert approval_request.data.function_call.arguments == '{"query": "test"}'
# Act
events = await workflow.send_responses({approval_request.request_id: approval_request.data.create_response(True)})
# Assert
final_response = events.get_outputs()
assert len(final_response) == 1
assert final_response[0] == "Tool executed successfully."
async def test_agent_executor_tool_call_with_approval_streaming() -> None:
"""Test that AgentExecutor handles tool calls requiring approval in streaming mode."""
# Arrange
agent = ChatAgent(
chat_client=MockChatClient(),
name="ApprovalAgent",
tools=[mock_tool_requiring_approval],
)
workflow = WorkflowBuilder().set_start_executor(agent).add_edge(agent, test_executor).build()
# Act
request_info_events: list[RequestInfoEvent] = []
async for event in workflow.run_stream("Invoke tool requiring approval"):
if isinstance(event, RequestInfoEvent):
request_info_events.append(event)
# Assert
assert len(request_info_events) == 1
approval_request = request_info_events[0]
assert isinstance(approval_request.data, FunctionApprovalRequestContent)
assert approval_request.data.function_call.name == "mock_tool_requiring_approval"
assert approval_request.data.function_call.arguments == '{"query": "test"}'
# Act
output: str | None = None
async for event in workflow.send_responses_streaming({
approval_request.request_id: approval_request.data.create_response(True)
}):
if isinstance(event, WorkflowOutputEvent):
output = event.data
# Assert
assert output is not None
assert output == "Tool executed successfully."
async def test_agent_executor_parallel_tool_call_with_approval() -> None:
"""Test that AgentExecutor handles parallel tool calls requiring approval."""
# Arrange
agent = ChatAgent(
chat_client=MockChatClient(parallel_request=True),
name="ApprovalAgent",
tools=[mock_tool_requiring_approval],
)
workflow = WorkflowBuilder().set_start_executor(agent).add_edge(agent, test_executor).build()
# Act
events = await workflow.run("Invoke tool requiring approval")
# Assert
assert len(events.get_request_info_events()) == 2
for approval_request in events.get_request_info_events():
assert isinstance(approval_request.data, FunctionApprovalRequestContent)
assert approval_request.data.function_call.name == "mock_tool_requiring_approval"
assert approval_request.data.function_call.arguments == '{"query": "test"}'
# Act
responses = {
approval_request.request_id: approval_request.data.create_response(True) # type: ignore
for approval_request in events.get_request_info_events()
}
events = await workflow.send_responses(responses)
# Assert
final_response = events.get_outputs()
assert len(final_response) == 1
assert final_response[0] == "Tool executed successfully."
async def test_agent_executor_parallel_tool_call_with_approval_streaming() -> None:
"""Test that AgentExecutor handles parallel tool calls requiring approval in streaming mode."""
# Arrange
agent = ChatAgent(
chat_client=MockChatClient(parallel_request=True),
name="ApprovalAgent",
tools=[mock_tool_requiring_approval],
)
workflow = WorkflowBuilder().set_start_executor(agent).add_edge(agent, test_executor).build()
# Act
request_info_events: list[RequestInfoEvent] = []
async for event in workflow.run_stream("Invoke tool requiring approval"):
if isinstance(event, RequestInfoEvent):
request_info_events.append(event)
# Assert
assert len(request_info_events) == 2
for approval_request in request_info_events:
assert isinstance(approval_request.data, FunctionApprovalRequestContent)
assert approval_request.data.function_call.name == "mock_tool_requiring_approval"
assert approval_request.data.function_call.arguments == '{"query": "test"}'
# Act
responses = {
approval_request.request_id: approval_request.data.create_response(True) # type: ignore
for approval_request in request_info_events
}
output: str | None = None
async for event in workflow.send_responses_streaming(responses):
if isinstance(event, WorkflowOutputEvent):
output = event.data
# Assert
assert output is not None
assert output == "Tool executed successfully."
@@ -23,7 +23,7 @@ from agent_framework import (
WorkflowOutputEvent,
)
from agent_framework._mcp import MCPTool
from agent_framework._workflows._handoff import _clone_chat_agent
from agent_framework._workflows._handoff import _clone_chat_agent # type: ignore[reportPrivateUsage]
@dataclass
@@ -392,12 +392,218 @@ async def test_clone_chat_agent_preserves_mcp_tools() -> None:
)
assert hasattr(original_agent, "_local_mcp_tools")
assert len(original_agent._local_mcp_tools) == 1
assert original_agent._local_mcp_tools[0] == mock_mcp_tool
assert len(original_agent._local_mcp_tools) == 1 # type: ignore[reportPrivateUsage]
assert original_agent._local_mcp_tools[0] == mock_mcp_tool # type: ignore[reportPrivateUsage]
cloned_agent = _clone_chat_agent(original_agent)
assert hasattr(cloned_agent, "_local_mcp_tools")
assert len(cloned_agent._local_mcp_tools) == 1
assert cloned_agent._local_mcp_tools[0] == mock_mcp_tool
assert len(cloned_agent._local_mcp_tools) == 1 # type: ignore[reportPrivateUsage]
assert cloned_agent._local_mcp_tools[0] == mock_mcp_tool # type: ignore[reportPrivateUsage]
assert cloned_agent.chat_options.tools is not None
assert len(cloned_agent.chat_options.tools) == 1
async def test_return_to_previous_routing():
"""Test that return-to-previous routes back to the current specialist handling the conversation."""
triage = _RecordingAgent(name="triage", handoff_to="specialist_a")
specialist_a = _RecordingAgent(name="specialist_a", handoff_to="specialist_b")
specialist_b = _RecordingAgent(name="specialist_b")
workflow = (
HandoffBuilder(participants=[triage, specialist_a, specialist_b])
.set_coordinator(triage)
.add_handoff(triage, [specialist_a, specialist_b])
.add_handoff(specialist_a, specialist_b)
.enable_return_to_previous(True)
.with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 4)
.build()
)
# Start conversation - triage hands off to specialist_a
events = await _drain(workflow.run_stream("Initial request"))
requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)]
assert requests
assert len(specialist_a.calls) > 0
# Specialist_a should have been called with initial request
initial_specialist_a_calls = len(specialist_a.calls)
# Second user message - specialist_a hands off to specialist_b
events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Need more help"}))
requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)]
assert requests
# Specialist_b should have been called
assert len(specialist_b.calls) > 0
initial_specialist_b_calls = len(specialist_b.calls)
# Third user message - with return_to_previous, should route back to specialist_b (current agent)
events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Follow up question"}))
third_requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)]
# Specialist_b should have been called again (return-to-previous routes to current agent)
assert len(specialist_b.calls) > initial_specialist_b_calls, (
"Specialist B should be called again due to return-to-previous routing to current agent"
)
# Specialist_a should NOT be called again (it's no longer the current agent)
assert len(specialist_a.calls) == initial_specialist_a_calls, (
"Specialist A should not be called again - specialist_b is the current agent"
)
# Triage should only have been called once at the start
assert len(triage.calls) == 1, "Triage should only be called once (initial routing)"
# Verify awaiting_agent_id is set to specialist_b (the agent that just responded)
if third_requests:
user_input_req = third_requests[-1].data
assert isinstance(user_input_req, HandoffUserInputRequest)
assert user_input_req.awaiting_agent_id == "specialist_b", (
f"Expected awaiting_agent_id 'specialist_b' but got '{user_input_req.awaiting_agent_id}'"
)
async def test_return_to_previous_disabled_routes_to_coordinator():
"""Test that with return-to-previous disabled, routing goes back to coordinator."""
triage = _RecordingAgent(name="triage", handoff_to="specialist_a")
specialist_a = _RecordingAgent(name="specialist_a", handoff_to="specialist_b")
specialist_b = _RecordingAgent(name="specialist_b")
workflow = (
HandoffBuilder(participants=[triage, specialist_a, specialist_b])
.set_coordinator(triage)
.add_handoff(triage, [specialist_a, specialist_b])
.add_handoff(specialist_a, specialist_b)
.enable_return_to_previous(False)
.with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 3)
.build()
)
# Start conversation - triage hands off to specialist_a
events = await _drain(workflow.run_stream("Initial request"))
requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)]
assert requests
assert len(triage.calls) == 1
# Second user message - specialist_a hands off to specialist_b
events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Need more help"}))
requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)]
assert requests
# Third user message - without return_to_previous, should route back to triage
await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Follow up question"}))
# Triage should have been called twice total: initial + after specialist_b responds
assert len(triage.calls) == 2, "Triage should be called twice (initial + default routing to coordinator)"
async def test_return_to_previous_enabled():
"""Verify that enable_return_to_previous() keeps control with the current specialist."""
triage = _RecordingAgent(name="triage", handoff_to="specialist_a")
specialist_a = _RecordingAgent(name="specialist_a")
specialist_b = _RecordingAgent(name="specialist_b")
workflow = (
HandoffBuilder(participants=[triage, specialist_a, specialist_b])
.set_coordinator("triage")
.enable_return_to_previous(True)
.with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 3)
.build()
)
# Start conversation - triage hands off to specialist_a
events = await _drain(workflow.run_stream("Initial request"))
requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)]
assert requests
assert len(triage.calls) == 1
assert len(specialist_a.calls) == 1
# Second user message - with return_to_previous, should route to specialist_a (not triage)
events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Follow up question"}))
requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)]
assert requests
# Triage should only have been called once (initial) - specialist_a handles follow-up
assert len(triage.calls) == 1, "Triage should only be called once (initial)"
assert len(specialist_a.calls) == 2, "Specialist A should handle follow-up with return_to_previous enabled"
async def test_tool_choice_preserved_from_agent_config():
"""Verify that agent-level tool_choice configuration is preserved and not overridden."""
from unittest.mock import AsyncMock
from agent_framework import ChatResponse, ToolMode
# Create a mock chat client that records the tool_choice used
recorded_tool_choices: list[Any] = []
async def mock_get_response(messages: Any, **kwargs: Any) -> ChatResponse:
chat_options = kwargs.get("chat_options")
if chat_options:
recorded_tool_choices.append(chat_options.tool_choice)
return ChatResponse(
messages=[ChatMessage(role=Role.ASSISTANT, text="Response")],
response_id="test_response",
)
mock_client = MagicMock()
mock_client.get_response = AsyncMock(side_effect=mock_get_response)
# Create agent with specific tool_choice configuration
agent = ChatAgent(
chat_client=mock_client,
name="test_agent",
tool_choice=ToolMode(mode="required"), # type: ignore[arg-type]
)
# Run the agent
await agent.run("Test message")
# Verify tool_choice was preserved
assert len(recorded_tool_choices) > 0, "No tool_choice recorded"
last_tool_choice = recorded_tool_choices[-1]
assert last_tool_choice is not None, "tool_choice should not be None"
assert str(last_tool_choice) == "required", f"Expected 'required', got {last_tool_choice}"
async def test_return_to_previous_state_serialization():
"""Test that return_to_previous state is properly serialized/deserialized for checkpointing."""
from agent_framework._workflows._handoff import _HandoffCoordinator # type: ignore[reportPrivateUsage]
# Create a coordinator with return_to_previous enabled
coordinator = _HandoffCoordinator(
starting_agent_id="triage",
specialist_ids={"specialist_a": "specialist_a", "specialist_b": "specialist_b"},
input_gateway_id="gateway",
termination_condition=lambda conv: False,
id="test-coordinator",
return_to_previous=True,
)
# Set the current agent (simulating a handoff scenario)
coordinator._current_agent_id = "specialist_a" # type: ignore[reportPrivateUsage]
# Snapshot the state
state = coordinator.snapshot_state()
# Verify pattern metadata includes current_agent_id
assert "metadata" in state
assert "current_agent_id" in state["metadata"]
assert state["metadata"]["current_agent_id"] == "specialist_a"
# Create a new coordinator and restore state
coordinator2 = _HandoffCoordinator(
starting_agent_id="triage",
specialist_ids={"specialist_a": "specialist_a", "specialist_b": "specialist_b"},
input_gateway_id="gateway",
termination_condition=lambda conv: False,
id="test-coordinator",
return_to_previous=True,
)
# Restore state
coordinator2.restore_state(state)
# Verify current_agent_id was restored
assert coordinator2._current_agent_id == "specialist_a", "Current agent should be restored from checkpoint" # type: ignore[reportPrivateUsage]