mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: [Feature Branch] Merge from main to Azure AI branch (#2111)
* Do not build DevUI assets during .NET project build (#2010) * .NET: Add unit tests for declarative executor SetMultipleVariables (#2016) * Add unit tests for create conversation executor * Update indentation and comment typo. * Added unit tests for declarative executor SetMultipleVariablesExecutor * Updated comments and syntactic sugar * Python: DevUI: Use metadata.entity_id instead of model field (#1984) * DevUI: Use metadata.entity_id for agent/workflow name instead of model field * OpenAI Responses: add explicit request validation * Review feedback * .NET: DevUI - Do not automatically add/map OpenAI services/endpoints (#2014) * Don't add OpenAIResponses as part of Dev UI You should be able to add and remove Dev UI without impacting your other production endpoints. * Remove `AddDevUI()` and do not map OpenAI endpoints from `MapDevUI()` * Fix comment wording * Revise documentation --------- Co-authored-by: Daniel Roth <daroth@microsoft.com> * Python: DevUI: Add OpenAI Responses API proxy support + HIL for Workflows (#1737) * DevUI: Add OpenAI Responses API proxy support with enhanced UI features This commit adds support for proxying requests to OpenAI's Responses API, allowing DevUI to route conversations to OpenAI models when configured to enable testing. Backend changes: - Add OpenAI proxy executor with conversation routing logic - Enhance event mapper to support OpenAI Responses API format - Extend server endpoints to handle OpenAI proxy mode - Update models with OpenAI-specific response types - Remove emojis from logging and CLI output for cleaner text Frontend changes: - Add settings modal with OpenAI proxy configuration UI - Enhance agent and workflow views with improved state management - Add new UI components (separator, switch) for settings - Update debug panel with better event filtering - Improve message renderers for OpenAI content types - Update types and API client for OpenAI integration * update ui, settings modal and workflow input form, add register cleanup hooks. * add workflow HIL support, user mode, other fixes * feat(devui): add human-in-the-loop (HIL) support with dynamic response schemas Implement HIL workflow support allowing workflows to pause for user input with dynamically generated JSON schemas based on response handler type hints. Key Features: - Automatic response schema extraction from @response_handler decorators - Dynamic form generation in UI based on Pydantic/dataclass response types - Checkpoint-based conversation storage for HIL requests/responses - Resume workflow execution after user provides HIL response Backend Changes: - Add extract_response_type_from_executor() to introspect response handlers - Enrich RequestInfoEvent with response_schema via _enrich_request_info_event_with_response_schema() - Map RequestInfoEvent to response.input.requested OpenAI event format - Store HIL responses in conversation history and restore checkpoints Frontend Changes: - Add HILInputModal component with SchemaFormRenderer for dynamic forms - Support Pydantic BaseModel and dataclass response types - Render enum fields as dropdowns, strings as text/textarea, numbers, booleans, arrays, objects - Display original request context alongside response form Testing: - Add tests for checkpoint storage (test_checkpoints.py) - Add schema generation tests for all input types (test_schema_generation.py) - Validate end-to-end HIL flow with spam workflow sample This enables workflows to seamlessly pause execution and request structured user input with type-safe, validated forms generated automatically from response type annotations. * improve HIL support, improve workflow execution view * ui updates * ui updates * improve HIL for workflows, add auth and view modes * update workflow * security improvements , ui fixes * fix mypy error * update loading spinner in ui --------- Co-authored-by: Mark Wallace <127216156+markwallace-microsoft@users.noreply.github.com> * .NET: Remove launchSettings.json from .gitignore in dotnet/samples (#2006) * Remove launchSettings.json from .gitignore in dotnet/samples * Update dotnet/samples/GettingStarted/DevUI/DevUI_Step01_BasicUsage/Properties/launchSettings.json Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update dotnet/samples/AGUIClientServer/AGUIServer/Properties/launchSettings.json Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * DevUI: Serialize workflow input as string to maintain conformance with OpenAI Responses format (#2021) Co-authored-by: Victor Dibia <chuvidi2003@gmail.com> * Add Microsoft Agent Framework logo to assets (#2007) * Updated package versions (#2027) * DevUI: Prevent line breaks within words in the agent view (#2024) Co-authored-by: Victor Dibia <chuvidi2003@gmail.com> * .NET [AG-UI]: Adds support for shared state. (#1996) * Product changes * Tests * Dojo project * Cleanups * Python: Fix underlying tool choice bug and all for return to previous Handoff subagent (#2037) * Fix tool_choice override bug and add enable_return_to_previous support * Add unit test for handoff checkpointing * Handle tools when we have them * added missing chatAgent params (#2044) * .NET: fix ChatCompletions Tools serialization (#2043) * fix serialization in chat completions on tools * nit * .NET: assign AgentCard's URL to mapped-endpoint if not defined explicitly (#2047) * fix serialization in chat completions on tools * nit * write e2e test for agent card resolve + adjust behavior * nit * Version 1.0.0-preview.251110.1 (#2048) * .NET: Remove moved OpenAPI sample and point to SK one. (#1997) * Remove moved OpenAPI sample and point to SK one. * Update dotnet/samples/GettingStarted/Agents/README.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Bump AWSSDK.Extensions.Bedrock.MEAI from 4.0.4.2 to 4.0.4.6 (#2031) --- updated-dependencies: - dependency-name: AWSSDK.Extensions.Bedrock.MEAI dependency-version: 4.0.4.6 dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * .NET: Separate all memory and rag samples into their own folders (#2000) * Separate all memory and rag samples into their own folders * Fix broken link. * Python: .Net: Dotnet devui compatibility fixes (#2026) * DevUI: Add OpenAI Responses API proxy support with enhanced UI features This commit adds support for proxying requests to OpenAI's Responses API, allowing DevUI to route conversations to OpenAI models when configured to enable testing. Backend changes: - Add OpenAI proxy executor with conversation routing logic - Enhance event mapper to support OpenAI Responses API format - Extend server endpoints to handle OpenAI proxy mode - Update models with OpenAI-specific response types - Remove emojis from logging and CLI output for cleaner text Frontend changes: - Add settings modal with OpenAI proxy configuration UI - Enhance agent and workflow views with improved state management - Add new UI components (separator, switch) for settings - Update debug panel with better event filtering - Improve message renderers for OpenAI content types - Update types and API client for OpenAI integration * update ui, settings modal and workflow input form, add register cleanup hooks. * add workflow HIL support, user mode, other fixes * feat(devui): add human-in-the-loop (HIL) support with dynamic response schemas Implement HIL workflow support allowing workflows to pause for user input with dynamically generated JSON schemas based on response handler type hints. Key Features: - Automatic response schema extraction from @response_handler decorators - Dynamic form generation in UI based on Pydantic/dataclass response types - Checkpoint-based conversation storage for HIL requests/responses - Resume workflow execution after user provides HIL response Backend Changes: - Add extract_response_type_from_executor() to introspect response handlers - Enrich RequestInfoEvent with response_schema via _enrich_request_info_event_with_response_schema() - Map RequestInfoEvent to response.input.requested OpenAI event format - Store HIL responses in conversation history and restore checkpoints Frontend Changes: - Add HILInputModal component with SchemaFormRenderer for dynamic forms - Support Pydantic BaseModel and dataclass response types - Render enum fields as dropdowns, strings as text/textarea, numbers, booleans, arrays, objects - Display original request context alongside response form Testing: - Add tests for checkpoint storage (test_checkpoints.py) - Add schema generation tests for all input types (test_schema_generation.py) - Validate end-to-end HIL flow with spam workflow sample This enables workflows to seamlessly pause execution and request structured user input with type-safe, validated forms generated automatically from response type annotations. * improve HIL support, improve workflow execution view * ui updates * ui updates * improve HIL for workflows, add auth and view modes * update workflow * security improvements , ui fixes * fix mypy error * update loading spinner in ui * DevUI: Serialize workflow input as string to maintain conformance with OpenAI Responses format * Phase 1: Add /meta endpoint and fix workflow event naming for .NET DevUI compatibility * additional fixes for .NET DevUI workflow visualization item ID tracking **Problem:** .NET DevUI was generating different item IDs for ExecutorInvokedEvent and ExecutorCompletedEvent, causing only the first executor to highlight in the workflow graph. Long executor names and error messages also broke UI layout. **Changes:** - Add ExecutorActionItemResource to match Python DevUI implementation - Track item IDs per executor using dictionary in AgentRunResponseUpdateExtensions - Reuse same item ID across invoked/completed/failed events for proper pairing - Add truncateText() utility to workflow-utils.ts - Truncate executor names to 35 chars in execution timeline - Truncate error messages to 150 chars in workflow graph nodes ** Details:** - ExecutorActionItemResource registered with JSON source generation context - Dictionary cleaned up after executor completion/failure to prevent memory leaks - Frontend item tracking by unique item.id supports multiple executor runs - All changes follow existing codebase patterns and conventions Tested with review-workflow showing correct executor highlighting and state transitions for sequential and concurrent executors. * format fixes, remove cors tests * remove unecessary attributes --------- Co-authored-by: Mark Wallace <127216156+markwallace-microsoft@users.noreply.github.com> Co-authored-by: Reuben Bond <reuben.bond@gmail.com> * DevUI: support having both an agent and a workflow with the same id in discovery (#2023) * Python: Fix Model ID attribute not showing up in `invoke_agent` span (#2061) * Best effort to surface the model id to invoke agent span * Fix tests * Fix tests * Version 1.0.0-preview.251107.2 (#2065) * Version 1.0.0-preview.251110.2 (#2067) * Update README.md to change Grafana links to Azure portal links for dashboard access (#1983) * .NET - Enable build & test on branch `feature-foundry-agents` (#2068) * Tests good, mkay * Update .github/workflows/dotnet-build-and-test.yml Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Enable feature build pipelines --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Roger Barreto <19890735+rogerbarreto@users.noreply.github.com> * Python: Add concrete AGUIChatClient (#2072) * Add concrete AGUIChatClient * Update logging docstrings and conventions * PR feedback * Updates to support client-side tool calls * .NET: Move catalog samples to the HostedAgents folder (#2090) * move catalog samples to the HostedAgents folder * move the catalog samples' projects to the HostedAgents folder * Bump OpenTelemetry.Instrumentation.Runtime from 1.12.0 to 1.13.0 (#1856) --- updated-dependencies: - dependency-name: OpenTelemetry.Instrumentation.Runtime dependency-version: 1.13.0 dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * .NET: Bump Microsoft.SemanticKernel.Agents.Abstractions from 1.66.0 to 1.67.0 (#1962) * Bump Microsoft.SemanticKernel.Agents.Abstractions from 1.66.0 to 1.67.0 --- updated-dependencies: - dependency-name: Microsoft.SemanticKernel.Agents.Abstractions dependency-version: 1.67.0 dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] <support@github.com> * .NET: Bump all Microsoft.SemanticKernel packages from 1.66.* to 1.67.* (#1969) * Initial plan * Update all Microsoft.SemanticKernel packages to 1.67.* Co-authored-by: rogerbarreto <19890735+rogerbarreto@users.noreply.github.com> * Remove unrelated changes to package-lock.json and yarn.lock Co-authored-by: markwallace-microsoft <127216156+markwallace-microsoft@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: rogerbarreto <19890735+rogerbarreto@users.noreply.github.com> Co-authored-by: markwallace-microsoft <127216156+markwallace-microsoft@users.noreply.github.com> --------- Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com> Co-authored-by: rogerbarreto <19890735+rogerbarreto@users.noreply.github.com> Co-authored-by: markwallace-microsoft <127216156+markwallace-microsoft@users.noreply.github.com> * .NET: fix: WorkflowAsAgent Sample (#1787) * fix: WorkflowAsAgent Sample * Also makes ChatForwardingExecutor public * feat: Expand ChatForwardingExecutor handled types Make ChatForwardingExecutor match the input types of ChatProtocolExecutor. * fix: Update for the new AgentRunResponseUpdate merge logic AIAgent always sends out List<ChatMessage> now. * Updated (#2076) * Bump vite in /python/samples/demos/chatkit-integration/frontend (#1918) Bumps [vite](https://github.com/vitejs/vite/tree/HEAD/packages/vite) from 7.1.9 to 7.1.12. - [Release notes](https://github.com/vitejs/vite/releases) - [Changelog](https://github.com/vitejs/vite/blob/v7.1.12/packages/vite/CHANGELOG.md) - [Commits](https://github.com/vitejs/vite/commits/v7.1.12/packages/vite) --- updated-dependencies: - dependency-name: vite dependency-version: 7.1.12 dependency-type: direct:development ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Bump Roslynator.Analyzers from 4.14.0 to 4.14.1 (#1857) --- updated-dependencies: - dependency-name: Roslynator.Analyzers dependency-version: 4.14.1 dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * Bump MishaKav/pytest-coverage-comment from 1.1.57 to 1.1.59 (#2034) Bumps [MishaKav/pytest-coverage-comment](https://github.com/mishakav/pytest-coverage-comment) from 1.1.57 to 1.1.59. - [Release notes](https://github.com/mishakav/pytest-coverage-comment/releases) - [Changelog](https://github.com/MishaKav/pytest-coverage-comment/blob/main/CHANGELOG.md) - [Commits](https://github.com/mishakav/pytest-coverage-comment/compare/v1.1.57...v1.1.59) --- updated-dependencies: - dependency-name: MishaKav/pytest-coverage-comment dependency-version: 1.1.59 dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Chris <66376200+crickman@users.noreply.github.com> * Python: Handle agent user input request in AgentExecutor (#2022) * Handle agent user input request in AgentExecutor * fix test * Address comments * Fix tests * Fix tests * Address comments * Address comments * Python: OpenAI Responses Image Generation Stream Support, Sample and Unit Tests (#1853) * support for image gen streaming * small fixes * fixes * added comment * Python: Fix MCP Tool Parameter Descriptions Not Propagated to LLMs (#1978) * mcp tool description fix * small fix * .NET: Allow extending agent run options via additional properties (#1872) * Allow extending agent run options via additional properties This mirrors the M.E.AI model in ChatOptions.AdditionalProperties which is very useful when building functionality pipelines. Fixes https://github.com/microsoft/agent-framework/issues/1815 * Expand XML documentation Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Add AdditionalProperties tests to AgentRunOptions Co-authored-by: kzu <169707+kzu@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: kzu <169707+kzu@users.noreply.github.com> * Python: Use the last entry in the task history to avoid empty responses (#2101) * Use the last entry in the task history to avoid empty responses * History only contains Messages * Updated package versions (#2104) --------- Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: Reuben Bond <203839+ReubenBond@users.noreply.github.com> Co-authored-by: Peter Ibekwe <109177538+peibekwe@users.noreply.github.com> Co-authored-by: Jeff Handley <jeffhandley@users.noreply.github.com> Co-authored-by: Daniel Roth <daroth@microsoft.com> Co-authored-by: Victor Dibia <chuvidi2003@gmail.com> Co-authored-by: Mark Wallace <127216156+markwallace-microsoft@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Shawn Henry <sphenry@gmail.com> Co-authored-by: Javier Calvarro Nelson <jacalvar@microsoft.com> Co-authored-by: Evan Mattson <35585003+moonbox3@users.noreply.github.com> Co-authored-by: Eduard van Valkenburg <eavanvalkenburg@users.noreply.github.com> Co-authored-by: Korolev Dmitry <deagle.gross@gmail.com> Co-authored-by: westey <164392973+westey-m@users.noreply.github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Reuben Bond <reuben.bond@gmail.com> Co-authored-by: Tao Chen <taochen@microsoft.com> Co-authored-by: wuweng <wuweng@microsoft.com> Co-authored-by: Chris <66376200+crickman@users.noreply.github.com> Co-authored-by: Roger Barreto <19890735+rogerbarreto@users.noreply.github.com> Co-authored-by: SergeyMenshykh <68852919+SergeyMenshykh@users.noreply.github.com> Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com> Co-authored-by: Jacob Alber <jaalber@microsoft.com> Co-authored-by: Giles Odigwe <79032838+giles17@users.noreply.github.com> Co-authored-by: Daniel Cazzulino <daniel@cazzulino.com> Co-authored-by: kzu <169707+kzu@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
85fcd230bf
commit
361c47f30f
@@ -587,9 +587,11 @@ class ChatAgent(BaseAgent):
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
chat_message_store_factory: Callable[[], ChatMessageStoreProtocol] | None = None,
|
||||
conversation_id: str | None = None,
|
||||
context_providers: ContextProvider | list[ContextProvider] | AggregateContextProvider | None = None,
|
||||
middleware: Middleware | list[Middleware] | None = None,
|
||||
# chat option params
|
||||
allow_multiple_tool_calls: bool | None = None,
|
||||
conversation_id: str | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str | int, float] | None = None,
|
||||
max_tokens: int | None = None,
|
||||
@@ -630,15 +632,17 @@ class ChatAgent(BaseAgent):
|
||||
description: A brief description of the agent's purpose.
|
||||
chat_message_store_factory: Factory function to create an instance of ChatMessageStoreProtocol.
|
||||
If not provided, the default in-memory store will be used.
|
||||
conversation_id: The conversation ID for service-managed threads.
|
||||
Cannot be used together with chat_message_store_factory.
|
||||
context_providers: The collection of multiple context providers to include during agent invocation.
|
||||
middleware: List of middleware to intercept agent and function invocations.
|
||||
allow_multiple_tool_calls: Whether to allow multiple tool calls in a single response.
|
||||
conversation_id: The conversation ID for service-managed threads.
|
||||
Cannot be used together with chat_message_store_factory.
|
||||
frequency_penalty: The frequency penalty to use.
|
||||
logit_bias: The logit bias to use.
|
||||
max_tokens: The maximum number of tokens to generate.
|
||||
metadata: Additional metadata to include in the request.
|
||||
model_id: The model_id to use for the agent.
|
||||
This overrides the model_id set in the chat client if it contains one.
|
||||
presence_penalty: The presence penalty to use.
|
||||
response_format: The format of the response.
|
||||
seed: The random seed to use.
|
||||
@@ -687,7 +691,8 @@ class ChatAgent(BaseAgent):
|
||||
self._local_mcp_tools = [tool for tool in normalized_tools if isinstance(tool, MCPTool)]
|
||||
agent_tools = [tool for tool in normalized_tools if not isinstance(tool, MCPTool)]
|
||||
self.chat_options = ChatOptions(
|
||||
model_id=model_id,
|
||||
model_id=model_id or (str(chat_client.model_id) if hasattr(chat_client, "model_id") else None),
|
||||
allow_multiple_tool_calls=allow_multiple_tool_calls,
|
||||
conversation_id=conversation_id,
|
||||
frequency_penalty=frequency_penalty,
|
||||
instructions=instructions,
|
||||
@@ -758,6 +763,7 @@ class ChatAgent(BaseAgent):
|
||||
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
|
||||
*,
|
||||
thread: AgentThread | None = None,
|
||||
allow_multiple_tool_calls: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str | int, float] | None = None,
|
||||
max_tokens: int | None = None,
|
||||
@@ -793,6 +799,7 @@ class ChatAgent(BaseAgent):
|
||||
|
||||
Keyword Args:
|
||||
thread: The thread to use for the agent.
|
||||
allow_multiple_tool_calls: Whether to allow multiple tool calls in a single response.
|
||||
frequency_penalty: The frequency penalty to use.
|
||||
logit_bias: The logit bias to use.
|
||||
max_tokens: The maximum number of tokens to generate.
|
||||
@@ -844,6 +851,7 @@ class ChatAgent(BaseAgent):
|
||||
co = run_chat_options & ChatOptions(
|
||||
model_id=model_id,
|
||||
conversation_id=thread.service_thread_id,
|
||||
allow_multiple_tool_calls=allow_multiple_tool_calls,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
max_tokens=max_tokens,
|
||||
@@ -887,6 +895,7 @@ class ChatAgent(BaseAgent):
|
||||
messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None,
|
||||
*,
|
||||
thread: AgentThread | None = None,
|
||||
allow_multiple_tool_calls: bool | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str | int, float] | None = None,
|
||||
max_tokens: int | None = None,
|
||||
@@ -922,6 +931,7 @@ class ChatAgent(BaseAgent):
|
||||
|
||||
Keyword Args:
|
||||
thread: The thread to use for the agent.
|
||||
allow_multiple_tool_calls: Whether to allow multiple tool calls in a single response.
|
||||
frequency_penalty: The frequency penalty to use.
|
||||
logit_bias: The logit bias to use.
|
||||
max_tokens: The maximum number of tokens to generate.
|
||||
@@ -971,6 +981,7 @@ class ChatAgent(BaseAgent):
|
||||
|
||||
co = run_chat_options & ChatOptions(
|
||||
conversation_id=thread.service_thread_id,
|
||||
allow_multiple_tool_calls=allow_multiple_tool_calls,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
max_tokens=max_tokens,
|
||||
|
||||
@@ -224,7 +224,7 @@ def _merge_chat_options(
|
||||
stop: str | Sequence[str] | None = None,
|
||||
store: bool | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
|
||||
tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = None,
|
||||
tools: list[ToolProtocol | dict[str, Any] | Callable[..., Any]] | None = None,
|
||||
top_p: float | None = None,
|
||||
user: str | None = None,
|
||||
@@ -496,7 +496,7 @@ class BaseChatClient(SerializationMixin, ABC):
|
||||
stop: str | Sequence[str] | None = None,
|
||||
store: bool | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
|
||||
tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = None,
|
||||
tools: ToolProtocol
|
||||
| Callable[..., Any]
|
||||
| MutableMapping[str, Any]
|
||||
@@ -591,7 +591,7 @@ class BaseChatClient(SerializationMixin, ABC):
|
||||
stop: str | Sequence[str] | None = None,
|
||||
store: bool | None = None,
|
||||
temperature: float | None = None,
|
||||
tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = "auto",
|
||||
tool_choice: ToolMode | Literal["auto", "required", "none"] | dict[str, Any] | None = None,
|
||||
tools: ToolProtocol
|
||||
| Callable[..., Any]
|
||||
| MutableMapping[str, Any]
|
||||
@@ -714,6 +714,8 @@ class BaseChatClient(SerializationMixin, ABC):
|
||||
chat_message_store_factory: Callable[[], ChatMessageStoreProtocol] | None = None,
|
||||
context_providers: ContextProvider | list[ContextProvider] | AggregateContextProvider | None = None,
|
||||
middleware: Middleware | list[Middleware] | None = None,
|
||||
allow_multiple_tool_calls: bool | None = None,
|
||||
conversation_id: str | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[str | int, float] | None = None,
|
||||
max_tokens: int | None = None,
|
||||
@@ -751,6 +753,8 @@ class BaseChatClient(SerializationMixin, ABC):
|
||||
If not provided, the default in-memory store will be used.
|
||||
context_providers: Context providers to include during agent invocation.
|
||||
middleware: List of middleware to intercept agent and function invocations.
|
||||
allow_multiple_tool_calls: Whether to allow multiple tool calls per agent turn.
|
||||
conversation_id: The conversation ID to associate with the agent's messages.
|
||||
frequency_penalty: The frequency penalty to use.
|
||||
logit_bias: The logit bias to use.
|
||||
max_tokens: The maximum number of tokens to generate.
|
||||
@@ -801,6 +805,8 @@ class BaseChatClient(SerializationMixin, ABC):
|
||||
chat_message_store_factory=chat_message_store_factory,
|
||||
context_providers=context_providers,
|
||||
middleware=middleware,
|
||||
allow_multiple_tool_calls=allow_multiple_tool_calls,
|
||||
conversation_id=conversation_id,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
max_tokens=max_tokens,
|
||||
|
||||
@@ -19,7 +19,7 @@ from mcp.client.websocket import websocket_client
|
||||
from mcp.shared.context import RequestContext
|
||||
from mcp.shared.exceptions import McpError
|
||||
from mcp.shared.session import RequestResponder
|
||||
from pydantic import BaseModel, create_model
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
from ._tools import AIFunction, HostedMCPSpecificApproval
|
||||
from ._types import ChatMessage, Contents, DataContent, Role, TextContent, UriContent
|
||||
@@ -224,13 +224,20 @@ def _get_input_model_from_mcp_tool(tool: types.Tool) -> type[BaseModel]:
|
||||
prop_details = json.loads(prop_details) if isinstance(prop_details, str) else prop_details
|
||||
|
||||
python_type = resolve_type(prop_details)
|
||||
description = prop_details.get("description", "")
|
||||
|
||||
# Create field definition for create_model
|
||||
if prop_name in required:
|
||||
field_definitions[prop_name] = (python_type, ...)
|
||||
field_definitions[prop_name] = (
|
||||
(python_type, Field(description=description)) if description else (python_type, ...)
|
||||
)
|
||||
else:
|
||||
default_value = prop_details.get("default", None)
|
||||
field_definitions[prop_name] = (python_type, default_value)
|
||||
field_definitions[prop_name] = (
|
||||
(python_type, Field(default=default_value, description=description))
|
||||
if description
|
||||
else (python_type, default_value)
|
||||
)
|
||||
|
||||
return create_model(f"{tool.name}_input", **field_definitions)
|
||||
|
||||
|
||||
@@ -1525,6 +1525,12 @@ def _handle_function_calls_response(
|
||||
prepped_messages = prepare_messages(messages)
|
||||
response: "ChatResponse | None" = None
|
||||
fcc_messages: "list[ChatMessage]" = []
|
||||
|
||||
# If tools are provided but tool_choice is not set, default to "auto" for function invocation
|
||||
tools = _extract_tools(kwargs)
|
||||
if tools and kwargs.get("tool_choice") is None:
|
||||
kwargs["tool_choice"] = "auto"
|
||||
|
||||
for attempt_idx in range(config.max_iterations if config.enabled else 0):
|
||||
fcc_todo = _collect_approval_responses(prepped_messages)
|
||||
if fcc_todo:
|
||||
|
||||
@@ -1050,6 +1050,50 @@ class DataContent(BaseContent):
|
||||
def has_top_level_media_type(self, top_level_media_type: Literal["application", "audio", "image", "text"]) -> bool:
|
||||
return _has_top_level_media_type(self.media_type, top_level_media_type)
|
||||
|
||||
@staticmethod
|
||||
def detect_image_format_from_base64(image_base64: str) -> str:
|
||||
"""Detect image format from base64 data by examining the binary header.
|
||||
|
||||
Args:
|
||||
image_base64: Base64 encoded image data
|
||||
|
||||
Returns:
|
||||
Image format as string (png, jpeg, webp, gif) with png as fallback
|
||||
"""
|
||||
try:
|
||||
# Constants for image format detection
|
||||
# ~75 bytes of binary data should be enough to detect most image formats
|
||||
FORMAT_DETECTION_BASE64_CHARS = 100
|
||||
|
||||
# Decode a small portion to detect format
|
||||
decoded_data = base64.b64decode(image_base64[:FORMAT_DETECTION_BASE64_CHARS])
|
||||
if decoded_data.startswith(b"\x89PNG"):
|
||||
return "png"
|
||||
if decoded_data.startswith(b"\xff\xd8\xff"):
|
||||
return "jpeg"
|
||||
if decoded_data.startswith(b"RIFF") and b"WEBP" in decoded_data[:12]:
|
||||
return "webp"
|
||||
if decoded_data.startswith(b"GIF87a") or decoded_data.startswith(b"GIF89a"):
|
||||
return "gif"
|
||||
return "png" # Default fallback
|
||||
except Exception:
|
||||
return "png" # Fallback if decoding fails
|
||||
|
||||
@classmethod
|
||||
def create_data_uri_from_base64(cls, image_base64: str) -> tuple[str, str]:
|
||||
"""Create a data URI and media type from base64 image data.
|
||||
|
||||
Args:
|
||||
image_base64: Base64 encoded image data
|
||||
|
||||
Returns:
|
||||
Tuple of (data_uri, media_type)
|
||||
"""
|
||||
format_type = cls.detect_image_format_from_base64(image_base64)
|
||||
uri = f"data:image/{format_type};base64,{image_base64}"
|
||||
media_type = f"image/{format_type}"
|
||||
return uri, media_type
|
||||
|
||||
|
||||
class UriContent(BaseContent):
|
||||
"""Represents a URI content.
|
||||
|
||||
@@ -2,11 +2,14 @@
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from agent_framework import FunctionApprovalRequestContent, FunctionApprovalResponseContent
|
||||
|
||||
from .._agents import AgentProtocol, ChatAgent
|
||||
from .._threads import AgentThread
|
||||
from .._types import AgentRunResponse, AgentRunResponseUpdate, ChatMessage
|
||||
from ._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value
|
||||
from ._conversation_state import encode_chat_messages
|
||||
from ._events import (
|
||||
AgentRunEvent,
|
||||
@@ -14,6 +17,7 @@ from ._events import (
|
||||
)
|
||||
from ._executor import Executor, handler
|
||||
from ._message_utils import normalize_messages_input
|
||||
from ._request_info_mixin import response_handler
|
||||
from ._workflow_context import WorkflowContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -83,6 +87,8 @@ class AgentExecutor(Executor):
|
||||
super().__init__(exec_id)
|
||||
self._agent = agent
|
||||
self._agent_thread = agent_thread or self._agent.get_new_thread()
|
||||
self._pending_agent_requests: dict[str, FunctionApprovalRequestContent] = {}
|
||||
self._pending_responses_to_agent: list[FunctionApprovalResponseContent] = []
|
||||
self._output_response = output_response
|
||||
self._cache: list[ChatMessage] = []
|
||||
|
||||
@@ -93,50 +99,6 @@ class AgentExecutor(Executor):
|
||||
return [AgentRunResponse]
|
||||
return []
|
||||
|
||||
async def _run_agent_and_emit(self, ctx: WorkflowContext[AgentExecutorResponse, AgentRunResponse]) -> None:
|
||||
"""Execute the underlying agent, emit events, and enqueue response.
|
||||
|
||||
Checks ctx.is_streaming() to determine whether to emit incremental AgentRunUpdateEvent
|
||||
events (streaming mode) or a single AgentRunEvent (non-streaming mode).
|
||||
"""
|
||||
if ctx.is_streaming():
|
||||
# Streaming mode: emit incremental updates
|
||||
updates: list[AgentRunResponseUpdate] = []
|
||||
async for update in self._agent.run_stream(
|
||||
self._cache,
|
||||
thread=self._agent_thread,
|
||||
):
|
||||
updates.append(update)
|
||||
await ctx.add_event(AgentRunUpdateEvent(self.id, update))
|
||||
|
||||
if isinstance(self._agent, ChatAgent):
|
||||
response_format = self._agent.chat_options.response_format
|
||||
response = AgentRunResponse.from_agent_run_response_updates(
|
||||
updates,
|
||||
output_format_type=response_format,
|
||||
)
|
||||
else:
|
||||
response = AgentRunResponse.from_agent_run_response_updates(updates)
|
||||
else:
|
||||
# Non-streaming mode: use run() and emit single event
|
||||
response = await self._agent.run(
|
||||
self._cache,
|
||||
thread=self._agent_thread,
|
||||
)
|
||||
await ctx.add_event(AgentRunEvent(self.id, response))
|
||||
|
||||
if self._output_response:
|
||||
await ctx.yield_output(response)
|
||||
|
||||
# Always construct a full conversation snapshot from inputs (cache)
|
||||
# plus agent outputs (agent_run_response.messages). Do not mutate
|
||||
# response.messages so AgentRunEvent remains faithful to the raw output.
|
||||
full_conversation: list[ChatMessage] = list(self._cache) + list(response.messages)
|
||||
|
||||
agent_response = AgentExecutorResponse(self.id, response, full_conversation=full_conversation)
|
||||
await ctx.send_message(agent_response)
|
||||
self._cache.clear()
|
||||
|
||||
@handler
|
||||
async def run(
|
||||
self, request: AgentExecutorRequest, ctx: WorkflowContext[AgentExecutorResponse, AgentRunResponse]
|
||||
@@ -192,6 +154,31 @@ class AgentExecutor(Executor):
|
||||
self._cache = normalize_messages_input(messages)
|
||||
await self._run_agent_and_emit(ctx)
|
||||
|
||||
@response_handler
|
||||
async def handle_user_input_response(
|
||||
self,
|
||||
original_request: FunctionApprovalRequestContent,
|
||||
response: FunctionApprovalResponseContent,
|
||||
ctx: WorkflowContext[AgentExecutorResponse, AgentRunResponse],
|
||||
) -> None:
|
||||
"""Handle user input responses for function approvals during agent execution.
|
||||
|
||||
This will hold the executor's execution until all pending user input requests are resolved.
|
||||
|
||||
Args:
|
||||
original_request: The original function approval request sent by the agent.
|
||||
response: The user's response to the function approval request.
|
||||
ctx: The workflow context for emitting events and outputs.
|
||||
"""
|
||||
self._pending_responses_to_agent.append(response)
|
||||
self._pending_agent_requests.pop(original_request.id, None)
|
||||
|
||||
if not self._pending_agent_requests:
|
||||
# All pending requests have been resolved; resume agent execution
|
||||
self._cache = normalize_messages_input(ChatMessage(role="user", contents=self._pending_responses_to_agent))
|
||||
self._pending_responses_to_agent.clear()
|
||||
await self._run_agent_and_emit(ctx)
|
||||
|
||||
async def snapshot_state(self) -> dict[str, Any]:
|
||||
"""Capture current executor state for checkpointing.
|
||||
|
||||
@@ -226,6 +213,8 @@ class AgentExecutor(Executor):
|
||||
return {
|
||||
"cache": encode_chat_messages(self._cache),
|
||||
"agent_thread": serialized_thread,
|
||||
"pending_agent_requests": encode_checkpoint_value(self._pending_agent_requests),
|
||||
"pending_responses_to_agent": encode_checkpoint_value(self._pending_responses_to_agent),
|
||||
}
|
||||
|
||||
async def restore_state(self, state: dict[str, Any]) -> None:
|
||||
@@ -258,7 +247,109 @@ class AgentExecutor(Executor):
|
||||
else:
|
||||
self._agent_thread = self._agent.get_new_thread()
|
||||
|
||||
pending_requests_payload = state.get("pending_agent_requests")
|
||||
if pending_requests_payload:
|
||||
self._pending_agent_requests = decode_checkpoint_value(pending_requests_payload)
|
||||
|
||||
pending_responses_payload = state.get("pending_responses_to_agent")
|
||||
if pending_responses_payload:
|
||||
self._pending_responses_to_agent = decode_checkpoint_value(pending_responses_payload)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the internal cache of the executor."""
|
||||
logger.debug("AgentExecutor %s: Resetting cache", self.id)
|
||||
self._cache.clear()
|
||||
|
||||
async def _run_agent_and_emit(self, ctx: WorkflowContext[AgentExecutorResponse, AgentRunResponse]) -> None:
|
||||
"""Execute the underlying agent, emit events, and enqueue response.
|
||||
|
||||
Checks ctx.is_streaming() to determine whether to emit incremental AgentRunUpdateEvent
|
||||
events (streaming mode) or a single AgentRunEvent (non-streaming mode).
|
||||
"""
|
||||
if ctx.is_streaming():
|
||||
# Streaming mode: emit incremental updates
|
||||
response = await self._run_agent_streaming(cast(WorkflowContext, ctx))
|
||||
else:
|
||||
# Non-streaming mode: use run() and emit single event
|
||||
response = await self._run_agent(cast(WorkflowContext, ctx))
|
||||
|
||||
if response is None:
|
||||
# Agent did not complete (e.g., waiting for user input); do not emit response
|
||||
logger.info("AgentExecutor %s: Agent did not complete, awaiting user input", self.id)
|
||||
return
|
||||
|
||||
if self._output_response:
|
||||
await ctx.yield_output(response)
|
||||
|
||||
# Always construct a full conversation snapshot from inputs (cache)
|
||||
# plus agent outputs (agent_run_response.messages). Do not mutate
|
||||
# response.messages so AgentRunEvent remains faithful to the raw output.
|
||||
full_conversation: list[ChatMessage] = list(self._cache) + list(response.messages)
|
||||
|
||||
agent_response = AgentExecutorResponse(self.id, response, full_conversation=full_conversation)
|
||||
await ctx.send_message(agent_response)
|
||||
self._cache.clear()
|
||||
|
||||
async def _run_agent(self, ctx: WorkflowContext) -> AgentRunResponse | None:
|
||||
"""Execute the underlying agent in non-streaming mode.
|
||||
|
||||
Args:
|
||||
ctx: The workflow context for emitting events.
|
||||
|
||||
Returns:
|
||||
The complete AgentRunResponse, or None if waiting for user input.
|
||||
"""
|
||||
response = await self._agent.run(
|
||||
self._cache,
|
||||
thread=self._agent_thread,
|
||||
)
|
||||
await ctx.add_event(AgentRunEvent(self.id, response))
|
||||
|
||||
# Handle any user input requests
|
||||
if response.user_input_requests:
|
||||
for user_input_request in response.user_input_requests:
|
||||
self._pending_agent_requests[user_input_request.id] = user_input_request
|
||||
await ctx.request_info(user_input_request, FunctionApprovalResponseContent)
|
||||
return None
|
||||
|
||||
return response
|
||||
|
||||
async def _run_agent_streaming(self, ctx: WorkflowContext) -> AgentRunResponse | None:
|
||||
"""Execute the underlying agent in streaming mode and collect the full response.
|
||||
|
||||
Args:
|
||||
ctx: The workflow context for emitting events.
|
||||
|
||||
Returns:
|
||||
The complete AgentRunResponse, or None if waiting for user input.
|
||||
"""
|
||||
updates: list[AgentRunResponseUpdate] = []
|
||||
user_input_requests: list[FunctionApprovalRequestContent] = []
|
||||
async for update in self._agent.run_stream(
|
||||
self._cache,
|
||||
thread=self._agent_thread,
|
||||
):
|
||||
updates.append(update)
|
||||
await ctx.add_event(AgentRunUpdateEvent(self.id, update))
|
||||
|
||||
if update.user_input_requests:
|
||||
user_input_requests.extend(update.user_input_requests)
|
||||
|
||||
# Build the final AgentRunResponse from the collected updates
|
||||
if isinstance(self._agent, ChatAgent):
|
||||
response_format = self._agent.chat_options.response_format
|
||||
response = AgentRunResponse.from_agent_run_response_updates(
|
||||
updates,
|
||||
output_format_type=response_format,
|
||||
)
|
||||
else:
|
||||
response = AgentRunResponse.from_agent_run_response_updates(updates)
|
||||
|
||||
# Handle any user input requests after the streaming completes
|
||||
if user_input_requests:
|
||||
for user_input_request in user_input_requests:
|
||||
self._pending_agent_requests[user_input_request.id] = user_input_request
|
||||
await ctx.request_info(user_input_request, FunctionApprovalResponseContent)
|
||||
return None
|
||||
|
||||
return response
|
||||
|
||||
@@ -85,8 +85,8 @@ def _clone_chat_agent(agent: ChatAgent) -> ChatAgent:
|
||||
# so we need to recombine them here to pass the complete tools list to the constructor.
|
||||
# This makes sure MCP tools are preserved when cloning agents for handoff workflows.
|
||||
all_tools = list(options.tools) if options.tools else []
|
||||
if agent._local_mcp_tools:
|
||||
all_tools.extend(agent._local_mcp_tools)
|
||||
if agent._local_mcp_tools: # type: ignore
|
||||
all_tools.extend(agent._local_mcp_tools) # type: ignore
|
||||
|
||||
return ChatAgent(
|
||||
chat_client=agent.chat_client,
|
||||
@@ -133,6 +133,14 @@ class _ConversationWithUserInput:
|
||||
full_conversation: list[ChatMessage] = field(default_factory=lambda: []) # type: ignore[misc]
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ConversationForUserInput:
|
||||
"""Internal message from coordinator to gateway specifying which agent will receive the response."""
|
||||
|
||||
conversation: list[ChatMessage]
|
||||
next_agent_id: str
|
||||
|
||||
|
||||
class _AutoHandoffMiddleware(FunctionMiddleware):
|
||||
"""Intercept handoff tool invocations and short-circuit execution with synthetic results."""
|
||||
|
||||
@@ -275,6 +283,7 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
|
||||
termination_condition: Callable[[list[ChatMessage]], bool | Awaitable[bool]],
|
||||
id: str,
|
||||
handoff_tool_targets: Mapping[str, str] | None = None,
|
||||
return_to_previous: bool = False,
|
||||
) -> None:
|
||||
"""Create a coordinator that manages routing between specialists and the user."""
|
||||
super().__init__(id)
|
||||
@@ -284,6 +293,8 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
|
||||
self._input_gateway_id = input_gateway_id
|
||||
self._termination_condition = termination_condition
|
||||
self._handoff_tool_targets = {k.lower(): v for k, v in (handoff_tool_targets or {}).items()}
|
||||
self._return_to_previous = return_to_previous
|
||||
self._current_agent_id: str | None = None # Track the current agent handling conversation
|
||||
|
||||
def _get_author_name(self) -> str:
|
||||
"""Get the coordinator name for orchestrator-generated messages."""
|
||||
@@ -293,7 +304,7 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
|
||||
async def handle_agent_response(
|
||||
self,
|
||||
response: AgentExecutorResponse,
|
||||
ctx: WorkflowContext[AgentExecutorRequest | list[ChatMessage], list[ChatMessage]],
|
||||
ctx: WorkflowContext[AgentExecutorRequest | list[ChatMessage], list[ChatMessage] | _ConversationForUserInput],
|
||||
) -> None:
|
||||
"""Process an agent's response and determine whether to route, request input, or terminate."""
|
||||
# Hydrate coordinator state (and detect new run) using checkpointable executor state
|
||||
@@ -329,6 +340,9 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
|
||||
# Check for handoff from ANY agent (starting agent or specialist)
|
||||
target = self._resolve_specialist(response.agent_run_response, conversation)
|
||||
if target is not None:
|
||||
# Update current agent when handoff occurs
|
||||
self._current_agent_id = target
|
||||
logger.info(f"Handoff detected: {source} -> {target}. Routing control to specialist '{target}'.")
|
||||
await self._persist_state(ctx)
|
||||
# Clean tool-related content before sending to next agent
|
||||
cleaned = clean_conversation_for_handoff(conversation)
|
||||
@@ -340,10 +354,15 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
|
||||
if not is_starting_agent and source not in self._specialist_ids:
|
||||
raise RuntimeError(f"HandoffCoordinator received response from unknown executor '{source}'.")
|
||||
|
||||
# Update current agent when they respond without handoff
|
||||
self._current_agent_id = source
|
||||
logger.info(
|
||||
f"Agent '{source}' responded without handoff. "
|
||||
f"Requesting user input. Return-to-previous: {self._return_to_previous}"
|
||||
)
|
||||
await self._persist_state(ctx)
|
||||
|
||||
if await self._check_termination():
|
||||
logger.info("Handoff workflow termination condition met. Ending conversation.")
|
||||
# Clean the output conversation for display
|
||||
cleaned_output = clean_conversation_for_handoff(conversation)
|
||||
await ctx.yield_output(cleaned_output)
|
||||
@@ -352,7 +371,13 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
|
||||
# Clean conversation before sending to gateway for user input request
|
||||
# This removes tool messages that shouldn't be shown to users
|
||||
cleaned_for_display = clean_conversation_for_handoff(conversation)
|
||||
await ctx.send_message(cleaned_for_display, target_id=self._input_gateway_id)
|
||||
|
||||
# The awaiting_agent_id is the agent that just responded and is awaiting user input
|
||||
# This is the source of the current response
|
||||
next_agent_id = source
|
||||
|
||||
message_to_gateway = _ConversationForUserInput(conversation=cleaned_for_display, next_agent_id=next_agent_id)
|
||||
await ctx.send_message(message_to_gateway, target_id=self._input_gateway_id) # type: ignore[arg-type]
|
||||
|
||||
@handler
|
||||
async def handle_user_input(
|
||||
@@ -367,14 +392,26 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
|
||||
|
||||
# Check termination before sending to agent
|
||||
if await self._check_termination():
|
||||
logger.info("Handoff workflow termination condition met. Ending conversation.")
|
||||
await ctx.yield_output(list(self._conversation))
|
||||
return
|
||||
|
||||
# Clean before sending to starting agent
|
||||
# Determine routing target based on return-to-previous setting
|
||||
target_agent_id = self._starting_agent_id
|
||||
if self._return_to_previous and self._current_agent_id:
|
||||
# Route back to the current agent that's handling the conversation
|
||||
target_agent_id = self._current_agent_id
|
||||
logger.info(
|
||||
f"Return-to-previous enabled: routing user input to current agent '{target_agent_id}' "
|
||||
f"(bypassing coordinator '{self._starting_agent_id}')"
|
||||
)
|
||||
else:
|
||||
logger.info(f"Routing user input to coordinator '{target_agent_id}'")
|
||||
# Note: Stack is only used for specialist-to-specialist handoffs, not user input routing
|
||||
|
||||
# Clean before sending to target agent
|
||||
cleaned = clean_conversation_for_handoff(self._conversation)
|
||||
request = AgentExecutorRequest(messages=cleaned, should_respond=True)
|
||||
await ctx.send_message(request, target_id=self._starting_agent_id)
|
||||
await ctx.send_message(request, target_id=target_agent_id)
|
||||
|
||||
def _resolve_specialist(self, agent_response: AgentRunResponse, conversation: list[ChatMessage]) -> str | None:
|
||||
"""Resolve the specialist executor id requested by the agent response, if any."""
|
||||
@@ -444,22 +481,27 @@ class _HandoffCoordinator(BaseGroupChatOrchestrator):
|
||||
def _snapshot_pattern_metadata(self) -> dict[str, Any]:
|
||||
"""Serialize pattern-specific state.
|
||||
|
||||
Handoff has no additional metadata beyond base conversation state.
|
||||
Includes the current agent for return-to-previous routing.
|
||||
|
||||
Returns:
|
||||
Empty dict (no pattern-specific state)
|
||||
Dict containing current agent if return-to-previous is enabled
|
||||
"""
|
||||
if self._return_to_previous:
|
||||
return {
|
||||
"current_agent_id": self._current_agent_id,
|
||||
}
|
||||
return {}
|
||||
|
||||
def _restore_pattern_metadata(self, metadata: dict[str, Any]) -> None:
|
||||
"""Restore pattern-specific state.
|
||||
|
||||
Handoff has no additional metadata beyond base conversation state.
|
||||
Restores the current agent for return-to-previous routing.
|
||||
|
||||
Args:
|
||||
metadata: Pattern-specific state dict (ignored)
|
||||
metadata: Pattern-specific state dict
|
||||
"""
|
||||
pass
|
||||
if self._return_to_previous and "current_agent_id" in metadata:
|
||||
self._current_agent_id = metadata["current_agent_id"]
|
||||
|
||||
def _restore_conversation_from_state(self, state: Mapping[str, Any]) -> list[ChatMessage]:
|
||||
"""Rehydrate the coordinator's conversation history from checkpointed state.
|
||||
@@ -507,8 +549,21 @@ class _UserInputGateway(Executor):
|
||||
self._prompt = prompt or "Provide your next input for the conversation."
|
||||
|
||||
@handler
|
||||
async def request_input(self, conversation: list[ChatMessage], ctx: WorkflowContext) -> None:
|
||||
async def request_input(self, message: _ConversationForUserInput, ctx: WorkflowContext) -> None:
|
||||
"""Emit a `HandoffUserInputRequest` capturing the conversation snapshot."""
|
||||
if not message.conversation:
|
||||
raise ValueError("Handoff workflow requires non-empty conversation before requesting user input.")
|
||||
request = HandoffUserInputRequest(
|
||||
conversation=list(message.conversation),
|
||||
awaiting_agent_id=message.next_agent_id,
|
||||
prompt=self._prompt,
|
||||
source_executor_id=self.id,
|
||||
)
|
||||
await ctx.request_info(request, object)
|
||||
|
||||
@handler
|
||||
async def request_input_legacy(self, conversation: list[ChatMessage], ctx: WorkflowContext) -> None:
|
||||
"""Legacy handler for backward compatibility - emit user input request with starting agent."""
|
||||
if not conversation:
|
||||
raise ValueError("Handoff workflow requires non-empty conversation before requesting user input.")
|
||||
request = HandoffUserInputRequest(
|
||||
@@ -558,7 +613,7 @@ def _as_user_messages(payload: Any) -> list[ChatMessage]:
|
||||
|
||||
|
||||
def _default_termination_condition(conversation: list[ChatMessage]) -> bool:
|
||||
"""Default termination: stop after 10 user messages to prevent infinite loops."""
|
||||
"""Default termination: stop after 10 user messages."""
|
||||
user_message_count = sum(1 for msg in conversation if msg.role == Role.USER)
|
||||
return user_message_count >= 10
|
||||
|
||||
@@ -743,6 +798,7 @@ class HandoffBuilder:
|
||||
)
|
||||
self._auto_register_handoff_tools: bool = True
|
||||
self._handoff_config: dict[str, list[str]] = {} # Maps agent_id -> [target_agent_ids]
|
||||
self._return_to_previous: bool = False
|
||||
|
||||
if participants:
|
||||
self.participants(participants)
|
||||
@@ -1198,6 +1254,77 @@ class HandoffBuilder:
|
||||
self._termination_condition = condition
|
||||
return self
|
||||
|
||||
def enable_return_to_previous(self, enabled: bool = True) -> "HandoffBuilder":
|
||||
"""Enable direct return to the current agent after user input, bypassing the coordinator.
|
||||
|
||||
When enabled, after a specialist responds without requesting another handoff, user input
|
||||
routes directly back to that same specialist instead of always routing back to the
|
||||
coordinator agent for re-evaluation.
|
||||
|
||||
This is useful when a specialist needs multiple turns with the user to gather information
|
||||
or resolve an issue, avoiding unnecessary coordinator involvement while maintaining context.
|
||||
|
||||
Flow Comparison:
|
||||
|
||||
**Default (disabled):**
|
||||
User -> Coordinator -> Specialist -> User -> Coordinator -> Specialist -> ...
|
||||
|
||||
**With return_to_previous (enabled):**
|
||||
User -> Coordinator -> Specialist -> User -> Specialist -> ...
|
||||
|
||||
Args:
|
||||
enabled: Whether to enable return-to-previous routing. Default is True.
|
||||
|
||||
Returns:
|
||||
Self for method chaining.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
workflow = (
|
||||
HandoffBuilder(participants=[triage, technical_support, billing])
|
||||
.set_coordinator("triage")
|
||||
.add_handoff(triage, [technical_support, billing])
|
||||
.enable_return_to_previous() # Enable direct return routing
|
||||
.build()
|
||||
)
|
||||
|
||||
# Flow: User asks question
|
||||
# -> Triage routes to Technical Support
|
||||
# -> Technical Support asks clarifying question
|
||||
# -> User provides more info
|
||||
# -> Routes back to Technical Support (not Triage)
|
||||
# -> Technical Support continues helping
|
||||
|
||||
Multi-tier handoff example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
workflow = (
|
||||
HandoffBuilder(participants=[triage, specialist_a, specialist_b])
|
||||
.set_coordinator("triage")
|
||||
.add_handoff(triage, [specialist_a, specialist_b])
|
||||
.add_handoff(specialist_a, specialist_b)
|
||||
.enable_return_to_previous()
|
||||
.build()
|
||||
)
|
||||
|
||||
# Flow: User asks question
|
||||
# -> Triage routes to Specialist A
|
||||
# -> Specialist A hands off to Specialist B
|
||||
# -> Specialist B asks clarifying question
|
||||
# -> User provides more info
|
||||
# -> Routes back to Specialist B (who is currently handling the conversation)
|
||||
|
||||
Note:
|
||||
This feature routes to whichever agent most recently responded, whether that's
|
||||
the coordinator or a specialist. The conversation continues with that agent until
|
||||
they either hand off to another agent or the termination condition is met.
|
||||
"""
|
||||
self._return_to_previous = enabled
|
||||
return self
|
||||
|
||||
def build(self) -> Workflow:
|
||||
"""Construct the final Workflow instance from the configured builder.
|
||||
|
||||
@@ -1326,6 +1453,7 @@ class HandoffBuilder:
|
||||
termination_condition=self._termination_condition,
|
||||
id="handoff-coordinator",
|
||||
handoff_tool_targets=handoff_tool_targets,
|
||||
return_to_previous=self._return_to_previous,
|
||||
)
|
||||
|
||||
wiring = _GroupChatConfig(
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import importlib
|
||||
from typing import Any
|
||||
|
||||
PACKAGE_NAME = "agent_framework_ag_ui"
|
||||
PACKAGE_EXTRA = "ag-ui"
|
||||
_IMPORTS = [
|
||||
"__version__",
|
||||
"AgentFrameworkAgent",
|
||||
"add_agent_framework_fastapi_endpoint",
|
||||
"AGUIChatClient",
|
||||
"AGUIEventConverter",
|
||||
"AGUIHttpService",
|
||||
"ConfirmationStrategy",
|
||||
"DefaultConfirmationStrategy",
|
||||
"TaskPlannerConfirmationStrategy",
|
||||
"RecipeConfirmationStrategy",
|
||||
"DocumentWriterConfirmationStrategy",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name in _IMPORTS:
|
||||
try:
|
||||
return getattr(importlib.import_module(PACKAGE_NAME), name)
|
||||
except ModuleNotFoundError as exc:
|
||||
raise ModuleNotFoundError(
|
||||
f"The '{PACKAGE_EXTRA}' extra is not installed, please do `pip install agent-framework-{PACKAGE_EXTRA}`"
|
||||
) from exc
|
||||
raise AttributeError(f"Module {PACKAGE_NAME} has no attribute {name}.")
|
||||
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return _IMPORTS
|
||||
@@ -0,0 +1,29 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from agent_framework_ag_ui import (
|
||||
AgentFrameworkAgent,
|
||||
AGUIChatClient,
|
||||
AGUIEventConverter,
|
||||
AGUIHttpService,
|
||||
ConfirmationStrategy,
|
||||
DefaultConfirmationStrategy,
|
||||
DocumentWriterConfirmationStrategy,
|
||||
RecipeConfirmationStrategy,
|
||||
TaskPlannerConfirmationStrategy,
|
||||
__version__,
|
||||
add_agent_framework_fastapi_endpoint,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AGUIChatClient",
|
||||
"AGUIEventConverter",
|
||||
"AGUIHttpService",
|
||||
"AgentFrameworkAgent",
|
||||
"ConfirmationStrategy",
|
||||
"DefaultConfirmationStrategy",
|
||||
"DocumentWriterConfirmationStrategy",
|
||||
"RecipeConfirmationStrategy",
|
||||
"TaskPlannerConfirmationStrategy",
|
||||
"__version__",
|
||||
"add_agent_framework_fastapi_endpoint",
|
||||
]
|
||||
@@ -846,6 +846,7 @@ def _trace_get_response(
|
||||
kwargs.get("model_id")
|
||||
or (chat_options.model_id if (chat_options := kwargs.get("chat_options")) else None)
|
||||
or getattr(self, "model_id", None)
|
||||
or "unknown"
|
||||
)
|
||||
service_url = str(
|
||||
service_url_func()
|
||||
@@ -933,6 +934,7 @@ def _trace_get_streaming_response(
|
||||
kwargs.get("model_id")
|
||||
or (chat_options.model_id if (chat_options := kwargs.get("chat_options")) else None)
|
||||
or getattr(self, "model_id", None)
|
||||
or "unknown"
|
||||
)
|
||||
service_url = str(
|
||||
service_url_func()
|
||||
@@ -1324,7 +1326,10 @@ def _get_span(
|
||||
attributes: dict[str, Any],
|
||||
span_name_attribute: str,
|
||||
) -> Generator["trace.Span", Any, Any]:
|
||||
"""Start a span for a agent run."""
|
||||
"""Start a span for a agent run.
|
||||
|
||||
Note: `attributes` must contain the `span_name_attribute` key.
|
||||
"""
|
||||
span = get_tracer().start_span(f"{attributes[OtelAttr.OPERATION]} {attributes[span_name_attribute]}")
|
||||
span.set_attributes(attributes)
|
||||
with trace.use_span(
|
||||
@@ -1353,7 +1358,8 @@ def _get_span_attributes(**kwargs: Any) -> dict[str, Any]:
|
||||
attributes[SpanAttributes.LLM_SYSTEM] = system_name
|
||||
if provider_name := kwargs.get("provider_name"):
|
||||
attributes[OtelAttr.PROVIDER_NAME] = provider_name
|
||||
attributes[SpanAttributes.LLM_REQUEST_MODEL] = kwargs.get("model", "unknown")
|
||||
if model_id := kwargs.get("model", chat_options.model_id):
|
||||
attributes[SpanAttributes.LLM_REQUEST_MODEL] = model_id
|
||||
if service_url := kwargs.get("service_url"):
|
||||
attributes[OtelAttr.ADDRESS] = service_url
|
||||
if conversation_id := kwargs.get("conversation_id", chat_options.conversation_id):
|
||||
|
||||
@@ -276,6 +276,14 @@ class OpenAIBaseResponsesClient(OpenAIBase, BaseChatClient):
|
||||
# Map the parameter name and remove the old one
|
||||
mapped_tool[api_param] = mapped_tool.pop(user_param)
|
||||
|
||||
# Validate partial_images parameter for streaming image generation
|
||||
# OpenAI API requires partial_images to be between 0-3 (inclusive) for image_generation tool
|
||||
# Reference: https://platform.openai.com/docs/api-reference/responses/create#responses_create-tools-image_generation_tool-partial_images
|
||||
if "partial_images" in mapped_tool:
|
||||
partial_images = mapped_tool["partial_images"]
|
||||
if not isinstance(partial_images, int) or partial_images < 0 or partial_images > 3:
|
||||
raise ValueError("partial_images must be an integer between 0 and 3 (inclusive).")
|
||||
|
||||
response_tools.append(mapped_tool)
|
||||
else:
|
||||
response_tools.append(tool_dict)
|
||||
@@ -707,29 +715,8 @@ class OpenAIBaseResponsesClient(OpenAIBase, BaseChatClient):
|
||||
uri = item.result
|
||||
media_type = None
|
||||
if not uri.startswith("data:"):
|
||||
# Raw base64 string - convert to proper data URI format
|
||||
# Detect format from base64 data
|
||||
import base64
|
||||
|
||||
try:
|
||||
# Decode a small portion to detect format
|
||||
decoded_data = base64.b64decode(uri[:100]) # First ~75 bytes should be enough
|
||||
if decoded_data.startswith(b"\x89PNG"):
|
||||
format_type = "png"
|
||||
elif decoded_data.startswith(b"\xff\xd8\xff"):
|
||||
format_type = "jpeg"
|
||||
elif decoded_data.startswith(b"RIFF") and b"WEBP" in decoded_data[:12]:
|
||||
format_type = "webp"
|
||||
elif decoded_data.startswith(b"GIF87a") or decoded_data.startswith(b"GIF89a"):
|
||||
format_type = "gif"
|
||||
else:
|
||||
# Default to png if format cannot be detected
|
||||
format_type = "png"
|
||||
except Exception:
|
||||
# Fallback to png if decoding fails
|
||||
format_type = "png"
|
||||
uri = f"data:image/{format_type};base64,{uri}"
|
||||
media_type = f"image/{format_type}"
|
||||
# Raw base64 string - convert to proper data URI format using helper
|
||||
uri, media_type = DataContent.create_data_uri_from_base64(uri)
|
||||
else:
|
||||
# Parse media type from existing data URI
|
||||
try:
|
||||
@@ -945,6 +932,25 @@ class OpenAIBaseResponsesClient(OpenAIBase, BaseChatClient):
|
||||
raw_representation=event,
|
||||
)
|
||||
)
|
||||
case "response.image_generation_call.partial_image":
|
||||
# Handle streaming partial image generation
|
||||
image_base64 = event.partial_image_b64
|
||||
partial_index = event.partial_image_index
|
||||
|
||||
# Use helper function to create data URI from base64
|
||||
uri, media_type = DataContent.create_data_uri_from_base64(image_base64)
|
||||
|
||||
contents.append(
|
||||
DataContent(
|
||||
uri=uri,
|
||||
media_type=media_type,
|
||||
additional_properties={
|
||||
"partial_image_index": partial_index,
|
||||
"is_partial_image": True,
|
||||
},
|
||||
raw_representation=event,
|
||||
)
|
||||
)
|
||||
case _:
|
||||
logger.debug("Unparsed event of type: %s: %s", event.type, event)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ description = "Microsoft Agent Framework for building AI Agents with Python. Thi
|
||||
authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}]
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
version = "1.0.0b251105"
|
||||
version = "1.0.0b251111"
|
||||
license-files = ["LICENSE"]
|
||||
urls.homepage = "https://aka.ms/agent-framework"
|
||||
urls.source = "https://github.com/microsoft/agent-framework/tree/main/python"
|
||||
@@ -42,13 +42,14 @@ dependencies = [
|
||||
[project.optional-dependencies]
|
||||
all = [
|
||||
"agent-framework-a2a",
|
||||
"agent-framework-ag-ui",
|
||||
"agent-framework-anthropic",
|
||||
"agent-framework-azure-ai",
|
||||
"agent-framework-copilotstudio",
|
||||
"agent-framework-mem0",
|
||||
"agent-framework-redis",
|
||||
"agent-framework-devui",
|
||||
"agent-framework-mem0",
|
||||
"agent-framework-purview",
|
||||
"agent-framework-anthropic",
|
||||
"agent-framework-redis",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
|
||||
@@ -279,6 +279,45 @@ async def test_chat_client_streaming_observability(
|
||||
assert span.attributes[OtelAttr.OUTPUT_MESSAGES] is not None
|
||||
|
||||
|
||||
async def test_chat_client_without_model_id_observability(mock_chat_client, span_exporter: InMemorySpanExporter):
|
||||
"""Test telemetry shouldn't fail when the model_id is not provided for unknown reason."""
|
||||
client = use_observability(mock_chat_client)()
|
||||
messages = [ChatMessage(role=Role.USER, text="Test")]
|
||||
span_exporter.clear()
|
||||
response = await client.get_response(messages=messages)
|
||||
|
||||
assert response is not None
|
||||
spans = span_exporter.get_finished_spans()
|
||||
assert len(spans) == 1
|
||||
span = spans[0]
|
||||
|
||||
assert span.name == "chat unknown"
|
||||
assert span.attributes[OtelAttr.OPERATION.value] == OtelAttr.CHAT_COMPLETION_OPERATION
|
||||
assert span.attributes[SpanAttributes.LLM_REQUEST_MODEL] == "unknown"
|
||||
|
||||
|
||||
async def test_chat_client_streaming_without_model_id_observability(
|
||||
mock_chat_client, span_exporter: InMemorySpanExporter
|
||||
):
|
||||
"""Test streaming telemetry shouldn't fail when the model_id is not provided for unknown reason."""
|
||||
client = use_observability(mock_chat_client)()
|
||||
messages = [ChatMessage(role=Role.USER, text="Test")]
|
||||
span_exporter.clear()
|
||||
# Collect all yielded updates
|
||||
updates = []
|
||||
async for update in client.get_streaming_response(messages=messages):
|
||||
updates.append(update)
|
||||
|
||||
# Verify we got the expected updates, this shouldn't be dependent on otel
|
||||
assert len(updates) == 2
|
||||
spans = span_exporter.get_finished_spans()
|
||||
assert len(spans) == 1
|
||||
span = spans[0]
|
||||
assert span.name == "chat unknown"
|
||||
assert span.attributes[OtelAttr.OPERATION.value] == OtelAttr.CHAT_COMPLETION_OPERATION
|
||||
assert span.attributes[SpanAttributes.LLM_REQUEST_MODEL] == "unknown"
|
||||
|
||||
|
||||
def test_prepend_user_agent_with_none_value():
|
||||
"""Test prepend user agent with None value in headers."""
|
||||
headers = {"User-Agent": None}
|
||||
@@ -368,6 +407,7 @@ def mock_chat_agent():
|
||||
self.name = "test_agent"
|
||||
self.display_name = "Test Agent"
|
||||
self.description = "Test agent description"
|
||||
self.chat_options = ChatOptions(model_id="TestModel")
|
||||
|
||||
async def run(self, messages=None, *, thread=None, **kwargs):
|
||||
return AgentRunResponse(
|
||||
@@ -405,7 +445,7 @@ async def test_agent_instrumentation_enabled(
|
||||
assert span.attributes[OtelAttr.AGENT_ID] == "test_agent_id"
|
||||
assert span.attributes[OtelAttr.AGENT_NAME] == "Test Agent"
|
||||
assert span.attributes[OtelAttr.AGENT_DESCRIPTION] == "Test agent description"
|
||||
assert span.attributes[SpanAttributes.LLM_REQUEST_MODEL] == "unknown"
|
||||
assert span.attributes[SpanAttributes.LLM_REQUEST_MODEL] == "TestModel"
|
||||
assert span.attributes[OtelAttr.INPUT_TOKENS] == 15
|
||||
assert span.attributes[OtelAttr.OUTPUT_TOKENS] == 25
|
||||
if enable_sensitive_data:
|
||||
@@ -433,7 +473,7 @@ async def test_agent_streaming_response_with_diagnostics_enabled_via_decorator(
|
||||
assert span.attributes[OtelAttr.AGENT_ID] == "test_agent_id"
|
||||
assert span.attributes[OtelAttr.AGENT_NAME] == "Test Agent"
|
||||
assert span.attributes[OtelAttr.AGENT_DESCRIPTION] == "Test agent description"
|
||||
assert span.attributes[SpanAttributes.LLM_REQUEST_MODEL] == "unknown"
|
||||
assert span.attributes[SpanAttributes.LLM_REQUEST_MODEL] == "TestModel"
|
||||
if enable_sensitive_data:
|
||||
assert span.attributes.get(OtelAttr.OUTPUT_MESSAGES) is not None # Streaming, so no usage yet
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import base64
|
||||
from collections.abc import AsyncIterable
|
||||
from typing import Any
|
||||
|
||||
@@ -166,6 +167,57 @@ def test_data_content_empty():
|
||||
DataContent(uri="")
|
||||
|
||||
|
||||
def test_data_content_detect_image_format_from_base64():
|
||||
"""Test the detect_image_format_from_base64 static method."""
|
||||
# Test each supported format
|
||||
png_data = b"\x89PNG\r\n\x1a\n" + b"fake_data"
|
||||
assert DataContent.detect_image_format_from_base64(base64.b64encode(png_data).decode()) == "png"
|
||||
|
||||
jpeg_data = b"\xff\xd8\xff\xe0" + b"fake_data"
|
||||
assert DataContent.detect_image_format_from_base64(base64.b64encode(jpeg_data).decode()) == "jpeg"
|
||||
|
||||
webp_data = b"RIFF" + b"1234" + b"WEBP" + b"fake_data"
|
||||
assert DataContent.detect_image_format_from_base64(base64.b64encode(webp_data).decode()) == "webp"
|
||||
|
||||
gif_data = b"GIF89a" + b"fake_data"
|
||||
assert DataContent.detect_image_format_from_base64(base64.b64encode(gif_data).decode()) == "gif"
|
||||
|
||||
# Test fallback behavior
|
||||
unknown_data = b"UNKNOWN_FORMAT"
|
||||
assert DataContent.detect_image_format_from_base64(base64.b64encode(unknown_data).decode()) == "png"
|
||||
|
||||
# Test error handling
|
||||
assert DataContent.detect_image_format_from_base64("invalid_base64!") == "png"
|
||||
assert DataContent.detect_image_format_from_base64("") == "png"
|
||||
|
||||
|
||||
def test_data_content_create_data_uri_from_base64():
|
||||
"""Test the create_data_uri_from_base64 class method."""
|
||||
# Test with PNG data
|
||||
png_data = b"\x89PNG\r\n\x1a\n" + b"fake_data"
|
||||
png_base64 = base64.b64encode(png_data).decode()
|
||||
uri, media_type = DataContent.create_data_uri_from_base64(png_base64)
|
||||
|
||||
assert uri == f"data:image/png;base64,{png_base64}"
|
||||
assert media_type == "image/png"
|
||||
|
||||
# Test with different format
|
||||
jpeg_data = b"\xff\xd8\xff\xe0" + b"fake_data"
|
||||
jpeg_base64 = base64.b64encode(jpeg_data).decode()
|
||||
uri, media_type = DataContent.create_data_uri_from_base64(jpeg_base64)
|
||||
|
||||
assert uri == f"data:image/jpeg;base64,{jpeg_base64}"
|
||||
assert media_type == "image/jpeg"
|
||||
|
||||
# Test fallback for unknown format
|
||||
unknown_data = b"UNKNOWN_FORMAT"
|
||||
unknown_base64 = base64.b64encode(unknown_data).decode()
|
||||
uri, media_type = DataContent.create_data_uri_from_base64(unknown_base64)
|
||||
|
||||
assert uri == f"data:image/png;base64,{unknown_base64}"
|
||||
assert media_type == "image/png"
|
||||
|
||||
|
||||
# region UriContent
|
||||
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -111,6 +111,10 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None:
|
||||
chat_store_state = thread_state["chat_message_store_state"] # type: ignore[index]
|
||||
assert "messages" in chat_store_state, "Message store state should include messages"
|
||||
|
||||
# Verify checkpoint contains pending requests from agents and responses to be sent
|
||||
assert "pending_agent_requests" in executor_state
|
||||
assert "pending_responses_to_agent" in executor_state
|
||||
|
||||
# Create a new agent and executor for restoration
|
||||
# This simulates starting from a fresh state and restoring from checkpoint
|
||||
restored_agent = _CountingAgent(id="test_agent", name="TestAgent")
|
||||
|
||||
@@ -5,19 +5,32 @@
|
||||
from collections.abc import AsyncIterable
|
||||
from typing import Any
|
||||
|
||||
from typing_extensions import Never
|
||||
|
||||
from agent_framework import (
|
||||
AgentExecutor,
|
||||
AgentExecutorResponse,
|
||||
AgentRunResponse,
|
||||
AgentRunResponseUpdate,
|
||||
AgentRunUpdateEvent,
|
||||
AgentThread,
|
||||
BaseAgent,
|
||||
ChatAgent,
|
||||
ChatMessage,
|
||||
ChatResponse,
|
||||
ChatResponseUpdate,
|
||||
FunctionApprovalRequestContent,
|
||||
FunctionCallContent,
|
||||
FunctionResultContent,
|
||||
RequestInfoEvent,
|
||||
Role,
|
||||
TextContent,
|
||||
WorkflowBuilder,
|
||||
WorkflowContext,
|
||||
WorkflowOutputEvent,
|
||||
ai_function,
|
||||
executor,
|
||||
use_function_invocation,
|
||||
)
|
||||
|
||||
|
||||
@@ -120,3 +133,235 @@ async def test_agent_executor_emits_tool_calls_in_streaming_mode() -> None:
|
||||
assert events[3].data is not None
|
||||
assert isinstance(events[3].data.contents[0], TextContent)
|
||||
assert "sunny" in events[3].data.contents[0].text
|
||||
|
||||
|
||||
@ai_function(approval_mode="always_require")
|
||||
def mock_tool_requiring_approval(query: str) -> str:
|
||||
"""Mock tool that requires approval before execution."""
|
||||
return f"Executed tool with query: {query}"
|
||||
|
||||
|
||||
@use_function_invocation
|
||||
class MockChatClient:
|
||||
"""Simple implementation of a chat client."""
|
||||
|
||||
def __init__(self, parallel_request: bool = False) -> None:
|
||||
self.additional_properties: dict[str, Any] = {}
|
||||
self._iteration: int = 0
|
||||
self._parallel_request: bool = parallel_request
|
||||
|
||||
async def get_response(
|
||||
self,
|
||||
messages: str | ChatMessage | list[str] | list[ChatMessage],
|
||||
**kwargs: Any,
|
||||
) -> ChatResponse:
|
||||
if self._iteration == 0:
|
||||
if self._parallel_request:
|
||||
response = ChatResponse(
|
||||
messages=ChatMessage(
|
||||
role="assistant",
|
||||
contents=[
|
||||
FunctionCallContent(
|
||||
call_id="1", name="mock_tool_requiring_approval", arguments='{"query": "test"}'
|
||||
),
|
||||
FunctionCallContent(
|
||||
call_id="2", name="mock_tool_requiring_approval", arguments='{"query": "test"}'
|
||||
),
|
||||
],
|
||||
)
|
||||
)
|
||||
else:
|
||||
response = ChatResponse(
|
||||
messages=ChatMessage(
|
||||
role="assistant",
|
||||
contents=[
|
||||
FunctionCallContent(
|
||||
call_id="1", name="mock_tool_requiring_approval", arguments='{"query": "test"}'
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
else:
|
||||
response = ChatResponse(messages=ChatMessage(role="assistant", text="Tool executed successfully."))
|
||||
|
||||
self._iteration += 1
|
||||
return response
|
||||
|
||||
async def get_streaming_response(
|
||||
self,
|
||||
messages: str | ChatMessage | list[str] | list[ChatMessage],
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[ChatResponseUpdate]:
|
||||
if self._iteration == 0:
|
||||
if self._parallel_request:
|
||||
yield ChatResponseUpdate(
|
||||
contents=[
|
||||
FunctionCallContent(
|
||||
call_id="1", name="mock_tool_requiring_approval", arguments='{"query": "test"}'
|
||||
),
|
||||
FunctionCallContent(
|
||||
call_id="2", name="mock_tool_requiring_approval", arguments='{"query": "test"}'
|
||||
),
|
||||
],
|
||||
role="assistant",
|
||||
)
|
||||
else:
|
||||
yield ChatResponseUpdate(
|
||||
contents=[
|
||||
FunctionCallContent(
|
||||
call_id="1", name="mock_tool_requiring_approval", arguments='{"query": "test"}'
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
)
|
||||
else:
|
||||
yield ChatResponseUpdate(text=TextContent(text="Tool executed "), role="assistant")
|
||||
yield ChatResponseUpdate(contents=[TextContent(text="successfully.")], role="assistant")
|
||||
|
||||
self._iteration += 1
|
||||
|
||||
|
||||
@executor(id="test_executor")
|
||||
async def test_executor(agent_executor_response: AgentExecutorResponse, ctx: WorkflowContext[Never, str]) -> None:
|
||||
await ctx.yield_output(agent_executor_response.agent_run_response.text)
|
||||
|
||||
|
||||
async def test_agent_executor_tool_call_with_approval() -> None:
|
||||
"""Test that AgentExecutor handles tool calls requiring approval."""
|
||||
# Arrange
|
||||
agent = ChatAgent(
|
||||
chat_client=MockChatClient(),
|
||||
name="ApprovalAgent",
|
||||
tools=[mock_tool_requiring_approval],
|
||||
)
|
||||
|
||||
workflow = WorkflowBuilder().set_start_executor(agent).add_edge(agent, test_executor).build()
|
||||
|
||||
# Act
|
||||
events = await workflow.run("Invoke tool requiring approval")
|
||||
|
||||
# Assert
|
||||
assert len(events.get_request_info_events()) == 1
|
||||
approval_request = events.get_request_info_events()[0]
|
||||
assert isinstance(approval_request.data, FunctionApprovalRequestContent)
|
||||
assert approval_request.data.function_call.name == "mock_tool_requiring_approval"
|
||||
assert approval_request.data.function_call.arguments == '{"query": "test"}'
|
||||
|
||||
# Act
|
||||
events = await workflow.send_responses({approval_request.request_id: approval_request.data.create_response(True)})
|
||||
|
||||
# Assert
|
||||
final_response = events.get_outputs()
|
||||
assert len(final_response) == 1
|
||||
assert final_response[0] == "Tool executed successfully."
|
||||
|
||||
|
||||
async def test_agent_executor_tool_call_with_approval_streaming() -> None:
|
||||
"""Test that AgentExecutor handles tool calls requiring approval in streaming mode."""
|
||||
# Arrange
|
||||
agent = ChatAgent(
|
||||
chat_client=MockChatClient(),
|
||||
name="ApprovalAgent",
|
||||
tools=[mock_tool_requiring_approval],
|
||||
)
|
||||
|
||||
workflow = WorkflowBuilder().set_start_executor(agent).add_edge(agent, test_executor).build()
|
||||
|
||||
# Act
|
||||
request_info_events: list[RequestInfoEvent] = []
|
||||
async for event in workflow.run_stream("Invoke tool requiring approval"):
|
||||
if isinstance(event, RequestInfoEvent):
|
||||
request_info_events.append(event)
|
||||
|
||||
# Assert
|
||||
assert len(request_info_events) == 1
|
||||
approval_request = request_info_events[0]
|
||||
assert isinstance(approval_request.data, FunctionApprovalRequestContent)
|
||||
assert approval_request.data.function_call.name == "mock_tool_requiring_approval"
|
||||
assert approval_request.data.function_call.arguments == '{"query": "test"}'
|
||||
|
||||
# Act
|
||||
output: str | None = None
|
||||
async for event in workflow.send_responses_streaming({
|
||||
approval_request.request_id: approval_request.data.create_response(True)
|
||||
}):
|
||||
if isinstance(event, WorkflowOutputEvent):
|
||||
output = event.data
|
||||
|
||||
# Assert
|
||||
assert output is not None
|
||||
assert output == "Tool executed successfully."
|
||||
|
||||
|
||||
async def test_agent_executor_parallel_tool_call_with_approval() -> None:
|
||||
"""Test that AgentExecutor handles parallel tool calls requiring approval."""
|
||||
# Arrange
|
||||
agent = ChatAgent(
|
||||
chat_client=MockChatClient(parallel_request=True),
|
||||
name="ApprovalAgent",
|
||||
tools=[mock_tool_requiring_approval],
|
||||
)
|
||||
|
||||
workflow = WorkflowBuilder().set_start_executor(agent).add_edge(agent, test_executor).build()
|
||||
|
||||
# Act
|
||||
events = await workflow.run("Invoke tool requiring approval")
|
||||
|
||||
# Assert
|
||||
assert len(events.get_request_info_events()) == 2
|
||||
for approval_request in events.get_request_info_events():
|
||||
assert isinstance(approval_request.data, FunctionApprovalRequestContent)
|
||||
assert approval_request.data.function_call.name == "mock_tool_requiring_approval"
|
||||
assert approval_request.data.function_call.arguments == '{"query": "test"}'
|
||||
|
||||
# Act
|
||||
responses = {
|
||||
approval_request.request_id: approval_request.data.create_response(True) # type: ignore
|
||||
for approval_request in events.get_request_info_events()
|
||||
}
|
||||
events = await workflow.send_responses(responses)
|
||||
|
||||
# Assert
|
||||
final_response = events.get_outputs()
|
||||
assert len(final_response) == 1
|
||||
assert final_response[0] == "Tool executed successfully."
|
||||
|
||||
|
||||
async def test_agent_executor_parallel_tool_call_with_approval_streaming() -> None:
|
||||
"""Test that AgentExecutor handles parallel tool calls requiring approval in streaming mode."""
|
||||
# Arrange
|
||||
agent = ChatAgent(
|
||||
chat_client=MockChatClient(parallel_request=True),
|
||||
name="ApprovalAgent",
|
||||
tools=[mock_tool_requiring_approval],
|
||||
)
|
||||
|
||||
workflow = WorkflowBuilder().set_start_executor(agent).add_edge(agent, test_executor).build()
|
||||
|
||||
# Act
|
||||
request_info_events: list[RequestInfoEvent] = []
|
||||
async for event in workflow.run_stream("Invoke tool requiring approval"):
|
||||
if isinstance(event, RequestInfoEvent):
|
||||
request_info_events.append(event)
|
||||
|
||||
# Assert
|
||||
assert len(request_info_events) == 2
|
||||
for approval_request in request_info_events:
|
||||
assert isinstance(approval_request.data, FunctionApprovalRequestContent)
|
||||
assert approval_request.data.function_call.name == "mock_tool_requiring_approval"
|
||||
assert approval_request.data.function_call.arguments == '{"query": "test"}'
|
||||
|
||||
# Act
|
||||
responses = {
|
||||
approval_request.request_id: approval_request.data.create_response(True) # type: ignore
|
||||
for approval_request in request_info_events
|
||||
}
|
||||
|
||||
output: str | None = None
|
||||
async for event in workflow.send_responses_streaming(responses):
|
||||
if isinstance(event, WorkflowOutputEvent):
|
||||
output = event.data
|
||||
|
||||
# Assert
|
||||
assert output is not None
|
||||
assert output == "Tool executed successfully."
|
||||
|
||||
@@ -23,7 +23,7 @@ from agent_framework import (
|
||||
WorkflowOutputEvent,
|
||||
)
|
||||
from agent_framework._mcp import MCPTool
|
||||
from agent_framework._workflows._handoff import _clone_chat_agent
|
||||
from agent_framework._workflows._handoff import _clone_chat_agent # type: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -392,12 +392,218 @@ async def test_clone_chat_agent_preserves_mcp_tools() -> None:
|
||||
)
|
||||
|
||||
assert hasattr(original_agent, "_local_mcp_tools")
|
||||
assert len(original_agent._local_mcp_tools) == 1
|
||||
assert original_agent._local_mcp_tools[0] == mock_mcp_tool
|
||||
assert len(original_agent._local_mcp_tools) == 1 # type: ignore[reportPrivateUsage]
|
||||
assert original_agent._local_mcp_tools[0] == mock_mcp_tool # type: ignore[reportPrivateUsage]
|
||||
|
||||
cloned_agent = _clone_chat_agent(original_agent)
|
||||
|
||||
assert hasattr(cloned_agent, "_local_mcp_tools")
|
||||
assert len(cloned_agent._local_mcp_tools) == 1
|
||||
assert cloned_agent._local_mcp_tools[0] == mock_mcp_tool
|
||||
assert len(cloned_agent._local_mcp_tools) == 1 # type: ignore[reportPrivateUsage]
|
||||
assert cloned_agent._local_mcp_tools[0] == mock_mcp_tool # type: ignore[reportPrivateUsage]
|
||||
assert cloned_agent.chat_options.tools is not None
|
||||
assert len(cloned_agent.chat_options.tools) == 1
|
||||
|
||||
|
||||
async def test_return_to_previous_routing():
|
||||
"""Test that return-to-previous routes back to the current specialist handling the conversation."""
|
||||
triage = _RecordingAgent(name="triage", handoff_to="specialist_a")
|
||||
specialist_a = _RecordingAgent(name="specialist_a", handoff_to="specialist_b")
|
||||
specialist_b = _RecordingAgent(name="specialist_b")
|
||||
|
||||
workflow = (
|
||||
HandoffBuilder(participants=[triage, specialist_a, specialist_b])
|
||||
.set_coordinator(triage)
|
||||
.add_handoff(triage, [specialist_a, specialist_b])
|
||||
.add_handoff(specialist_a, specialist_b)
|
||||
.enable_return_to_previous(True)
|
||||
.with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 4)
|
||||
.build()
|
||||
)
|
||||
|
||||
# Start conversation - triage hands off to specialist_a
|
||||
events = await _drain(workflow.run_stream("Initial request"))
|
||||
requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)]
|
||||
assert requests
|
||||
assert len(specialist_a.calls) > 0
|
||||
|
||||
# Specialist_a should have been called with initial request
|
||||
initial_specialist_a_calls = len(specialist_a.calls)
|
||||
|
||||
# Second user message - specialist_a hands off to specialist_b
|
||||
events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Need more help"}))
|
||||
requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)]
|
||||
assert requests
|
||||
|
||||
# Specialist_b should have been called
|
||||
assert len(specialist_b.calls) > 0
|
||||
initial_specialist_b_calls = len(specialist_b.calls)
|
||||
|
||||
# Third user message - with return_to_previous, should route back to specialist_b (current agent)
|
||||
events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Follow up question"}))
|
||||
third_requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)]
|
||||
|
||||
# Specialist_b should have been called again (return-to-previous routes to current agent)
|
||||
assert len(specialist_b.calls) > initial_specialist_b_calls, (
|
||||
"Specialist B should be called again due to return-to-previous routing to current agent"
|
||||
)
|
||||
|
||||
# Specialist_a should NOT be called again (it's no longer the current agent)
|
||||
assert len(specialist_a.calls) == initial_specialist_a_calls, (
|
||||
"Specialist A should not be called again - specialist_b is the current agent"
|
||||
)
|
||||
|
||||
# Triage should only have been called once at the start
|
||||
assert len(triage.calls) == 1, "Triage should only be called once (initial routing)"
|
||||
|
||||
# Verify awaiting_agent_id is set to specialist_b (the agent that just responded)
|
||||
if third_requests:
|
||||
user_input_req = third_requests[-1].data
|
||||
assert isinstance(user_input_req, HandoffUserInputRequest)
|
||||
assert user_input_req.awaiting_agent_id == "specialist_b", (
|
||||
f"Expected awaiting_agent_id 'specialist_b' but got '{user_input_req.awaiting_agent_id}'"
|
||||
)
|
||||
|
||||
|
||||
async def test_return_to_previous_disabled_routes_to_coordinator():
|
||||
"""Test that with return-to-previous disabled, routing goes back to coordinator."""
|
||||
triage = _RecordingAgent(name="triage", handoff_to="specialist_a")
|
||||
specialist_a = _RecordingAgent(name="specialist_a", handoff_to="specialist_b")
|
||||
specialist_b = _RecordingAgent(name="specialist_b")
|
||||
|
||||
workflow = (
|
||||
HandoffBuilder(participants=[triage, specialist_a, specialist_b])
|
||||
.set_coordinator(triage)
|
||||
.add_handoff(triage, [specialist_a, specialist_b])
|
||||
.add_handoff(specialist_a, specialist_b)
|
||||
.enable_return_to_previous(False)
|
||||
.with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 3)
|
||||
.build()
|
||||
)
|
||||
|
||||
# Start conversation - triage hands off to specialist_a
|
||||
events = await _drain(workflow.run_stream("Initial request"))
|
||||
requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)]
|
||||
assert requests
|
||||
assert len(triage.calls) == 1
|
||||
|
||||
# Second user message - specialist_a hands off to specialist_b
|
||||
events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Need more help"}))
|
||||
requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)]
|
||||
assert requests
|
||||
|
||||
# Third user message - without return_to_previous, should route back to triage
|
||||
await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Follow up question"}))
|
||||
|
||||
# Triage should have been called twice total: initial + after specialist_b responds
|
||||
assert len(triage.calls) == 2, "Triage should be called twice (initial + default routing to coordinator)"
|
||||
|
||||
|
||||
async def test_return_to_previous_enabled():
|
||||
"""Verify that enable_return_to_previous() keeps control with the current specialist."""
|
||||
triage = _RecordingAgent(name="triage", handoff_to="specialist_a")
|
||||
specialist_a = _RecordingAgent(name="specialist_a")
|
||||
specialist_b = _RecordingAgent(name="specialist_b")
|
||||
|
||||
workflow = (
|
||||
HandoffBuilder(participants=[triage, specialist_a, specialist_b])
|
||||
.set_coordinator("triage")
|
||||
.enable_return_to_previous(True)
|
||||
.with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 3)
|
||||
.build()
|
||||
)
|
||||
|
||||
# Start conversation - triage hands off to specialist_a
|
||||
events = await _drain(workflow.run_stream("Initial request"))
|
||||
requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)]
|
||||
assert requests
|
||||
assert len(triage.calls) == 1
|
||||
assert len(specialist_a.calls) == 1
|
||||
|
||||
# Second user message - with return_to_previous, should route to specialist_a (not triage)
|
||||
events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Follow up question"}))
|
||||
requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)]
|
||||
assert requests
|
||||
|
||||
# Triage should only have been called once (initial) - specialist_a handles follow-up
|
||||
assert len(triage.calls) == 1, "Triage should only be called once (initial)"
|
||||
assert len(specialist_a.calls) == 2, "Specialist A should handle follow-up with return_to_previous enabled"
|
||||
|
||||
|
||||
async def test_tool_choice_preserved_from_agent_config():
|
||||
"""Verify that agent-level tool_choice configuration is preserved and not overridden."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from agent_framework import ChatResponse, ToolMode
|
||||
|
||||
# Create a mock chat client that records the tool_choice used
|
||||
recorded_tool_choices: list[Any] = []
|
||||
|
||||
async def mock_get_response(messages: Any, **kwargs: Any) -> ChatResponse:
|
||||
chat_options = kwargs.get("chat_options")
|
||||
if chat_options:
|
||||
recorded_tool_choices.append(chat_options.tool_choice)
|
||||
return ChatResponse(
|
||||
messages=[ChatMessage(role=Role.ASSISTANT, text="Response")],
|
||||
response_id="test_response",
|
||||
)
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_response = AsyncMock(side_effect=mock_get_response)
|
||||
|
||||
# Create agent with specific tool_choice configuration
|
||||
agent = ChatAgent(
|
||||
chat_client=mock_client,
|
||||
name="test_agent",
|
||||
tool_choice=ToolMode(mode="required"), # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# Run the agent
|
||||
await agent.run("Test message")
|
||||
|
||||
# Verify tool_choice was preserved
|
||||
assert len(recorded_tool_choices) > 0, "No tool_choice recorded"
|
||||
last_tool_choice = recorded_tool_choices[-1]
|
||||
assert last_tool_choice is not None, "tool_choice should not be None"
|
||||
assert str(last_tool_choice) == "required", f"Expected 'required', got {last_tool_choice}"
|
||||
|
||||
|
||||
async def test_return_to_previous_state_serialization():
|
||||
"""Test that return_to_previous state is properly serialized/deserialized for checkpointing."""
|
||||
from agent_framework._workflows._handoff import _HandoffCoordinator # type: ignore[reportPrivateUsage]
|
||||
|
||||
# Create a coordinator with return_to_previous enabled
|
||||
coordinator = _HandoffCoordinator(
|
||||
starting_agent_id="triage",
|
||||
specialist_ids={"specialist_a": "specialist_a", "specialist_b": "specialist_b"},
|
||||
input_gateway_id="gateway",
|
||||
termination_condition=lambda conv: False,
|
||||
id="test-coordinator",
|
||||
return_to_previous=True,
|
||||
)
|
||||
|
||||
# Set the current agent (simulating a handoff scenario)
|
||||
coordinator._current_agent_id = "specialist_a" # type: ignore[reportPrivateUsage]
|
||||
|
||||
# Snapshot the state
|
||||
state = coordinator.snapshot_state()
|
||||
|
||||
# Verify pattern metadata includes current_agent_id
|
||||
assert "metadata" in state
|
||||
assert "current_agent_id" in state["metadata"]
|
||||
assert state["metadata"]["current_agent_id"] == "specialist_a"
|
||||
|
||||
# Create a new coordinator and restore state
|
||||
coordinator2 = _HandoffCoordinator(
|
||||
starting_agent_id="triage",
|
||||
specialist_ids={"specialist_a": "specialist_a", "specialist_b": "specialist_b"},
|
||||
input_gateway_id="gateway",
|
||||
termination_condition=lambda conv: False,
|
||||
id="test-coordinator",
|
||||
return_to_previous=True,
|
||||
)
|
||||
|
||||
# Restore state
|
||||
coordinator2.restore_state(state)
|
||||
|
||||
# Verify current_agent_id was restored
|
||||
assert coordinator2._current_agent_id == "specialist_a", "Current agent should be restored from checkpoint" # type: ignore[reportPrivateUsage]
|
||||
|
||||
Reference in New Issue
Block a user