mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Merge branch 'main' into feature-foundry-agents
This commit is contained in:
@@ -165,10 +165,10 @@
|
||||
<Project Path="samples/GettingStarted/Workflows/_Foundational/07_MixedWorkflowAgentsAndExecutors/07_MixedWorkflowAgentsAndExecutors.csproj" />
|
||||
<Project Path="samples/GettingStarted/Workflows/_Foundational/08_WriterCriticWorkflow/08_WriterCriticWorkflow.csproj" />
|
||||
</Folder>
|
||||
<Folder Name="/Samples/Catalog/">
|
||||
<Project Path="samples/Catalog/AgentsInWorkflows/AgentsInWorkflows.csproj" />
|
||||
<Project Path="samples/Catalog/AgentWithTextSearchRag/AgentWithTextSearchRag.csproj" />
|
||||
<Project Path="samples/Catalog/DeepResearchAgent/DeepResearchAgent.csproj" />
|
||||
<Folder Name="/Samples/HostedAgents/">
|
||||
<Project Path="samples/HostedAgents/AgentsInWorkflows/AgentsInWorkflows.csproj" />
|
||||
<Project Path="samples/HostedAgents/AgentWithTextSearchRag/AgentWithTextSearchRag.csproj" />
|
||||
<Project Path="samples/HostedAgents/DeepResearchAgent/DeepResearchAgent.csproj" />
|
||||
</Folder>
|
||||
<Folder Name="/Solution Items/">
|
||||
<File Path=".editorconfig" />
|
||||
|
||||
@@ -10,6 +10,8 @@ pip install agent-framework-ag-ui
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Server (Host an AI Agent)
|
||||
|
||||
```python
|
||||
from fastapi import FastAPI
|
||||
from agent_framework import ChatAgent
|
||||
@@ -23,6 +25,7 @@ agent = ChatAgent(
|
||||
chat_client=AzureOpenAIChatClient(
|
||||
endpoint="https://your-resource.openai.azure.com/",
|
||||
deployment_name="gpt-4o-mini",
|
||||
api_key="your-api-key",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -33,9 +36,38 @@ add_agent_framework_fastapi_endpoint(app, agent, "/")
|
||||
# Run with: uvicorn main:app --reload
|
||||
```
|
||||
|
||||
### Client (Connect to an AG-UI Server)
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from agent_framework import TextContent
|
||||
from agent_framework_ag_ui import AGUIChatClient
|
||||
|
||||
async def main():
|
||||
async with AGUIChatClient(endpoint="http://localhost:8000/") as client:
|
||||
# Stream responses
|
||||
async for update in client.get_streaming_response("Hello!"):
|
||||
for content in update.contents:
|
||||
if isinstance(content, TextContent):
|
||||
print(content.text, end="", flush=True)
|
||||
print()
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
The `AGUIChatClient` supports:
|
||||
- Streaming and non-streaming responses
|
||||
- Hybrid tool execution (client-side + server-side tools)
|
||||
- Automatic thread management for conversation continuity
|
||||
- Integration with `ChatAgent` for client-side history management
|
||||
|
||||
## Documentation
|
||||
|
||||
- **[Getting Started Tutorial](getting_started/)** - Step-by-step guide to building your first AG-UI server and client
|
||||
- **[Getting Started Tutorial](getting_started/)** - Step-by-step guide to building AG-UI servers and clients
|
||||
- Server setup with FastAPI
|
||||
- Client examples using `AGUIChatClient`
|
||||
- Hybrid tool execution (client-side + server-side)
|
||||
- Thread management and conversation continuity
|
||||
- **[Examples](agent_framework_ag_ui_examples/)** - Complete examples for AG-UI features
|
||||
|
||||
## Features
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
import importlib.metadata
|
||||
|
||||
from ._agent import AgentFrameworkAgent
|
||||
from ._client import AGUIChatClient
|
||||
from ._confirmation_strategies import (
|
||||
ConfirmationStrategy,
|
||||
DefaultConfirmationStrategy,
|
||||
@@ -13,6 +14,8 @@ from ._confirmation_strategies import (
|
||||
TaskPlannerConfirmationStrategy,
|
||||
)
|
||||
from ._endpoint import add_agent_framework_fastapi_endpoint
|
||||
from ._event_converters import AGUIEventConverter
|
||||
from ._http_service import AGUIHttpService
|
||||
|
||||
try:
|
||||
__version__ = importlib.metadata.version(__name__)
|
||||
@@ -22,6 +25,9 @@ except importlib.metadata.PackageNotFoundError:
|
||||
__all__ = [
|
||||
"AgentFrameworkAgent",
|
||||
"add_agent_framework_fastapi_endpoint",
|
||||
"AGUIChatClient",
|
||||
"AGUIEventConverter",
|
||||
"AGUIHttpService",
|
||||
"ConfirmationStrategy",
|
||||
"DefaultConfirmationStrategy",
|
||||
"TaskPlannerConfirmationStrategy",
|
||||
|
||||
@@ -0,0 +1,407 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""AG-UI Chat Client implementation."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncIterable, MutableSequence
|
||||
from functools import wraps
|
||||
from typing import Any, TypeVar, cast
|
||||
|
||||
import httpx
|
||||
from agent_framework import (
|
||||
AIFunction,
|
||||
BaseChatClient,
|
||||
ChatMessage,
|
||||
ChatOptions,
|
||||
ChatResponse,
|
||||
ChatResponseUpdate,
|
||||
DataContent,
|
||||
FunctionCallContent,
|
||||
)
|
||||
from agent_framework._middleware import use_chat_middleware
|
||||
from agent_framework._tools import use_function_invocation
|
||||
from agent_framework._types import BaseContent, Contents
|
||||
from agent_framework.observability import use_observability
|
||||
|
||||
from ._event_converters import AGUIEventConverter
|
||||
from ._http_service import AGUIHttpService
|
||||
from ._message_adapters import agent_framework_messages_to_agui
|
||||
from ._utils import convert_tools_to_agui_format
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ServerFunctionCallContent(BaseContent):
|
||||
"""Wrapper for server function calls to prevent client re-execution.
|
||||
|
||||
All function calls from the remote server are server-side executions.
|
||||
This wrapper prevents @use_function_invocation from trying to execute them again.
|
||||
"""
|
||||
|
||||
function_call_content: FunctionCallContent
|
||||
|
||||
def __init__(self, function_call_content: FunctionCallContent) -> None:
|
||||
"""Initialize with the function call content."""
|
||||
super().__init__(type="server_function_call")
|
||||
self.function_call_content = function_call_content
|
||||
|
||||
|
||||
def _unwrap_server_function_call_contents(contents: MutableSequence[Contents | dict[str, Any]]) -> None:
|
||||
"""Replace ServerFunctionCallContent instances with their underlying call content."""
|
||||
for idx, content in enumerate(contents):
|
||||
if isinstance(content, ServerFunctionCallContent):
|
||||
contents[idx] = content.function_call_content # type: ignore[assignment]
|
||||
|
||||
|
||||
TBaseChatClient = TypeVar("TBaseChatClient", bound=type[BaseChatClient])
|
||||
|
||||
|
||||
def _apply_server_function_call_unwrap(chat_client: TBaseChatClient) -> TBaseChatClient:
|
||||
"""Class decorator that unwraps server-side function calls after tool handling."""
|
||||
|
||||
original_get_streaming_response = chat_client.get_streaming_response
|
||||
|
||||
@wraps(original_get_streaming_response)
|
||||
async def streaming_wrapper(self, *args: Any, **kwargs: Any) -> AsyncIterable[ChatResponseUpdate]:
|
||||
async for update in original_get_streaming_response(self, *args, **kwargs):
|
||||
_unwrap_server_function_call_contents(cast(MutableSequence[Contents | dict[str, Any]], update.contents))
|
||||
yield update
|
||||
|
||||
chat_client.get_streaming_response = streaming_wrapper # type: ignore[assignment]
|
||||
|
||||
original_get_response = chat_client.get_response
|
||||
|
||||
@wraps(original_get_response)
|
||||
async def response_wrapper(self, *args: Any, **kwargs: Any) -> ChatResponse:
|
||||
response = await original_get_response(self, *args, **kwargs)
|
||||
if response.messages:
|
||||
for message in response.messages:
|
||||
_unwrap_server_function_call_contents(
|
||||
cast(MutableSequence[Contents | dict[str, Any]], message.contents)
|
||||
)
|
||||
return response
|
||||
|
||||
chat_client.get_response = response_wrapper # type: ignore[assignment]
|
||||
return chat_client
|
||||
|
||||
|
||||
@_apply_server_function_call_unwrap
|
||||
@use_function_invocation
|
||||
@use_observability
|
||||
@use_chat_middleware
|
||||
class AGUIChatClient(BaseChatClient):
|
||||
"""Chat client for communicating with AG-UI compliant servers.
|
||||
|
||||
This client implements the BaseChatClient interface and automatically handles:
|
||||
- Thread ID management for conversation continuity
|
||||
- State synchronization between client and server
|
||||
- Server-Sent Events (SSE) streaming
|
||||
- Event conversion to Agent Framework types
|
||||
|
||||
Important: Message History Management
|
||||
This client sends exactly the messages it receives to the server. It does NOT
|
||||
automatically maintain conversation history. The server must handle history via thread_id.
|
||||
|
||||
For stateless servers: Use ChatAgent wrapper which will send full message history on each
|
||||
request. However, even with ChatAgent, the server must echo back all context for the
|
||||
agent to maintain history across turns.
|
||||
|
||||
Important: Tool Handling (Hybrid Execution - matches .NET)
|
||||
1. Client tool metadata sent to server - LLM knows about both client and server tools
|
||||
2. Server has its own tools that execute server-side
|
||||
3. When LLM calls a client tool, @use_function_invocation executes it locally
|
||||
4. Both client and server tools work together (hybrid pattern)
|
||||
|
||||
The wrapping ChatAgent's @use_function_invocation handles client tool execution
|
||||
automatically when the server's LLM decides to call them.
|
||||
|
||||
Examples:
|
||||
Direct usage (server manages thread history):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from agent_framework.ag_ui import AGUIChatClient
|
||||
|
||||
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
||||
|
||||
# First message - thread ID auto-generated
|
||||
response = await client.get_response("Hello!")
|
||||
thread_id = response.additional_properties.get("thread_id")
|
||||
|
||||
# Second message - server retrieves history using thread_id
|
||||
response2 = await client.get_response(
|
||||
"How are you?",
|
||||
metadata={"thread_id": thread_id}
|
||||
)
|
||||
|
||||
Recommended usage with ChatAgent (client manages history):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from agent_framework import ChatAgent
|
||||
from agent_framework.ag_ui import AGUIChatClient
|
||||
|
||||
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
||||
agent = ChatAgent(name="assistant", client=client)
|
||||
thread = await agent.get_new_thread()
|
||||
|
||||
# ChatAgent automatically maintains history and sends full context
|
||||
response = await agent.run("Hello!", thread=thread)
|
||||
response2 = await agent.run("How are you?", thread=thread)
|
||||
|
||||
Streaming usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
async for update in client.get_streaming_response("Tell me a story"):
|
||||
if update.contents:
|
||||
for content in update.contents:
|
||||
if hasattr(content, "text"):
|
||||
print(content.text, end="", flush=True)
|
||||
|
||||
Context manager:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
async with AGUIChatClient(endpoint="http://localhost:8888/") as client:
|
||||
response = await client.get_response("Hello!")
|
||||
print(response.messages[0].text)
|
||||
"""
|
||||
|
||||
OTEL_PROVIDER_NAME = "agui"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
endpoint: str,
|
||||
http_client: httpx.AsyncClient | None = None,
|
||||
timeout: float = 60.0,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the AG-UI chat client.
|
||||
|
||||
Args:
|
||||
endpoint: The AG-UI server endpoint URL (e.g., "http://localhost:8888/")
|
||||
http_client: Optional httpx.AsyncClient instance. If None, one will be created.
|
||||
timeout: Request timeout in seconds (default: 60.0)
|
||||
additional_properties: Additional properties to store
|
||||
**kwargs: Additional arguments passed to BaseChatClient
|
||||
"""
|
||||
super().__init__(additional_properties=additional_properties, **kwargs)
|
||||
self._http_service = AGUIHttpService(
|
||||
endpoint=endpoint,
|
||||
http_client=http_client,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the HTTP client."""
|
||||
await self._http_service.close()
|
||||
|
||||
async def __aenter__(self) -> "AGUIChatClient":
|
||||
"""Enter async context manager."""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args: Any) -> None:
|
||||
"""Exit async context manager."""
|
||||
await self.close()
|
||||
|
||||
def _register_server_tool_placeholder(self, tool_name: str) -> None:
|
||||
"""Register a declaration-only placeholder so function invocation skips execution."""
|
||||
|
||||
config = getattr(self, "function_invocation_configuration", None)
|
||||
if not config:
|
||||
return
|
||||
if any(getattr(tool, "name", None) == tool_name for tool in config.additional_tools):
|
||||
return
|
||||
|
||||
placeholder: AIFunction[Any, Any] = AIFunction(
|
||||
name=tool_name,
|
||||
description="Server-managed tool placeholder (AG-UI)",
|
||||
func=None,
|
||||
)
|
||||
config.additional_tools = list(config.additional_tools) + [placeholder]
|
||||
registered: set[str] = getattr(self, "_registered_server_tools", set())
|
||||
registered.add(tool_name)
|
||||
self._registered_server_tools = registered # type: ignore[attr-defined]
|
||||
from agent_framework._logging import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
logger.debug(f"[AGUIChatClient] Registered server placeholder: {tool_name}")
|
||||
|
||||
def _extract_state_from_messages(
|
||||
self, messages: MutableSequence[ChatMessage]
|
||||
) -> tuple[list[ChatMessage], dict[str, Any] | None]:
|
||||
"""Extract state from last message if present.
|
||||
|
||||
Args:
|
||||
messages: List of chat messages
|
||||
|
||||
Returns:
|
||||
Tuple of (messages_without_state, state_dict)
|
||||
"""
|
||||
if not messages:
|
||||
return list(messages), None
|
||||
|
||||
last_message = messages[-1]
|
||||
|
||||
for content in last_message.contents:
|
||||
if isinstance(content, DataContent) and content.media_type == "application/json":
|
||||
try:
|
||||
uri = content.uri
|
||||
if uri.startswith("data:application/json;base64,"):
|
||||
import base64
|
||||
|
||||
encoded_data = uri.split(",", 1)[1]
|
||||
decoded_bytes = base64.b64decode(encoded_data)
|
||||
state = json.loads(decoded_bytes.decode("utf-8"))
|
||||
|
||||
messages_without_state = list(messages[:-1]) if len(messages) > 1 else []
|
||||
return messages_without_state, state
|
||||
except (json.JSONDecodeError, ValueError, KeyError) as e:
|
||||
from agent_framework._logging import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
logger.warning(f"Failed to extract state from message: {e}")
|
||||
|
||||
return list(messages), None
|
||||
|
||||
def _convert_messages_to_agui_format(self, messages: list[ChatMessage]) -> list[dict[str, Any]]:
|
||||
"""Convert Agent Framework messages to AG-UI format.
|
||||
|
||||
Args:
|
||||
messages: List of ChatMessage objects
|
||||
|
||||
Returns:
|
||||
List of AG-UI formatted message dictionaries
|
||||
"""
|
||||
return agent_framework_messages_to_agui(messages)
|
||||
|
||||
def _get_thread_id(self, chat_options: ChatOptions) -> str:
|
||||
"""Get or generate thread ID from chat options.
|
||||
|
||||
Args:
|
||||
chat_options: Chat options containing metadata
|
||||
|
||||
Returns:
|
||||
Thread ID string
|
||||
"""
|
||||
thread_id = None
|
||||
if chat_options.metadata:
|
||||
thread_id = chat_options.metadata.get("thread_id")
|
||||
|
||||
if not thread_id:
|
||||
thread_id = f"thread_{uuid.uuid4().hex}"
|
||||
|
||||
return thread_id
|
||||
|
||||
async def _inner_get_response(
|
||||
self,
|
||||
*,
|
||||
messages: MutableSequence[ChatMessage],
|
||||
chat_options: ChatOptions,
|
||||
**kwargs: Any,
|
||||
) -> ChatResponse:
|
||||
"""Internal method to get non-streaming response.
|
||||
|
||||
Keyword Args:
|
||||
messages: List of chat messages
|
||||
chat_options: Chat options for the request
|
||||
**kwargs: Additional keyword arguments
|
||||
|
||||
Returns:
|
||||
ChatResponse object
|
||||
"""
|
||||
return await ChatResponse.from_chat_response_generator(
|
||||
self._inner_get_streaming_response(
|
||||
messages=messages,
|
||||
chat_options=chat_options,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
async def _inner_get_streaming_response(
|
||||
self,
|
||||
*,
|
||||
messages: MutableSequence[ChatMessage],
|
||||
chat_options: ChatOptions,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[ChatResponseUpdate]:
|
||||
"""Internal method to get streaming response.
|
||||
|
||||
Keyword Args:
|
||||
messages: List of chat messages
|
||||
chat_options: Chat options for the request
|
||||
**kwargs: Additional keyword arguments
|
||||
|
||||
Yields:
|
||||
ChatResponseUpdate objects
|
||||
"""
|
||||
messages_to_send, state = self._extract_state_from_messages(messages)
|
||||
|
||||
thread_id = self._get_thread_id(chat_options)
|
||||
run_id = f"run_{uuid.uuid4().hex}"
|
||||
|
||||
agui_messages = self._convert_messages_to_agui_format(messages_to_send)
|
||||
|
||||
# Send client tools to server so LLM knows about them
|
||||
# Client tools execute via ChatAgent's @use_function_invocation wrapper
|
||||
agui_tools = convert_tools_to_agui_format(chat_options.tools)
|
||||
|
||||
# Build set of client tool names (matches .NET clientToolSet)
|
||||
# Used to distinguish client vs server tools in response stream
|
||||
client_tool_set: set[str] = set()
|
||||
if chat_options.tools:
|
||||
for tool in chat_options.tools:
|
||||
if hasattr(tool, "name"):
|
||||
client_tool_set.add(tool.name) # type: ignore[arg-type]
|
||||
self._last_client_tool_set = client_tool_set # type: ignore[attr-defined]
|
||||
|
||||
logger.debug(
|
||||
"[AGUIChatClient] Preparing request",
|
||||
extra={
|
||||
"thread_id": thread_id,
|
||||
"run_id": run_id,
|
||||
"client_tools": list(client_tool_set),
|
||||
"messages": [msg.text for msg in messages_to_send if msg.text],
|
||||
},
|
||||
)
|
||||
logger.debug(f"[AGUIChatClient] Client tool set: {client_tool_set}")
|
||||
|
||||
converter = AGUIEventConverter()
|
||||
|
||||
async for event in self._http_service.post_run(
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
messages=agui_messages,
|
||||
state=state,
|
||||
tools=agui_tools,
|
||||
):
|
||||
logger.debug(f"[AGUIChatClient] Raw AG-UI event: {event}")
|
||||
update = converter.convert_event(event)
|
||||
if update is not None:
|
||||
logger.debug(
|
||||
"[AGUIChatClient] Converted update",
|
||||
extra={"role": update.role, "contents": [type(c).__name__ for c in update.contents]},
|
||||
)
|
||||
# Distinguish client vs server tools
|
||||
for i, content in enumerate(update.contents):
|
||||
if isinstance(content, FunctionCallContent):
|
||||
logger.debug(
|
||||
f"[AGUIChatClient] Function call: {content.name}, in client_tool_set: {content.name in client_tool_set}"
|
||||
)
|
||||
if content.name in client_tool_set:
|
||||
# Client tool - let @use_function_invocation execute it
|
||||
if not content.additional_properties:
|
||||
content.additional_properties = {}
|
||||
content.additional_properties["agui_thread_id"] = thread_id
|
||||
else:
|
||||
# Server tool - wrap so @use_function_invocation ignores it
|
||||
logger.debug(f"[AGUIChatClient] Wrapping server tool: {content.name}")
|
||||
self._register_server_tool_placeholder(content.name)
|
||||
update.contents[i] = ServerFunctionCallContent(content) # type: ignore
|
||||
|
||||
yield update
|
||||
@@ -0,0 +1,209 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Event converter for AG-UI protocol events to Agent Framework types."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import (
|
||||
ChatResponseUpdate,
|
||||
ErrorContent,
|
||||
FinishReason,
|
||||
FunctionCallContent,
|
||||
FunctionResultContent,
|
||||
Role,
|
||||
TextContent,
|
||||
)
|
||||
|
||||
|
||||
class AGUIEventConverter:
|
||||
"""Converter for AG-UI events to Agent Framework types.
|
||||
|
||||
Handles conversion of AG-UI protocol events to ChatResponseUpdate objects
|
||||
while maintaining state, aggregating content, and tracking metadata.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the converter with fresh state."""
|
||||
self.current_message_id: str | None = None
|
||||
self.current_tool_call_id: str | None = None
|
||||
self.current_tool_name: str | None = None
|
||||
self.accumulated_tool_args: str = ""
|
||||
self.thread_id: str | None = None
|
||||
self.run_id: str | None = None
|
||||
|
||||
def convert_event(self, event: dict[str, Any]) -> ChatResponseUpdate | None:
|
||||
"""Convert a single AG-UI event to ChatResponseUpdate.
|
||||
|
||||
Args:
|
||||
event: AG-UI event dictionary
|
||||
|
||||
Returns:
|
||||
ChatResponseUpdate if event produces content, None otherwise
|
||||
|
||||
Examples:
|
||||
RUN_STARTED event:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
converter = AGUIEventConverter()
|
||||
event = {"type": "RUN_STARTED", "threadId": "t1", "runId": "r1"}
|
||||
update = converter.convert_event(event)
|
||||
assert update.additional_properties["thread_id"] == "t1"
|
||||
|
||||
TEXT_MESSAGE_CONTENT event:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
event = {"type": "TEXT_MESSAGE_CONTENT", "messageId": "m1", "delta": "Hello"}
|
||||
update = converter.convert_event(event)
|
||||
assert update.contents[0].text == "Hello"
|
||||
"""
|
||||
event_type = event.get("type", "")
|
||||
|
||||
if event_type == "RUN_STARTED":
|
||||
return self._handle_run_started(event)
|
||||
elif event_type == "TEXT_MESSAGE_START":
|
||||
return self._handle_text_message_start(event)
|
||||
elif event_type == "TEXT_MESSAGE_CONTENT":
|
||||
return self._handle_text_message_content(event)
|
||||
elif event_type == "TEXT_MESSAGE_END":
|
||||
return self._handle_text_message_end(event)
|
||||
elif event_type == "TOOL_CALL_START":
|
||||
return self._handle_tool_call_start(event)
|
||||
elif event_type == "TOOL_CALL_ARGS":
|
||||
return self._handle_tool_call_args(event)
|
||||
elif event_type == "TOOL_CALL_END":
|
||||
return self._handle_tool_call_end(event)
|
||||
elif event_type == "TOOL_CALL_RESULT":
|
||||
return self._handle_tool_call_result(event)
|
||||
elif event_type == "RUN_FINISHED":
|
||||
return self._handle_run_finished(event)
|
||||
elif event_type == "RUN_ERROR":
|
||||
return self._handle_run_error(event)
|
||||
|
||||
return None
|
||||
|
||||
def _handle_run_started(self, event: dict[str, Any]) -> ChatResponseUpdate:
|
||||
"""Handle RUN_STARTED event."""
|
||||
self.thread_id = event.get("threadId")
|
||||
self.run_id = event.get("runId")
|
||||
|
||||
return ChatResponseUpdate(
|
||||
role=Role.ASSISTANT,
|
||||
contents=[],
|
||||
additional_properties={
|
||||
"thread_id": self.thread_id,
|
||||
"run_id": self.run_id,
|
||||
},
|
||||
)
|
||||
|
||||
def _handle_text_message_start(self, event: dict[str, Any]) -> ChatResponseUpdate | None:
|
||||
"""Handle TEXT_MESSAGE_START event."""
|
||||
self.current_message_id = event.get("messageId")
|
||||
return ChatResponseUpdate(
|
||||
role=Role.ASSISTANT,
|
||||
message_id=self.current_message_id,
|
||||
contents=[],
|
||||
)
|
||||
|
||||
def _handle_text_message_content(self, event: dict[str, Any]) -> ChatResponseUpdate:
|
||||
"""Handle TEXT_MESSAGE_CONTENT event."""
|
||||
message_id = event.get("messageId")
|
||||
delta = event.get("delta", "")
|
||||
|
||||
if message_id != self.current_message_id:
|
||||
self.current_message_id = message_id
|
||||
|
||||
return ChatResponseUpdate(
|
||||
role=Role.ASSISTANT,
|
||||
message_id=self.current_message_id,
|
||||
contents=[TextContent(text=delta)],
|
||||
)
|
||||
|
||||
def _handle_text_message_end(self, event: dict[str, Any]) -> ChatResponseUpdate | None:
|
||||
"""Handle TEXT_MESSAGE_END event."""
|
||||
return None
|
||||
|
||||
def _handle_tool_call_start(self, event: dict[str, Any]) -> ChatResponseUpdate:
|
||||
"""Handle TOOL_CALL_START event."""
|
||||
self.current_tool_call_id = event.get("toolCallId")
|
||||
self.current_tool_name = event.get("toolName") or event.get("toolCallName") or event.get("tool_call_name")
|
||||
self.accumulated_tool_args = ""
|
||||
|
||||
return ChatResponseUpdate(
|
||||
role=Role.ASSISTANT,
|
||||
contents=[
|
||||
FunctionCallContent(
|
||||
call_id=self.current_tool_call_id or "",
|
||||
name=self.current_tool_name or "",
|
||||
arguments="",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def _handle_tool_call_args(self, event: dict[str, Any]) -> ChatResponseUpdate:
|
||||
"""Handle TOOL_CALL_ARGS event."""
|
||||
delta = event.get("delta", "")
|
||||
self.accumulated_tool_args += delta
|
||||
|
||||
return ChatResponseUpdate(
|
||||
role=Role.ASSISTANT,
|
||||
contents=[
|
||||
FunctionCallContent(
|
||||
call_id=self.current_tool_call_id or "",
|
||||
name=self.current_tool_name or "",
|
||||
arguments=delta,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def _handle_tool_call_end(self, event: dict[str, Any]) -> ChatResponseUpdate | None:
|
||||
"""Handle TOOL_CALL_END event."""
|
||||
self.accumulated_tool_args = ""
|
||||
return None
|
||||
|
||||
def _handle_tool_call_result(self, event: dict[str, Any]) -> ChatResponseUpdate:
|
||||
"""Handle TOOL_CALL_RESULT event."""
|
||||
tool_call_id = event.get("toolCallId", "")
|
||||
result = event.get("result") if event.get("result") is not None else event.get("content")
|
||||
|
||||
return ChatResponseUpdate(
|
||||
role=Role.TOOL,
|
||||
contents=[
|
||||
FunctionResultContent(
|
||||
call_id=tool_call_id,
|
||||
result=result,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def _handle_run_finished(self, event: dict[str, Any]) -> ChatResponseUpdate:
|
||||
"""Handle RUN_FINISHED event."""
|
||||
return ChatResponseUpdate(
|
||||
role=Role.ASSISTANT,
|
||||
finish_reason=FinishReason.STOP,
|
||||
contents=[],
|
||||
additional_properties={
|
||||
"thread_id": self.thread_id,
|
||||
"run_id": self.run_id,
|
||||
},
|
||||
)
|
||||
|
||||
def _handle_run_error(self, event: dict[str, Any]) -> ChatResponseUpdate:
|
||||
"""Handle RUN_ERROR event."""
|
||||
error_message = event.get("message", "Unknown error")
|
||||
|
||||
return ChatResponseUpdate(
|
||||
role=Role.ASSISTANT,
|
||||
finish_reason=FinishReason.CONTENT_FILTER,
|
||||
contents=[
|
||||
ErrorContent(
|
||||
message=error_message,
|
||||
error_code="RUN_ERROR",
|
||||
)
|
||||
],
|
||||
additional_properties={
|
||||
"thread_id": self.thread_id,
|
||||
"run_id": self.run_id,
|
||||
},
|
||||
)
|
||||
@@ -107,7 +107,7 @@ class AgentFrameworkEventBridge:
|
||||
# Skip text content if we're about to emit confirm_changes
|
||||
# The summary should only appear after user confirms
|
||||
if self.should_stop_after_confirm:
|
||||
logger.debug(" >>> Skipping text content - waiting for confirm_changes response")
|
||||
logger.debug("Skipping text content - waiting for confirm_changes response")
|
||||
# Save the summary text to show after confirmation
|
||||
self.suppressed_summary += content.text
|
||||
continue
|
||||
@@ -156,7 +156,7 @@ class AgentFrameworkEventBridge:
|
||||
tool_call_name=content.name,
|
||||
parent_message_id=self.current_message_id,
|
||||
)
|
||||
logger.info(f" >>> Emitting ToolCallStartEvent with name='{content.name}', id='{tool_call_id}'")
|
||||
logger.info(f"Emitting ToolCallStartEvent with name='{content.name}', id='{tool_call_id}'")
|
||||
events.append(tool_start_event)
|
||||
|
||||
# Track tool call for MessagesSnapshotEvent
|
||||
@@ -186,7 +186,7 @@ class AgentFrameworkEventBridge:
|
||||
# If it's a dict, convert to JSON
|
||||
delta_str = json.dumps(content.arguments)
|
||||
|
||||
logger.info(f" >>> Emitting ToolCallArgsEvent with delta: {delta_str!r}..., id='{tool_call_id}'")
|
||||
logger.info(f"Emitting ToolCallArgsEvent with delta: {delta_str!r}..., id='{tool_call_id}'")
|
||||
args_event = ToolCallArgsEvent(
|
||||
tool_call_id=tool_call_id,
|
||||
delta=delta_str,
|
||||
@@ -211,7 +211,7 @@ class AgentFrameworkEventBridge:
|
||||
self.streaming_tool_args += json.dumps(content.arguments)
|
||||
|
||||
logger.debug(
|
||||
f" >>> Predictive state: accumulated {len(self.streaming_tool_args)} chars for tool '{self.current_tool_call_name}'"
|
||||
f"Predictive state: accumulated {len(self.streaming_tool_args)} chars for tool '{self.current_tool_call_name}'"
|
||||
)
|
||||
|
||||
# Try to parse accumulated arguments (may be incomplete JSON)
|
||||
@@ -262,11 +262,11 @@ class AgentFrameworkEventBridge:
|
||||
else str(partial_value)
|
||||
)
|
||||
logger.info(
|
||||
f" >>> StateDeltaEvent #{self.state_delta_count} for '{state_key}': "
|
||||
f"StateDeltaEvent #{self.state_delta_count} for '{state_key}': "
|
||||
f"op=replace, path=/{state_key}, value={value_preview}"
|
||||
)
|
||||
elif self.state_delta_count % 100 == 0:
|
||||
logger.info(f" >>> StateDeltaEvent #{self.state_delta_count} emitted")
|
||||
logger.info(f"StateDeltaEvent #{self.state_delta_count} emitted")
|
||||
|
||||
events.append(state_delta_event)
|
||||
self.last_emitted_state[state_key] = partial_value
|
||||
@@ -312,11 +312,11 @@ class AgentFrameworkEventBridge:
|
||||
else str(state_value)
|
||||
)
|
||||
logger.info(
|
||||
f" >>> StateDeltaEvent #{self.state_delta_count} for '{state_key}': "
|
||||
f"StateDeltaEvent #{self.state_delta_count} for '{state_key}': "
|
||||
f"op=replace, path=/{state_key}, value={value_preview}"
|
||||
)
|
||||
elif self.state_delta_count % 100 == 0: # Also log every 100th
|
||||
logger.info(f" >>> StateDeltaEvent #{self.state_delta_count} emitted")
|
||||
logger.info(f"StateDeltaEvent #{self.state_delta_count} emitted")
|
||||
|
||||
events.append(state_delta_event)
|
||||
|
||||
@@ -360,7 +360,7 @@ class AgentFrameworkEventBridge:
|
||||
],
|
||||
)
|
||||
logger.info(
|
||||
f" >>> Emitting StateDeltaEvent for key '{state_key}', value type: {type(state_value)}"
|
||||
f"Emitting StateDeltaEvent for key '{state_key}', value type: {type(state_value)}"
|
||||
)
|
||||
events.append(state_delta_event)
|
||||
|
||||
@@ -376,13 +376,13 @@ class AgentFrameworkEventBridge:
|
||||
end_event = ToolCallEndEvent(
|
||||
tool_call_id=content.call_id,
|
||||
)
|
||||
logger.info(f" >>> Emitting ToolCallEndEvent for completed tool call '{content.call_id}'")
|
||||
logger.info(f"Emitting ToolCallEndEvent for completed tool call '{content.call_id}'")
|
||||
events.append(end_event)
|
||||
|
||||
# Log total StateDeltaEvent count for this tool call
|
||||
if self.state_delta_count > 0:
|
||||
logger.info(
|
||||
f" >>> Tool call '{content.call_id}' complete: emitted {self.state_delta_count} StateDeltaEvents total"
|
||||
f"Tool call '{content.call_id}' complete: emitted {self.state_delta_count} StateDeltaEvents total"
|
||||
)
|
||||
|
||||
# Reset streaming accumulator and counter for next tool call
|
||||
@@ -410,11 +410,13 @@ class AgentFrameworkEventBridge:
|
||||
events.append(result_event)
|
||||
|
||||
# Track tool result for MessagesSnapshotEvent
|
||||
# AG-UI protocol expects: { role: "tool", toolCallId: ..., content: ... }
|
||||
# Use camelCase for Pydantic's alias_generator=to_camel
|
||||
self.tool_results.append(
|
||||
{
|
||||
"id": result_message_id,
|
||||
"role": "tool",
|
||||
"tool_call_id": content.call_id,
|
||||
"toolCallId": content.call_id,
|
||||
"content": result_content,
|
||||
}
|
||||
)
|
||||
@@ -422,6 +424,9 @@ class AgentFrameworkEventBridge:
|
||||
# Emit MessagesSnapshotEvent with the complete conversation including tool calls and results
|
||||
# This is required for CopilotKit's useCopilotAction to detect tool result
|
||||
if self.pending_tool_calls and self.tool_results:
|
||||
# Import message adapter
|
||||
from ._message_adapters import agent_framework_messages_to_agui
|
||||
|
||||
# Build assistant message with tool_calls
|
||||
assistant_message = {
|
||||
"id": generate_event_id(),
|
||||
@@ -429,14 +434,19 @@ class AgentFrameworkEventBridge:
|
||||
"tool_calls": self.pending_tool_calls.copy(), # Copy the accumulated tool calls
|
||||
}
|
||||
|
||||
# Convert Agent Framework messages to AG-UI format (adds required 'id' field)
|
||||
converted_input_messages = agent_framework_messages_to_agui(self.input_messages)
|
||||
|
||||
# Build complete messages array: input messages + assistant message + tool results
|
||||
all_messages = list(self.input_messages) + [assistant_message] + self.tool_results.copy()
|
||||
all_messages = converted_input_messages + [assistant_message] + self.tool_results.copy()
|
||||
|
||||
# Emit MessagesSnapshotEvent using the proper event type
|
||||
# Note: messages are dict[str, Any] but Pydantic will validate them as Message types
|
||||
messages_snapshot_event = MessagesSnapshotEvent(
|
||||
type=EventType.MESSAGES_SNAPSHOT, messages=all_messages
|
||||
type=EventType.MESSAGES_SNAPSHOT,
|
||||
messages=all_messages, # type: ignore[arg-type]
|
||||
)
|
||||
logger.info(f" >>> Emitting MessagesSnapshotEvent with {len(all_messages)} messages")
|
||||
logger.info(f"Emitting MessagesSnapshotEvent with {len(all_messages)} messages")
|
||||
events.append(messages_snapshot_event)
|
||||
|
||||
# After tool execution, emit StateSnapshotEvent if we have pending state updates
|
||||
@@ -466,7 +476,7 @@ class AgentFrameworkEventBridge:
|
||||
# If so, emit a confirm_changes tool call for the UI modal
|
||||
tool_was_predictive = False
|
||||
logger.debug(
|
||||
f" >>> Checking predictive state: current_tool='{self.current_tool_call_name}', "
|
||||
f"Checking predictive state: current_tool='{self.current_tool_call_name}', "
|
||||
f"predict_config={list(self.predict_state_config.keys()) if self.predict_state_config else 'None'}"
|
||||
)
|
||||
for state_key, config in self.predict_state_config.items():
|
||||
@@ -474,7 +484,7 @@ class AgentFrameworkEventBridge:
|
||||
# We need to match against self.current_tool_call_name
|
||||
if self.current_tool_call_name and config["tool"] == self.current_tool_call_name:
|
||||
logger.info(
|
||||
f" >>> Tool '{self.current_tool_call_name}' matches predictive config for state key '{state_key}'"
|
||||
f"Tool '{self.current_tool_call_name}' matches predictive config for state key '{state_key}'"
|
||||
)
|
||||
tool_was_predictive = True
|
||||
break
|
||||
@@ -483,7 +493,7 @@ class AgentFrameworkEventBridge:
|
||||
# Emit confirm_changes tool call sequence
|
||||
confirm_call_id = generate_event_id()
|
||||
|
||||
logger.info(" >>> Emitting confirm_changes tool call for predictive update")
|
||||
logger.info("Emitting confirm_changes tool call for predictive update")
|
||||
|
||||
# Track confirm_changes tool call for MessagesSnapshotEvent (so it persists after RUN_FINISHED)
|
||||
self.pending_tool_calls.append(
|
||||
@@ -518,6 +528,9 @@ class AgentFrameworkEventBridge:
|
||||
events.append(confirm_end)
|
||||
|
||||
# Emit MessagesSnapshotEvent so confirm_changes persists after RUN_FINISHED
|
||||
# Import message adapter
|
||||
from ._message_adapters import agent_framework_messages_to_agui
|
||||
|
||||
# Build assistant message with pending confirm_changes tool call
|
||||
assistant_message = {
|
||||
"id": generate_event_id(),
|
||||
@@ -525,23 +538,28 @@ class AgentFrameworkEventBridge:
|
||||
"tool_calls": self.pending_tool_calls.copy(), # Includes confirm_changes
|
||||
}
|
||||
|
||||
# Convert Agent Framework messages to AG-UI format (adds required 'id' field)
|
||||
converted_input_messages = agent_framework_messages_to_agui(self.input_messages)
|
||||
|
||||
# Build complete messages array: input messages + assistant message + any tool results
|
||||
all_messages = list(self.input_messages) + [assistant_message] + self.tool_results.copy()
|
||||
all_messages = converted_input_messages + [assistant_message] + self.tool_results.copy()
|
||||
|
||||
# Emit MessagesSnapshotEvent
|
||||
# Note: messages are dict[str, Any] but Pydantic will validate them as Message types
|
||||
messages_snapshot_event = MessagesSnapshotEvent(
|
||||
type=EventType.MESSAGES_SNAPSHOT, messages=all_messages
|
||||
type=EventType.MESSAGES_SNAPSHOT,
|
||||
messages=all_messages, # type: ignore[arg-type]
|
||||
)
|
||||
logger.info(
|
||||
f" >>> Emitting MessagesSnapshotEvent for confirm_changes with {len(all_messages)} messages"
|
||||
f"Emitting MessagesSnapshotEvent for confirm_changes with {len(all_messages)} messages"
|
||||
)
|
||||
events.append(messages_snapshot_event)
|
||||
|
||||
# Set flag to stop the run after this - we're waiting for user response
|
||||
self.should_stop_after_confirm = True
|
||||
logger.info(" >>> Set flag to stop run after confirm_changes")
|
||||
logger.info("Set flag to stop run after confirm_changes")
|
||||
elif tool_was_predictive:
|
||||
logger.info(" >>> Skipping confirm_changes - require_confirmation is False")
|
||||
logger.info("Skipping confirm_changes - require_confirmation is False")
|
||||
|
||||
# Clear pending updates and reset tool name tracker
|
||||
self.pending_state_updates.clear()
|
||||
@@ -580,7 +598,7 @@ class AgentFrameworkEventBridge:
|
||||
# Update current state
|
||||
self.current_state[state_key] = state_value
|
||||
logger.info(
|
||||
f" >>> Emitting StateSnapshotEvent for key '{state_key}', value type: {type(state_value)}"
|
||||
f"Emitting StateSnapshotEvent for key '{state_key}', value type: {type(state_value)}"
|
||||
)
|
||||
|
||||
# Emit state snapshot
|
||||
@@ -596,7 +614,7 @@ class AgentFrameworkEventBridge:
|
||||
tool_call_id=content.function_call.call_id,
|
||||
)
|
||||
logger.info(
|
||||
f" >>> Emitting ToolCallEndEvent for approval-required tool '{content.function_call.call_id}'"
|
||||
f"Emitting ToolCallEndEvent for approval-required tool '{content.function_call.call_id}'"
|
||||
)
|
||||
events.append(end_event)
|
||||
|
||||
@@ -615,7 +633,7 @@ class AgentFrameworkEventBridge:
|
||||
},
|
||||
},
|
||||
)
|
||||
logger.info(f" >>> Emitting function_approval_request custom event for '{content.function_call.name}'")
|
||||
logger.info(f"Emitting function_approval_request custom event for '{content.function_call.name}'")
|
||||
events.append(approval_event)
|
||||
|
||||
return events
|
||||
|
||||
@@ -0,0 +1,157 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""HTTP service for AG-UI protocol communication."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncIterable
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AGUIHttpService:
|
||||
"""HTTP service for AG-UI protocol communication.
|
||||
|
||||
Handles HTTP POST requests and Server-Sent Events (SSE) stream parsing
|
||||
for the AG-UI protocol.
|
||||
|
||||
Examples:
|
||||
Basic usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
service = AGUIHttpService("http://localhost:8888/")
|
||||
async for event in service.post_run(
|
||||
thread_id="thread_123",
|
||||
run_id="run_456",
|
||||
messages=[{"role": "user", "content": "Hello"}]
|
||||
):
|
||||
print(event["type"])
|
||||
|
||||
With context manager:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
async with AGUIHttpService("http://localhost:8888/") as service:
|
||||
async for event in service.post_run(...):
|
||||
print(event)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str,
|
||||
http_client: httpx.AsyncClient | None = None,
|
||||
timeout: float = 60.0,
|
||||
) -> None:
|
||||
"""Initialize the HTTP service.
|
||||
|
||||
Args:
|
||||
endpoint: AG-UI server endpoint URL (e.g., "http://localhost:8888/")
|
||||
http_client: Optional httpx AsyncClient. If None, creates a new one.
|
||||
timeout: Request timeout in seconds (default: 60.0)
|
||||
"""
|
||||
self.endpoint = endpoint.rstrip("/")
|
||||
self._owns_client = http_client is None
|
||||
self.http_client = http_client or httpx.AsyncClient(timeout=timeout)
|
||||
|
||||
async def post_run(
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
messages: list[dict[str, Any]],
|
||||
state: dict[str, Any] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> AsyncIterable[dict[str, Any]]:
|
||||
"""Post a run request and stream AG-UI events.
|
||||
|
||||
Args:
|
||||
thread_id: Thread identifier for conversation continuity
|
||||
run_id: Unique run identifier
|
||||
messages: List of messages in AG-UI format
|
||||
state: Optional state object to send to server
|
||||
tools: Optional list of tools available to the agent
|
||||
|
||||
Yields:
|
||||
AG-UI event dictionaries parsed from SSE stream
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If the HTTP request fails
|
||||
ValueError: If SSE parsing encounters invalid data
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
service = AGUIHttpService("http://localhost:8888/")
|
||||
async for event in service.post_run(
|
||||
thread_id="thread_abc",
|
||||
run_id="run_123",
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
state={"user_context": {"name": "Alice"}}
|
||||
):
|
||||
if event["type"] == "TEXT_MESSAGE_CONTENT":
|
||||
print(event["delta"])
|
||||
"""
|
||||
# Build request payload
|
||||
request_data: dict[str, Any] = {
|
||||
"thread_id": thread_id,
|
||||
"run_id": run_id,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
if state is not None:
|
||||
request_data["state"] = state
|
||||
|
||||
if tools is not None:
|
||||
request_data["tools"] = tools
|
||||
|
||||
logger.debug(
|
||||
f"Posting run to {self.endpoint}: thread_id={thread_id}, run_id={run_id}, "
|
||||
f"messages={len(messages)}, has_state={state is not None}, has_tools={tools is not None}"
|
||||
)
|
||||
|
||||
# Stream the response using SSE
|
||||
async with self.http_client.stream(
|
||||
"POST",
|
||||
self.endpoint,
|
||||
json=request_data,
|
||||
headers={"Accept": "text/event-stream"},
|
||||
) as response:
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"HTTP request failed: {e.response.status_code} - {e.response.text}")
|
||||
raise
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
# Parse Server-Sent Events format
|
||||
if line.startswith("data: "):
|
||||
data = line[6:] # Remove "data: " prefix
|
||||
try:
|
||||
event = json.loads(data)
|
||||
logger.debug(f"Received event: {event.get('type', 'UNKNOWN')}")
|
||||
yield event
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Failed to parse SSE data: {data}. Error: {e}")
|
||||
# Continue processing other events instead of failing
|
||||
continue
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the HTTP client if owned by this service.
|
||||
|
||||
Only closes the client if it was created by this service instance.
|
||||
If an external client was provided, it remains the caller's
|
||||
responsibility to close it.
|
||||
"""
|
||||
if self._owns_client and self.http_client:
|
||||
await self.http_client.aclose()
|
||||
|
||||
async def __aenter__(self) -> "AGUIHttpService":
|
||||
"""Enter async context manager."""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args: Any) -> None:
|
||||
"""Exit async context manager and clean up resources."""
|
||||
await self.close()
|
||||
@@ -2,12 +2,13 @@
|
||||
|
||||
"""Message format conversion between AG-UI and Agent Framework."""
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from agent_framework import (
|
||||
ChatMessage,
|
||||
FunctionApprovalResponseContent,
|
||||
FunctionCallContent,
|
||||
FunctionResultContent,
|
||||
Role,
|
||||
TextContent,
|
||||
)
|
||||
@@ -46,7 +47,7 @@ def agui_messages_to_agent_framework(messages: list[dict[str, Any]]) -> list[Cha
|
||||
result_content = msg.get("result", msg.get("content", ""))
|
||||
|
||||
chat_msg = ChatMessage(
|
||||
role=Role.ASSISTANT, # Tool results are assistant messages
|
||||
role=Role.TOOL, # Tool results must be tool role
|
||||
contents=[FunctionResultContent(call_id=tool_call_id, result=result_content)],
|
||||
)
|
||||
|
||||
@@ -56,6 +57,42 @@ def agui_messages_to_agent_framework(messages: list[dict[str, Any]]) -> list[Cha
|
||||
result.append(chat_msg)
|
||||
continue
|
||||
|
||||
# If assistant message includes tool calls, convert to FunctionCallContent(s)
|
||||
tool_calls = msg.get("tool_calls") or msg.get("toolCalls")
|
||||
if tool_calls:
|
||||
contents: list[Any] = []
|
||||
# Include any assistant text content if present
|
||||
content_text = msg.get("content")
|
||||
if isinstance(content_text, str) and content_text:
|
||||
contents.append(TextContent(text=content_text))
|
||||
# Convert each tool call entry
|
||||
for tc in tool_calls:
|
||||
if not isinstance(tc, dict):
|
||||
continue
|
||||
# Cast to typed dict for proper type inference
|
||||
tc_dict = cast(dict[str, Any], tc)
|
||||
tc_type = tc_dict.get("type")
|
||||
if tc_type == "function":
|
||||
func_data = tc_dict.get("function", {})
|
||||
func_dict = cast(dict[str, Any], func_data) if isinstance(func_data, dict) else {}
|
||||
|
||||
call_id = str(tc_dict.get("id", ""))
|
||||
name = str(func_dict.get("name", ""))
|
||||
arguments = func_dict.get("arguments")
|
||||
|
||||
contents.append(
|
||||
FunctionCallContent(
|
||||
call_id=call_id,
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
)
|
||||
)
|
||||
chat_msg = ChatMessage(role=Role.ASSISTANT, contents=contents)
|
||||
if "id" in msg:
|
||||
chat_msg.message_id = msg["id"]
|
||||
result.append(chat_msg)
|
||||
continue
|
||||
|
||||
role_str = msg.get("role", "user")
|
||||
|
||||
# Handle tool result messages (with role="tool")
|
||||
@@ -78,11 +115,11 @@ def agui_messages_to_agent_framework(messages: list[dict[str, Any]]) -> list[Cha
|
||||
|
||||
# Backend tool results have non-empty content WITHOUT "accepted" field
|
||||
if tool_call_id and result_content and not is_approval:
|
||||
# Backend tool execution - convert to FunctionResultContent
|
||||
# Tool execution result - convert to FunctionResultContent with correct role
|
||||
from agent_framework import FunctionResultContent
|
||||
|
||||
chat_msg = ChatMessage(
|
||||
role=Role.ASSISTANT, # Tool results are assistant messages
|
||||
role=Role.TOOL,
|
||||
contents=[FunctionResultContent(call_id=tool_call_id, result=result_content)],
|
||||
)
|
||||
|
||||
@@ -97,9 +134,8 @@ def agui_messages_to_agent_framework(messages: list[dict[str, Any]]) -> list[Cha
|
||||
chat_msg = ChatMessage(
|
||||
role=Role.USER, # Approval responses are user messages
|
||||
contents=[TextContent(text=content)],
|
||||
additional_properties={"is_tool_result": True, "tool_call_id": msg.get("toolCallId", "")},
|
||||
)
|
||||
# Mark this as a tool result so we can detect it later
|
||||
chat_msg.metadata = {"is_tool_result": True, "tool_call_id": msg.get("toolCallId", "")} # type: ignore[attr-defined]
|
||||
|
||||
if "id" in msg:
|
||||
chat_msg.message_id = msg["id"]
|
||||
@@ -112,7 +148,7 @@ def agui_messages_to_agent_framework(messages: list[dict[str, Any]]) -> list[Cha
|
||||
# Check if this message contains function approvals
|
||||
if "function_approvals" in msg and msg["function_approvals"]:
|
||||
# Convert function approvals to FunctionApprovalResponseContent
|
||||
contents: list[Any] = []
|
||||
approval_contents: list[Any] = []
|
||||
for approval in msg["function_approvals"]:
|
||||
# Create FunctionCallContent with the modified arguments
|
||||
func_call = FunctionCallContent(
|
||||
@@ -127,9 +163,9 @@ def agui_messages_to_agent_framework(messages: list[dict[str, Any]]) -> list[Cha
|
||||
id=approval.get("id", ""),
|
||||
function_call=func_call,
|
||||
)
|
||||
contents.append(approval_response)
|
||||
approval_contents.append(approval_response)
|
||||
|
||||
chat_msg = ChatMessage(role=role, contents=contents) # type: ignore[arg-type]
|
||||
chat_msg = ChatMessage(role=role, contents=approval_contents) # type: ignore[arg-type]
|
||||
else:
|
||||
# Regular text message
|
||||
content = msg.get("content", "")
|
||||
@@ -146,21 +182,44 @@ def agui_messages_to_agent_framework(messages: list[dict[str, Any]]) -> list[Cha
|
||||
return result
|
||||
|
||||
|
||||
def agent_framework_messages_to_agui(messages: list[ChatMessage]) -> list[dict[str, Any]]:
|
||||
def agent_framework_messages_to_agui(messages: list[ChatMessage] | list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Convert Agent Framework messages to AG-UI format.
|
||||
|
||||
Args:
|
||||
messages: List of Agent Framework ChatMessage objects
|
||||
messages: List of Agent Framework ChatMessage objects or AG-UI dicts (already converted)
|
||||
|
||||
Returns:
|
||||
List of AG-UI message dictionaries
|
||||
"""
|
||||
from ._utils import generate_event_id
|
||||
|
||||
result: list[dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
# If already a dict (AG-UI format), ensure it has an ID and normalize keys for Pydantic
|
||||
if isinstance(msg, dict):
|
||||
# Always work on a copy to avoid mutating input
|
||||
normalized_msg = msg.copy()
|
||||
# Ensure ID exists
|
||||
if "id" not in normalized_msg:
|
||||
normalized_msg["id"] = generate_event_id()
|
||||
# Normalize tool_call_id to toolCallId for Pydantic's alias_generator=to_camel
|
||||
if normalized_msg.get("role") == "tool":
|
||||
if "tool_call_id" in normalized_msg:
|
||||
normalized_msg["toolCallId"] = normalized_msg["tool_call_id"]
|
||||
del normalized_msg["tool_call_id"]
|
||||
elif "toolCallId" not in normalized_msg:
|
||||
# Tool message missing toolCallId - add empty string to satisfy schema
|
||||
normalized_msg["toolCallId"] = ""
|
||||
# Always append the normalized copy, not the original
|
||||
result.append(normalized_msg)
|
||||
continue
|
||||
|
||||
# Convert ChatMessage to AG-UI format
|
||||
role = _FRAMEWORK_TO_AGUI_ROLE.get(msg.role, "user")
|
||||
|
||||
content_text = ""
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
tool_result_call_id: str | None = None
|
||||
|
||||
for content in msg.contents:
|
||||
if isinstance(content, TextContent):
|
||||
@@ -176,18 +235,32 @@ def agent_framework_messages_to_agui(messages: list[ChatMessage]) -> list[dict[s
|
||||
},
|
||||
}
|
||||
)
|
||||
elif isinstance(content, FunctionResultContent):
|
||||
# Tool result content - extract call_id and result
|
||||
tool_result_call_id = content.call_id
|
||||
# Serialize result to string
|
||||
if isinstance(content.result, dict):
|
||||
import json
|
||||
|
||||
content_text = json.dumps(content.result) # type: ignore
|
||||
elif content.result is not None:
|
||||
content_text = str(content.result)
|
||||
|
||||
agui_msg: dict[str, Any] = {
|
||||
"id": msg.message_id if msg.message_id else generate_event_id(), # Always include id
|
||||
"role": role,
|
||||
"content": content_text,
|
||||
}
|
||||
|
||||
if msg.message_id:
|
||||
agui_msg["id"] = msg.message_id
|
||||
|
||||
if tool_calls:
|
||||
agui_msg["tool_calls"] = tool_calls
|
||||
|
||||
# If this is a tool result message, add toolCallId (using camelCase for Pydantic)
|
||||
if tool_result_call_id:
|
||||
agui_msg["toolCallId"] = tool_result_call_id
|
||||
# Tool result messages should have role="tool"
|
||||
agui_msg["role"] = "tool"
|
||||
|
||||
result.append(agui_msg)
|
||||
|
||||
return result
|
||||
|
||||
@@ -16,9 +16,9 @@ from ag_ui.core import (
|
||||
TextMessageEndEvent,
|
||||
TextMessageStartEvent,
|
||||
)
|
||||
from agent_framework import AgentProtocol, AgentThread, TextContent
|
||||
from agent_framework import AgentProtocol, AgentThread, ChatAgent, TextContent
|
||||
|
||||
from ._utils import generate_event_id
|
||||
from ._utils import convert_agui_tools_to_agent_framework, generate_event_id
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._agent import AgentConfig
|
||||
@@ -142,14 +142,10 @@ class HumanInTheLoopOrchestrator(Orchestrator):
|
||||
True if last message is a tool result
|
||||
"""
|
||||
msg = context.last_message
|
||||
if not msg or not hasattr(msg, "metadata"):
|
||||
if not msg:
|
||||
return False
|
||||
|
||||
metadata = getattr(msg, "metadata", None)
|
||||
if not metadata:
|
||||
return False
|
||||
|
||||
return bool(metadata.get("is_tool_result", False))
|
||||
return bool(msg.additional_properties.get("is_tool_result", False))
|
||||
|
||||
async def run(
|
||||
self,
|
||||
@@ -274,8 +270,10 @@ class DefaultOrchestrator(Orchestrator):
|
||||
current_state: dict[str, Any] = initial_state.copy() if initial_state else {}
|
||||
|
||||
# Check if agent uses structured outputs (response_format)
|
||||
chat_options = getattr(context.agent, "chat_options", None)
|
||||
response_format = getattr(chat_options, "response_format", None) if chat_options else None
|
||||
# Use isinstance to narrow type for proper attribute access
|
||||
response_format = None
|
||||
if isinstance(context.agent, ChatAgent):
|
||||
response_format = context.agent.chat_options.response_format
|
||||
skip_text_content = response_format is not None
|
||||
|
||||
# Create event bridge
|
||||
@@ -334,9 +332,8 @@ class DefaultOrchestrator(Orchestrator):
|
||||
if context.messages:
|
||||
await thread.on_new_messages(context.messages)
|
||||
|
||||
# Get the last message as the new input
|
||||
new_message = context.last_message
|
||||
if not new_message:
|
||||
# Use the full incoming message batch to preserve tool-call adjacency
|
||||
if not context.messages:
|
||||
logger.warning("No messages provided in AG-UI input")
|
||||
yield event_bridge.create_run_finished_event()
|
||||
return
|
||||
@@ -362,11 +359,68 @@ Never replace existing data - always append or merge."""
|
||||
)
|
||||
messages_to_run.append(state_context_msg)
|
||||
|
||||
messages_to_run.append(new_message)
|
||||
# Preserve order from client to satisfy provider constraints (assistant tool_calls must
|
||||
# immediately precede tool result messages). Using the full batch avoids reordering.
|
||||
messages_to_run.extend(context.messages)
|
||||
|
||||
# Handle client tools for hybrid execution
|
||||
# Client sends tool metadata, server merges with its own tools.
|
||||
# Client tools have func=None (declaration-only), so @use_function_invocation
|
||||
# will return the function call without executing (passes back to client).
|
||||
from agent_framework import BaseChatClient
|
||||
|
||||
client_tools = convert_agui_tools_to_agent_framework(context.input_data.get("tools"))
|
||||
|
||||
# Extract server tools - use type narrowing when possible
|
||||
server_tools: list[Any] = []
|
||||
if isinstance(context.agent, ChatAgent):
|
||||
server_tools = context.agent.chat_options.tools or []
|
||||
else:
|
||||
# AgentProtocol allows duck-typed implementations - fallback to attribute access
|
||||
# This supports test mocks and custom agent implementations
|
||||
try:
|
||||
chat_options_attr = getattr(context.agent, "chat_options", None)
|
||||
if chat_options_attr is not None:
|
||||
server_tools = getattr(chat_options_attr, "tools", None) or []
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
# Register client tools as additional (declaration-only) so they are not executed on server
|
||||
if client_tools:
|
||||
if isinstance(context.agent, ChatAgent):
|
||||
# Type-safe path for ChatAgent
|
||||
chat_client = context.agent.chat_client
|
||||
if (
|
||||
isinstance(chat_client, BaseChatClient)
|
||||
and chat_client.function_invocation_configuration is not None
|
||||
):
|
||||
chat_client.function_invocation_configuration.additional_tools = client_tools
|
||||
logger.debug(
|
||||
f"[TOOLS] Registered {len(client_tools)} client tools as additional_tools (declaration-only)"
|
||||
)
|
||||
else:
|
||||
# Fallback for AgentProtocol implementations (test mocks, custom agents)
|
||||
try:
|
||||
chat_client_attr = getattr(context.agent, "chat_client", None)
|
||||
if chat_client_attr is not None:
|
||||
fic = getattr(chat_client_attr, "function_invocation_configuration", None)
|
||||
if fic is not None:
|
||||
fic.additional_tools = client_tools # type: ignore[attr-defined]
|
||||
logger.debug(
|
||||
f"[TOOLS] Registered {len(client_tools)} client tools as additional_tools (declaration-only)"
|
||||
)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
combined_tools: list[Any] = []
|
||||
if server_tools:
|
||||
combined_tools.extend(server_tools)
|
||||
if client_tools:
|
||||
combined_tools.extend(client_tools)
|
||||
|
||||
# Collect all updates to get the final structured output
|
||||
all_updates: list[Any] = []
|
||||
async for update in context.agent.run_stream(messages_to_run, thread=thread):
|
||||
async for update in context.agent.run_stream(messages_to_run, thread=thread, tools=combined_tools or None):
|
||||
all_updates.append(update)
|
||||
events = await event_bridge.from_agent_run_update(update)
|
||||
for event in events:
|
||||
@@ -374,7 +428,7 @@ Never replace existing data - always append or merge."""
|
||||
|
||||
# After agent completes, check if we should stop (waiting for user to confirm changes)
|
||||
if event_bridge.should_stop_after_confirm:
|
||||
logger.info(" >>> Stopping run after confirm_changes - waiting for user response")
|
||||
logger.info("Stopping run after confirm_changes - waiting for user response")
|
||||
yield event_bridge.create_run_finished_event()
|
||||
return
|
||||
|
||||
|
||||
@@ -4,10 +4,13 @@
|
||||
|
||||
import copy
|
||||
import uuid
|
||||
from collections.abc import Callable, MutableMapping, Sequence
|
||||
from dataclasses import asdict, is_dataclass
|
||||
from datetime import date, datetime
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import AIFunction, ToolProtocol
|
||||
|
||||
|
||||
def generate_event_id() -> str:
|
||||
"""Generate a unique event ID."""
|
||||
@@ -55,3 +58,109 @@ def make_json_safe(obj: Any) -> Any: # noqa: ANN401
|
||||
if isinstance(obj, dict):
|
||||
return {key: make_json_safe(value) for key, value in obj.items()} # type: ignore[misc]
|
||||
return str(obj)
|
||||
|
||||
|
||||
def convert_agui_tools_to_agent_framework(
|
||||
agui_tools: list[dict[str, Any]] | None,
|
||||
) -> list[AIFunction[Any, Any]] | None:
|
||||
"""Convert AG-UI tool definitions to Agent Framework AIFunction declarations.
|
||||
|
||||
Creates declaration-only AIFunction instances (no executable implementation).
|
||||
These are used to tell the LLM about available tools. The actual execution
|
||||
happens on the client side via @use_function_invocation.
|
||||
|
||||
CRITICAL: These tools MUST have func=None so that declaration_only returns True.
|
||||
This prevents the server from trying to execute client-side tools.
|
||||
|
||||
Args:
|
||||
agui_tools: List of AG-UI tool definitions with name, description, parameters
|
||||
|
||||
Returns:
|
||||
List of AIFunction declarations, or None if no tools provided
|
||||
"""
|
||||
if not agui_tools:
|
||||
return None
|
||||
|
||||
result: list[AIFunction[Any, Any]] = []
|
||||
for tool_def in agui_tools:
|
||||
# Create declaration-only AIFunction (func=None means no implementation)
|
||||
# When func=None, the declaration_only property returns True,
|
||||
# which tells @use_function_invocation to return the function call
|
||||
# without executing it (so it can be sent back to the client)
|
||||
func: AIFunction[Any, Any] = AIFunction(
|
||||
name=tool_def.get("name", ""),
|
||||
description=tool_def.get("description", ""),
|
||||
func=None, # CRITICAL: Makes declaration_only=True
|
||||
input_model=tool_def.get("parameters", {}),
|
||||
)
|
||||
result.append(func)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def convert_tools_to_agui_format(
|
||||
tools: (
|
||||
ToolProtocol
|
||||
| Callable[..., Any]
|
||||
| MutableMapping[str, Any]
|
||||
| Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]
|
||||
| None
|
||||
),
|
||||
) -> list[dict[str, Any]] | None:
|
||||
"""Convert tools to AG-UI format.
|
||||
|
||||
This sends only the metadata (name, description, JSON schema) to the server.
|
||||
The actual executable implementation stays on the client side.
|
||||
The @use_function_invocation decorator handles client-side execution when
|
||||
the server requests a function.
|
||||
|
||||
Args:
|
||||
tools: Tools to convert (single tool or sequence of tools)
|
||||
|
||||
Returns:
|
||||
List of tool specifications in AG-UI format, or None if no tools provided
|
||||
"""
|
||||
if not tools:
|
||||
return None
|
||||
|
||||
# Normalize to list
|
||||
if not isinstance(tools, list):
|
||||
tool_list: list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] = [tools] # type: ignore[list-item]
|
||||
else:
|
||||
tool_list = tools # type: ignore[assignment]
|
||||
|
||||
results: list[dict[str, Any]] = []
|
||||
|
||||
for tool in tool_list:
|
||||
if isinstance(tool, dict):
|
||||
# Already in dict format, pass through
|
||||
results.append(tool) # type: ignore[arg-type]
|
||||
elif isinstance(tool, AIFunction):
|
||||
# Convert AIFunction to AG-UI tool format
|
||||
results.append(
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters(),
|
||||
}
|
||||
)
|
||||
elif callable(tool):
|
||||
# Convert callable to AIFunction first, then to AG-UI format
|
||||
from agent_framework import ai_function
|
||||
|
||||
ai_func = ai_function(tool)
|
||||
results.append(
|
||||
{
|
||||
"name": ai_func.name,
|
||||
"description": ai_func.description,
|
||||
"parameters": ai_func.parameters(),
|
||||
}
|
||||
)
|
||||
elif isinstance(tool, ToolProtocol):
|
||||
# Handle other ToolProtocol implementations
|
||||
# For now, we'll skip non-AIFunction tools as they may not have
|
||||
# the parameters() method. This matches .NET behavior which only
|
||||
# converts AIFunctionDeclaration instances.
|
||||
continue
|
||||
|
||||
return results if results else None
|
||||
|
||||
@@ -14,7 +14,7 @@ pip install agent-framework-ag-ui
|
||||
from fastapi import FastAPI
|
||||
from agent_framework import ChatAgent
|
||||
from agent_framework.azure import AzureOpenAIChatClient
|
||||
from agent_framework_ag_ui import add_agent_framework_fastapi_endpoint
|
||||
from agent_framework.ag_ui import add_agent_framework_fastapi_endpoint
|
||||
|
||||
# Create your agent
|
||||
agent = ChatAgent(
|
||||
@@ -104,7 +104,7 @@ State is injected as system messages and updated via predictive state updates:
|
||||
```python
|
||||
from agent_framework import ChatAgent
|
||||
from agent_framework.azure import AzureOpenAIChatClient
|
||||
from agent_framework_ag_ui import AgentFrameworkAgent
|
||||
from agent_framework.ag_ui import AgentFrameworkAgent
|
||||
|
||||
# Create your agent
|
||||
agent = ChatAgent(
|
||||
@@ -141,7 +141,7 @@ Predictive state updates automatically stream tool arguments as optimistic state
|
||||
```python
|
||||
from agent_framework import ChatAgent
|
||||
from agent_framework.azure import AzureOpenAIChatClient
|
||||
from agent_framework_ag_ui import AgentFrameworkAgent
|
||||
from agent_framework.ag_ui import AgentFrameworkAgent
|
||||
|
||||
# Create your agent
|
||||
agent = ChatAgent(
|
||||
@@ -170,7 +170,7 @@ Provide domain-specific confirmation messages:
|
||||
from typing import Any
|
||||
from agent_framework import ChatAgent
|
||||
from agent_framework.azure import AzureOpenAIChatClient
|
||||
from agent_framework_ag_ui import AgentFrameworkAgent, ConfirmationStrategy
|
||||
from agent_framework.ag_ui import AgentFrameworkAgent, ConfirmationStrategy
|
||||
|
||||
class CustomConfirmationStrategy(ConfirmationStrategy):
|
||||
def on_approval_accepted(self, steps: list[dict[str, Any]]) -> str:
|
||||
@@ -216,7 +216,7 @@ def sensitive_action(param: str) -> str:
|
||||
Add custom execution flows by implementing the Orchestrator pattern:
|
||||
|
||||
```python
|
||||
from agent_framework_ag_ui._orchestrators import Orchestrator, ExecutionContext
|
||||
from agent_framework.ag_ui._orchestrators import Orchestrator, ExecutionContext
|
||||
|
||||
class MyCustomOrchestrator(Orchestrator):
|
||||
def can_handle(self, context: ExecutionContext) -> bool:
|
||||
|
||||
@@ -128,7 +128,7 @@ class TaskStepsAgentWithExecution:
|
||||
import uuid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(">>> TaskStepsAgentWithExecution.run_agent() called - wrapper is active")
|
||||
logger.info("TaskStepsAgentWithExecution.run_agent() called - wrapper is active")
|
||||
|
||||
# First, run the base agent to generate the plan - buffer text messages
|
||||
final_state: dict[str, Any] | None = None
|
||||
@@ -138,41 +138,41 @@ class TaskStepsAgentWithExecution:
|
||||
|
||||
async for event in self._base_agent.run_agent(input_data):
|
||||
event_type_str = str(event.type) if hasattr(event, "type") else type(event).__name__
|
||||
logger.info(f">>> Processing event: {event_type_str}")
|
||||
logger.info(f"Processing event: {event_type_str}")
|
||||
|
||||
match event:
|
||||
case StateSnapshotEvent(snapshot=snapshot):
|
||||
final_state = snapshot
|
||||
logger.info(f">>> Captured STATE_SNAPSHOT event with state: {final_state}")
|
||||
logger.info(f"Captured STATE_SNAPSHOT event with state: {final_state}")
|
||||
yield event
|
||||
case RunFinishedEvent():
|
||||
run_finished_event = event
|
||||
logger.info(">>> Captured RUN_FINISHED event - will send after step execution and summary")
|
||||
logger.info("Captured RUN_FINISHED event - will send after step execution and summary")
|
||||
case ToolCallStartEvent(tool_call_id=call_id):
|
||||
tool_call_id = call_id
|
||||
logger.info(f">>> Captured tool_call_id: {tool_call_id}")
|
||||
logger.info(f"Captured tool_call_id: {tool_call_id}")
|
||||
yield event
|
||||
case TextMessageStartEvent() | TextMessageContentEvent() | TextMessageEndEvent():
|
||||
buffered_text_events.append(event)
|
||||
logger.info(f">>> Buffered {event_type_str} from first LLM call")
|
||||
logger.info(f"Buffered {event_type_str} from first LLM call")
|
||||
case _:
|
||||
logger.info(f">>> Yielding event immediately: {event_type_str}")
|
||||
logger.info(f"Yielding event immediately: {event_type_str}")
|
||||
yield event
|
||||
|
||||
logger.info(f">>> Base agent completed. Final state: {final_state}")
|
||||
logger.info(f"Base agent completed. Final state: {final_state}")
|
||||
|
||||
# Now simulate executing the steps
|
||||
if final_state and "steps" in final_state:
|
||||
steps = final_state["steps"]
|
||||
logger.info(f">>> Starting step execution simulation for {len(steps)} steps")
|
||||
logger.info(f"Starting step execution simulation for {len(steps)} steps")
|
||||
|
||||
for i in range(len(steps)):
|
||||
logger.info(f">>> Simulating execution of step {i + 1}/{len(steps)}: {steps[i].get('description')}")
|
||||
logger.info(f"Simulating execution of step {i + 1}/{len(steps)}: {steps[i].get('description')}")
|
||||
await asyncio.sleep(1.0) # Simulate work
|
||||
|
||||
# Update step to completed
|
||||
steps[i]["status"] = "completed"
|
||||
logger.info(f">>> Step {i + 1} marked as completed")
|
||||
logger.info(f"Step {i + 1} marked as completed")
|
||||
|
||||
# Send delta event with manual JSON patch format
|
||||
delta_event = StateDeltaEvent(
|
||||
@@ -185,7 +185,7 @@ class TaskStepsAgentWithExecution:
|
||||
}
|
||||
],
|
||||
)
|
||||
logger.info(f">>> Yielding StateDeltaEvent for step {i + 1}")
|
||||
logger.info(f"Yielding StateDeltaEvent for step {i + 1}")
|
||||
yield delta_event
|
||||
|
||||
# Send final snapshot
|
||||
@@ -193,11 +193,11 @@ class TaskStepsAgentWithExecution:
|
||||
type=EventType.STATE_SNAPSHOT,
|
||||
snapshot={"steps": steps},
|
||||
)
|
||||
logger.info(">>> Yielding final StateSnapshotEvent with all steps completed")
|
||||
logger.info("Yielding final StateSnapshotEvent with all steps completed")
|
||||
yield final_snapshot
|
||||
|
||||
# SECOND LLM call: Stream summary from chat client directly
|
||||
logger.info(">>> Making SECOND LLM call to generate summary after step execution")
|
||||
logger.info("Making SECOND LLM call to generate summary after step execution")
|
||||
|
||||
# Get the underlying chat agent and client
|
||||
chat_agent = self._base_agent.agent # type: ignore
|
||||
@@ -236,7 +236,7 @@ class TaskStepsAgentWithExecution:
|
||||
)
|
||||
|
||||
# Stream the LLM response and manually emit text events
|
||||
logger.info(">>> Calling chat client for summary")
|
||||
logger.info("Calling chat client for summary")
|
||||
|
||||
message_id = str(uuid.uuid4())
|
||||
|
||||
@@ -268,7 +268,7 @@ class TaskStepsAgentWithExecution:
|
||||
type=EventType.TEXT_MESSAGE_END,
|
||||
message_id=message_id,
|
||||
)
|
||||
logger.info(f">>> Summary complete: {accumulated_text}")
|
||||
logger.info(f"Summary complete: {accumulated_text}")
|
||||
|
||||
# Build complete message for persistence
|
||||
summary_message = {
|
||||
@@ -285,7 +285,7 @@ class TaskStepsAgentWithExecution:
|
||||
messages=final_messages,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f">>> Error generating summary: {e}")
|
||||
logger.error(f"Error generating summary: {e}")
|
||||
# Generate a new message ID for the error
|
||||
error_message_id = str(uuid.uuid4())
|
||||
# Yield TEXT_MESSAGE_START for error
|
||||
@@ -306,11 +306,11 @@ class TaskStepsAgentWithExecution:
|
||||
message_id=error_message_id,
|
||||
)
|
||||
else:
|
||||
logger.warning(f">>> No steps found in final_state to execute. final_state={final_state}")
|
||||
logger.warning(f"No steps found in final_state to execute. final_state={final_state}")
|
||||
|
||||
# Finally send the original RUN_FINISHED event
|
||||
if run_finished_event:
|
||||
logger.info(">>> Yielding original RUN_FINISHED event")
|
||||
logger.info("Yielding original RUN_FINISHED event")
|
||||
yield run_finished_event
|
||||
|
||||
|
||||
|
||||
@@ -2,6 +2,135 @@
|
||||
|
||||
The AG-UI (Agent UI) protocol provides a standardized way for client applications to interact with AI agents over HTTP. This tutorial demonstrates how to build both server and client applications using the AG-UI protocol with Python.
|
||||
|
||||
## Quick Start - Client Examples
|
||||
|
||||
If you want to quickly try out the AG-UI client, we provide three ready-to-use examples:
|
||||
|
||||
### Basic Interactive Client (`client.py`)
|
||||
|
||||
A simple command-line chat client that demonstrates:
|
||||
- Streaming responses in real-time
|
||||
- Automatic thread management for conversation continuity
|
||||
- Direct `AGUIChatClient` usage (caller manages message history)
|
||||
|
||||
**Run:**
|
||||
```bash
|
||||
python client.py
|
||||
```
|
||||
|
||||
**Note:** This example sends only the current message to the server. The server is responsible for maintaining conversation history using the thread_id.
|
||||
|
||||
### Advanced Features Client (`client_advanced.py`)
|
||||
|
||||
Demonstrates advanced capabilities:
|
||||
- Tool/function calling
|
||||
- Both streaming and non-streaming responses
|
||||
- Multi-turn conversations
|
||||
- Error handling patterns
|
||||
|
||||
**Run:**
|
||||
```bash
|
||||
python client_advanced.py
|
||||
```
|
||||
|
||||
**Note:** This example shows direct `AGUIChatClient` usage. Tool execution and conversation continuity depend on server-side configuration and capabilities.
|
||||
|
||||
### ChatAgent Integration (`client_with_agent.py`)
|
||||
|
||||
Best practice example using `ChatAgent` wrapper with **AgentThread**
|
||||
- **AgentThread** maintains conversation state
|
||||
- Client-side conversation history management via `thread.message_store`
|
||||
- **Hybrid tool execution**: client-side + server-side tools simultaneously
|
||||
- Full conversation history sent on each request
|
||||
- Tool calling with conversation context
|
||||
|
||||
**To demonstrate hybrid tools:**
|
||||
|
||||
1. **Start server with server-side tool** (Terminal 1):
|
||||
```bash
|
||||
# Server has get_time_zone tool
|
||||
python server.py
|
||||
```
|
||||
|
||||
2. **Run client with client-side tool** (Terminal 2):
|
||||
```bash
|
||||
# Client has get_weather tool
|
||||
python client_with_agent.py
|
||||
```
|
||||
|
||||
All examples require a running AG-UI server (see Step 1 below for setup).
|
||||
|
||||
## Understanding AG-UI Architecture
|
||||
|
||||
### Thread Management
|
||||
|
||||
The AG-UI protocol supports two approaches to conversation history:
|
||||
|
||||
1. **Server-Managed Threads** (client.py, client_advanced.py)
|
||||
- Client sends only the current message + thread_id
|
||||
- Server maintains full conversation history
|
||||
- Requires server to support stateful thread storage
|
||||
- Lighter network payload
|
||||
|
||||
2. **Client-Managed History** (client_with_agent.py)
|
||||
- Client maintains full conversation history locally
|
||||
- Full message history sent with each request
|
||||
- Works with any AG-UI server (stateful or stateless)
|
||||
|
||||
The `ChatAgent` wrapper (used in client_with_agent.py) collects messages from local storage and sends the full history to `AGUIChatClient`, which then forwards everything to the server.
|
||||
|
||||
### Tool/Function Calling
|
||||
|
||||
The AG-UI protocol supports **hybrid tool execution** - both client-side AND server-side tools can coexist in the same conversation.
|
||||
|
||||
**The Hybrid Pattern** (client_with_agent.py):
|
||||
```
|
||||
Client defines: Server defines:
|
||||
- get_weather() - get_current_time()
|
||||
- read_sensors() - get_server_forecast()
|
||||
|
||||
User: "What's the weather in SF and what time is it?"
|
||||
↓
|
||||
ChatAgent sends: full history + tool definitions for get_weather, read_sensors
|
||||
↓
|
||||
Server LLM decides: "I need get_weather('SF') and get_current_time()"
|
||||
↓
|
||||
Server executes get_current_time() → "2025-11-11 14:30:00 UTC"
|
||||
Server sends function call request → get_weather('SF')
|
||||
↓
|
||||
ChatAgent intercepts get_weather call → executes locally
|
||||
↓
|
||||
Client sends result → "Sunny, 72°F"
|
||||
↓
|
||||
Server combines both results → "It's sunny and 72°F in SF, and the current time is 2:30 PM UTC"
|
||||
↓
|
||||
Client receives final response
|
||||
```
|
||||
|
||||
**How it works:**
|
||||
|
||||
1. **Client-Side Tools** (`client_with_agent.py`):
|
||||
- Tools defined in ChatAgent's `tools` parameter execute locally
|
||||
- Tool metadata (name, description, schema) sent to server for planning
|
||||
- When server requests client tool → client intercepts → executes locally → sends result
|
||||
|
||||
2. **Server-Side Tools**:
|
||||
- Defined in server agent's configuration
|
||||
- Server executes directly without client involvement
|
||||
- Results included in server's response
|
||||
|
||||
3. **Hybrid Pattern (Both Together)**:
|
||||
- Server LLM sees ALL tool definitions (client + server)
|
||||
- Decides which to use based on task
|
||||
- Server tools execute server-side
|
||||
- Client tools execute client-side
|
||||
|
||||
**Direct AGUIChatClient Usage** (client_advanced.py):
|
||||
Even without ChatAgent wrapper, client-side tools work:
|
||||
- Tools passed in ChatOptions execute locally
|
||||
- Server can also have its own tools
|
||||
- Hybrid execution works automatically
|
||||
|
||||
## What is AG-UI?
|
||||
|
||||
AG-UI is a protocol that enables:
|
||||
@@ -35,13 +164,13 @@ The AG-UI server hosts your AI agent and exposes it via HTTP endpoints using Fas
|
||||
### Install Required Packages
|
||||
|
||||
```bash
|
||||
pip install agent-framework-ag-ui agent-framework-core fastapi uvicorn
|
||||
pip install agent-framework-ag-ui
|
||||
```
|
||||
|
||||
Or using uv:
|
||||
|
||||
```bash
|
||||
uv pip install agent-framework-ag-ui agent-framework-core fastapi uvicorn
|
||||
uv pip install agent-framework-ag-ui
|
||||
```
|
||||
|
||||
### Server Code
|
||||
@@ -57,17 +186,20 @@ import os
|
||||
|
||||
from agent_framework import ChatAgent
|
||||
from agent_framework.azure import AzureOpenAIChatClient
|
||||
from agent_framework_ag_ui import add_agent_framework_fastapi_endpoint
|
||||
from agent_framework.ag_ui import add_agent_framework_fastapi_endpoint
|
||||
from fastapi import FastAPI
|
||||
|
||||
# Read required configuration
|
||||
endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
|
||||
deployment_name = os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME")
|
||||
api_key = os.environ.get("AZURE_OPENAI_API_KEY")
|
||||
|
||||
if not endpoint:
|
||||
raise ValueError("AZURE_OPENAI_ENDPOINT environment variable is required")
|
||||
if not deployment_name:
|
||||
raise ValueError("AZURE_OPENAI_DEPLOYMENT_NAME environment variable is required")
|
||||
if not api_key:
|
||||
raise ValueError("AZURE_OPENAI_API_KEY environment variable is required")
|
||||
|
||||
# Create the AI agent
|
||||
agent = ChatAgent(
|
||||
@@ -76,6 +208,7 @@ agent = ChatAgent(
|
||||
chat_client=AzureOpenAIChatClient(
|
||||
endpoint=endpoint,
|
||||
deployment_name=deployment_name,
|
||||
api_key=api_key,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -137,12 +270,14 @@ The server will start listening on `http://127.0.0.1:5100`.
|
||||
|
||||
## Step 2: Creating an AG-UI Client
|
||||
|
||||
The AG-UI client connects to the remote server and displays streaming responses.
|
||||
The AG-UI client connects to the remote server and displays streaming responses. The `AGUIChatClient` is a built-in implementation that integrates with the Agent Framework's standard chat interface.
|
||||
|
||||
### Install Required Packages
|
||||
|
||||
The `AGUIChatClient` is included in the `agent-framework-ag-ui` package (already installed if you installed the server packages).
|
||||
|
||||
```bash
|
||||
pip install httpx
|
||||
pip install agent-framework-ag-ui
|
||||
```
|
||||
|
||||
### Client Code
|
||||
@@ -152,122 +287,61 @@ Create a file named `client.py`:
|
||||
```python
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""AG-UI client example."""
|
||||
"""AG-UI client example using AGUIChatClient."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from typing import AsyncIterator
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
class AGUIClient:
|
||||
"""Simple AG-UI protocol client."""
|
||||
|
||||
def __init__(self, server_url: str):
|
||||
"""Initialize the client.
|
||||
|
||||
Args:
|
||||
server_url: The AG-UI server endpoint URL
|
||||
"""
|
||||
self.server_url = server_url
|
||||
self.thread_id: str | None = None
|
||||
|
||||
async def send_message(self, message: str) -> AsyncIterator[dict]:
|
||||
"""Send a message and stream the response.
|
||||
|
||||
Args:
|
||||
message: The user message to send
|
||||
|
||||
Yields:
|
||||
AG-UI events from the server
|
||||
"""
|
||||
# Prepare the request
|
||||
request_data = {
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": message},
|
||||
]
|
||||
}
|
||||
|
||||
# Include thread_id if we have one (for conversation continuity)
|
||||
if self.thread_id:
|
||||
request_data["thread_id"] = self.thread_id
|
||||
|
||||
# Stream the response
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
self.server_url,
|
||||
json=request_data,
|
||||
headers={"Accept": "text/event-stream"},
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
# Parse Server-Sent Events format
|
||||
if line.startswith("data: "):
|
||||
data = line[6:] # Remove "data: " prefix
|
||||
try:
|
||||
event = json.loads(data)
|
||||
yield event
|
||||
|
||||
# Capture thread_id from RUN_STARTED event
|
||||
if event.get("type") == "RUN_STARTED" and not self.thread_id:
|
||||
self.thread_id = event.get("threadId")
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
from agent_framework import TextContent
|
||||
from agent_framework.ag_ui import AGUIChatClient
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main client loop."""
|
||||
"""Main client loop demonstrating AGUIChatClient usage."""
|
||||
# Get server URL from environment or use default
|
||||
server_url = os.environ.get("AGUI_SERVER_URL", "http://127.0.0.1:5100/")
|
||||
print(f"Connecting to AG-UI server at: {server_url}\n")
|
||||
|
||||
client = AGUIClient(server_url)
|
||||
# Create client with context manager for automatic cleanup
|
||||
async with AGUIChatClient(endpoint=server_url) as client:
|
||||
thread_id: str | None = None
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Get user input
|
||||
message = input("\nUser (:q or quit to exit): ")
|
||||
if not message.strip():
|
||||
print("Request cannot be empty.")
|
||||
continue
|
||||
try:
|
||||
while True:
|
||||
# Get user input
|
||||
message = input("\nUser (:q or quit to exit): ")
|
||||
if not message.strip():
|
||||
print("Request cannot be empty.")
|
||||
continue
|
||||
|
||||
if message.lower() in (":q", "quit"):
|
||||
break
|
||||
if message.lower() in (":q", "quit"):
|
||||
break
|
||||
|
||||
# Send message and display streaming response
|
||||
print("\n", end="")
|
||||
async for event in client.send_message(message):
|
||||
event_type = event.get("type", "")
|
||||
# Send message and stream the response
|
||||
print("\nAssistant: ", end="", flush=True)
|
||||
|
||||
if event_type == "RUN_STARTED":
|
||||
thread_id = event.get("threadId", "")
|
||||
run_id = event.get("runId", "")
|
||||
print(f"\033[93m[Run Started - Thread: {thread_id}, Run: {run_id}]\033[0m")
|
||||
# Use metadata to maintain conversation continuity
|
||||
metadata = {"thread_id": thread_id} if thread_id else None
|
||||
|
||||
elif event_type == "TEXT_MESSAGE_CONTENT":
|
||||
# Stream text content in cyan
|
||||
print(f"\033[96m{event.get('delta', '')}\033[0m", end="", flush=True)
|
||||
async for update in client.get_streaming_response(message, metadata=metadata):
|
||||
# Extract thread ID from first update
|
||||
if not thread_id and update.additional_properties:
|
||||
thread_id = update.additional_properties.get("thread_id")
|
||||
if thread_id:
|
||||
print(f"\n[Thread: {thread_id}]")
|
||||
print("Assistant: ", end="", flush=True)
|
||||
|
||||
elif event_type == "RUN_FINISHED":
|
||||
thread_id = event.get("threadId", "")
|
||||
run_id = event.get("runId", "")
|
||||
print(f"\n\033[92m[Run Finished - Thread: {thread_id}, Run: {run_id}]\033[0m")
|
||||
# Stream text content as it arrives
|
||||
for content in update.contents:
|
||||
if isinstance(content, TextContent) and content.text:
|
||||
print(content.text, end="", flush=True)
|
||||
|
||||
elif event_type == "RUN_ERROR":
|
||||
error_message = event.get("message", "Unknown error")
|
||||
print(f"\n\033[91m[Run Error - Message: {error_message}]\033[0m")
|
||||
print() # New line after response
|
||||
|
||||
print()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nExiting...")
|
||||
except Exception as e:
|
||||
print(f"\n\033[91mAn error occurred: {e}\033[0m")
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nExiting...")
|
||||
except Exception as e:
|
||||
print(f"\nAn error occurred: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -276,17 +350,13 @@ if __name__ == "__main__":
|
||||
|
||||
### Key Concepts
|
||||
|
||||
- **Server-Sent Events (SSE)**: The protocol uses SSE format (`data: {json}\n\n`)
|
||||
- **Event Types**: Different events provide metadata and content (all event types use UPPERCASE with underscores):
|
||||
- `RUN_STARTED`: Signals the agent has started processing
|
||||
- `TEXT_MESSAGE_START`: Signals the start of a text message from the agent
|
||||
- `TEXT_MESSAGE_CONTENT`: Incremental text streamed from the agent (with `delta` field)
|
||||
- `TEXT_MESSAGE_END`: Signals the end of a text message
|
||||
- `RUN_FINISHED`: Signals successful completion
|
||||
- `RUN_ERROR`: Error information if something goes wrong
|
||||
- **Field Naming**: Event fields use camelCase (e.g., `threadId`, `runId`, `messageId`) when accessing JSON events
|
||||
- **Thread Management**: The `threadId` maintains conversation context across requests
|
||||
- **Client-Side Instructions**: System messages are sent from the client
|
||||
- **`AGUIChatClient`**: Built-in client that implements the Agent Framework's `BaseChatClient` interface
|
||||
- **Automatic Event Handling**: The client automatically converts AG-UI events to Agent Framework types
|
||||
- **Thread Management**: Pass `thread_id` in metadata to maintain conversation context across requests
|
||||
- **Streaming Responses**: Use `get_streaming_response()` for real-time streaming or `get_response()` for non-streaming
|
||||
- **Context Manager**: Use `async with` for automatic cleanup of HTTP connections
|
||||
- **Standard Interface**: Works with all Agent Framework patterns (ChatAgent, tools, etc.)
|
||||
- **Hybrid Tool Execution**: Supports both client-side and server-side tools executing together in the same conversation
|
||||
|
||||
### Configure and Run the Client
|
||||
|
||||
@@ -312,327 +382,13 @@ Connecting to AG-UI server at: http://127.0.0.1:5100/
|
||||
|
||||
User (:q or quit to exit): What is the capital of France?
|
||||
|
||||
[Run Started - Thread: abc123, Run: xyz789]
|
||||
The capital of France is Paris. It is known for its rich history, culture,
|
||||
[Thread: abc123]
|
||||
Assistant: The capital of France is Paris. It is known for its rich history, culture,
|
||||
and iconic landmarks such as the Eiffel Tower and the Louvre Museum.
|
||||
[Run Finished - Thread: abc123, Run: xyz789]
|
||||
|
||||
User (:q or quit to exit): Tell me a fun fact about space
|
||||
|
||||
[Run Started - Thread: abc123, Run: def456]
|
||||
Here's a fun fact: A day on Venus is longer than its year! Venus takes
|
||||
about 243 Earth days to rotate once on its axis, but only about 225 Earth
|
||||
days to orbit the Sun.
|
||||
[Run Finished - Thread: abc123, Run: def456]
|
||||
|
||||
User (:q or quit to exit): :q
|
||||
```
|
||||
|
||||
### Color-Coded Output
|
||||
|
||||
The client displays different content types with distinct colors:
|
||||
- **Yellow**: Run started notifications
|
||||
- **Cyan**: Agent text responses (streamed in real-time)
|
||||
- **Green**: Run completion notifications
|
||||
- **Red**: Error messages
|
||||
|
||||
## Testing with curl (Optional)
|
||||
|
||||
Before running the client, you can test the server manually using curl:
|
||||
|
||||
```bash
|
||||
curl -N http://127.0.0.1:5100/ \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Accept: text/event-stream" \
|
||||
-d '{
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is the capital of France?"}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
You should see Server-Sent Events streaming back:
|
||||
|
||||
```
|
||||
data: {"type":"RUN_STARTED","threadId":"...","runId":"..."}
|
||||
|
||||
data: {"type":"TEXT_MESSAGE_START","messageId":"...","role":"assistant"}
|
||||
|
||||
data: {"type":"TEXT_MESSAGE_CONTENT","messageId":"...","delta":"The"}
|
||||
|
||||
data: {"type":"TEXT_MESSAGE_CONTENT","messageId":"...","delta":" capital"}
|
||||
|
||||
...
|
||||
|
||||
data: {"type":"TEXT_MESSAGE_END","messageId":"..."}
|
||||
|
||||
data: {"type":"RUN_FINISHED","threadId":"...","runId":"..."}
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
### Server-Side Flow
|
||||
|
||||
1. Client sends HTTP POST request with messages
|
||||
2. FastAPI endpoint receives the request
|
||||
3. `AgentFrameworkAgent` wrapper orchestrates the execution
|
||||
4. Agent processes the messages using Agent Framework
|
||||
5. `AgentFrameworkEventBridge` converts agent updates to AG-UI events
|
||||
6. Responses are streamed back as Server-Sent Events (SSE)
|
||||
7. Connection closes when the run completes
|
||||
|
||||
### Client-Side Flow
|
||||
|
||||
1. Client sends HTTP POST request to server endpoint
|
||||
2. Server responds with SSE stream
|
||||
3. Client parses incoming `data:` lines as JSON events
|
||||
4. Each event is displayed based on its type
|
||||
5. `threadId` is captured for conversation continuity
|
||||
6. Stream completes when `RUN_FINISHED` event arrives
|
||||
|
||||
### Protocol Details
|
||||
|
||||
The AG-UI protocol uses:
|
||||
- **HTTP POST** for sending requests
|
||||
- **Server-Sent Events (SSE)** for streaming responses
|
||||
- **JSON** for event serialization
|
||||
- **Thread IDs** for maintaining conversation context
|
||||
- **Run IDs** for tracking individual executions
|
||||
- **Event type naming**: UPPERCASE with underscores (e.g., `RUN_STARTED`, `TEXT_MESSAGE_CONTENT`)
|
||||
- **Field naming**: camelCase (e.g., `threadId`, `runId`, `messageId`)
|
||||
|
||||
## Advanced Features
|
||||
|
||||
The Python AG-UI implementation supports all 7 AG-UI features:
|
||||
|
||||
### 1. Backend Tool Rendering
|
||||
|
||||
Add tools to your agent for backend execution:
|
||||
|
||||
```python
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import ChatAgent, ai_function
|
||||
from agent_framework.azure import AzureOpenAIChatClient
|
||||
|
||||
|
||||
@ai_function
|
||||
def get_weather(location: str) -> dict[str, Any]:
|
||||
"""Get weather for a location."""
|
||||
return {"temperature": 72, "conditions": "sunny"}
|
||||
|
||||
|
||||
agent = ChatAgent(
|
||||
name="weather_agent",
|
||||
instructions="Use tools to help users.",
|
||||
chat_client=AzureOpenAIChatClient(
|
||||
endpoint="https://your-resource.openai.azure.com/",
|
||||
deployment_name="gpt-4o-mini",
|
||||
),
|
||||
tools=[get_weather],
|
||||
)
|
||||
```
|
||||
|
||||
The client will receive `TOOL_CALL_START`, `TOOL_CALL_ARGS`, `TOOL_CALL_END`, and `TOOL_CALL_RESULT` events.
|
||||
|
||||
### 2. Human in the Loop
|
||||
|
||||
Request user confirmation before executing tools:
|
||||
|
||||
```python
|
||||
from fastapi import FastAPI
|
||||
from agent_framework import ChatAgent
|
||||
from agent_framework.azure import AzureOpenAIChatClient
|
||||
from agent_framework_ag_ui import AgentFrameworkAgent, add_agent_framework_fastapi_endpoint
|
||||
|
||||
agent = ChatAgent(
|
||||
name="my_agent",
|
||||
instructions="You are a helpful assistant.",
|
||||
chat_client=AzureOpenAIChatClient(
|
||||
endpoint="https://your-resource.openai.azure.com/",
|
||||
deployment_name="gpt-4o-mini",
|
||||
),
|
||||
)
|
||||
|
||||
wrapped_agent = AgentFrameworkAgent(
|
||||
agent=agent,
|
||||
require_confirmation=True, # Enable human-in-the-loop
|
||||
)
|
||||
|
||||
app = FastAPI()
|
||||
add_agent_framework_fastapi_endpoint(app, wrapped_agent, "/")
|
||||
```
|
||||
|
||||
The client receives tool approval request events and can send approval responses.
|
||||
|
||||
### 3. State Management
|
||||
|
||||
Share state between client and server:
|
||||
|
||||
```python
|
||||
wrapped_agent = AgentFrameworkAgent(
|
||||
agent=agent,
|
||||
state_schema={
|
||||
"location": {"type": "string"},
|
||||
"preferences": {"type": "object"},
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
Events include `STATE_SNAPSHOT` and `STATE_DELTA` for bidirectional sync.
|
||||
|
||||
### 4. Predictive State Updates
|
||||
|
||||
Stream tool arguments as optimistic state updates:
|
||||
|
||||
```python
|
||||
wrapped_agent = AgentFrameworkAgent(
|
||||
agent=agent,
|
||||
predict_state_config={
|
||||
"location": {"tool": "get_weather", "tool_argument": "location"}
|
||||
},
|
||||
require_confirmation=False, # Auto-update without confirmation
|
||||
)
|
||||
```
|
||||
|
||||
State updates stream in real-time as the LLM generates tool arguments.
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Custom Server Configuration
|
||||
|
||||
```python
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Add CORS for web clients
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
add_agent_framework_fastapi_endpoint(app, agent, "/agent")
|
||||
```
|
||||
|
||||
### Multiple Agents
|
||||
|
||||
```python
|
||||
app = FastAPI()
|
||||
|
||||
weather_agent = ChatAgent(name="weather", ...)
|
||||
finance_agent = ChatAgent(name="finance", ...)
|
||||
|
||||
add_agent_framework_fastapi_endpoint(app, weather_agent, "/weather")
|
||||
add_agent_framework_fastapi_endpoint(app, finance_agent, "/finance")
|
||||
```
|
||||
|
||||
### Custom Client Timeout
|
||||
|
||||
```python
|
||||
async with httpx.AsyncClient(timeout=300.0) as client:
|
||||
async with client.stream("POST", server_url, ...) as response:
|
||||
async for line in response.aiter_lines():
|
||||
# Process events
|
||||
pass
|
||||
```
|
||||
|
||||
### Error Handling
|
||||
|
||||
```python
|
||||
try:
|
||||
async for event in client.send_message(message):
|
||||
if event.get("type") == "RUN_ERROR":
|
||||
error_msg = event.get("message", "Unknown error")
|
||||
print(f"Error: {error_msg}")
|
||||
# Handle error appropriately
|
||||
except httpx.HTTPError as e:
|
||||
print(f"HTTP error: {e}")
|
||||
except Exception as e:
|
||||
print(f"Unexpected error: {e}")
|
||||
```
|
||||
|
||||
### Conversation Continuity
|
||||
|
||||
The client automatically maintains `threadId` across requests:
|
||||
|
||||
```python
|
||||
client = AGUIClient(server_url)
|
||||
|
||||
# First message
|
||||
async for event in client.send_message("Hello"):
|
||||
# Client captures threadId from RUN_STARTED
|
||||
pass
|
||||
|
||||
# Second message - uses same threadId
|
||||
async for event in client.send_message("Continue our conversation"):
|
||||
# Conversation context is maintained
|
||||
pass
|
||||
```
|
||||
|
||||
## AG-UI Event Reference
|
||||
|
||||
### Core Events
|
||||
|
||||
| Event Type | Description | Key Fields |
|
||||
|------------|-------------|------------|
|
||||
| `RUN_STARTED` | Agent execution started | `threadId`, `runId` |
|
||||
| `RUN_FINISHED` | Agent execution completed | `threadId`, `runId` |
|
||||
| `RUN_ERROR` | Agent execution error | `message` |
|
||||
|
||||
### Text Message Events
|
||||
|
||||
| Event Type | Description | Key Fields |
|
||||
|------------|-------------|------------|
|
||||
| `TEXT_MESSAGE_START` | Start of agent text message | `messageId`, `role` |
|
||||
| `TEXT_MESSAGE_CONTENT` | Streaming text content | `messageId`, `delta` |
|
||||
| `TEXT_MESSAGE_END` | End of agent text message | `messageId` |
|
||||
|
||||
### Tool Events
|
||||
|
||||
| Event Type | Description | Key Fields |
|
||||
|------------|-------------|------------|
|
||||
| `TOOL_CALL_START` | Tool call initiated | `toolCallId`, `toolCallName` |
|
||||
| `TOOL_CALL_ARGS` | Tool arguments streaming | `toolCallId`, `delta` |
|
||||
| `TOOL_CALL_END` | Tool call complete | `toolCallId` |
|
||||
| `TOOL_CALL_RESULT` | Tool execution result | `toolCallId`, `content` |
|
||||
|
||||
### State Events
|
||||
|
||||
| Event Type | Description | Key Fields |
|
||||
|------------|-------------|------------|
|
||||
| `STATE_SNAPSHOT` | Complete state | `snapshot` |
|
||||
| `STATE_DELTA` | State changes (JSON Patch) | `delta` |
|
||||
|
||||
### Other Events
|
||||
|
||||
| Event Type | Description | Key Fields |
|
||||
|------------|-------------|------------|
|
||||
| `MESSAGES_SNAPSHOT` | Conversation history | `messages` |
|
||||
| `CUSTOM` | Custom event data | `name`, `value` |
|
||||
|
||||
## Next Steps
|
||||
|
||||
Now that you understand the basics of AG-UI, you can:
|
||||
|
||||
- **Add Tools**: Create custom `@ai_function` tools for your domain
|
||||
- **Web Integration**: Build React/Vue frontends using the AG-UI protocol
|
||||
- **State Management**: Implement shared state for generative UI applications
|
||||
- **Human-in-the-Loop**: Add approval workflows for sensitive operations
|
||||
- **Deployment**: Deploy to Azure Container Apps or Azure App Service
|
||||
- **Multi-Agent Systems**: Coordinate multiple specialized agents
|
||||
- **Monitoring**: Add logging and OpenTelemetry for observability
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [AG-UI Examples](../agent_framework_ag_ui_examples/README.md): Complete working examples for all 7 features
|
||||
- [Agent Framework Documentation](../../core/README.md): Learn more about creating agents
|
||||
- [AG-UI Protocol Spec](https://docs.ag-ui.com/): Official protocol documentation
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Connection Refused
|
||||
|
||||
@@ -1,121 +1,71 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""AG-UI client example."""
|
||||
"""AG-UI client example using AGUIChatClient.
|
||||
|
||||
This example demonstrates how to use the AGUIChatClient to connect to
|
||||
a remote AG-UI server and interact with it using the Agent Framework's
|
||||
standard chat interface.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
class AGUIClient:
|
||||
"""Simple AG-UI protocol client."""
|
||||
|
||||
def __init__(self, server_url: str):
|
||||
"""Initialize the client.
|
||||
|
||||
Args:
|
||||
server_url: The AG-UI server endpoint URL
|
||||
"""
|
||||
self.server_url = server_url
|
||||
self.thread_id: str | None = None
|
||||
|
||||
async def send_message(self, message: str) -> AsyncIterator[dict]:
|
||||
"""Send a message and stream the response.
|
||||
|
||||
Args:
|
||||
message: The user message to send
|
||||
|
||||
Yields:
|
||||
AG-UI events from the server
|
||||
"""
|
||||
# Prepare the request
|
||||
request_data: dict[str, object] = {
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": message},
|
||||
]
|
||||
}
|
||||
|
||||
# Include thread_id if we have one (for conversation continuity)
|
||||
if self.thread_id:
|
||||
request_data["thread_id"] = self.thread_id
|
||||
|
||||
# Stream the response
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
self.server_url,
|
||||
json=request_data,
|
||||
headers={"Accept": "text/event-stream"},
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
# Parse Server-Sent Events format
|
||||
if line.startswith("data: "):
|
||||
data = line[6:] # Remove "data: " prefix
|
||||
try:
|
||||
event = json.loads(data)
|
||||
yield event
|
||||
|
||||
# Capture thread_id from RUN_STARTED event
|
||||
if event.get("type") == "RUN_STARTED" and not self.thread_id:
|
||||
self.thread_id = event.get("threadId")
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
from agent_framework_ag_ui import AGUIChatClient
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main client loop."""
|
||||
"""Main client loop demonstrating AGUIChatClient usage."""
|
||||
# Get server URL from environment or use default
|
||||
server_url = os.environ.get("AGUI_SERVER_URL", "http://127.0.0.1:5100/")
|
||||
print(f"Connecting to AG-UI server at: {server_url}\n")
|
||||
print("Using AGUIChatClient with automatic thread management and Agent Framework integration.\n")
|
||||
|
||||
client = AGUIClient(server_url)
|
||||
# Create client with context manager for automatic cleanup
|
||||
async with AGUIChatClient(endpoint=server_url) as client:
|
||||
thread_id: str | None = None
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Get user input
|
||||
message = input("\nUser (:q or quit to exit): ")
|
||||
if not message.strip():
|
||||
print("Request cannot be empty.")
|
||||
continue
|
||||
try:
|
||||
while True:
|
||||
# Get user input
|
||||
message = input("\nUser (:q or quit to exit): ")
|
||||
if not message.strip():
|
||||
print("Request cannot be empty.")
|
||||
continue
|
||||
|
||||
if message.lower() in (":q", "quit"):
|
||||
break
|
||||
if message.lower() in (":q", "quit"):
|
||||
break
|
||||
|
||||
# Send message and display streaming response
|
||||
print("\n", end="")
|
||||
async for event in client.send_message(message):
|
||||
event_type = event.get("type", "")
|
||||
# Send message and stream the response
|
||||
print("\nAssistant: ", end="", flush=True)
|
||||
|
||||
if event_type == "RUN_STARTED":
|
||||
thread_id = event.get("threadId", "")
|
||||
run_id = event.get("runId", "")
|
||||
print(f"\033[93m[Run Started - Thread: {thread_id}, Run: {run_id}]\033[0m")
|
||||
# Use metadata to maintain conversation continuity
|
||||
metadata = {"thread_id": thread_id} if thread_id else None
|
||||
|
||||
elif event_type == "TEXT_MESSAGE_CONTENT":
|
||||
# Stream text content in cyan
|
||||
print(f"\033[96m{event.get('delta', '')}\033[0m", end="", flush=True)
|
||||
async for update in client.get_streaming_response(message, metadata=metadata):
|
||||
# Extract and display thread ID from first update
|
||||
if not thread_id and update.additional_properties:
|
||||
thread_id = update.additional_properties.get("thread_id")
|
||||
if thread_id:
|
||||
print(f"\n\033[93m[Thread: {thread_id}]\033[0m", end="", flush=True)
|
||||
print("\nAssistant: ", end="", flush=True)
|
||||
|
||||
elif event_type == "RUN_FINISHED":
|
||||
thread_id = event.get("threadId", "")
|
||||
run_id = event.get("runId", "")
|
||||
print(f"\n\033[92m[Run Finished - Thread: {thread_id}, Run: {run_id}]\033[0m")
|
||||
# Display text content as it streams
|
||||
from agent_framework import TextContent
|
||||
|
||||
elif event_type == "RUN_ERROR":
|
||||
error_message = event.get("message", "Unknown error")
|
||||
print(f"\n\033[91m[Run Error - Message: {error_message}]\033[0m")
|
||||
for content in update.contents:
|
||||
if isinstance(content, TextContent) and content.text:
|
||||
print(f"\033[96m{content.text}\033[0m", end="", flush=True)
|
||||
|
||||
print()
|
||||
# Display finish reason if present
|
||||
if update.finish_reason:
|
||||
print(f"\n\033[92m[Finished: {update.finish_reason}]\033[0m", end="", flush=True)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nExiting...")
|
||||
except Exception as e:
|
||||
print(f"\n\033[91mAn error occurred: {e}\033[0m")
|
||||
print() # New line after response
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nExiting...")
|
||||
except Exception as e:
|
||||
print(f"\n\033[91mAn error occurred: {e}\033[0m")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -0,0 +1,235 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Advanced AG-UI client example with tools and features.
|
||||
|
||||
This example demonstrates advanced AGUIChatClient features including:
|
||||
- Tool/function calling
|
||||
- Non-streaming responses
|
||||
- Multiple conversation turns
|
||||
- Error handling
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from agent_framework import ai_function
|
||||
|
||||
from agent_framework_ag_ui import AGUIChatClient
|
||||
|
||||
|
||||
@ai_function
|
||||
def get_weather(location: str) -> str:
|
||||
"""Get the current weather for a location.
|
||||
|
||||
Args:
|
||||
location: The city or location name
|
||||
"""
|
||||
# Simulate weather lookup
|
||||
weather_data = {
|
||||
"seattle": "Rainy, 55°F",
|
||||
"san francisco": "Foggy, 62°F",
|
||||
"new york": "Sunny, 68°F",
|
||||
"london": "Cloudy, 52°F",
|
||||
}
|
||||
return weather_data.get(location.lower(), f"Weather data not available for {location}")
|
||||
|
||||
|
||||
@ai_function
|
||||
def calculate(a: float, b: float, operation: str) -> str:
|
||||
"""Perform basic arithmetic operations.
|
||||
|
||||
Args:
|
||||
a: First number
|
||||
b: Second number
|
||||
operation: Operation to perform (add, subtract, multiply, divide)
|
||||
"""
|
||||
try:
|
||||
if operation == "add":
|
||||
result = a + b
|
||||
elif operation == "subtract":
|
||||
result = a - b
|
||||
elif operation == "multiply":
|
||||
result = a * b
|
||||
elif operation == "divide":
|
||||
result = a / b
|
||||
else:
|
||||
return f"Unsupported operation: {operation}"
|
||||
return f"The result is: {result}"
|
||||
except Exception as e:
|
||||
return f"Error calculating: {e}"
|
||||
|
||||
|
||||
async def streaming_example(client: AGUIChatClient, thread_id: str | None = None):
|
||||
"""Demonstrate streaming responses."""
|
||||
print("\n" + "=" * 60)
|
||||
print("STREAMING EXAMPLE")
|
||||
print("=" * 60)
|
||||
|
||||
metadata = {"thread_id": thread_id} if thread_id else None
|
||||
|
||||
print("\nUser: Tell me a short joke\n")
|
||||
print("Assistant: ", end="", flush=True)
|
||||
|
||||
async for update in client.get_streaming_response("Tell me a short joke", metadata=metadata):
|
||||
if not thread_id and update.additional_properties:
|
||||
thread_id = update.additional_properties.get("thread_id")
|
||||
|
||||
from agent_framework import TextContent
|
||||
|
||||
for content in update.contents:
|
||||
if isinstance(content, TextContent) and content.text:
|
||||
print(content.text, end="", flush=True)
|
||||
|
||||
print("\n")
|
||||
return thread_id
|
||||
|
||||
|
||||
async def non_streaming_example(client: AGUIChatClient, thread_id: str | None = None):
|
||||
"""Demonstrate non-streaming responses."""
|
||||
print("\n" + "=" * 60)
|
||||
print("NON-STREAMING EXAMPLE")
|
||||
print("=" * 60)
|
||||
|
||||
metadata = {"thread_id": thread_id} if thread_id else None
|
||||
|
||||
print("\nUser: What is 2 + 2?\n")
|
||||
|
||||
response = await client.get_response("What is 2 + 2?", metadata=metadata)
|
||||
|
||||
print(f"Assistant: {response.text}")
|
||||
|
||||
if response.additional_properties:
|
||||
thread_id = response.additional_properties.get("thread_id")
|
||||
print(f"\n[Thread: {thread_id}]")
|
||||
|
||||
return thread_id
|
||||
|
||||
|
||||
async def tool_example(client: AGUIChatClient, thread_id: str | None = None):
|
||||
"""Demonstrate sending tool definitions to the server.
|
||||
|
||||
IMPORTANT: When using AGUIChatClient directly (without ChatAgent wrapper):
|
||||
- Tools are sent as DEFINITIONS only
|
||||
- No automatic client-side execution (no function invocation middleware)
|
||||
- Server must have matching tool implementations to execute them
|
||||
|
||||
For CLIENT-SIDE tool execution (like .NET AGUIClient sample):
|
||||
- Use ChatAgent wrapper with tools
|
||||
- See client_with_agent.py for the hybrid pattern
|
||||
- ChatAgent middleware intercepts and executes client tools locally
|
||||
- Server can have its own tools that execute server-side
|
||||
- Both client and server tools work together in same conversation
|
||||
|
||||
This example sends tool definitions and assumes server-side execution.
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("TOOL DEFINITION EXAMPLE")
|
||||
print("=" * 60)
|
||||
|
||||
metadata = {"thread_id": thread_id} if thread_id else None
|
||||
|
||||
print("\nUser: What's the weather in Seattle?\n")
|
||||
print("Sending tool definitions to server...")
|
||||
print("(Server must be configured with matching tools to execute them)\n")
|
||||
|
||||
response = await client.get_response(
|
||||
"What's the weather in Seattle?", tools=[get_weather, calculate], metadata=metadata
|
||||
)
|
||||
|
||||
print(f"Assistant: {response.text}")
|
||||
|
||||
# Show tool calls if any
|
||||
from agent_framework import FunctionCallContent
|
||||
|
||||
tool_called = False
|
||||
for message in response.messages:
|
||||
for content in message.contents:
|
||||
if isinstance(content, FunctionCallContent):
|
||||
print(f"\n[Tool Called: {content.name}]")
|
||||
tool_called = True
|
||||
|
||||
if not tool_called:
|
||||
print("\n[Note: No tools were called - server may not be configured for tool execution]")
|
||||
|
||||
if response.additional_properties:
|
||||
thread_id = response.additional_properties.get("thread_id")
|
||||
|
||||
return thread_id
|
||||
|
||||
|
||||
async def conversation_example(client: AGUIChatClient):
|
||||
"""Demonstrate multi-turn conversation.
|
||||
|
||||
Note: Conversation continuity depends on the server maintaining thread state.
|
||||
Some servers may require explicit message history to be sent with each request.
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("MULTI-TURN CONVERSATION EXAMPLE")
|
||||
print("=" * 60)
|
||||
print("\nNote: This example uses thread_id for context. Server must support thread-based state.\n")
|
||||
|
||||
# First turn
|
||||
print("User: My name is Alice\n")
|
||||
response1 = await client.get_response("My name is Alice")
|
||||
print(f"Assistant: {response1.text}")
|
||||
thread_id = response1.additional_properties.get("thread_id")
|
||||
print(f"\n[Thread: {thread_id}]")
|
||||
|
||||
# Second turn - using same thread
|
||||
print("\nUser: What's my name?\n")
|
||||
response2 = await client.get_response("What's my name?", metadata={"thread_id": thread_id})
|
||||
print(f"Assistant: {response2.text}")
|
||||
|
||||
# Check if context was maintained
|
||||
if "alice" not in response2.text.lower():
|
||||
print("\n[Note: Server may not maintain thread context - consider using ChatAgent for history management]")
|
||||
|
||||
# Third turn
|
||||
print("\nUser: Can you also tell me what 10 * 5 is?\n")
|
||||
response3 = await client.get_response(
|
||||
"Can you also tell me what 10 * 5 is?", metadata={"thread_id": thread_id}, tools=[calculate]
|
||||
)
|
||||
print(f"Assistant: {response3.text}")
|
||||
|
||||
|
||||
async def main():
|
||||
"""Run all examples."""
|
||||
# Get server URL from environment or use default
|
||||
server_url = os.environ.get("AGUI_SERVER_URL", "http://127.0.0.1:5100/")
|
||||
|
||||
print("=" * 60)
|
||||
print("AG-UI Chat Client Advanced Examples")
|
||||
print("=" * 60)
|
||||
print(f"\nServer: {server_url}")
|
||||
print("\nThese examples demonstrate various AGUIChatClient features:")
|
||||
print(" 1. Streaming responses")
|
||||
print(" 2. Non-streaming responses")
|
||||
print(" 3. Tool/function calling")
|
||||
print(" 4. Multi-turn conversations")
|
||||
|
||||
try:
|
||||
async with AGUIChatClient(endpoint=server_url) as client:
|
||||
# Run examples in sequence
|
||||
thread_id = await streaming_example(client)
|
||||
thread_id = await non_streaming_example(client, thread_id)
|
||||
await tool_example(client, thread_id)
|
||||
|
||||
# Separate conversation with new thread
|
||||
await conversation_example(client)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("All examples completed successfully!")
|
||||
print("=" * 60)
|
||||
|
||||
except ConnectionError as e:
|
||||
print(f"\n\033[91mConnection Error: {e}\033[0m")
|
||||
print("\nMake sure an AG-UI server is running at the specified endpoint.")
|
||||
except Exception as e:
|
||||
print(f"\n\033[91mError: {e}\033[0m")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,186 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Example showing ChatAgent with AGUIChatClient for hybrid tool execution.
|
||||
|
||||
This demonstrates the HYBRID pattern matching .NET AGUIClient implementation:
|
||||
|
||||
1. AgentThread Pattern (like .NET):
|
||||
- Create thread with agent.get_new_thread()
|
||||
- Pass thread to agent.run_stream() on each turn
|
||||
- Thread automatically maintains conversation history via message_store
|
||||
|
||||
2. Hybrid Tool Execution:
|
||||
- AGUIChatClient has @use_function_invocation decorator
|
||||
- Client-side tools (get_weather) can execute locally when server requests them
|
||||
- Server may also have its own tools that execute server-side
|
||||
- Both work together: server LLM decides which tool to call, decorator handles client execution
|
||||
|
||||
This matches .NET pattern: thread maintains state, tools execute on appropriate side.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
|
||||
from agent_framework import ChatAgent, FunctionCallContent, FunctionResultContent, TextContent, ai_function
|
||||
|
||||
from agent_framework_ag_ui import AGUIChatClient
|
||||
|
||||
# Enable debug logging
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ai_function(description="Get the current weather for a location.")
|
||||
def get_weather(location: str) -> str:
|
||||
"""Get the current weather for a location.
|
||||
|
||||
Args:
|
||||
location: The city or location name
|
||||
"""
|
||||
print(f"[CLIENT] get_weather tool called with location: {location}")
|
||||
weather_data = {
|
||||
"seattle": "Rainy, 55°F",
|
||||
"san francisco": "Foggy, 62°F",
|
||||
"new york": "Sunny, 68°F",
|
||||
"london": "Cloudy, 52°F",
|
||||
}
|
||||
result = weather_data.get(location.lower(), f"Weather data not available for {location}")
|
||||
print(f"[CLIENT] get_weather returning: {result}")
|
||||
return result
|
||||
|
||||
|
||||
async def main():
|
||||
"""Demonstrate ChatAgent + AGUIChatClient hybrid tool execution.
|
||||
|
||||
This matches the .NET pattern from Program.cs where:
|
||||
- AIAgent agent = chatClient.CreateAIAgent(tools: [...])
|
||||
- AgentThread thread = agent.GetNewThread()
|
||||
- RunStreamingAsync(messages, thread)
|
||||
|
||||
Python equivalent:
|
||||
- agent = ChatAgent(chat_client=AGUIChatClient(...), tools=[...])
|
||||
- thread = agent.get_new_thread() # Creates thread with message_store
|
||||
- agent.run_stream(message, thread=thread) # Thread accumulates history
|
||||
"""
|
||||
server_url = os.environ.get("AGUI_SERVER_URL", "http://127.0.0.1:5100/")
|
||||
|
||||
print("=" * 70)
|
||||
print("ChatAgent + AGUIChatClient: Hybrid Tool Execution")
|
||||
print("=" * 70)
|
||||
print(f"\nServer: {server_url}")
|
||||
print("\nThis example demonstrates:")
|
||||
print(" 1. AgentThread maintains conversation state (like .NET)")
|
||||
print(" 2. Client-side tools execute locally via @use_function_invocation")
|
||||
print(" 3. Server may have additional tools that execute server-side")
|
||||
print(" 4. HYBRID: Client and server tools work together simultaneously\n")
|
||||
|
||||
try:
|
||||
# Create remote client in async context manager
|
||||
async with AGUIChatClient(endpoint=server_url) as remote_client:
|
||||
# Wrap in ChatAgent for conversation history management
|
||||
agent = ChatAgent(
|
||||
name="remote_assistant",
|
||||
instructions="You are a helpful assistant. Remember user information across the conversation.",
|
||||
chat_client=remote_client,
|
||||
tools=[get_weather],
|
||||
)
|
||||
|
||||
# Create a thread to maintain conversation state (like .NET AgentThread)
|
||||
thread = agent.get_new_thread()
|
||||
|
||||
print("=" * 70)
|
||||
print("CONVERSATION WITH HISTORY")
|
||||
print("=" * 70)
|
||||
|
||||
# Turn 1: Introduce
|
||||
print("\nUser: My name is Alice and I live in Seattle\n")
|
||||
async for chunk in agent.run_stream("My name is Alice and I live in Seattle", thread=thread):
|
||||
if chunk.text:
|
||||
print(chunk.text, end="", flush=True)
|
||||
print("\n")
|
||||
|
||||
# Turn 2: Ask about name (tests history)
|
||||
print("User: What's my name?\n")
|
||||
async for chunk in agent.run_stream("What's my name?", thread=thread):
|
||||
if chunk.text:
|
||||
print(chunk.text, end="", flush=True)
|
||||
print("\n")
|
||||
|
||||
# Turn 3: Ask about location (tests history)
|
||||
print("User: Where do I live?\n")
|
||||
async for chunk in agent.run_stream("Where do I live?", thread=thread):
|
||||
if chunk.text:
|
||||
print(chunk.text, end="", flush=True)
|
||||
print("\n")
|
||||
|
||||
# Turn 4: Test client-side tool (get_weather is client-side)
|
||||
print("User: What's the weather forecast for today in Seattle?\n")
|
||||
async for chunk in agent.run_stream("What's the weather forecast for today in Seattle?", thread=thread):
|
||||
if chunk.text:
|
||||
print(chunk.text, end="", flush=True)
|
||||
print("\n")
|
||||
|
||||
# Turn 5: Test server-side tool (get_time_zone is server-side only)
|
||||
print("User: What time zone is Seattle in?\n")
|
||||
async for chunk in agent.run_stream("What time zone is Seattle in?", thread=thread):
|
||||
if chunk.text:
|
||||
print(chunk.text, end="", flush=True)
|
||||
print("\n")
|
||||
|
||||
# Show thread state
|
||||
if thread.message_store:
|
||||
|
||||
def _preview_for_message(m) -> str:
|
||||
# Prefer plain text when present
|
||||
if getattr(m, "text", ""):
|
||||
t = m.text
|
||||
return (t[:60] + "...") if len(t) > 60 else t
|
||||
# Build from contents when no direct text
|
||||
parts: list[str] = []
|
||||
for c in getattr(m, "contents", []) or []:
|
||||
if isinstance(c, FunctionCallContent):
|
||||
args = c.arguments
|
||||
if isinstance(args, dict):
|
||||
try:
|
||||
import json as _json
|
||||
|
||||
args_str = _json.dumps(args)
|
||||
except Exception:
|
||||
args_str = str(args)
|
||||
else:
|
||||
args_str = str(args or "{}")
|
||||
parts.append(f"tool_call {c.name} {args_str}")
|
||||
elif isinstance(c, FunctionResultContent):
|
||||
parts.append(f"tool_result[{c.call_id}]: {str(c.result)[:40]}")
|
||||
elif isinstance(c, TextContent):
|
||||
if c.text:
|
||||
parts.append(c.text)
|
||||
else:
|
||||
typename = getattr(c, "type", c.__class__.__name__)
|
||||
parts.append(f"<{typename}>")
|
||||
preview = " | ".join(parts) if parts else ""
|
||||
return (preview[:60] + "...") if len(preview) > 60 else preview
|
||||
|
||||
messages = await thread.message_store.list_messages()
|
||||
print(f"\n[THREAD STATE] {len(messages)} messages in thread's message_store")
|
||||
for i, msg in enumerate(messages[-6:], 1): # Show last 6
|
||||
role = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
|
||||
text_preview = _preview_for_message(msg)
|
||||
print(f" {i}. [{role}]: {text_preview}")
|
||||
|
||||
except ConnectionError as e:
|
||||
print(f"\n\033[91mConnection Error: {e}\033[0m")
|
||||
print("\nMake sure an AG-UI server is running at the specified endpoint.")
|
||||
except Exception as e:
|
||||
print(f"\n\033[91mError: {e}\033[0m")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,18 +1,26 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""AG-UI server example."""
|
||||
"""AG-UI server example with server-side tools."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from agent_framework import ChatAgent
|
||||
from agent_framework import ChatAgent, ai_function
|
||||
from agent_framework.ag_ui import add_agent_framework_fastapi_endpoint
|
||||
from agent_framework.azure import AzureOpenAIChatClient
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import FastAPI
|
||||
|
||||
from agent_framework_ag_ui import add_agent_framework_fastapi_endpoint
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Enable debug logging
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Read required configuration
|
||||
endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
|
||||
deployment_name = os.environ.get("AZURE_OPENAI_CHAT_DEPLOYMENT_NAME")
|
||||
@@ -22,14 +30,43 @@ if not endpoint:
|
||||
if not deployment_name:
|
||||
raise ValueError("AZURE_OPENAI_CHAT_DEPLOYMENT_NAME environment variable is required")
|
||||
|
||||
# Create the AI agent
|
||||
|
||||
# Server-side tool (executes on server)
|
||||
@ai_function(description="Get the time zone for a location.")
|
||||
def get_time_zone(location: str) -> str:
|
||||
"""Get the time zone for a location.
|
||||
|
||||
Args:
|
||||
location: The city or location name
|
||||
"""
|
||||
print(f"[SERVER] get_time_zone tool called with location: {location}")
|
||||
timezone_data = {
|
||||
"seattle": "Pacific Time (UTC-8)",
|
||||
"san francisco": "Pacific Time (UTC-8)",
|
||||
"new york": "Eastern Time (UTC-5)",
|
||||
"london": "Greenwich Mean Time (UTC+0)",
|
||||
}
|
||||
result = timezone_data.get(location.lower(), f"Time zone data not available for {location}")
|
||||
print(f"[SERVER] get_time_zone returning: {result}")
|
||||
return result
|
||||
|
||||
|
||||
# Create the AI agent with ONLY server-side tools
|
||||
# IMPORTANT: Do NOT include tools that the client provides!
|
||||
# In this example:
|
||||
# - get_time_zone: SERVER-ONLY tool (only server has this)
|
||||
# - get_weather: CLIENT-ONLY tool (client provides this, server should NOT include it)
|
||||
# The client will send get_weather tool metadata so the LLM knows about it,
|
||||
# and @use_function_invocation on AGUIChatClient will execute it client-side.
|
||||
# This matches the .NET AG-UI hybrid execution pattern.
|
||||
agent = ChatAgent(
|
||||
name="AGUIAssistant",
|
||||
instructions="You are a helpful assistant.",
|
||||
instructions="You are a helpful assistant. Use get_weather for weather and get_time_zone for time zones.",
|
||||
chat_client=AzureOpenAIChatClient(
|
||||
endpoint=endpoint,
|
||||
deployment_name=deployment_name,
|
||||
),
|
||||
tools=[get_time_zone], # ONLY server-side tools
|
||||
)
|
||||
|
||||
# Create FastAPI app
|
||||
@@ -41,4 +78,4 @@ add_agent_framework_fastapi_endpoint(app, agent, "/")
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="127.0.0.1", port=5100)
|
||||
uvicorn.run(app, host="127.0.0.1", port=5100, log_level="debug", access_log=True)
|
||||
|
||||
@@ -0,0 +1,317 @@
|
||||
"""Tests for AGUIChatClient."""
|
||||
|
||||
import json
|
||||
|
||||
from agent_framework import ChatMessage, ChatOptions, FunctionCallContent, Role, ai_function
|
||||
|
||||
from agent_framework_ag_ui._client import AGUIChatClient, ServerFunctionCallContent
|
||||
|
||||
|
||||
class TestAGUIChatClient:
|
||||
"""Test suite for AGUIChatClient."""
|
||||
|
||||
async def test_client_initialization(self) -> None:
|
||||
"""Test client initialization."""
|
||||
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
||||
|
||||
assert client._http_service is not None
|
||||
assert client._http_service.endpoint.startswith("http://localhost:8888")
|
||||
|
||||
async def test_client_context_manager(self) -> None:
|
||||
"""Test client as async context manager."""
|
||||
async with AGUIChatClient(endpoint="http://localhost:8888/") as client:
|
||||
assert client is not None
|
||||
|
||||
async def test_extract_state_from_messages_no_state(self) -> None:
|
||||
"""Test state extraction when no state is present."""
|
||||
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
||||
messages = [
|
||||
ChatMessage(role="user", text="Hello"),
|
||||
ChatMessage(role="assistant", text="Hi there"),
|
||||
]
|
||||
|
||||
result_messages, state = client._extract_state_from_messages(messages)
|
||||
|
||||
assert result_messages == messages
|
||||
assert state is None
|
||||
|
||||
async def test_extract_state_from_messages_with_state(self) -> None:
|
||||
"""Test state extraction from last message."""
|
||||
import base64
|
||||
|
||||
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
||||
|
||||
state_data = {"key": "value", "count": 42}
|
||||
state_json = json.dumps(state_data)
|
||||
state_b64 = base64.b64encode(state_json.encode("utf-8")).decode("utf-8")
|
||||
|
||||
from agent_framework import DataContent
|
||||
|
||||
messages = [
|
||||
ChatMessage(role="user", text="Hello"),
|
||||
ChatMessage(
|
||||
role="user",
|
||||
contents=[DataContent(uri=f"data:application/json;base64,{state_b64}")],
|
||||
),
|
||||
]
|
||||
|
||||
result_messages, state = client._extract_state_from_messages(messages)
|
||||
|
||||
assert len(result_messages) == 1
|
||||
assert result_messages[0].text == "Hello"
|
||||
assert state == state_data
|
||||
|
||||
async def test_extract_state_invalid_json(self) -> None:
|
||||
"""Test state extraction with invalid JSON."""
|
||||
import base64
|
||||
|
||||
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
||||
|
||||
invalid_json = "not valid json"
|
||||
state_b64 = base64.b64encode(invalid_json.encode("utf-8")).decode("utf-8")
|
||||
|
||||
from agent_framework import DataContent
|
||||
|
||||
messages = [
|
||||
ChatMessage(
|
||||
role="user",
|
||||
contents=[DataContent(uri=f"data:application/json;base64,{state_b64}")],
|
||||
),
|
||||
]
|
||||
|
||||
result_messages, state = client._extract_state_from_messages(messages)
|
||||
|
||||
assert result_messages == messages
|
||||
assert state is None
|
||||
|
||||
async def test_convert_messages_to_agui_format(self) -> None:
|
||||
"""Test message conversion to AG-UI format."""
|
||||
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
||||
messages = [
|
||||
ChatMessage(role=Role.USER, text="What is the weather?"),
|
||||
ChatMessage(role=Role.ASSISTANT, text="Let me check.", message_id="msg_123"),
|
||||
]
|
||||
|
||||
agui_messages = client._convert_messages_to_agui_format(messages)
|
||||
|
||||
assert len(agui_messages) == 2
|
||||
assert agui_messages[0]["role"] == "user"
|
||||
assert agui_messages[0]["content"] == "What is the weather?"
|
||||
assert agui_messages[1]["role"] == "assistant"
|
||||
assert agui_messages[1]["content"] == "Let me check."
|
||||
assert agui_messages[1]["id"] == "msg_123"
|
||||
|
||||
async def test_get_thread_id_from_metadata(self) -> None:
|
||||
"""Test thread ID extraction from metadata."""
|
||||
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
||||
chat_options = ChatOptions(metadata={"thread_id": "existing_thread_123"})
|
||||
|
||||
thread_id = client._get_thread_id(chat_options)
|
||||
|
||||
assert thread_id == "existing_thread_123"
|
||||
|
||||
async def test_get_thread_id_generation(self) -> None:
|
||||
"""Test automatic thread ID generation."""
|
||||
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
||||
chat_options = ChatOptions()
|
||||
|
||||
thread_id = client._get_thread_id(chat_options)
|
||||
|
||||
assert thread_id.startswith("thread_")
|
||||
assert len(thread_id) > 7
|
||||
|
||||
async def test_get_streaming_response(self, monkeypatch) -> None:
|
||||
"""Test streaming response method."""
|
||||
mock_events = [
|
||||
{"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"},
|
||||
{"type": "TEXT_MESSAGE_CONTENT", "messageId": "msg_1", "delta": "Hello"},
|
||||
{"type": "TEXT_MESSAGE_CONTENT", "messageId": "msg_1", "delta": " world"},
|
||||
{"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"},
|
||||
]
|
||||
|
||||
async def mock_post_run(*args, **kwargs):
|
||||
for event in mock_events:
|
||||
yield event
|
||||
|
||||
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
||||
monkeypatch.setattr(client._http_service, "post_run", mock_post_run)
|
||||
|
||||
messages = [ChatMessage(role="user", text="Test message")]
|
||||
chat_options = ChatOptions()
|
||||
|
||||
updates = []
|
||||
async for update in client._inner_get_streaming_response(messages=messages, chat_options=chat_options):
|
||||
updates.append(update)
|
||||
|
||||
assert len(updates) == 4
|
||||
assert updates[0].additional_properties["thread_id"] == "thread_1"
|
||||
assert updates[1].contents[0].text == "Hello"
|
||||
assert updates[2].contents[0].text == " world"
|
||||
|
||||
async def test_get_response_non_streaming(self, monkeypatch) -> None:
|
||||
"""Test non-streaming response method."""
|
||||
mock_events = [
|
||||
{"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"},
|
||||
{"type": "TEXT_MESSAGE_CONTENT", "messageId": "msg_1", "delta": "Complete response"},
|
||||
{"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"},
|
||||
]
|
||||
|
||||
async def mock_post_run(*args, **kwargs):
|
||||
for event in mock_events:
|
||||
yield event
|
||||
|
||||
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
||||
monkeypatch.setattr(client._http_service, "post_run", mock_post_run)
|
||||
|
||||
messages = [ChatMessage(role="user", text="Test message")]
|
||||
chat_options = ChatOptions()
|
||||
|
||||
response = await client._inner_get_response(messages=messages, chat_options=chat_options)
|
||||
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
assert "Complete response" in response.text
|
||||
|
||||
async def test_tool_handling(self, monkeypatch) -> None:
|
||||
"""Test that client tool metadata is sent to server.
|
||||
|
||||
Client tool metadata (name, description, schema) is sent to server for planning.
|
||||
When server requests a client function, @use_function_invocation decorator
|
||||
intercepts and executes it locally. This matches .NET AG-UI implementation.
|
||||
"""
|
||||
from agent_framework import ai_function
|
||||
|
||||
@ai_function
|
||||
def test_tool(param: str) -> str:
|
||||
"""Test tool."""
|
||||
return "result"
|
||||
|
||||
mock_events = [
|
||||
{"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"},
|
||||
{"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"},
|
||||
]
|
||||
|
||||
async def mock_post_run(*args, **kwargs):
|
||||
# Client tool metadata should be sent to server
|
||||
tools = kwargs.get("tools")
|
||||
assert tools is not None
|
||||
assert len(tools) == 1
|
||||
assert tools[0]["name"] == "test_tool"
|
||||
assert tools[0]["description"] == "Test tool."
|
||||
assert "parameters" in tools[0]
|
||||
for event in mock_events:
|
||||
yield event
|
||||
|
||||
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
||||
monkeypatch.setattr(client._http_service, "post_run", mock_post_run)
|
||||
|
||||
messages = [ChatMessage(role="user", text="Test with tools")]
|
||||
chat_options = ChatOptions(tools=[test_tool])
|
||||
|
||||
response = await client._inner_get_response(messages=messages, chat_options=chat_options)
|
||||
|
||||
assert response is not None
|
||||
|
||||
async def test_server_tool_calls_unwrapped_after_invocation(self, monkeypatch) -> None:
|
||||
"""Ensure server-side tool calls are exposed as FunctionCallContent after processing."""
|
||||
|
||||
mock_events = [
|
||||
{"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"},
|
||||
{"type": "TOOL_CALL_START", "toolCallId": "call_1", "toolName": "get_time_zone"},
|
||||
{"type": "TOOL_CALL_ARGS", "toolCallId": "call_1", "delta": '{"location": "Seattle"}'},
|
||||
{"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"},
|
||||
]
|
||||
|
||||
async def mock_post_run(*args, **kwargs):
|
||||
for event in mock_events:
|
||||
yield event
|
||||
|
||||
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
||||
monkeypatch.setattr(client._http_service, "post_run", mock_post_run)
|
||||
|
||||
messages = [ChatMessage(role="user", text="Test server tool execution")]
|
||||
chat_options = ChatOptions()
|
||||
|
||||
updates = []
|
||||
async for update in client.get_streaming_response(messages, chat_options=chat_options):
|
||||
updates.append(update)
|
||||
|
||||
function_calls = [
|
||||
content for update in updates for content in update.contents if isinstance(content, FunctionCallContent)
|
||||
]
|
||||
assert function_calls
|
||||
assert function_calls[0].name == "get_time_zone"
|
||||
assert not any(
|
||||
isinstance(content, ServerFunctionCallContent) for update in updates for content in update.contents
|
||||
)
|
||||
|
||||
async def test_server_tool_calls_not_executed_locally(self, monkeypatch) -> None:
|
||||
"""Server tools should not trigger local function invocation even when client tools exist."""
|
||||
|
||||
@ai_function
|
||||
def client_tool() -> str:
|
||||
"""Client tool stub."""
|
||||
return "client"
|
||||
|
||||
mock_events = [
|
||||
{"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"},
|
||||
{"type": "TOOL_CALL_START", "toolCallId": "call_1", "toolName": "get_time_zone"},
|
||||
{"type": "TOOL_CALL_ARGS", "toolCallId": "call_1", "delta": '{"location": "Seattle"}'},
|
||||
{"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"},
|
||||
]
|
||||
|
||||
async def mock_post_run(*args, **kwargs):
|
||||
for event in mock_events:
|
||||
yield event
|
||||
|
||||
async def fake_auto_invoke(*args, **kwargs):
|
||||
function_call = kwargs.get("function_call_content") or args[0]
|
||||
raise AssertionError(f"Unexpected local execution of server tool: {getattr(function_call, 'name', '?')}")
|
||||
|
||||
monkeypatch.setattr("agent_framework._tools._auto_invoke_function", fake_auto_invoke)
|
||||
|
||||
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
||||
monkeypatch.setattr(client._http_service, "post_run", mock_post_run)
|
||||
|
||||
messages = [ChatMessage(role="user", text="Test server tool execution")]
|
||||
chat_options = ChatOptions(tool_choice="auto", tools=[client_tool])
|
||||
|
||||
async for _ in client.get_streaming_response(messages, chat_options=chat_options):
|
||||
pass
|
||||
|
||||
async def test_state_transmission(self, monkeypatch) -> None:
|
||||
"""Test state is properly transmitted to server."""
|
||||
import base64
|
||||
|
||||
state_data = {"user_id": "123", "session": "abc"}
|
||||
state_json = json.dumps(state_data)
|
||||
state_b64 = base64.b64encode(state_json.encode("utf-8")).decode("utf-8")
|
||||
|
||||
from agent_framework import DataContent
|
||||
|
||||
messages = [
|
||||
ChatMessage(role="user", text="Hello"),
|
||||
ChatMessage(
|
||||
role="user",
|
||||
contents=[DataContent(uri=f"data:application/json;base64,{state_b64}")],
|
||||
),
|
||||
]
|
||||
|
||||
mock_events = [
|
||||
{"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"},
|
||||
{"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"},
|
||||
]
|
||||
|
||||
async def mock_post_run(*args, **kwargs):
|
||||
assert kwargs.get("state") == state_data
|
||||
for event in mock_events:
|
||||
yield event
|
||||
|
||||
client = AGUIChatClient(endpoint="http://localhost:8888/")
|
||||
monkeypatch.setattr(client._http_service, "post_run", mock_post_run)
|
||||
|
||||
chat_options = ChatOptions()
|
||||
|
||||
response = await client._inner_get_response(messages=messages, chat_options=chat_options)
|
||||
|
||||
assert response is not None
|
||||
@@ -0,0 +1,287 @@
|
||||
"""Tests for AG-UI event converter."""
|
||||
|
||||
from agent_framework import FinishReason, Role
|
||||
|
||||
from agent_framework_ag_ui._event_converters import AGUIEventConverter
|
||||
|
||||
|
||||
class TestAGUIEventConverter:
|
||||
"""Test suite for AGUIEventConverter."""
|
||||
|
||||
def test_run_started_event(self) -> None:
|
||||
"""Test conversion of RUN_STARTED event."""
|
||||
converter = AGUIEventConverter()
|
||||
event = {
|
||||
"type": "RUN_STARTED",
|
||||
"threadId": "thread_123",
|
||||
"runId": "run_456",
|
||||
}
|
||||
|
||||
update = converter.convert_event(event)
|
||||
|
||||
assert update is not None
|
||||
assert update.role == Role.ASSISTANT
|
||||
assert update.additional_properties["thread_id"] == "thread_123"
|
||||
assert update.additional_properties["run_id"] == "run_456"
|
||||
assert converter.thread_id == "thread_123"
|
||||
assert converter.run_id == "run_456"
|
||||
|
||||
def test_text_message_start_event(self) -> None:
|
||||
"""Test conversion of TEXT_MESSAGE_START event."""
|
||||
converter = AGUIEventConverter()
|
||||
event = {
|
||||
"type": "TEXT_MESSAGE_START",
|
||||
"messageId": "msg_789",
|
||||
}
|
||||
|
||||
update = converter.convert_event(event)
|
||||
|
||||
assert update is not None
|
||||
assert update.role == Role.ASSISTANT
|
||||
assert update.message_id == "msg_789"
|
||||
assert converter.current_message_id == "msg_789"
|
||||
|
||||
def test_text_message_content_event(self) -> None:
|
||||
"""Test conversion of TEXT_MESSAGE_CONTENT event."""
|
||||
converter = AGUIEventConverter()
|
||||
event = {
|
||||
"type": "TEXT_MESSAGE_CONTENT",
|
||||
"messageId": "msg_1",
|
||||
"delta": "Hello",
|
||||
}
|
||||
|
||||
update = converter.convert_event(event)
|
||||
|
||||
assert update is not None
|
||||
assert update.role == Role.ASSISTANT
|
||||
assert update.message_id == "msg_1"
|
||||
assert len(update.contents) == 1
|
||||
assert update.contents[0].text == "Hello"
|
||||
|
||||
def test_text_message_streaming(self) -> None:
|
||||
"""Test streaming text across multiple TEXT_MESSAGE_CONTENT events."""
|
||||
converter = AGUIEventConverter()
|
||||
events = [
|
||||
{"type": "TEXT_MESSAGE_CONTENT", "messageId": "msg_1", "delta": "Hello"},
|
||||
{"type": "TEXT_MESSAGE_CONTENT", "messageId": "msg_1", "delta": " world"},
|
||||
{"type": "TEXT_MESSAGE_CONTENT", "messageId": "msg_1", "delta": "!"},
|
||||
]
|
||||
|
||||
updates = [converter.convert_event(event) for event in events]
|
||||
|
||||
assert all(update is not None for update in updates)
|
||||
assert all(update.message_id == "msg_1" for update in updates)
|
||||
assert updates[0].contents[0].text == "Hello"
|
||||
assert updates[1].contents[0].text == " world"
|
||||
assert updates[2].contents[0].text == "!"
|
||||
|
||||
def test_text_message_end_event(self) -> None:
|
||||
"""Test conversion of TEXT_MESSAGE_END event."""
|
||||
converter = AGUIEventConverter()
|
||||
event = {
|
||||
"type": "TEXT_MESSAGE_END",
|
||||
"messageId": "msg_1",
|
||||
}
|
||||
|
||||
update = converter.convert_event(event)
|
||||
|
||||
assert update is None
|
||||
|
||||
def test_tool_call_start_event(self) -> None:
|
||||
"""Test conversion of TOOL_CALL_START event."""
|
||||
converter = AGUIEventConverter()
|
||||
event = {
|
||||
"type": "TOOL_CALL_START",
|
||||
"toolCallId": "call_123",
|
||||
"toolName": "get_weather",
|
||||
}
|
||||
|
||||
update = converter.convert_event(event)
|
||||
|
||||
assert update is not None
|
||||
assert update.role == Role.ASSISTANT
|
||||
assert len(update.contents) == 1
|
||||
assert update.contents[0].call_id == "call_123"
|
||||
assert update.contents[0].name == "get_weather"
|
||||
assert update.contents[0].arguments == ""
|
||||
assert converter.current_tool_call_id == "call_123"
|
||||
assert converter.current_tool_name == "get_weather"
|
||||
|
||||
def test_tool_call_start_with_tool_call_name(self) -> None:
|
||||
"""Ensure TOOL_CALL_START with toolCallName still sets the tool name."""
|
||||
converter = AGUIEventConverter()
|
||||
event = {
|
||||
"type": "TOOL_CALL_START",
|
||||
"toolCallId": "call_abc",
|
||||
"toolCallName": "get_weather",
|
||||
}
|
||||
|
||||
update = converter.convert_event(event)
|
||||
|
||||
assert update is not None
|
||||
assert update.contents[0].name == "get_weather"
|
||||
assert converter.current_tool_name == "get_weather"
|
||||
|
||||
def test_tool_call_start_with_tool_call_name_snake_case(self) -> None:
|
||||
"""Support tool_call_name snake_case field for backwards compatibility."""
|
||||
converter = AGUIEventConverter()
|
||||
event = {
|
||||
"type": "TOOL_CALL_START",
|
||||
"toolCallId": "call_snake",
|
||||
"tool_call_name": "get_weather",
|
||||
}
|
||||
|
||||
update = converter.convert_event(event)
|
||||
|
||||
assert update is not None
|
||||
assert update.contents[0].name == "get_weather"
|
||||
assert converter.current_tool_name == "get_weather"
|
||||
|
||||
def test_tool_call_args_streaming(self) -> None:
|
||||
"""Test streaming tool arguments across multiple TOOL_CALL_ARGS events."""
|
||||
converter = AGUIEventConverter()
|
||||
converter.current_tool_call_id = "call_123"
|
||||
converter.current_tool_name = "search"
|
||||
|
||||
events = [
|
||||
{"type": "TOOL_CALL_ARGS", "delta": '{"query": "'},
|
||||
{"type": "TOOL_CALL_ARGS", "delta": 'latest news"}'},
|
||||
]
|
||||
|
||||
updates = [converter.convert_event(event) for event in events]
|
||||
|
||||
assert all(update is not None for update in updates)
|
||||
assert updates[0].contents[0].arguments == '{"query": "'
|
||||
assert updates[1].contents[0].arguments == 'latest news"}'
|
||||
assert converter.accumulated_tool_args == '{"query": "latest news"}'
|
||||
|
||||
def test_tool_call_end_event(self) -> None:
|
||||
"""Test conversion of TOOL_CALL_END event."""
|
||||
converter = AGUIEventConverter()
|
||||
converter.accumulated_tool_args = '{"location": "Seattle"}'
|
||||
|
||||
event = {
|
||||
"type": "TOOL_CALL_END",
|
||||
"toolCallId": "call_123",
|
||||
}
|
||||
|
||||
update = converter.convert_event(event)
|
||||
|
||||
assert update is None
|
||||
assert converter.accumulated_tool_args == ""
|
||||
|
||||
def test_tool_call_result_event(self) -> None:
|
||||
"""Test conversion of TOOL_CALL_RESULT event."""
|
||||
converter = AGUIEventConverter()
|
||||
event = {
|
||||
"type": "TOOL_CALL_RESULT",
|
||||
"toolCallId": "call_123",
|
||||
"result": {"temperature": 22, "condition": "sunny"},
|
||||
}
|
||||
|
||||
update = converter.convert_event(event)
|
||||
|
||||
assert update is not None
|
||||
assert update.role == Role.TOOL
|
||||
assert len(update.contents) == 1
|
||||
assert update.contents[0].call_id == "call_123"
|
||||
assert update.contents[0].result == {"temperature": 22, "condition": "sunny"}
|
||||
|
||||
def test_run_finished_event(self) -> None:
|
||||
"""Test conversion of RUN_FINISHED event."""
|
||||
converter = AGUIEventConverter()
|
||||
converter.thread_id = "thread_123"
|
||||
converter.run_id = "run_456"
|
||||
|
||||
event = {
|
||||
"type": "RUN_FINISHED",
|
||||
"threadId": "thread_123",
|
||||
"runId": "run_456",
|
||||
}
|
||||
|
||||
update = converter.convert_event(event)
|
||||
|
||||
assert update is not None
|
||||
assert update.role == Role.ASSISTANT
|
||||
assert update.finish_reason == FinishReason.STOP
|
||||
assert update.additional_properties["thread_id"] == "thread_123"
|
||||
assert update.additional_properties["run_id"] == "run_456"
|
||||
|
||||
def test_run_error_event(self) -> None:
|
||||
"""Test conversion of RUN_ERROR event."""
|
||||
converter = AGUIEventConverter()
|
||||
converter.thread_id = "thread_123"
|
||||
converter.run_id = "run_456"
|
||||
|
||||
event = {
|
||||
"type": "RUN_ERROR",
|
||||
"message": "Connection timeout",
|
||||
}
|
||||
|
||||
update = converter.convert_event(event)
|
||||
|
||||
assert update is not None
|
||||
assert update.role == Role.ASSISTANT
|
||||
assert update.finish_reason == FinishReason.CONTENT_FILTER
|
||||
assert len(update.contents) == 1
|
||||
assert update.contents[0].message == "Connection timeout"
|
||||
assert update.contents[0].error_code == "RUN_ERROR"
|
||||
|
||||
def test_unknown_event_type(self) -> None:
|
||||
"""Test handling of unknown event types."""
|
||||
converter = AGUIEventConverter()
|
||||
event = {
|
||||
"type": "UNKNOWN_EVENT",
|
||||
"data": "some data",
|
||||
}
|
||||
|
||||
update = converter.convert_event(event)
|
||||
|
||||
assert update is None
|
||||
|
||||
def test_full_conversation_flow(self) -> None:
|
||||
"""Test complete conversation flow with multiple event types."""
|
||||
converter = AGUIEventConverter()
|
||||
|
||||
events = [
|
||||
{"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"},
|
||||
{"type": "TEXT_MESSAGE_START", "messageId": "msg_1"},
|
||||
{"type": "TEXT_MESSAGE_CONTENT", "messageId": "msg_1", "delta": "I'll check"},
|
||||
{"type": "TEXT_MESSAGE_CONTENT", "messageId": "msg_1", "delta": " the weather."},
|
||||
{"type": "TEXT_MESSAGE_END", "messageId": "msg_1"},
|
||||
{"type": "TOOL_CALL_START", "toolCallId": "call_1", "toolName": "get_weather"},
|
||||
{"type": "TOOL_CALL_ARGS", "delta": '{"location": "Seattle"}'},
|
||||
{"type": "TOOL_CALL_END", "toolCallId": "call_1"},
|
||||
{"type": "TOOL_CALL_RESULT", "toolCallId": "call_1", "result": "Sunny, 72°F"},
|
||||
{"type": "TEXT_MESSAGE_START", "messageId": "msg_2"},
|
||||
{"type": "TEXT_MESSAGE_CONTENT", "messageId": "msg_2", "delta": "It's sunny!"},
|
||||
{"type": "TEXT_MESSAGE_END", "messageId": "msg_2"},
|
||||
{"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"},
|
||||
]
|
||||
|
||||
updates = [converter.convert_event(event) for event in events]
|
||||
non_none_updates = [u for u in updates if u is not None]
|
||||
|
||||
assert len(non_none_updates) == 10
|
||||
assert converter.thread_id == "thread_1"
|
||||
assert converter.run_id == "run_1"
|
||||
|
||||
def test_multiple_tool_calls(self) -> None:
|
||||
"""Test handling multiple tool calls in sequence."""
|
||||
converter = AGUIEventConverter()
|
||||
|
||||
events = [
|
||||
{"type": "TOOL_CALL_START", "toolCallId": "call_1", "toolName": "search"},
|
||||
{"type": "TOOL_CALL_ARGS", "delta": '{"query": "weather"}'},
|
||||
{"type": "TOOL_CALL_END", "toolCallId": "call_1"},
|
||||
{"type": "TOOL_CALL_START", "toolCallId": "call_2", "toolName": "fetch"},
|
||||
{"type": "TOOL_CALL_ARGS", "delta": '{"url": "http://api.weather.com"}'},
|
||||
{"type": "TOOL_CALL_END", "toolCallId": "call_2"},
|
||||
]
|
||||
|
||||
updates = [converter.convert_event(event) for event in events]
|
||||
non_none_updates = [u for u in updates if u is not None]
|
||||
|
||||
assert len(non_none_updates) == 4
|
||||
assert non_none_updates[0].contents[0].name == "search"
|
||||
assert non_none_updates[2].contents[0].name == "fetch"
|
||||
@@ -0,0 +1,238 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Tests for AGUIHttpService."""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from agent_framework_ag_ui._http_service import AGUIHttpService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_http_client():
|
||||
"""Create a mock httpx.AsyncClient."""
|
||||
client = AsyncMock(spec=httpx.AsyncClient)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_events():
|
||||
"""Sample AG-UI events for testing."""
|
||||
return [
|
||||
{"type": "RUN_STARTED", "threadId": "thread_123", "runId": "run_456"},
|
||||
{"type": "TEXT_MESSAGE_START", "messageId": "msg_1", "role": "assistant"},
|
||||
{"type": "TEXT_MESSAGE_CONTENT", "messageId": "msg_1", "delta": "Hello"},
|
||||
{"type": "TEXT_MESSAGE_CONTENT", "messageId": "msg_1", "delta": " world"},
|
||||
{"type": "TEXT_MESSAGE_END", "messageId": "msg_1"},
|
||||
{"type": "RUN_FINISHED", "threadId": "thread_123", "runId": "run_456"},
|
||||
]
|
||||
|
||||
|
||||
def create_sse_response(events: list[dict]) -> str:
|
||||
"""Create SSE formatted response from events."""
|
||||
lines = []
|
||||
for event in events:
|
||||
lines.append(f"data: {json.dumps(event)}\n")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
async def test_http_service_initialization():
|
||||
"""Test AGUIHttpService initialization."""
|
||||
# Test with default client
|
||||
service = AGUIHttpService("http://localhost:8888/")
|
||||
assert service.endpoint == "http://localhost:8888"
|
||||
assert service._owns_client is True
|
||||
assert isinstance(service.http_client, httpx.AsyncClient)
|
||||
await service.close()
|
||||
|
||||
# Test with custom client
|
||||
custom_client = httpx.AsyncClient()
|
||||
service = AGUIHttpService("http://localhost:8888/", http_client=custom_client)
|
||||
assert service._owns_client is False
|
||||
assert service.http_client is custom_client
|
||||
# Shouldn't close the custom client
|
||||
await service.close()
|
||||
await custom_client.aclose()
|
||||
|
||||
|
||||
async def test_http_service_strips_trailing_slash():
|
||||
"""Test that endpoint trailing slash is stripped."""
|
||||
service = AGUIHttpService("http://localhost:8888/")
|
||||
assert service.endpoint == "http://localhost:8888"
|
||||
await service.close()
|
||||
|
||||
|
||||
async def test_post_run_successful_streaming(mock_http_client, sample_events):
|
||||
"""Test successful streaming of events."""
|
||||
|
||||
# Create async generator for lines
|
||||
async def mock_aiter_lines():
|
||||
sse_data = create_sse_response(sample_events)
|
||||
for line in sse_data.split("\n"):
|
||||
if line:
|
||||
yield line
|
||||
|
||||
# Create mock response
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
# aiter_lines is called as a method, so it should return a new generator each time
|
||||
mock_response.aiter_lines = mock_aiter_lines
|
||||
|
||||
# Setup mock streaming context manager
|
||||
mock_stream_context = AsyncMock()
|
||||
mock_stream_context.__aenter__.return_value = mock_response
|
||||
mock_stream_context.__aexit__.return_value = None
|
||||
mock_http_client.stream.return_value = mock_stream_context
|
||||
|
||||
service = AGUIHttpService("http://localhost:8888/", http_client=mock_http_client)
|
||||
|
||||
events = []
|
||||
async for event in service.post_run(
|
||||
thread_id="thread_123", run_id="run_456", messages=[{"role": "user", "content": "Hello"}]
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
assert len(events) == len(sample_events)
|
||||
assert events[0]["type"] == "RUN_STARTED"
|
||||
assert events[-1]["type"] == "RUN_FINISHED"
|
||||
|
||||
# Verify request was made correctly
|
||||
mock_http_client.stream.assert_called_once()
|
||||
call_args = mock_http_client.stream.call_args
|
||||
assert call_args.args[0] == "POST"
|
||||
assert call_args.args[1] == "http://localhost:8888"
|
||||
assert call_args.kwargs["headers"] == {"Accept": "text/event-stream"}
|
||||
|
||||
|
||||
async def test_post_run_with_state_and_tools(mock_http_client):
|
||||
"""Test posting run with state and tools."""
|
||||
|
||||
async def mock_aiter_lines():
|
||||
return
|
||||
yield # Make it an async generator
|
||||
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.aiter_lines = mock_aiter_lines
|
||||
|
||||
mock_stream_context = AsyncMock()
|
||||
mock_stream_context.__aenter__.return_value = mock_response
|
||||
mock_stream_context.__aexit__.return_value = None
|
||||
mock_http_client.stream.return_value = mock_stream_context
|
||||
|
||||
service = AGUIHttpService("http://localhost:8888/", http_client=mock_http_client)
|
||||
|
||||
state = {"user_context": {"name": "Alice"}}
|
||||
tools = [{"type": "function", "function": {"name": "test_tool"}}]
|
||||
|
||||
async for _ in service.post_run(thread_id="thread_123", run_id="run_456", messages=[], state=state, tools=tools):
|
||||
pass
|
||||
|
||||
# Verify state and tools were included in request
|
||||
call_args = mock_http_client.stream.call_args
|
||||
request_data = call_args.kwargs["json"]
|
||||
assert request_data["state"] == state
|
||||
assert request_data["tools"] == tools
|
||||
|
||||
|
||||
async def test_post_run_http_error(mock_http_client):
|
||||
"""Test handling of HTTP errors."""
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 500
|
||||
mock_response.text = "Internal Server Error"
|
||||
|
||||
def raise_http_error():
|
||||
raise httpx.HTTPStatusError("Server error", request=Mock(), response=mock_response)
|
||||
|
||||
mock_response_async = AsyncMock()
|
||||
mock_response_async.raise_for_status = raise_http_error
|
||||
|
||||
mock_stream_context = AsyncMock()
|
||||
mock_stream_context.__aenter__.return_value = mock_response_async
|
||||
mock_stream_context.__aexit__.return_value = None
|
||||
mock_http_client.stream.return_value = mock_stream_context
|
||||
|
||||
service = AGUIHttpService("http://localhost:8888/", http_client=mock_http_client)
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
async for _ in service.post_run(thread_id="thread_123", run_id="run_456", messages=[]):
|
||||
pass
|
||||
|
||||
|
||||
async def test_post_run_invalid_json(mock_http_client):
|
||||
"""Test handling of invalid JSON in SSE stream."""
|
||||
invalid_sse = "data: {invalid json}\n\ndata: " + json.dumps({"type": "RUN_FINISHED"}) + "\n"
|
||||
|
||||
async def mock_aiter_lines():
|
||||
for line in invalid_sse.split("\n"):
|
||||
if line:
|
||||
yield line
|
||||
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.aiter_lines = mock_aiter_lines
|
||||
|
||||
mock_stream_context = AsyncMock()
|
||||
mock_stream_context.__aenter__.return_value = mock_response
|
||||
mock_stream_context.__aexit__.return_value = None
|
||||
mock_http_client.stream.return_value = mock_stream_context
|
||||
|
||||
service = AGUIHttpService("http://localhost:8888/", http_client=mock_http_client)
|
||||
|
||||
events = []
|
||||
async for event in service.post_run(thread_id="thread_123", run_id="run_456", messages=[]):
|
||||
events.append(event)
|
||||
|
||||
# Should skip invalid JSON and continue with valid events
|
||||
assert len(events) == 1
|
||||
assert events[0]["type"] == "RUN_FINISHED"
|
||||
|
||||
|
||||
async def test_context_manager():
|
||||
"""Test context manager functionality."""
|
||||
async with AGUIHttpService("http://localhost:8888/") as service:
|
||||
assert service.http_client is not None
|
||||
assert service._owns_client is True
|
||||
|
||||
# Client should be closed after exiting context
|
||||
|
||||
|
||||
async def test_context_manager_with_external_client():
|
||||
"""Test context manager doesn't close external client."""
|
||||
external_client = httpx.AsyncClient()
|
||||
|
||||
async with AGUIHttpService("http://localhost:8888/", http_client=external_client) as service:
|
||||
assert service.http_client is external_client
|
||||
assert service._owns_client is False
|
||||
|
||||
# External client should still be open
|
||||
# (caller's responsibility to close)
|
||||
await external_client.aclose()
|
||||
|
||||
|
||||
async def test_post_run_empty_response(mock_http_client):
|
||||
"""Test handling of empty response stream."""
|
||||
|
||||
async def mock_aiter_lines():
|
||||
return
|
||||
yield # Make it an async generator
|
||||
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.aiter_lines = mock_aiter_lines
|
||||
|
||||
mock_stream_context = AsyncMock()
|
||||
mock_stream_context.__aenter__.return_value = mock_response
|
||||
mock_stream_context.__aexit__.return_value = None
|
||||
mock_http_client.stream.return_value = mock_stream_context
|
||||
|
||||
service = AGUIHttpService("http://localhost:8888/", http_client=mock_http_client)
|
||||
|
||||
events = []
|
||||
async for event in service.post_run(thread_id="thread_123", run_id="run_456", messages=[]):
|
||||
events.append(event)
|
||||
|
||||
assert len(events) == 0
|
||||
@@ -63,10 +63,9 @@ def test_agui_tool_result_to_agent_framework():
|
||||
assert isinstance(message.contents[0], TextContent)
|
||||
assert message.contents[0].text == '{"accepted": true, "steps": []}'
|
||||
|
||||
assert hasattr(message, "metadata")
|
||||
assert message.metadata is not None
|
||||
assert message.metadata.get("is_tool_result") is True
|
||||
assert message.metadata.get("tool_call_id") == "call_123"
|
||||
assert message.additional_properties is not None
|
||||
assert message.additional_properties.get("is_tool_result") is True
|
||||
assert message.additional_properties.get("tool_call_id") == "call_123"
|
||||
|
||||
|
||||
def test_agui_multiple_messages_to_agent_framework():
|
||||
@@ -159,6 +158,36 @@ def test_agui_message_without_id():
|
||||
assert messages[0].message_id is None
|
||||
|
||||
|
||||
def test_agui_with_tool_calls_to_agent_framework():
|
||||
"""Assistant message with tool_calls is converted to FunctionCallContent."""
|
||||
agui_msg = {
|
||||
"role": "assistant",
|
||||
"content": "Calling tool",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call-123",
|
||||
"type": "function",
|
||||
"function": {"name": "get_weather", "arguments": {"location": "Seattle"}},
|
||||
}
|
||||
],
|
||||
"id": "msg-789",
|
||||
}
|
||||
|
||||
messages = agui_messages_to_agent_framework([agui_msg])
|
||||
|
||||
assert len(messages) == 1
|
||||
msg = messages[0]
|
||||
assert msg.role == Role.ASSISTANT
|
||||
assert msg.message_id == "msg-789"
|
||||
# First content is text, second is the function call
|
||||
assert isinstance(msg.contents[0], TextContent)
|
||||
assert msg.contents[0].text == "Calling tool"
|
||||
assert isinstance(msg.contents[1], FunctionCallContent)
|
||||
assert msg.contents[1].call_id == "call-123"
|
||||
assert msg.contents[1].name == "get_weather"
|
||||
assert msg.contents[1].arguments == {"location": "Seattle"}
|
||||
|
||||
|
||||
def test_agent_framework_to_agui_with_tool_calls():
|
||||
"""Test converting Agent Framework message with tool calls to AG-UI."""
|
||||
msg = ChatMessage(
|
||||
@@ -198,13 +227,15 @@ def test_agent_framework_to_agui_multiple_text_contents():
|
||||
|
||||
|
||||
def test_agent_framework_to_agui_no_message_id():
|
||||
"""Test message without message_id."""
|
||||
"""Test message without message_id - should auto-generate ID."""
|
||||
msg = ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")])
|
||||
|
||||
messages = agent_framework_messages_to_agui([msg])
|
||||
|
||||
assert len(messages) == 1
|
||||
assert "id" not in messages[0]
|
||||
assert "id" in messages[0] # ID should be auto-generated
|
||||
assert messages[0]["id"] # ID should not be empty
|
||||
assert len(messages[0]["id"]) > 0 # ID should be a valid string
|
||||
|
||||
|
||||
def test_agent_framework_to_agui_system_role():
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
"""Tests for AG-UI orchestrators."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import AgentRunResponseUpdate, TextContent, ai_function
|
||||
from agent_framework._tools import FunctionInvocationConfiguration
|
||||
|
||||
from agent_framework_ag_ui._agent import AgentConfig
|
||||
from agent_framework_ag_ui._orchestrators import DefaultOrchestrator, ExecutionContext
|
||||
|
||||
|
||||
@ai_function
|
||||
def server_tool() -> str:
|
||||
"""Server-executable tool."""
|
||||
return "server"
|
||||
|
||||
|
||||
class DummyAgent:
|
||||
"""Minimal agent stub to capture run_stream parameters."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.chat_options = SimpleNamespace(tools=[server_tool], response_format=None)
|
||||
self.tools = [server_tool]
|
||||
self.chat_client = SimpleNamespace(
|
||||
function_invocation_configuration=FunctionInvocationConfiguration(),
|
||||
)
|
||||
self.seen_tools: list[Any] | None = None
|
||||
|
||||
async def run_stream(
|
||||
self,
|
||||
messages: list[Any],
|
||||
*,
|
||||
thread: Any,
|
||||
tools: list[Any] | None = None,
|
||||
) -> AsyncGenerator[AgentRunResponseUpdate, None]:
|
||||
self.seen_tools = tools
|
||||
yield AgentRunResponseUpdate(contents=[TextContent(text="ok")], role="assistant")
|
||||
|
||||
|
||||
async def test_default_orchestrator_merges_client_tools() -> None:
|
||||
"""Client tool declarations are merged with server tools before running agent."""
|
||||
|
||||
agent = DummyAgent()
|
||||
orchestrator = DefaultOrchestrator()
|
||||
|
||||
input_data = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "input_text", "text": "Hello"}],
|
||||
}
|
||||
],
|
||||
"tools": [
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Client weather lookup.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"location": {"type": "string"}},
|
||||
"required": ["location"],
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
context = ExecutionContext(
|
||||
input_data=input_data,
|
||||
agent=agent,
|
||||
config=AgentConfig(),
|
||||
)
|
||||
|
||||
events = []
|
||||
async for event in orchestrator.run(context):
|
||||
events.append(event)
|
||||
|
||||
assert agent.seen_tools is not None
|
||||
tool_names = [getattr(tool, "name", "?") for tool in agent.seen_tools]
|
||||
assert "server_tool" in tool_names
|
||||
assert "get_weather" in tool_names
|
||||
assert agent.chat_client.function_invocation_configuration.additional_tools
|
||||
@@ -197,3 +197,109 @@ def test_make_json_safe_fallback():
|
||||
result = make_json_safe(obj)
|
||||
# Objects with __dict__ return their __dict__ dict
|
||||
assert isinstance(result, dict)
|
||||
|
||||
|
||||
def test_convert_tools_to_agui_format_with_ai_function():
|
||||
"""Test converting AIFunction to AG-UI format."""
|
||||
from agent_framework import ai_function
|
||||
|
||||
from agent_framework_ag_ui._utils import convert_tools_to_agui_format
|
||||
|
||||
@ai_function
|
||||
def test_func(param: str, count: int = 5) -> str:
|
||||
"""Test function."""
|
||||
return f"{param} {count}"
|
||||
|
||||
result = convert_tools_to_agui_format([test_func])
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == "test_func"
|
||||
assert result[0]["description"] == "Test function."
|
||||
assert "parameters" in result[0]
|
||||
assert "properties" in result[0]["parameters"]
|
||||
|
||||
|
||||
def test_convert_tools_to_agui_format_with_callable():
|
||||
"""Test converting plain callable to AG-UI format."""
|
||||
from agent_framework_ag_ui._utils import convert_tools_to_agui_format
|
||||
|
||||
def plain_func(x: int) -> int:
|
||||
"""A plain function."""
|
||||
return x * 2
|
||||
|
||||
result = convert_tools_to_agui_format([plain_func])
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == "plain_func"
|
||||
assert result[0]["description"] == "A plain function."
|
||||
assert "parameters" in result[0]
|
||||
|
||||
|
||||
def test_convert_tools_to_agui_format_with_dict():
|
||||
"""Test converting dict tool to AG-UI format."""
|
||||
from agent_framework_ag_ui._utils import convert_tools_to_agui_format
|
||||
|
||||
tool_dict = {
|
||||
"name": "custom_tool",
|
||||
"description": "Custom tool",
|
||||
"parameters": {"type": "object"},
|
||||
}
|
||||
|
||||
result = convert_tools_to_agui_format([tool_dict])
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 1
|
||||
assert result[0] == tool_dict
|
||||
|
||||
|
||||
def test_convert_tools_to_agui_format_with_none():
|
||||
"""Test converting None tools."""
|
||||
from agent_framework_ag_ui._utils import convert_tools_to_agui_format
|
||||
|
||||
result = convert_tools_to_agui_format(None)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_convert_tools_to_agui_format_with_single_tool():
|
||||
"""Test converting single tool (not in list)."""
|
||||
from agent_framework import ai_function
|
||||
|
||||
from agent_framework_ag_ui._utils import convert_tools_to_agui_format
|
||||
|
||||
@ai_function
|
||||
def single_tool(arg: str) -> str:
|
||||
"""Single tool."""
|
||||
return arg
|
||||
|
||||
result = convert_tools_to_agui_format(single_tool)
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == "single_tool"
|
||||
|
||||
|
||||
def test_convert_tools_to_agui_format_with_multiple_tools():
|
||||
"""Test converting multiple tools."""
|
||||
from agent_framework import ai_function
|
||||
|
||||
from agent_framework_ag_ui._utils import convert_tools_to_agui_format
|
||||
|
||||
@ai_function
|
||||
def tool1(x: int) -> int:
|
||||
"""Tool 1."""
|
||||
return x
|
||||
|
||||
@ai_function
|
||||
def tool2(y: str) -> str:
|
||||
"""Tool 2."""
|
||||
return y
|
||||
|
||||
result = convert_tools_to_agui_format([tool1, tool2])
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 2
|
||||
assert result[0]["name"] == "tool1"
|
||||
assert result[1]["name"] == "tool2"
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import importlib
|
||||
from typing import Any
|
||||
|
||||
PACKAGE_NAME = "agent_framework_ag_ui"
|
||||
PACKAGE_EXTRA = "ag-ui"
|
||||
_IMPORTS = [
|
||||
"__version__",
|
||||
"AgentFrameworkAgent",
|
||||
"add_agent_framework_fastapi_endpoint",
|
||||
"AGUIChatClient",
|
||||
"AGUIEventConverter",
|
||||
"AGUIHttpService",
|
||||
"ConfirmationStrategy",
|
||||
"DefaultConfirmationStrategy",
|
||||
"TaskPlannerConfirmationStrategy",
|
||||
"RecipeConfirmationStrategy",
|
||||
"DocumentWriterConfirmationStrategy",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name in _IMPORTS:
|
||||
try:
|
||||
return getattr(importlib.import_module(PACKAGE_NAME), name)
|
||||
except ModuleNotFoundError as exc:
|
||||
raise ModuleNotFoundError(
|
||||
f"The '{PACKAGE_EXTRA}' extra is not installed, please do `pip install agent-framework-{PACKAGE_EXTRA}`"
|
||||
) from exc
|
||||
raise AttributeError(f"Module {PACKAGE_NAME} has no attribute {name}.")
|
||||
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
return _IMPORTS
|
||||
@@ -0,0 +1,29 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from agent_framework_ag_ui import (
|
||||
AgentFrameworkAgent,
|
||||
AGUIChatClient,
|
||||
AGUIEventConverter,
|
||||
AGUIHttpService,
|
||||
ConfirmationStrategy,
|
||||
DefaultConfirmationStrategy,
|
||||
DocumentWriterConfirmationStrategy,
|
||||
RecipeConfirmationStrategy,
|
||||
TaskPlannerConfirmationStrategy,
|
||||
__version__,
|
||||
add_agent_framework_fastapi_endpoint,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AGUIChatClient",
|
||||
"AGUIEventConverter",
|
||||
"AGUIHttpService",
|
||||
"AgentFrameworkAgent",
|
||||
"ConfirmationStrategy",
|
||||
"DefaultConfirmationStrategy",
|
||||
"DocumentWriterConfirmationStrategy",
|
||||
"RecipeConfirmationStrategy",
|
||||
"TaskPlannerConfirmationStrategy",
|
||||
"__version__",
|
||||
"add_agent_framework_fastapi_endpoint",
|
||||
]
|
||||
@@ -42,13 +42,14 @@ dependencies = [
|
||||
[project.optional-dependencies]
|
||||
all = [
|
||||
"agent-framework-a2a",
|
||||
"agent-framework-ag-ui",
|
||||
"agent-framework-anthropic",
|
||||
"agent-framework-azure-ai",
|
||||
"agent-framework-copilotstudio",
|
||||
"agent-framework-mem0",
|
||||
"agent-framework-redis",
|
||||
"agent-framework-devui",
|
||||
"agent-framework-mem0",
|
||||
"agent-framework-purview",
|
||||
"agent-framework-anthropic",
|
||||
"agent-framework-redis",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
|
||||
Generated
+3433
-3415
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user