Python: Add concrete AGUIChatClient (#2072)

* Add concrete AGUIChatClient

* Update logging docstrings and conventions

* PR feedback

* Updates to support client-side tool calls
This commit is contained in:
Evan Mattson
2025-11-11 23:39:30 +09:00
committed by GitHub
Unverified
parent 93ab43d788
commit 32bd884bfd
26 changed files with 6410 additions and 4037 deletions
+33 -1
View File
@@ -10,6 +10,8 @@ pip install agent-framework-ag-ui
## Quick Start
### Server (Host an AI Agent)
```python
from fastapi import FastAPI
from agent_framework import ChatAgent
@@ -23,6 +25,7 @@ agent = ChatAgent(
chat_client=AzureOpenAIChatClient(
endpoint="https://your-resource.openai.azure.com/",
deployment_name="gpt-4o-mini",
api_key="your-api-key",
),
)
@@ -33,9 +36,38 @@ add_agent_framework_fastapi_endpoint(app, agent, "/")
# Run with: uvicorn main:app --reload
```
### Client (Connect to an AG-UI Server)
```python
import asyncio
from agent_framework import TextContent
from agent_framework_ag_ui import AGUIChatClient
async def main():
async with AGUIChatClient(endpoint="http://localhost:8000/") as client:
# Stream responses
async for update in client.get_streaming_response("Hello!"):
for content in update.contents:
if isinstance(content, TextContent):
print(content.text, end="", flush=True)
print()
asyncio.run(main())
```
The `AGUIChatClient` supports:
- Streaming and non-streaming responses
- Hybrid tool execution (client-side + server-side tools)
- Automatic thread management for conversation continuity
- Integration with `ChatAgent` for client-side history management
## Documentation
- **[Getting Started Tutorial](getting_started/)** - Step-by-step guide to building your first AG-UI server and client
- **[Getting Started Tutorial](getting_started/)** - Step-by-step guide to building AG-UI servers and clients
- Server setup with FastAPI
- Client examples using `AGUIChatClient`
- Hybrid tool execution (client-side + server-side)
- Thread management and conversation continuity
- **[Examples](agent_framework_ag_ui_examples/)** - Complete examples for AG-UI features
## Features
@@ -5,6 +5,7 @@
import importlib.metadata
from ._agent import AgentFrameworkAgent
from ._client import AGUIChatClient
from ._confirmation_strategies import (
ConfirmationStrategy,
DefaultConfirmationStrategy,
@@ -13,6 +14,8 @@ from ._confirmation_strategies import (
TaskPlannerConfirmationStrategy,
)
from ._endpoint import add_agent_framework_fastapi_endpoint
from ._event_converters import AGUIEventConverter
from ._http_service import AGUIHttpService
try:
__version__ = importlib.metadata.version(__name__)
@@ -22,6 +25,9 @@ except importlib.metadata.PackageNotFoundError:
__all__ = [
"AgentFrameworkAgent",
"add_agent_framework_fastapi_endpoint",
"AGUIChatClient",
"AGUIEventConverter",
"AGUIHttpService",
"ConfirmationStrategy",
"DefaultConfirmationStrategy",
"TaskPlannerConfirmationStrategy",
@@ -0,0 +1,407 @@
# Copyright (c) Microsoft. All rights reserved.
"""AG-UI Chat Client implementation."""
import json
import logging
import uuid
from collections.abc import AsyncIterable, MutableSequence
from functools import wraps
from typing import Any, TypeVar, cast
import httpx
from agent_framework import (
AIFunction,
BaseChatClient,
ChatMessage,
ChatOptions,
ChatResponse,
ChatResponseUpdate,
DataContent,
FunctionCallContent,
)
from agent_framework._middleware import use_chat_middleware
from agent_framework._tools import use_function_invocation
from agent_framework._types import BaseContent, Contents
from agent_framework.observability import use_observability
from ._event_converters import AGUIEventConverter
from ._http_service import AGUIHttpService
from ._message_adapters import agent_framework_messages_to_agui
from ._utils import convert_tools_to_agui_format
logger: logging.Logger = logging.getLogger(__name__)
class ServerFunctionCallContent(BaseContent):
"""Wrapper for server function calls to prevent client re-execution.
All function calls from the remote server are server-side executions.
This wrapper prevents @use_function_invocation from trying to execute them again.
"""
function_call_content: FunctionCallContent
def __init__(self, function_call_content: FunctionCallContent) -> None:
"""Initialize with the function call content."""
super().__init__(type="server_function_call")
self.function_call_content = function_call_content
def _unwrap_server_function_call_contents(contents: MutableSequence[Contents | dict[str, Any]]) -> None:
"""Replace ServerFunctionCallContent instances with their underlying call content."""
for idx, content in enumerate(contents):
if isinstance(content, ServerFunctionCallContent):
contents[idx] = content.function_call_content # type: ignore[assignment]
TBaseChatClient = TypeVar("TBaseChatClient", bound=type[BaseChatClient])
def _apply_server_function_call_unwrap(chat_client: TBaseChatClient) -> TBaseChatClient:
"""Class decorator that unwraps server-side function calls after tool handling."""
original_get_streaming_response = chat_client.get_streaming_response
@wraps(original_get_streaming_response)
async def streaming_wrapper(self, *args: Any, **kwargs: Any) -> AsyncIterable[ChatResponseUpdate]:
async for update in original_get_streaming_response(self, *args, **kwargs):
_unwrap_server_function_call_contents(cast(MutableSequence[Contents | dict[str, Any]], update.contents))
yield update
chat_client.get_streaming_response = streaming_wrapper # type: ignore[assignment]
original_get_response = chat_client.get_response
@wraps(original_get_response)
async def response_wrapper(self, *args: Any, **kwargs: Any) -> ChatResponse:
response = await original_get_response(self, *args, **kwargs)
if response.messages:
for message in response.messages:
_unwrap_server_function_call_contents(
cast(MutableSequence[Contents | dict[str, Any]], message.contents)
)
return response
chat_client.get_response = response_wrapper # type: ignore[assignment]
return chat_client
@_apply_server_function_call_unwrap
@use_function_invocation
@use_observability
@use_chat_middleware
class AGUIChatClient(BaseChatClient):
"""Chat client for communicating with AG-UI compliant servers.
This client implements the BaseChatClient interface and automatically handles:
- Thread ID management for conversation continuity
- State synchronization between client and server
- Server-Sent Events (SSE) streaming
- Event conversion to Agent Framework types
Important: Message History Management
This client sends exactly the messages it receives to the server. It does NOT
automatically maintain conversation history. The server must handle history via thread_id.
For stateless servers: Use ChatAgent wrapper which will send full message history on each
request. However, even with ChatAgent, the server must echo back all context for the
agent to maintain history across turns.
Important: Tool Handling (Hybrid Execution - matches .NET)
1. Client tool metadata sent to server - LLM knows about both client and server tools
2. Server has its own tools that execute server-side
3. When LLM calls a client tool, @use_function_invocation executes it locally
4. Both client and server tools work together (hybrid pattern)
The wrapping ChatAgent's @use_function_invocation handles client tool execution
automatically when the server's LLM decides to call them.
Examples:
Direct usage (server manages thread history):
.. code-block:: python
from agent_framework.ag_ui import AGUIChatClient
client = AGUIChatClient(endpoint="http://localhost:8888/")
# First message - thread ID auto-generated
response = await client.get_response("Hello!")
thread_id = response.additional_properties.get("thread_id")
# Second message - server retrieves history using thread_id
response2 = await client.get_response(
"How are you?",
metadata={"thread_id": thread_id}
)
Recommended usage with ChatAgent (client manages history):
.. code-block:: python
from agent_framework import ChatAgent
from agent_framework.ag_ui import AGUIChatClient
client = AGUIChatClient(endpoint="http://localhost:8888/")
agent = ChatAgent(name="assistant", client=client)
thread = await agent.get_new_thread()
# ChatAgent automatically maintains history and sends full context
response = await agent.run("Hello!", thread=thread)
response2 = await agent.run("How are you?", thread=thread)
Streaming usage:
.. code-block:: python
async for update in client.get_streaming_response("Tell me a story"):
if update.contents:
for content in update.contents:
if hasattr(content, "text"):
print(content.text, end="", flush=True)
Context manager:
.. code-block:: python
async with AGUIChatClient(endpoint="http://localhost:8888/") as client:
response = await client.get_response("Hello!")
print(response.messages[0].text)
"""
OTEL_PROVIDER_NAME = "agui"
def __init__(
self,
*,
endpoint: str,
http_client: httpx.AsyncClient | None = None,
timeout: float = 60.0,
additional_properties: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
"""Initialize the AG-UI chat client.
Args:
endpoint: The AG-UI server endpoint URL (e.g., "http://localhost:8888/")
http_client: Optional httpx.AsyncClient instance. If None, one will be created.
timeout: Request timeout in seconds (default: 60.0)
additional_properties: Additional properties to store
**kwargs: Additional arguments passed to BaseChatClient
"""
super().__init__(additional_properties=additional_properties, **kwargs)
self._http_service = AGUIHttpService(
endpoint=endpoint,
http_client=http_client,
timeout=timeout,
)
async def close(self) -> None:
"""Close the HTTP client."""
await self._http_service.close()
async def __aenter__(self) -> "AGUIChatClient":
"""Enter async context manager."""
return self
async def __aexit__(self, *args: Any) -> None:
"""Exit async context manager."""
await self.close()
def _register_server_tool_placeholder(self, tool_name: str) -> None:
"""Register a declaration-only placeholder so function invocation skips execution."""
config = getattr(self, "function_invocation_configuration", None)
if not config:
return
if any(getattr(tool, "name", None) == tool_name for tool in config.additional_tools):
return
placeholder: AIFunction[Any, Any] = AIFunction(
name=tool_name,
description="Server-managed tool placeholder (AG-UI)",
func=None,
)
config.additional_tools = list(config.additional_tools) + [placeholder]
registered: set[str] = getattr(self, "_registered_server_tools", set())
registered.add(tool_name)
self._registered_server_tools = registered # type: ignore[attr-defined]
from agent_framework._logging import get_logger
logger = get_logger()
logger.debug(f"[AGUIChatClient] Registered server placeholder: {tool_name}")
def _extract_state_from_messages(
self, messages: MutableSequence[ChatMessage]
) -> tuple[list[ChatMessage], dict[str, Any] | None]:
"""Extract state from last message if present.
Args:
messages: List of chat messages
Returns:
Tuple of (messages_without_state, state_dict)
"""
if not messages:
return list(messages), None
last_message = messages[-1]
for content in last_message.contents:
if isinstance(content, DataContent) and content.media_type == "application/json":
try:
uri = content.uri
if uri.startswith("data:application/json;base64,"):
import base64
encoded_data = uri.split(",", 1)[1]
decoded_bytes = base64.b64decode(encoded_data)
state = json.loads(decoded_bytes.decode("utf-8"))
messages_without_state = list(messages[:-1]) if len(messages) > 1 else []
return messages_without_state, state
except (json.JSONDecodeError, ValueError, KeyError) as e:
from agent_framework._logging import get_logger
logger = get_logger()
logger.warning(f"Failed to extract state from message: {e}")
return list(messages), None
def _convert_messages_to_agui_format(self, messages: list[ChatMessage]) -> list[dict[str, Any]]:
"""Convert Agent Framework messages to AG-UI format.
Args:
messages: List of ChatMessage objects
Returns:
List of AG-UI formatted message dictionaries
"""
return agent_framework_messages_to_agui(messages)
def _get_thread_id(self, chat_options: ChatOptions) -> str:
"""Get or generate thread ID from chat options.
Args:
chat_options: Chat options containing metadata
Returns:
Thread ID string
"""
thread_id = None
if chat_options.metadata:
thread_id = chat_options.metadata.get("thread_id")
if not thread_id:
thread_id = f"thread_{uuid.uuid4().hex}"
return thread_id
async def _inner_get_response(
self,
*,
messages: MutableSequence[ChatMessage],
chat_options: ChatOptions,
**kwargs: Any,
) -> ChatResponse:
"""Internal method to get non-streaming response.
Keyword Args:
messages: List of chat messages
chat_options: Chat options for the request
**kwargs: Additional keyword arguments
Returns:
ChatResponse object
"""
return await ChatResponse.from_chat_response_generator(
self._inner_get_streaming_response(
messages=messages,
chat_options=chat_options,
**kwargs,
)
)
async def _inner_get_streaming_response(
self,
*,
messages: MutableSequence[ChatMessage],
chat_options: ChatOptions,
**kwargs: Any,
) -> AsyncIterable[ChatResponseUpdate]:
"""Internal method to get streaming response.
Keyword Args:
messages: List of chat messages
chat_options: Chat options for the request
**kwargs: Additional keyword arguments
Yields:
ChatResponseUpdate objects
"""
messages_to_send, state = self._extract_state_from_messages(messages)
thread_id = self._get_thread_id(chat_options)
run_id = f"run_{uuid.uuid4().hex}"
agui_messages = self._convert_messages_to_agui_format(messages_to_send)
# Send client tools to server so LLM knows about them
# Client tools execute via ChatAgent's @use_function_invocation wrapper
agui_tools = convert_tools_to_agui_format(chat_options.tools)
# Build set of client tool names (matches .NET clientToolSet)
# Used to distinguish client vs server tools in response stream
client_tool_set: set[str] = set()
if chat_options.tools:
for tool in chat_options.tools:
if hasattr(tool, "name"):
client_tool_set.add(tool.name) # type: ignore[arg-type]
self._last_client_tool_set = client_tool_set # type: ignore[attr-defined]
logger.debug(
"[AGUIChatClient] Preparing request",
extra={
"thread_id": thread_id,
"run_id": run_id,
"client_tools": list(client_tool_set),
"messages": [msg.text for msg in messages_to_send if msg.text],
},
)
logger.debug(f"[AGUIChatClient] Client tool set: {client_tool_set}")
converter = AGUIEventConverter()
async for event in self._http_service.post_run(
thread_id=thread_id,
run_id=run_id,
messages=agui_messages,
state=state,
tools=agui_tools,
):
logger.debug(f"[AGUIChatClient] Raw AG-UI event: {event}")
update = converter.convert_event(event)
if update is not None:
logger.debug(
"[AGUIChatClient] Converted update",
extra={"role": update.role, "contents": [type(c).__name__ for c in update.contents]},
)
# Distinguish client vs server tools
for i, content in enumerate(update.contents):
if isinstance(content, FunctionCallContent):
logger.debug(
f"[AGUIChatClient] Function call: {content.name}, in client_tool_set: {content.name in client_tool_set}"
)
if content.name in client_tool_set:
# Client tool - let @use_function_invocation execute it
if not content.additional_properties:
content.additional_properties = {}
content.additional_properties["agui_thread_id"] = thread_id
else:
# Server tool - wrap so @use_function_invocation ignores it
logger.debug(f"[AGUIChatClient] Wrapping server tool: {content.name}")
self._register_server_tool_placeholder(content.name)
update.contents[i] = ServerFunctionCallContent(content) # type: ignore
yield update
@@ -0,0 +1,209 @@
# Copyright (c) Microsoft. All rights reserved.
"""Event converter for AG-UI protocol events to Agent Framework types."""
from typing import Any
from agent_framework import (
ChatResponseUpdate,
ErrorContent,
FinishReason,
FunctionCallContent,
FunctionResultContent,
Role,
TextContent,
)
class AGUIEventConverter:
"""Converter for AG-UI events to Agent Framework types.
Handles conversion of AG-UI protocol events to ChatResponseUpdate objects
while maintaining state, aggregating content, and tracking metadata.
"""
def __init__(self) -> None:
"""Initialize the converter with fresh state."""
self.current_message_id: str | None = None
self.current_tool_call_id: str | None = None
self.current_tool_name: str | None = None
self.accumulated_tool_args: str = ""
self.thread_id: str | None = None
self.run_id: str | None = None
def convert_event(self, event: dict[str, Any]) -> ChatResponseUpdate | None:
"""Convert a single AG-UI event to ChatResponseUpdate.
Args:
event: AG-UI event dictionary
Returns:
ChatResponseUpdate if event produces content, None otherwise
Examples:
RUN_STARTED event:
.. code-block:: python
converter = AGUIEventConverter()
event = {"type": "RUN_STARTED", "threadId": "t1", "runId": "r1"}
update = converter.convert_event(event)
assert update.additional_properties["thread_id"] == "t1"
TEXT_MESSAGE_CONTENT event:
.. code-block:: python
event = {"type": "TEXT_MESSAGE_CONTENT", "messageId": "m1", "delta": "Hello"}
update = converter.convert_event(event)
assert update.contents[0].text == "Hello"
"""
event_type = event.get("type", "")
if event_type == "RUN_STARTED":
return self._handle_run_started(event)
elif event_type == "TEXT_MESSAGE_START":
return self._handle_text_message_start(event)
elif event_type == "TEXT_MESSAGE_CONTENT":
return self._handle_text_message_content(event)
elif event_type == "TEXT_MESSAGE_END":
return self._handle_text_message_end(event)
elif event_type == "TOOL_CALL_START":
return self._handle_tool_call_start(event)
elif event_type == "TOOL_CALL_ARGS":
return self._handle_tool_call_args(event)
elif event_type == "TOOL_CALL_END":
return self._handle_tool_call_end(event)
elif event_type == "TOOL_CALL_RESULT":
return self._handle_tool_call_result(event)
elif event_type == "RUN_FINISHED":
return self._handle_run_finished(event)
elif event_type == "RUN_ERROR":
return self._handle_run_error(event)
return None
def _handle_run_started(self, event: dict[str, Any]) -> ChatResponseUpdate:
"""Handle RUN_STARTED event."""
self.thread_id = event.get("threadId")
self.run_id = event.get("runId")
return ChatResponseUpdate(
role=Role.ASSISTANT,
contents=[],
additional_properties={
"thread_id": self.thread_id,
"run_id": self.run_id,
},
)
def _handle_text_message_start(self, event: dict[str, Any]) -> ChatResponseUpdate | None:
"""Handle TEXT_MESSAGE_START event."""
self.current_message_id = event.get("messageId")
return ChatResponseUpdate(
role=Role.ASSISTANT,
message_id=self.current_message_id,
contents=[],
)
def _handle_text_message_content(self, event: dict[str, Any]) -> ChatResponseUpdate:
"""Handle TEXT_MESSAGE_CONTENT event."""
message_id = event.get("messageId")
delta = event.get("delta", "")
if message_id != self.current_message_id:
self.current_message_id = message_id
return ChatResponseUpdate(
role=Role.ASSISTANT,
message_id=self.current_message_id,
contents=[TextContent(text=delta)],
)
def _handle_text_message_end(self, event: dict[str, Any]) -> ChatResponseUpdate | None:
"""Handle TEXT_MESSAGE_END event."""
return None
def _handle_tool_call_start(self, event: dict[str, Any]) -> ChatResponseUpdate:
"""Handle TOOL_CALL_START event."""
self.current_tool_call_id = event.get("toolCallId")
self.current_tool_name = event.get("toolName") or event.get("toolCallName") or event.get("tool_call_name")
self.accumulated_tool_args = ""
return ChatResponseUpdate(
role=Role.ASSISTANT,
contents=[
FunctionCallContent(
call_id=self.current_tool_call_id or "",
name=self.current_tool_name or "",
arguments="",
)
],
)
def _handle_tool_call_args(self, event: dict[str, Any]) -> ChatResponseUpdate:
"""Handle TOOL_CALL_ARGS event."""
delta = event.get("delta", "")
self.accumulated_tool_args += delta
return ChatResponseUpdate(
role=Role.ASSISTANT,
contents=[
FunctionCallContent(
call_id=self.current_tool_call_id or "",
name=self.current_tool_name or "",
arguments=delta,
)
],
)
def _handle_tool_call_end(self, event: dict[str, Any]) -> ChatResponseUpdate | None:
"""Handle TOOL_CALL_END event."""
self.accumulated_tool_args = ""
return None
def _handle_tool_call_result(self, event: dict[str, Any]) -> ChatResponseUpdate:
"""Handle TOOL_CALL_RESULT event."""
tool_call_id = event.get("toolCallId", "")
result = event.get("result") if event.get("result") is not None else event.get("content")
return ChatResponseUpdate(
role=Role.TOOL,
contents=[
FunctionResultContent(
call_id=tool_call_id,
result=result,
)
],
)
def _handle_run_finished(self, event: dict[str, Any]) -> ChatResponseUpdate:
"""Handle RUN_FINISHED event."""
return ChatResponseUpdate(
role=Role.ASSISTANT,
finish_reason=FinishReason.STOP,
contents=[],
additional_properties={
"thread_id": self.thread_id,
"run_id": self.run_id,
},
)
def _handle_run_error(self, event: dict[str, Any]) -> ChatResponseUpdate:
"""Handle RUN_ERROR event."""
error_message = event.get("message", "Unknown error")
return ChatResponseUpdate(
role=Role.ASSISTANT,
finish_reason=FinishReason.CONTENT_FILTER,
contents=[
ErrorContent(
message=error_message,
error_code="RUN_ERROR",
)
],
additional_properties={
"thread_id": self.thread_id,
"run_id": self.run_id,
},
)
@@ -107,7 +107,7 @@ class AgentFrameworkEventBridge:
# Skip text content if we're about to emit confirm_changes
# The summary should only appear after user confirms
if self.should_stop_after_confirm:
logger.debug(" >>> Skipping text content - waiting for confirm_changes response")
logger.debug("Skipping text content - waiting for confirm_changes response")
# Save the summary text to show after confirmation
self.suppressed_summary += content.text
continue
@@ -156,7 +156,7 @@ class AgentFrameworkEventBridge:
tool_call_name=content.name,
parent_message_id=self.current_message_id,
)
logger.info(f" >>> Emitting ToolCallStartEvent with name='{content.name}', id='{tool_call_id}'")
logger.info(f"Emitting ToolCallStartEvent with name='{content.name}', id='{tool_call_id}'")
events.append(tool_start_event)
# Track tool call for MessagesSnapshotEvent
@@ -186,7 +186,7 @@ class AgentFrameworkEventBridge:
# If it's a dict, convert to JSON
delta_str = json.dumps(content.arguments)
logger.info(f" >>> Emitting ToolCallArgsEvent with delta: {delta_str!r}..., id='{tool_call_id}'")
logger.info(f"Emitting ToolCallArgsEvent with delta: {delta_str!r}..., id='{tool_call_id}'")
args_event = ToolCallArgsEvent(
tool_call_id=tool_call_id,
delta=delta_str,
@@ -211,7 +211,7 @@ class AgentFrameworkEventBridge:
self.streaming_tool_args += json.dumps(content.arguments)
logger.debug(
f" >>> Predictive state: accumulated {len(self.streaming_tool_args)} chars for tool '{self.current_tool_call_name}'"
f"Predictive state: accumulated {len(self.streaming_tool_args)} chars for tool '{self.current_tool_call_name}'"
)
# Try to parse accumulated arguments (may be incomplete JSON)
@@ -262,11 +262,11 @@ class AgentFrameworkEventBridge:
else str(partial_value)
)
logger.info(
f" >>> StateDeltaEvent #{self.state_delta_count} for '{state_key}': "
f"StateDeltaEvent #{self.state_delta_count} for '{state_key}': "
f"op=replace, path=/{state_key}, value={value_preview}"
)
elif self.state_delta_count % 100 == 0:
logger.info(f" >>> StateDeltaEvent #{self.state_delta_count} emitted")
logger.info(f"StateDeltaEvent #{self.state_delta_count} emitted")
events.append(state_delta_event)
self.last_emitted_state[state_key] = partial_value
@@ -312,11 +312,11 @@ class AgentFrameworkEventBridge:
else str(state_value)
)
logger.info(
f" >>> StateDeltaEvent #{self.state_delta_count} for '{state_key}': "
f"StateDeltaEvent #{self.state_delta_count} for '{state_key}': "
f"op=replace, path=/{state_key}, value={value_preview}"
)
elif self.state_delta_count % 100 == 0: # Also log every 100th
logger.info(f" >>> StateDeltaEvent #{self.state_delta_count} emitted")
logger.info(f"StateDeltaEvent #{self.state_delta_count} emitted")
events.append(state_delta_event)
@@ -360,7 +360,7 @@ class AgentFrameworkEventBridge:
],
)
logger.info(
f" >>> Emitting StateDeltaEvent for key '{state_key}', value type: {type(state_value)}"
f"Emitting StateDeltaEvent for key '{state_key}', value type: {type(state_value)}"
)
events.append(state_delta_event)
@@ -376,13 +376,13 @@ class AgentFrameworkEventBridge:
end_event = ToolCallEndEvent(
tool_call_id=content.call_id,
)
logger.info(f" >>> Emitting ToolCallEndEvent for completed tool call '{content.call_id}'")
logger.info(f"Emitting ToolCallEndEvent for completed tool call '{content.call_id}'")
events.append(end_event)
# Log total StateDeltaEvent count for this tool call
if self.state_delta_count > 0:
logger.info(
f" >>> Tool call '{content.call_id}' complete: emitted {self.state_delta_count} StateDeltaEvents total"
f"Tool call '{content.call_id}' complete: emitted {self.state_delta_count} StateDeltaEvents total"
)
# Reset streaming accumulator and counter for next tool call
@@ -410,11 +410,13 @@ class AgentFrameworkEventBridge:
events.append(result_event)
# Track tool result for MessagesSnapshotEvent
# AG-UI protocol expects: { role: "tool", toolCallId: ..., content: ... }
# Use camelCase for Pydantic's alias_generator=to_camel
self.tool_results.append(
{
"id": result_message_id,
"role": "tool",
"tool_call_id": content.call_id,
"toolCallId": content.call_id,
"content": result_content,
}
)
@@ -422,6 +424,9 @@ class AgentFrameworkEventBridge:
# Emit MessagesSnapshotEvent with the complete conversation including tool calls and results
# This is required for CopilotKit's useCopilotAction to detect tool result
if self.pending_tool_calls and self.tool_results:
# Import message adapter
from ._message_adapters import agent_framework_messages_to_agui
# Build assistant message with tool_calls
assistant_message = {
"id": generate_event_id(),
@@ -429,14 +434,19 @@ class AgentFrameworkEventBridge:
"tool_calls": self.pending_tool_calls.copy(), # Copy the accumulated tool calls
}
# Convert Agent Framework messages to AG-UI format (adds required 'id' field)
converted_input_messages = agent_framework_messages_to_agui(self.input_messages)
# Build complete messages array: input messages + assistant message + tool results
all_messages = list(self.input_messages) + [assistant_message] + self.tool_results.copy()
all_messages = converted_input_messages + [assistant_message] + self.tool_results.copy()
# Emit MessagesSnapshotEvent using the proper event type
# Note: messages are dict[str, Any] but Pydantic will validate them as Message types
messages_snapshot_event = MessagesSnapshotEvent(
type=EventType.MESSAGES_SNAPSHOT, messages=all_messages
type=EventType.MESSAGES_SNAPSHOT,
messages=all_messages, # type: ignore[arg-type]
)
logger.info(f" >>> Emitting MessagesSnapshotEvent with {len(all_messages)} messages")
logger.info(f"Emitting MessagesSnapshotEvent with {len(all_messages)} messages")
events.append(messages_snapshot_event)
# After tool execution, emit StateSnapshotEvent if we have pending state updates
@@ -466,7 +476,7 @@ class AgentFrameworkEventBridge:
# If so, emit a confirm_changes tool call for the UI modal
tool_was_predictive = False
logger.debug(
f" >>> Checking predictive state: current_tool='{self.current_tool_call_name}', "
f"Checking predictive state: current_tool='{self.current_tool_call_name}', "
f"predict_config={list(self.predict_state_config.keys()) if self.predict_state_config else 'None'}"
)
for state_key, config in self.predict_state_config.items():
@@ -474,7 +484,7 @@ class AgentFrameworkEventBridge:
# We need to match against self.current_tool_call_name
if self.current_tool_call_name and config["tool"] == self.current_tool_call_name:
logger.info(
f" >>> Tool '{self.current_tool_call_name}' matches predictive config for state key '{state_key}'"
f"Tool '{self.current_tool_call_name}' matches predictive config for state key '{state_key}'"
)
tool_was_predictive = True
break
@@ -483,7 +493,7 @@ class AgentFrameworkEventBridge:
# Emit confirm_changes tool call sequence
confirm_call_id = generate_event_id()
logger.info(" >>> Emitting confirm_changes tool call for predictive update")
logger.info("Emitting confirm_changes tool call for predictive update")
# Track confirm_changes tool call for MessagesSnapshotEvent (so it persists after RUN_FINISHED)
self.pending_tool_calls.append(
@@ -518,6 +528,9 @@ class AgentFrameworkEventBridge:
events.append(confirm_end)
# Emit MessagesSnapshotEvent so confirm_changes persists after RUN_FINISHED
# Import message adapter
from ._message_adapters import agent_framework_messages_to_agui
# Build assistant message with pending confirm_changes tool call
assistant_message = {
"id": generate_event_id(),
@@ -525,23 +538,28 @@ class AgentFrameworkEventBridge:
"tool_calls": self.pending_tool_calls.copy(), # Includes confirm_changes
}
# Convert Agent Framework messages to AG-UI format (adds required 'id' field)
converted_input_messages = agent_framework_messages_to_agui(self.input_messages)
# Build complete messages array: input messages + assistant message + any tool results
all_messages = list(self.input_messages) + [assistant_message] + self.tool_results.copy()
all_messages = converted_input_messages + [assistant_message] + self.tool_results.copy()
# Emit MessagesSnapshotEvent
# Note: messages are dict[str, Any] but Pydantic will validate them as Message types
messages_snapshot_event = MessagesSnapshotEvent(
type=EventType.MESSAGES_SNAPSHOT, messages=all_messages
type=EventType.MESSAGES_SNAPSHOT,
messages=all_messages, # type: ignore[arg-type]
)
logger.info(
f" >>> Emitting MessagesSnapshotEvent for confirm_changes with {len(all_messages)} messages"
f"Emitting MessagesSnapshotEvent for confirm_changes with {len(all_messages)} messages"
)
events.append(messages_snapshot_event)
# Set flag to stop the run after this - we're waiting for user response
self.should_stop_after_confirm = True
logger.info(" >>> Set flag to stop run after confirm_changes")
logger.info("Set flag to stop run after confirm_changes")
elif tool_was_predictive:
logger.info(" >>> Skipping confirm_changes - require_confirmation is False")
logger.info("Skipping confirm_changes - require_confirmation is False")
# Clear pending updates and reset tool name tracker
self.pending_state_updates.clear()
@@ -580,7 +598,7 @@ class AgentFrameworkEventBridge:
# Update current state
self.current_state[state_key] = state_value
logger.info(
f" >>> Emitting StateSnapshotEvent for key '{state_key}', value type: {type(state_value)}"
f"Emitting StateSnapshotEvent for key '{state_key}', value type: {type(state_value)}"
)
# Emit state snapshot
@@ -596,7 +614,7 @@ class AgentFrameworkEventBridge:
tool_call_id=content.function_call.call_id,
)
logger.info(
f" >>> Emitting ToolCallEndEvent for approval-required tool '{content.function_call.call_id}'"
f"Emitting ToolCallEndEvent for approval-required tool '{content.function_call.call_id}'"
)
events.append(end_event)
@@ -615,7 +633,7 @@ class AgentFrameworkEventBridge:
},
},
)
logger.info(f" >>> Emitting function_approval_request custom event for '{content.function_call.name}'")
logger.info(f"Emitting function_approval_request custom event for '{content.function_call.name}'")
events.append(approval_event)
return events
@@ -0,0 +1,157 @@
# Copyright (c) Microsoft. All rights reserved.
"""HTTP service for AG-UI protocol communication."""
import json
import logging
from collections.abc import AsyncIterable
from typing import Any
import httpx
logger = logging.getLogger(__name__)
class AGUIHttpService:
"""HTTP service for AG-UI protocol communication.
Handles HTTP POST requests and Server-Sent Events (SSE) stream parsing
for the AG-UI protocol.
Examples:
Basic usage:
.. code-block:: python
service = AGUIHttpService("http://localhost:8888/")
async for event in service.post_run(
thread_id="thread_123",
run_id="run_456",
messages=[{"role": "user", "content": "Hello"}]
):
print(event["type"])
With context manager:
.. code-block:: python
async with AGUIHttpService("http://localhost:8888/") as service:
async for event in service.post_run(...):
print(event)
"""
def __init__(
self,
endpoint: str,
http_client: httpx.AsyncClient | None = None,
timeout: float = 60.0,
) -> None:
"""Initialize the HTTP service.
Args:
endpoint: AG-UI server endpoint URL (e.g., "http://localhost:8888/")
http_client: Optional httpx AsyncClient. If None, creates a new one.
timeout: Request timeout in seconds (default: 60.0)
"""
self.endpoint = endpoint.rstrip("/")
self._owns_client = http_client is None
self.http_client = http_client or httpx.AsyncClient(timeout=timeout)
async def post_run(
self,
thread_id: str,
run_id: str,
messages: list[dict[str, Any]],
state: dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
) -> AsyncIterable[dict[str, Any]]:
"""Post a run request and stream AG-UI events.
Args:
thread_id: Thread identifier for conversation continuity
run_id: Unique run identifier
messages: List of messages in AG-UI format
state: Optional state object to send to server
tools: Optional list of tools available to the agent
Yields:
AG-UI event dictionaries parsed from SSE stream
Raises:
httpx.HTTPStatusError: If the HTTP request fails
ValueError: If SSE parsing encounters invalid data
Examples:
.. code-block:: python
service = AGUIHttpService("http://localhost:8888/")
async for event in service.post_run(
thread_id="thread_abc",
run_id="run_123",
messages=[{"role": "user", "content": "Hello"}],
state={"user_context": {"name": "Alice"}}
):
if event["type"] == "TEXT_MESSAGE_CONTENT":
print(event["delta"])
"""
# Build request payload
request_data: dict[str, Any] = {
"thread_id": thread_id,
"run_id": run_id,
"messages": messages,
}
if state is not None:
request_data["state"] = state
if tools is not None:
request_data["tools"] = tools
logger.debug(
f"Posting run to {self.endpoint}: thread_id={thread_id}, run_id={run_id}, "
f"messages={len(messages)}, has_state={state is not None}, has_tools={tools is not None}"
)
# Stream the response using SSE
async with self.http_client.stream(
"POST",
self.endpoint,
json=request_data,
headers={"Accept": "text/event-stream"},
) as response:
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
logger.error(f"HTTP request failed: {e.response.status_code} - {e.response.text}")
raise
async for line in response.aiter_lines():
# Parse Server-Sent Events format
if line.startswith("data: "):
data = line[6:] # Remove "data: " prefix
try:
event = json.loads(data)
logger.debug(f"Received event: {event.get('type', 'UNKNOWN')}")
yield event
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse SSE data: {data}. Error: {e}")
# Continue processing other events instead of failing
continue
async def close(self) -> None:
"""Close the HTTP client if owned by this service.
Only closes the client if it was created by this service instance.
If an external client was provided, it remains the caller's
responsibility to close it.
"""
if self._owns_client and self.http_client:
await self.http_client.aclose()
async def __aenter__(self) -> "AGUIHttpService":
"""Enter async context manager."""
return self
async def __aexit__(self, *args: Any) -> None:
"""Exit async context manager and clean up resources."""
await self.close()
@@ -2,12 +2,13 @@
"""Message format conversion between AG-UI and Agent Framework."""
from typing import Any
from typing import Any, cast
from agent_framework import (
ChatMessage,
FunctionApprovalResponseContent,
FunctionCallContent,
FunctionResultContent,
Role,
TextContent,
)
@@ -46,7 +47,7 @@ def agui_messages_to_agent_framework(messages: list[dict[str, Any]]) -> list[Cha
result_content = msg.get("result", msg.get("content", ""))
chat_msg = ChatMessage(
role=Role.ASSISTANT, # Tool results are assistant messages
role=Role.TOOL, # Tool results must be tool role
contents=[FunctionResultContent(call_id=tool_call_id, result=result_content)],
)
@@ -56,6 +57,42 @@ def agui_messages_to_agent_framework(messages: list[dict[str, Any]]) -> list[Cha
result.append(chat_msg)
continue
# If assistant message includes tool calls, convert to FunctionCallContent(s)
tool_calls = msg.get("tool_calls") or msg.get("toolCalls")
if tool_calls:
contents: list[Any] = []
# Include any assistant text content if present
content_text = msg.get("content")
if isinstance(content_text, str) and content_text:
contents.append(TextContent(text=content_text))
# Convert each tool call entry
for tc in tool_calls:
if not isinstance(tc, dict):
continue
# Cast to typed dict for proper type inference
tc_dict = cast(dict[str, Any], tc)
tc_type = tc_dict.get("type")
if tc_type == "function":
func_data = tc_dict.get("function", {})
func_dict = cast(dict[str, Any], func_data) if isinstance(func_data, dict) else {}
call_id = str(tc_dict.get("id", ""))
name = str(func_dict.get("name", ""))
arguments = func_dict.get("arguments")
contents.append(
FunctionCallContent(
call_id=call_id,
name=name,
arguments=arguments,
)
)
chat_msg = ChatMessage(role=Role.ASSISTANT, contents=contents)
if "id" in msg:
chat_msg.message_id = msg["id"]
result.append(chat_msg)
continue
role_str = msg.get("role", "user")
# Handle tool result messages (with role="tool")
@@ -78,11 +115,11 @@ def agui_messages_to_agent_framework(messages: list[dict[str, Any]]) -> list[Cha
# Backend tool results have non-empty content WITHOUT "accepted" field
if tool_call_id and result_content and not is_approval:
# Backend tool execution - convert to FunctionResultContent
# Tool execution result - convert to FunctionResultContent with correct role
from agent_framework import FunctionResultContent
chat_msg = ChatMessage(
role=Role.ASSISTANT, # Tool results are assistant messages
role=Role.TOOL,
contents=[FunctionResultContent(call_id=tool_call_id, result=result_content)],
)
@@ -97,9 +134,8 @@ def agui_messages_to_agent_framework(messages: list[dict[str, Any]]) -> list[Cha
chat_msg = ChatMessage(
role=Role.USER, # Approval responses are user messages
contents=[TextContent(text=content)],
additional_properties={"is_tool_result": True, "tool_call_id": msg.get("toolCallId", "")},
)
# Mark this as a tool result so we can detect it later
chat_msg.metadata = {"is_tool_result": True, "tool_call_id": msg.get("toolCallId", "")} # type: ignore[attr-defined]
if "id" in msg:
chat_msg.message_id = msg["id"]
@@ -112,7 +148,7 @@ def agui_messages_to_agent_framework(messages: list[dict[str, Any]]) -> list[Cha
# Check if this message contains function approvals
if "function_approvals" in msg and msg["function_approvals"]:
# Convert function approvals to FunctionApprovalResponseContent
contents: list[Any] = []
approval_contents: list[Any] = []
for approval in msg["function_approvals"]:
# Create FunctionCallContent with the modified arguments
func_call = FunctionCallContent(
@@ -127,9 +163,9 @@ def agui_messages_to_agent_framework(messages: list[dict[str, Any]]) -> list[Cha
id=approval.get("id", ""),
function_call=func_call,
)
contents.append(approval_response)
approval_contents.append(approval_response)
chat_msg = ChatMessage(role=role, contents=contents) # type: ignore[arg-type]
chat_msg = ChatMessage(role=role, contents=approval_contents) # type: ignore[arg-type]
else:
# Regular text message
content = msg.get("content", "")
@@ -146,21 +182,44 @@ def agui_messages_to_agent_framework(messages: list[dict[str, Any]]) -> list[Cha
return result
def agent_framework_messages_to_agui(messages: list[ChatMessage]) -> list[dict[str, Any]]:
def agent_framework_messages_to_agui(messages: list[ChatMessage] | list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Convert Agent Framework messages to AG-UI format.
Args:
messages: List of Agent Framework ChatMessage objects
messages: List of Agent Framework ChatMessage objects or AG-UI dicts (already converted)
Returns:
List of AG-UI message dictionaries
"""
from ._utils import generate_event_id
result: list[dict[str, Any]] = []
for msg in messages:
# If already a dict (AG-UI format), ensure it has an ID and normalize keys for Pydantic
if isinstance(msg, dict):
# Always work on a copy to avoid mutating input
normalized_msg = msg.copy()
# Ensure ID exists
if "id" not in normalized_msg:
normalized_msg["id"] = generate_event_id()
# Normalize tool_call_id to toolCallId for Pydantic's alias_generator=to_camel
if normalized_msg.get("role") == "tool":
if "tool_call_id" in normalized_msg:
normalized_msg["toolCallId"] = normalized_msg["tool_call_id"]
del normalized_msg["tool_call_id"]
elif "toolCallId" not in normalized_msg:
# Tool message missing toolCallId - add empty string to satisfy schema
normalized_msg["toolCallId"] = ""
# Always append the normalized copy, not the original
result.append(normalized_msg)
continue
# Convert ChatMessage to AG-UI format
role = _FRAMEWORK_TO_AGUI_ROLE.get(msg.role, "user")
content_text = ""
tool_calls: list[dict[str, Any]] = []
tool_result_call_id: str | None = None
for content in msg.contents:
if isinstance(content, TextContent):
@@ -176,18 +235,32 @@ def agent_framework_messages_to_agui(messages: list[ChatMessage]) -> list[dict[s
},
}
)
elif isinstance(content, FunctionResultContent):
# Tool result content - extract call_id and result
tool_result_call_id = content.call_id
# Serialize result to string
if isinstance(content.result, dict):
import json
content_text = json.dumps(content.result) # type: ignore
elif content.result is not None:
content_text = str(content.result)
agui_msg: dict[str, Any] = {
"id": msg.message_id if msg.message_id else generate_event_id(), # Always include id
"role": role,
"content": content_text,
}
if msg.message_id:
agui_msg["id"] = msg.message_id
if tool_calls:
agui_msg["tool_calls"] = tool_calls
# If this is a tool result message, add toolCallId (using camelCase for Pydantic)
if tool_result_call_id:
agui_msg["toolCallId"] = tool_result_call_id
# Tool result messages should have role="tool"
agui_msg["role"] = "tool"
result.append(agui_msg)
return result
@@ -16,9 +16,9 @@ from ag_ui.core import (
TextMessageEndEvent,
TextMessageStartEvent,
)
from agent_framework import AgentProtocol, AgentThread, TextContent
from agent_framework import AgentProtocol, AgentThread, ChatAgent, TextContent
from ._utils import generate_event_id
from ._utils import convert_agui_tools_to_agent_framework, generate_event_id
if TYPE_CHECKING:
from ._agent import AgentConfig
@@ -142,14 +142,10 @@ class HumanInTheLoopOrchestrator(Orchestrator):
True if last message is a tool result
"""
msg = context.last_message
if not msg or not hasattr(msg, "metadata"):
if not msg:
return False
metadata = getattr(msg, "metadata", None)
if not metadata:
return False
return bool(metadata.get("is_tool_result", False))
return bool(msg.additional_properties.get("is_tool_result", False))
async def run(
self,
@@ -274,8 +270,10 @@ class DefaultOrchestrator(Orchestrator):
current_state: dict[str, Any] = initial_state.copy() if initial_state else {}
# Check if agent uses structured outputs (response_format)
chat_options = getattr(context.agent, "chat_options", None)
response_format = getattr(chat_options, "response_format", None) if chat_options else None
# Use isinstance to narrow type for proper attribute access
response_format = None
if isinstance(context.agent, ChatAgent):
response_format = context.agent.chat_options.response_format
skip_text_content = response_format is not None
# Create event bridge
@@ -334,9 +332,8 @@ class DefaultOrchestrator(Orchestrator):
if context.messages:
await thread.on_new_messages(context.messages)
# Get the last message as the new input
new_message = context.last_message
if not new_message:
# Use the full incoming message batch to preserve tool-call adjacency
if not context.messages:
logger.warning("No messages provided in AG-UI input")
yield event_bridge.create_run_finished_event()
return
@@ -362,11 +359,68 @@ Never replace existing data - always append or merge."""
)
messages_to_run.append(state_context_msg)
messages_to_run.append(new_message)
# Preserve order from client to satisfy provider constraints (assistant tool_calls must
# immediately precede tool result messages). Using the full batch avoids reordering.
messages_to_run.extend(context.messages)
# Handle client tools for hybrid execution
# Client sends tool metadata, server merges with its own tools.
# Client tools have func=None (declaration-only), so @use_function_invocation
# will return the function call without executing (passes back to client).
from agent_framework import BaseChatClient
client_tools = convert_agui_tools_to_agent_framework(context.input_data.get("tools"))
# Extract server tools - use type narrowing when possible
server_tools: list[Any] = []
if isinstance(context.agent, ChatAgent):
server_tools = context.agent.chat_options.tools or []
else:
# AgentProtocol allows duck-typed implementations - fallback to attribute access
# This supports test mocks and custom agent implementations
try:
chat_options_attr = getattr(context.agent, "chat_options", None)
if chat_options_attr is not None:
server_tools = getattr(chat_options_attr, "tools", None) or []
except AttributeError:
pass
# Register client tools as additional (declaration-only) so they are not executed on server
if client_tools:
if isinstance(context.agent, ChatAgent):
# Type-safe path for ChatAgent
chat_client = context.agent.chat_client
if (
isinstance(chat_client, BaseChatClient)
and chat_client.function_invocation_configuration is not None
):
chat_client.function_invocation_configuration.additional_tools = client_tools
logger.debug(
f"[TOOLS] Registered {len(client_tools)} client tools as additional_tools (declaration-only)"
)
else:
# Fallback for AgentProtocol implementations (test mocks, custom agents)
try:
chat_client_attr = getattr(context.agent, "chat_client", None)
if chat_client_attr is not None:
fic = getattr(chat_client_attr, "function_invocation_configuration", None)
if fic is not None:
fic.additional_tools = client_tools # type: ignore[attr-defined]
logger.debug(
f"[TOOLS] Registered {len(client_tools)} client tools as additional_tools (declaration-only)"
)
except AttributeError:
pass
combined_tools: list[Any] = []
if server_tools:
combined_tools.extend(server_tools)
if client_tools:
combined_tools.extend(client_tools)
# Collect all updates to get the final structured output
all_updates: list[Any] = []
async for update in context.agent.run_stream(messages_to_run, thread=thread):
async for update in context.agent.run_stream(messages_to_run, thread=thread, tools=combined_tools or None):
all_updates.append(update)
events = await event_bridge.from_agent_run_update(update)
for event in events:
@@ -374,7 +428,7 @@ Never replace existing data - always append or merge."""
# After agent completes, check if we should stop (waiting for user to confirm changes)
if event_bridge.should_stop_after_confirm:
logger.info(" >>> Stopping run after confirm_changes - waiting for user response")
logger.info("Stopping run after confirm_changes - waiting for user response")
yield event_bridge.create_run_finished_event()
return
@@ -4,10 +4,13 @@
import copy
import uuid
from collections.abc import Callable, MutableMapping, Sequence
from dataclasses import asdict, is_dataclass
from datetime import date, datetime
from typing import Any
from agent_framework import AIFunction, ToolProtocol
def generate_event_id() -> str:
"""Generate a unique event ID."""
@@ -55,3 +58,109 @@ def make_json_safe(obj: Any) -> Any: # noqa: ANN401
if isinstance(obj, dict):
return {key: make_json_safe(value) for key, value in obj.items()} # type: ignore[misc]
return str(obj)
def convert_agui_tools_to_agent_framework(
agui_tools: list[dict[str, Any]] | None,
) -> list[AIFunction[Any, Any]] | None:
"""Convert AG-UI tool definitions to Agent Framework AIFunction declarations.
Creates declaration-only AIFunction instances (no executable implementation).
These are used to tell the LLM about available tools. The actual execution
happens on the client side via @use_function_invocation.
CRITICAL: These tools MUST have func=None so that declaration_only returns True.
This prevents the server from trying to execute client-side tools.
Args:
agui_tools: List of AG-UI tool definitions with name, description, parameters
Returns:
List of AIFunction declarations, or None if no tools provided
"""
if not agui_tools:
return None
result: list[AIFunction[Any, Any]] = []
for tool_def in agui_tools:
# Create declaration-only AIFunction (func=None means no implementation)
# When func=None, the declaration_only property returns True,
# which tells @use_function_invocation to return the function call
# without executing it (so it can be sent back to the client)
func: AIFunction[Any, Any] = AIFunction(
name=tool_def.get("name", ""),
description=tool_def.get("description", ""),
func=None, # CRITICAL: Makes declaration_only=True
input_model=tool_def.get("parameters", {}),
)
result.append(func)
return result
def convert_tools_to_agui_format(
tools: (
ToolProtocol
| Callable[..., Any]
| MutableMapping[str, Any]
| Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]
| None
),
) -> list[dict[str, Any]] | None:
"""Convert tools to AG-UI format.
This sends only the metadata (name, description, JSON schema) to the server.
The actual executable implementation stays on the client side.
The @use_function_invocation decorator handles client-side execution when
the server requests a function.
Args:
tools: Tools to convert (single tool or sequence of tools)
Returns:
List of tool specifications in AG-UI format, or None if no tools provided
"""
if not tools:
return None
# Normalize to list
if not isinstance(tools, list):
tool_list: list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]] = [tools] # type: ignore[list-item]
else:
tool_list = tools # type: ignore[assignment]
results: list[dict[str, Any]] = []
for tool in tool_list:
if isinstance(tool, dict):
# Already in dict format, pass through
results.append(tool) # type: ignore[arg-type]
elif isinstance(tool, AIFunction):
# Convert AIFunction to AG-UI tool format
results.append(
{
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters(),
}
)
elif callable(tool):
# Convert callable to AIFunction first, then to AG-UI format
from agent_framework import ai_function
ai_func = ai_function(tool)
results.append(
{
"name": ai_func.name,
"description": ai_func.description,
"parameters": ai_func.parameters(),
}
)
elif isinstance(tool, ToolProtocol):
# Handle other ToolProtocol implementations
# For now, we'll skip non-AIFunction tools as they may not have
# the parameters() method. This matches .NET behavior which only
# converts AIFunctionDeclaration instances.
continue
return results if results else None
@@ -14,7 +14,7 @@ pip install agent-framework-ag-ui
from fastapi import FastAPI
from agent_framework import ChatAgent
from agent_framework.azure import AzureOpenAIChatClient
from agent_framework_ag_ui import add_agent_framework_fastapi_endpoint
from agent_framework.ag_ui import add_agent_framework_fastapi_endpoint
# Create your agent
agent = ChatAgent(
@@ -104,7 +104,7 @@ State is injected as system messages and updated via predictive state updates:
```python
from agent_framework import ChatAgent
from agent_framework.azure import AzureOpenAIChatClient
from agent_framework_ag_ui import AgentFrameworkAgent
from agent_framework.ag_ui import AgentFrameworkAgent
# Create your agent
agent = ChatAgent(
@@ -141,7 +141,7 @@ Predictive state updates automatically stream tool arguments as optimistic state
```python
from agent_framework import ChatAgent
from agent_framework.azure import AzureOpenAIChatClient
from agent_framework_ag_ui import AgentFrameworkAgent
from agent_framework.ag_ui import AgentFrameworkAgent
# Create your agent
agent = ChatAgent(
@@ -170,7 +170,7 @@ Provide domain-specific confirmation messages:
from typing import Any
from agent_framework import ChatAgent
from agent_framework.azure import AzureOpenAIChatClient
from agent_framework_ag_ui import AgentFrameworkAgent, ConfirmationStrategy
from agent_framework.ag_ui import AgentFrameworkAgent, ConfirmationStrategy
class CustomConfirmationStrategy(ConfirmationStrategy):
def on_approval_accepted(self, steps: list[dict[str, Any]]) -> str:
@@ -216,7 +216,7 @@ def sensitive_action(param: str) -> str:
Add custom execution flows by implementing the Orchestrator pattern:
```python
from agent_framework_ag_ui._orchestrators import Orchestrator, ExecutionContext
from agent_framework.ag_ui._orchestrators import Orchestrator, ExecutionContext
class MyCustomOrchestrator(Orchestrator):
def can_handle(self, context: ExecutionContext) -> bool:
@@ -128,7 +128,7 @@ class TaskStepsAgentWithExecution:
import uuid
logger = logging.getLogger(__name__)
logger.info(">>> TaskStepsAgentWithExecution.run_agent() called - wrapper is active")
logger.info("TaskStepsAgentWithExecution.run_agent() called - wrapper is active")
# First, run the base agent to generate the plan - buffer text messages
final_state: dict[str, Any] | None = None
@@ -138,41 +138,41 @@ class TaskStepsAgentWithExecution:
async for event in self._base_agent.run_agent(input_data):
event_type_str = str(event.type) if hasattr(event, "type") else type(event).__name__
logger.info(f">>> Processing event: {event_type_str}")
logger.info(f"Processing event: {event_type_str}")
match event:
case StateSnapshotEvent(snapshot=snapshot):
final_state = snapshot
logger.info(f">>> Captured STATE_SNAPSHOT event with state: {final_state}")
logger.info(f"Captured STATE_SNAPSHOT event with state: {final_state}")
yield event
case RunFinishedEvent():
run_finished_event = event
logger.info(">>> Captured RUN_FINISHED event - will send after step execution and summary")
logger.info("Captured RUN_FINISHED event - will send after step execution and summary")
case ToolCallStartEvent(tool_call_id=call_id):
tool_call_id = call_id
logger.info(f">>> Captured tool_call_id: {tool_call_id}")
logger.info(f"Captured tool_call_id: {tool_call_id}")
yield event
case TextMessageStartEvent() | TextMessageContentEvent() | TextMessageEndEvent():
buffered_text_events.append(event)
logger.info(f">>> Buffered {event_type_str} from first LLM call")
logger.info(f"Buffered {event_type_str} from first LLM call")
case _:
logger.info(f">>> Yielding event immediately: {event_type_str}")
logger.info(f"Yielding event immediately: {event_type_str}")
yield event
logger.info(f">>> Base agent completed. Final state: {final_state}")
logger.info(f"Base agent completed. Final state: {final_state}")
# Now simulate executing the steps
if final_state and "steps" in final_state:
steps = final_state["steps"]
logger.info(f">>> Starting step execution simulation for {len(steps)} steps")
logger.info(f"Starting step execution simulation for {len(steps)} steps")
for i in range(len(steps)):
logger.info(f">>> Simulating execution of step {i + 1}/{len(steps)}: {steps[i].get('description')}")
logger.info(f"Simulating execution of step {i + 1}/{len(steps)}: {steps[i].get('description')}")
await asyncio.sleep(1.0) # Simulate work
# Update step to completed
steps[i]["status"] = "completed"
logger.info(f">>> Step {i + 1} marked as completed")
logger.info(f"Step {i + 1} marked as completed")
# Send delta event with manual JSON patch format
delta_event = StateDeltaEvent(
@@ -185,7 +185,7 @@ class TaskStepsAgentWithExecution:
}
],
)
logger.info(f">>> Yielding StateDeltaEvent for step {i + 1}")
logger.info(f"Yielding StateDeltaEvent for step {i + 1}")
yield delta_event
# Send final snapshot
@@ -193,11 +193,11 @@ class TaskStepsAgentWithExecution:
type=EventType.STATE_SNAPSHOT,
snapshot={"steps": steps},
)
logger.info(">>> Yielding final StateSnapshotEvent with all steps completed")
logger.info("Yielding final StateSnapshotEvent with all steps completed")
yield final_snapshot
# SECOND LLM call: Stream summary from chat client directly
logger.info(">>> Making SECOND LLM call to generate summary after step execution")
logger.info("Making SECOND LLM call to generate summary after step execution")
# Get the underlying chat agent and client
chat_agent = self._base_agent.agent # type: ignore
@@ -236,7 +236,7 @@ class TaskStepsAgentWithExecution:
)
# Stream the LLM response and manually emit text events
logger.info(">>> Calling chat client for summary")
logger.info("Calling chat client for summary")
message_id = str(uuid.uuid4())
@@ -268,7 +268,7 @@ class TaskStepsAgentWithExecution:
type=EventType.TEXT_MESSAGE_END,
message_id=message_id,
)
logger.info(f">>> Summary complete: {accumulated_text}")
logger.info(f"Summary complete: {accumulated_text}")
# Build complete message for persistence
summary_message = {
@@ -285,7 +285,7 @@ class TaskStepsAgentWithExecution:
messages=final_messages,
)
except Exception as e:
logger.error(f">>> Error generating summary: {e}")
logger.error(f"Error generating summary: {e}")
# Generate a new message ID for the error
error_message_id = str(uuid.uuid4())
# Yield TEXT_MESSAGE_START for error
@@ -306,11 +306,11 @@ class TaskStepsAgentWithExecution:
message_id=error_message_id,
)
else:
logger.warning(f">>> No steps found in final_state to execute. final_state={final_state}")
logger.warning(f"No steps found in final_state to execute. final_state={final_state}")
# Finally send the original RUN_FINISHED event
if run_finished_event:
logger.info(">>> Yielding original RUN_FINISHED event")
logger.info("Yielding original RUN_FINISHED event")
yield run_finished_event
+185 -429
View File
@@ -2,6 +2,135 @@
The AG-UI (Agent UI) protocol provides a standardized way for client applications to interact with AI agents over HTTP. This tutorial demonstrates how to build both server and client applications using the AG-UI protocol with Python.
## Quick Start - Client Examples
If you want to quickly try out the AG-UI client, we provide three ready-to-use examples:
### Basic Interactive Client (`client.py`)
A simple command-line chat client that demonstrates:
- Streaming responses in real-time
- Automatic thread management for conversation continuity
- Direct `AGUIChatClient` usage (caller manages message history)
**Run:**
```bash
python client.py
```
**Note:** This example sends only the current message to the server. The server is responsible for maintaining conversation history using the thread_id.
### Advanced Features Client (`client_advanced.py`)
Demonstrates advanced capabilities:
- Tool/function calling
- Both streaming and non-streaming responses
- Multi-turn conversations
- Error handling patterns
**Run:**
```bash
python client_advanced.py
```
**Note:** This example shows direct `AGUIChatClient` usage. Tool execution and conversation continuity depend on server-side configuration and capabilities.
### ChatAgent Integration (`client_with_agent.py`)
Best practice example using `ChatAgent` wrapper with **AgentThread**
- **AgentThread** maintains conversation state
- Client-side conversation history management via `thread.message_store`
- **Hybrid tool execution**: client-side + server-side tools simultaneously
- Full conversation history sent on each request
- Tool calling with conversation context
**To demonstrate hybrid tools:**
1. **Start server with server-side tool** (Terminal 1):
```bash
# Server has get_time_zone tool
python server.py
```
2. **Run client with client-side tool** (Terminal 2):
```bash
# Client has get_weather tool
python client_with_agent.py
```
All examples require a running AG-UI server (see Step 1 below for setup).
## Understanding AG-UI Architecture
### Thread Management
The AG-UI protocol supports two approaches to conversation history:
1. **Server-Managed Threads** (client.py, client_advanced.py)
- Client sends only the current message + thread_id
- Server maintains full conversation history
- Requires server to support stateful thread storage
- Lighter network payload
2. **Client-Managed History** (client_with_agent.py)
- Client maintains full conversation history locally
- Full message history sent with each request
- Works with any AG-UI server (stateful or stateless)
The `ChatAgent` wrapper (used in client_with_agent.py) collects messages from local storage and sends the full history to `AGUIChatClient`, which then forwards everything to the server.
### Tool/Function Calling
The AG-UI protocol supports **hybrid tool execution** - both client-side AND server-side tools can coexist in the same conversation.
**The Hybrid Pattern** (client_with_agent.py):
```
Client defines: Server defines:
- get_weather() - get_current_time()
- read_sensors() - get_server_forecast()
User: "What's the weather in SF and what time is it?"
ChatAgent sends: full history + tool definitions for get_weather, read_sensors
Server LLM decides: "I need get_weather('SF') and get_current_time()"
Server executes get_current_time() → "2025-11-11 14:30:00 UTC"
Server sends function call request → get_weather('SF')
ChatAgent intercepts get_weather call → executes locally
Client sends result → "Sunny, 72°F"
Server combines both results → "It's sunny and 72°F in SF, and the current time is 2:30 PM UTC"
Client receives final response
```
**How it works:**
1. **Client-Side Tools** (`client_with_agent.py`):
- Tools defined in ChatAgent's `tools` parameter execute locally
- Tool metadata (name, description, schema) sent to server for planning
- When server requests client tool → client intercepts → executes locally → sends result
2. **Server-Side Tools**:
- Defined in server agent's configuration
- Server executes directly without client involvement
- Results included in server's response
3. **Hybrid Pattern (Both Together)**:
- Server LLM sees ALL tool definitions (client + server)
- Decides which to use based on task
- Server tools execute server-side
- Client tools execute client-side
**Direct AGUIChatClient Usage** (client_advanced.py):
Even without ChatAgent wrapper, client-side tools work:
- Tools passed in ChatOptions execute locally
- Server can also have its own tools
- Hybrid execution works automatically
## What is AG-UI?
AG-UI is a protocol that enables:
@@ -35,13 +164,13 @@ The AG-UI server hosts your AI agent and exposes it via HTTP endpoints using Fas
### Install Required Packages
```bash
pip install agent-framework-ag-ui agent-framework-core fastapi uvicorn
pip install agent-framework-ag-ui
```
Or using uv:
```bash
uv pip install agent-framework-ag-ui agent-framework-core fastapi uvicorn
uv pip install agent-framework-ag-ui
```
### Server Code
@@ -57,17 +186,20 @@ import os
from agent_framework import ChatAgent
from agent_framework.azure import AzureOpenAIChatClient
from agent_framework_ag_ui import add_agent_framework_fastapi_endpoint
from agent_framework.ag_ui import add_agent_framework_fastapi_endpoint
from fastapi import FastAPI
# Read required configuration
endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
deployment_name = os.environ.get("AZURE_OPENAI_DEPLOYMENT_NAME")
api_key = os.environ.get("AZURE_OPENAI_API_KEY")
if not endpoint:
raise ValueError("AZURE_OPENAI_ENDPOINT environment variable is required")
if not deployment_name:
raise ValueError("AZURE_OPENAI_DEPLOYMENT_NAME environment variable is required")
if not api_key:
raise ValueError("AZURE_OPENAI_API_KEY environment variable is required")
# Create the AI agent
agent = ChatAgent(
@@ -76,6 +208,7 @@ agent = ChatAgent(
chat_client=AzureOpenAIChatClient(
endpoint=endpoint,
deployment_name=deployment_name,
api_key=api_key,
),
)
@@ -137,12 +270,14 @@ The server will start listening on `http://127.0.0.1:5100`.
## Step 2: Creating an AG-UI Client
The AG-UI client connects to the remote server and displays streaming responses.
The AG-UI client connects to the remote server and displays streaming responses. The `AGUIChatClient` is a built-in implementation that integrates with the Agent Framework's standard chat interface.
### Install Required Packages
The `AGUIChatClient` is included in the `agent-framework-ag-ui` package (already installed if you installed the server packages).
```bash
pip install httpx
pip install agent-framework-ag-ui
```
### Client Code
@@ -152,122 +287,61 @@ Create a file named `client.py`:
```python
# Copyright (c) Microsoft. All rights reserved.
"""AG-UI client example."""
"""AG-UI client example using AGUIChatClient."""
import asyncio
import json
import os
from typing import AsyncIterator
import httpx
class AGUIClient:
"""Simple AG-UI protocol client."""
def __init__(self, server_url: str):
"""Initialize the client.
Args:
server_url: The AG-UI server endpoint URL
"""
self.server_url = server_url
self.thread_id: str | None = None
async def send_message(self, message: str) -> AsyncIterator[dict]:
"""Send a message and stream the response.
Args:
message: The user message to send
Yields:
AG-UI events from the server
"""
# Prepare the request
request_data = {
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": message},
]
}
# Include thread_id if we have one (for conversation continuity)
if self.thread_id:
request_data["thread_id"] = self.thread_id
# Stream the response
async with httpx.AsyncClient(timeout=60.0) as client:
async with client.stream(
"POST",
self.server_url,
json=request_data,
headers={"Accept": "text/event-stream"},
) as response:
response.raise_for_status()
async for line in response.aiter_lines():
# Parse Server-Sent Events format
if line.startswith("data: "):
data = line[6:] # Remove "data: " prefix
try:
event = json.loads(data)
yield event
# Capture thread_id from RUN_STARTED event
if event.get("type") == "RUN_STARTED" and not self.thread_id:
self.thread_id = event.get("threadId")
except json.JSONDecodeError:
continue
from agent_framework import TextContent
from agent_framework.ag_ui import AGUIChatClient
async def main():
"""Main client loop."""
"""Main client loop demonstrating AGUIChatClient usage."""
# Get server URL from environment or use default
server_url = os.environ.get("AGUI_SERVER_URL", "http://127.0.0.1:5100/")
print(f"Connecting to AG-UI server at: {server_url}\n")
client = AGUIClient(server_url)
# Create client with context manager for automatic cleanup
async with AGUIChatClient(endpoint=server_url) as client:
thread_id: str | None = None
try:
while True:
# Get user input
message = input("\nUser (:q or quit to exit): ")
if not message.strip():
print("Request cannot be empty.")
continue
try:
while True:
# Get user input
message = input("\nUser (:q or quit to exit): ")
if not message.strip():
print("Request cannot be empty.")
continue
if message.lower() in (":q", "quit"):
break
if message.lower() in (":q", "quit"):
break
# Send message and display streaming response
print("\n", end="")
async for event in client.send_message(message):
event_type = event.get("type", "")
# Send message and stream the response
print("\nAssistant: ", end="", flush=True)
if event_type == "RUN_STARTED":
thread_id = event.get("threadId", "")
run_id = event.get("runId", "")
print(f"\033[93m[Run Started - Thread: {thread_id}, Run: {run_id}]\033[0m")
# Use metadata to maintain conversation continuity
metadata = {"thread_id": thread_id} if thread_id else None
elif event_type == "TEXT_MESSAGE_CONTENT":
# Stream text content in cyan
print(f"\033[96m{event.get('delta', '')}\033[0m", end="", flush=True)
async for update in client.get_streaming_response(message, metadata=metadata):
# Extract thread ID from first update
if not thread_id and update.additional_properties:
thread_id = update.additional_properties.get("thread_id")
if thread_id:
print(f"\n[Thread: {thread_id}]")
print("Assistant: ", end="", flush=True)
elif event_type == "RUN_FINISHED":
thread_id = event.get("threadId", "")
run_id = event.get("runId", "")
print(f"\n\033[92m[Run Finished - Thread: {thread_id}, Run: {run_id}]\033[0m")
# Stream text content as it arrives
for content in update.contents:
if isinstance(content, TextContent) and content.text:
print(content.text, end="", flush=True)
elif event_type == "RUN_ERROR":
error_message = event.get("message", "Unknown error")
print(f"\n\033[91m[Run Error - Message: {error_message}]\033[0m")
print() # New line after response
print()
except KeyboardInterrupt:
print("\n\nExiting...")
except Exception as e:
print(f"\n\033[91mAn error occurred: {e}\033[0m")
except KeyboardInterrupt:
print("\n\nExiting...")
except Exception as e:
print(f"\nAn error occurred: {e}")
if __name__ == "__main__":
@@ -276,17 +350,13 @@ if __name__ == "__main__":
### Key Concepts
- **Server-Sent Events (SSE)**: The protocol uses SSE format (`data: {json}\n\n`)
- **Event Types**: Different events provide metadata and content (all event types use UPPERCASE with underscores):
- `RUN_STARTED`: Signals the agent has started processing
- `TEXT_MESSAGE_START`: Signals the start of a text message from the agent
- `TEXT_MESSAGE_CONTENT`: Incremental text streamed from the agent (with `delta` field)
- `TEXT_MESSAGE_END`: Signals the end of a text message
- `RUN_FINISHED`: Signals successful completion
- `RUN_ERROR`: Error information if something goes wrong
- **Field Naming**: Event fields use camelCase (e.g., `threadId`, `runId`, `messageId`) when accessing JSON events
- **Thread Management**: The `threadId` maintains conversation context across requests
- **Client-Side Instructions**: System messages are sent from the client
- **`AGUIChatClient`**: Built-in client that implements the Agent Framework's `BaseChatClient` interface
- **Automatic Event Handling**: The client automatically converts AG-UI events to Agent Framework types
- **Thread Management**: Pass `thread_id` in metadata to maintain conversation context across requests
- **Streaming Responses**: Use `get_streaming_response()` for real-time streaming or `get_response()` for non-streaming
- **Context Manager**: Use `async with` for automatic cleanup of HTTP connections
- **Standard Interface**: Works with all Agent Framework patterns (ChatAgent, tools, etc.)
- **Hybrid Tool Execution**: Supports both client-side and server-side tools executing together in the same conversation
### Configure and Run the Client
@@ -312,327 +382,13 @@ Connecting to AG-UI server at: http://127.0.0.1:5100/
User (:q or quit to exit): What is the capital of France?
[Run Started - Thread: abc123, Run: xyz789]
The capital of France is Paris. It is known for its rich history, culture,
[Thread: abc123]
Assistant: The capital of France is Paris. It is known for its rich history, culture,
and iconic landmarks such as the Eiffel Tower and the Louvre Museum.
[Run Finished - Thread: abc123, Run: xyz789]
User (:q or quit to exit): Tell me a fun fact about space
[Run Started - Thread: abc123, Run: def456]
Here's a fun fact: A day on Venus is longer than its year! Venus takes
about 243 Earth days to rotate once on its axis, but only about 225 Earth
days to orbit the Sun.
[Run Finished - Thread: abc123, Run: def456]
User (:q or quit to exit): :q
```
### Color-Coded Output
The client displays different content types with distinct colors:
- **Yellow**: Run started notifications
- **Cyan**: Agent text responses (streamed in real-time)
- **Green**: Run completion notifications
- **Red**: Error messages
## Testing with curl (Optional)
Before running the client, you can test the server manually using curl:
```bash
curl -N http://127.0.0.1:5100/ \
-H "Content-Type: application/json" \
-H "Accept: text/event-stream" \
-d '{
"messages": [
{"role": "user", "content": "What is the capital of France?"}
]
}'
```
You should see Server-Sent Events streaming back:
```
data: {"type":"RUN_STARTED","threadId":"...","runId":"..."}
data: {"type":"TEXT_MESSAGE_START","messageId":"...","role":"assistant"}
data: {"type":"TEXT_MESSAGE_CONTENT","messageId":"...","delta":"The"}
data: {"type":"TEXT_MESSAGE_CONTENT","messageId":"...","delta":" capital"}
...
data: {"type":"TEXT_MESSAGE_END","messageId":"..."}
data: {"type":"RUN_FINISHED","threadId":"...","runId":"..."}
```
## How It Works
### Server-Side Flow
1. Client sends HTTP POST request with messages
2. FastAPI endpoint receives the request
3. `AgentFrameworkAgent` wrapper orchestrates the execution
4. Agent processes the messages using Agent Framework
5. `AgentFrameworkEventBridge` converts agent updates to AG-UI events
6. Responses are streamed back as Server-Sent Events (SSE)
7. Connection closes when the run completes
### Client-Side Flow
1. Client sends HTTP POST request to server endpoint
2. Server responds with SSE stream
3. Client parses incoming `data:` lines as JSON events
4. Each event is displayed based on its type
5. `threadId` is captured for conversation continuity
6. Stream completes when `RUN_FINISHED` event arrives
### Protocol Details
The AG-UI protocol uses:
- **HTTP POST** for sending requests
- **Server-Sent Events (SSE)** for streaming responses
- **JSON** for event serialization
- **Thread IDs** for maintaining conversation context
- **Run IDs** for tracking individual executions
- **Event type naming**: UPPERCASE with underscores (e.g., `RUN_STARTED`, `TEXT_MESSAGE_CONTENT`)
- **Field naming**: camelCase (e.g., `threadId`, `runId`, `messageId`)
## Advanced Features
The Python AG-UI implementation supports all 7 AG-UI features:
### 1. Backend Tool Rendering
Add tools to your agent for backend execution:
```python
from typing import Any
from agent_framework import ChatAgent, ai_function
from agent_framework.azure import AzureOpenAIChatClient
@ai_function
def get_weather(location: str) -> dict[str, Any]:
"""Get weather for a location."""
return {"temperature": 72, "conditions": "sunny"}
agent = ChatAgent(
name="weather_agent",
instructions="Use tools to help users.",
chat_client=AzureOpenAIChatClient(
endpoint="https://your-resource.openai.azure.com/",
deployment_name="gpt-4o-mini",
),
tools=[get_weather],
)
```
The client will receive `TOOL_CALL_START`, `TOOL_CALL_ARGS`, `TOOL_CALL_END`, and `TOOL_CALL_RESULT` events.
### 2. Human in the Loop
Request user confirmation before executing tools:
```python
from fastapi import FastAPI
from agent_framework import ChatAgent
from agent_framework.azure import AzureOpenAIChatClient
from agent_framework_ag_ui import AgentFrameworkAgent, add_agent_framework_fastapi_endpoint
agent = ChatAgent(
name="my_agent",
instructions="You are a helpful assistant.",
chat_client=AzureOpenAIChatClient(
endpoint="https://your-resource.openai.azure.com/",
deployment_name="gpt-4o-mini",
),
)
wrapped_agent = AgentFrameworkAgent(
agent=agent,
require_confirmation=True, # Enable human-in-the-loop
)
app = FastAPI()
add_agent_framework_fastapi_endpoint(app, wrapped_agent, "/")
```
The client receives tool approval request events and can send approval responses.
### 3. State Management
Share state between client and server:
```python
wrapped_agent = AgentFrameworkAgent(
agent=agent,
state_schema={
"location": {"type": "string"},
"preferences": {"type": "object"},
},
)
```
Events include `STATE_SNAPSHOT` and `STATE_DELTA` for bidirectional sync.
### 4. Predictive State Updates
Stream tool arguments as optimistic state updates:
```python
wrapped_agent = AgentFrameworkAgent(
agent=agent,
predict_state_config={
"location": {"tool": "get_weather", "tool_argument": "location"}
},
require_confirmation=False, # Auto-update without confirmation
)
```
State updates stream in real-time as the LLM generates tool arguments.
## Common Patterns
### Custom Server Configuration
```python
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
# Add CORS for web clients
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
add_agent_framework_fastapi_endpoint(app, agent, "/agent")
```
### Multiple Agents
```python
app = FastAPI()
weather_agent = ChatAgent(name="weather", ...)
finance_agent = ChatAgent(name="finance", ...)
add_agent_framework_fastapi_endpoint(app, weather_agent, "/weather")
add_agent_framework_fastapi_endpoint(app, finance_agent, "/finance")
```
### Custom Client Timeout
```python
async with httpx.AsyncClient(timeout=300.0) as client:
async with client.stream("POST", server_url, ...) as response:
async for line in response.aiter_lines():
# Process events
pass
```
### Error Handling
```python
try:
async for event in client.send_message(message):
if event.get("type") == "RUN_ERROR":
error_msg = event.get("message", "Unknown error")
print(f"Error: {error_msg}")
# Handle error appropriately
except httpx.HTTPError as e:
print(f"HTTP error: {e}")
except Exception as e:
print(f"Unexpected error: {e}")
```
### Conversation Continuity
The client automatically maintains `threadId` across requests:
```python
client = AGUIClient(server_url)
# First message
async for event in client.send_message("Hello"):
# Client captures threadId from RUN_STARTED
pass
# Second message - uses same threadId
async for event in client.send_message("Continue our conversation"):
# Conversation context is maintained
pass
```
## AG-UI Event Reference
### Core Events
| Event Type | Description | Key Fields |
|------------|-------------|------------|
| `RUN_STARTED` | Agent execution started | `threadId`, `runId` |
| `RUN_FINISHED` | Agent execution completed | `threadId`, `runId` |
| `RUN_ERROR` | Agent execution error | `message` |
### Text Message Events
| Event Type | Description | Key Fields |
|------------|-------------|------------|
| `TEXT_MESSAGE_START` | Start of agent text message | `messageId`, `role` |
| `TEXT_MESSAGE_CONTENT` | Streaming text content | `messageId`, `delta` |
| `TEXT_MESSAGE_END` | End of agent text message | `messageId` |
### Tool Events
| Event Type | Description | Key Fields |
|------------|-------------|------------|
| `TOOL_CALL_START` | Tool call initiated | `toolCallId`, `toolCallName` |
| `TOOL_CALL_ARGS` | Tool arguments streaming | `toolCallId`, `delta` |
| `TOOL_CALL_END` | Tool call complete | `toolCallId` |
| `TOOL_CALL_RESULT` | Tool execution result | `toolCallId`, `content` |
### State Events
| Event Type | Description | Key Fields |
|------------|-------------|------------|
| `STATE_SNAPSHOT` | Complete state | `snapshot` |
| `STATE_DELTA` | State changes (JSON Patch) | `delta` |
### Other Events
| Event Type | Description | Key Fields |
|------------|-------------|------------|
| `MESSAGES_SNAPSHOT` | Conversation history | `messages` |
| `CUSTOM` | Custom event data | `name`, `value` |
## Next Steps
Now that you understand the basics of AG-UI, you can:
- **Add Tools**: Create custom `@ai_function` tools for your domain
- **Web Integration**: Build React/Vue frontends using the AG-UI protocol
- **State Management**: Implement shared state for generative UI applications
- **Human-in-the-Loop**: Add approval workflows for sensitive operations
- **Deployment**: Deploy to Azure Container Apps or Azure App Service
- **Multi-Agent Systems**: Coordinate multiple specialized agents
- **Monitoring**: Add logging and OpenTelemetry for observability
## Additional Resources
- [AG-UI Examples](../agent_framework_ag_ui_examples/README.md): Complete working examples for all 7 features
- [Agent Framework Documentation](../../core/README.md): Learn more about creating agents
- [AG-UI Protocol Spec](https://docs.ag-ui.com/): Official protocol documentation
## Troubleshooting
### Connection Refused
+46 -96
View File
@@ -1,121 +1,71 @@
# Copyright (c) Microsoft. All rights reserved.
"""AG-UI client example."""
"""AG-UI client example using AGUIChatClient.
This example demonstrates how to use the AGUIChatClient to connect to
a remote AG-UI server and interact with it using the Agent Framework's
standard chat interface.
"""
import asyncio
import json
import os
from collections.abc import AsyncIterator
import httpx
class AGUIClient:
"""Simple AG-UI protocol client."""
def __init__(self, server_url: str):
"""Initialize the client.
Args:
server_url: The AG-UI server endpoint URL
"""
self.server_url = server_url
self.thread_id: str | None = None
async def send_message(self, message: str) -> AsyncIterator[dict]:
"""Send a message and stream the response.
Args:
message: The user message to send
Yields:
AG-UI events from the server
"""
# Prepare the request
request_data: dict[str, object] = {
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": message},
]
}
# Include thread_id if we have one (for conversation continuity)
if self.thread_id:
request_data["thread_id"] = self.thread_id
# Stream the response
async with httpx.AsyncClient(timeout=60.0) as client:
async with client.stream(
"POST",
self.server_url,
json=request_data,
headers={"Accept": "text/event-stream"},
) as response:
response.raise_for_status()
async for line in response.aiter_lines():
# Parse Server-Sent Events format
if line.startswith("data: "):
data = line[6:] # Remove "data: " prefix
try:
event = json.loads(data)
yield event
# Capture thread_id from RUN_STARTED event
if event.get("type") == "RUN_STARTED" and not self.thread_id:
self.thread_id = event.get("threadId")
except json.JSONDecodeError:
continue
from agent_framework_ag_ui import AGUIChatClient
async def main():
"""Main client loop."""
"""Main client loop demonstrating AGUIChatClient usage."""
# Get server URL from environment or use default
server_url = os.environ.get("AGUI_SERVER_URL", "http://127.0.0.1:5100/")
print(f"Connecting to AG-UI server at: {server_url}\n")
print("Using AGUIChatClient with automatic thread management and Agent Framework integration.\n")
client = AGUIClient(server_url)
# Create client with context manager for automatic cleanup
async with AGUIChatClient(endpoint=server_url) as client:
thread_id: str | None = None
try:
while True:
# Get user input
message = input("\nUser (:q or quit to exit): ")
if not message.strip():
print("Request cannot be empty.")
continue
try:
while True:
# Get user input
message = input("\nUser (:q or quit to exit): ")
if not message.strip():
print("Request cannot be empty.")
continue
if message.lower() in (":q", "quit"):
break
if message.lower() in (":q", "quit"):
break
# Send message and display streaming response
print("\n", end="")
async for event in client.send_message(message):
event_type = event.get("type", "")
# Send message and stream the response
print("\nAssistant: ", end="", flush=True)
if event_type == "RUN_STARTED":
thread_id = event.get("threadId", "")
run_id = event.get("runId", "")
print(f"\033[93m[Run Started - Thread: {thread_id}, Run: {run_id}]\033[0m")
# Use metadata to maintain conversation continuity
metadata = {"thread_id": thread_id} if thread_id else None
elif event_type == "TEXT_MESSAGE_CONTENT":
# Stream text content in cyan
print(f"\033[96m{event.get('delta', '')}\033[0m", end="", flush=True)
async for update in client.get_streaming_response(message, metadata=metadata):
# Extract and display thread ID from first update
if not thread_id and update.additional_properties:
thread_id = update.additional_properties.get("thread_id")
if thread_id:
print(f"\n\033[93m[Thread: {thread_id}]\033[0m", end="", flush=True)
print("\nAssistant: ", end="", flush=True)
elif event_type == "RUN_FINISHED":
thread_id = event.get("threadId", "")
run_id = event.get("runId", "")
print(f"\n\033[92m[Run Finished - Thread: {thread_id}, Run: {run_id}]\033[0m")
# Display text content as it streams
from agent_framework import TextContent
elif event_type == "RUN_ERROR":
error_message = event.get("message", "Unknown error")
print(f"\n\033[91m[Run Error - Message: {error_message}]\033[0m")
for content in update.contents:
if isinstance(content, TextContent) and content.text:
print(f"\033[96m{content.text}\033[0m", end="", flush=True)
print()
# Display finish reason if present
if update.finish_reason:
print(f"\n\033[92m[Finished: {update.finish_reason}]\033[0m", end="", flush=True)
except KeyboardInterrupt:
print("\n\nExiting...")
except Exception as e:
print(f"\n\033[91mAn error occurred: {e}\033[0m")
print() # New line after response
except KeyboardInterrupt:
print("\n\nExiting...")
except Exception as e:
print(f"\n\033[91mAn error occurred: {e}\033[0m")
if __name__ == "__main__":
@@ -0,0 +1,235 @@
# Copyright (c) Microsoft. All rights reserved.
"""Advanced AG-UI client example with tools and features.
This example demonstrates advanced AGUIChatClient features including:
- Tool/function calling
- Non-streaming responses
- Multiple conversation turns
- Error handling
"""
import asyncio
import os
from agent_framework import ai_function
from agent_framework_ag_ui import AGUIChatClient
@ai_function
def get_weather(location: str) -> str:
"""Get the current weather for a location.
Args:
location: The city or location name
"""
# Simulate weather lookup
weather_data = {
"seattle": "Rainy, 55°F",
"san francisco": "Foggy, 62°F",
"new york": "Sunny, 68°F",
"london": "Cloudy, 52°F",
}
return weather_data.get(location.lower(), f"Weather data not available for {location}")
@ai_function
def calculate(a: float, b: float, operation: str) -> str:
"""Perform basic arithmetic operations.
Args:
a: First number
b: Second number
operation: Operation to perform (add, subtract, multiply, divide)
"""
try:
if operation == "add":
result = a + b
elif operation == "subtract":
result = a - b
elif operation == "multiply":
result = a * b
elif operation == "divide":
result = a / b
else:
return f"Unsupported operation: {operation}"
return f"The result is: {result}"
except Exception as e:
return f"Error calculating: {e}"
async def streaming_example(client: AGUIChatClient, thread_id: str | None = None):
"""Demonstrate streaming responses."""
print("\n" + "=" * 60)
print("STREAMING EXAMPLE")
print("=" * 60)
metadata = {"thread_id": thread_id} if thread_id else None
print("\nUser: Tell me a short joke\n")
print("Assistant: ", end="", flush=True)
async for update in client.get_streaming_response("Tell me a short joke", metadata=metadata):
if not thread_id and update.additional_properties:
thread_id = update.additional_properties.get("thread_id")
from agent_framework import TextContent
for content in update.contents:
if isinstance(content, TextContent) and content.text:
print(content.text, end="", flush=True)
print("\n")
return thread_id
async def non_streaming_example(client: AGUIChatClient, thread_id: str | None = None):
"""Demonstrate non-streaming responses."""
print("\n" + "=" * 60)
print("NON-STREAMING EXAMPLE")
print("=" * 60)
metadata = {"thread_id": thread_id} if thread_id else None
print("\nUser: What is 2 + 2?\n")
response = await client.get_response("What is 2 + 2?", metadata=metadata)
print(f"Assistant: {response.text}")
if response.additional_properties:
thread_id = response.additional_properties.get("thread_id")
print(f"\n[Thread: {thread_id}]")
return thread_id
async def tool_example(client: AGUIChatClient, thread_id: str | None = None):
"""Demonstrate sending tool definitions to the server.
IMPORTANT: When using AGUIChatClient directly (without ChatAgent wrapper):
- Tools are sent as DEFINITIONS only
- No automatic client-side execution (no function invocation middleware)
- Server must have matching tool implementations to execute them
For CLIENT-SIDE tool execution (like .NET AGUIClient sample):
- Use ChatAgent wrapper with tools
- See client_with_agent.py for the hybrid pattern
- ChatAgent middleware intercepts and executes client tools locally
- Server can have its own tools that execute server-side
- Both client and server tools work together in same conversation
This example sends tool definitions and assumes server-side execution.
"""
print("\n" + "=" * 60)
print("TOOL DEFINITION EXAMPLE")
print("=" * 60)
metadata = {"thread_id": thread_id} if thread_id else None
print("\nUser: What's the weather in Seattle?\n")
print("Sending tool definitions to server...")
print("(Server must be configured with matching tools to execute them)\n")
response = await client.get_response(
"What's the weather in Seattle?", tools=[get_weather, calculate], metadata=metadata
)
print(f"Assistant: {response.text}")
# Show tool calls if any
from agent_framework import FunctionCallContent
tool_called = False
for message in response.messages:
for content in message.contents:
if isinstance(content, FunctionCallContent):
print(f"\n[Tool Called: {content.name}]")
tool_called = True
if not tool_called:
print("\n[Note: No tools were called - server may not be configured for tool execution]")
if response.additional_properties:
thread_id = response.additional_properties.get("thread_id")
return thread_id
async def conversation_example(client: AGUIChatClient):
"""Demonstrate multi-turn conversation.
Note: Conversation continuity depends on the server maintaining thread state.
Some servers may require explicit message history to be sent with each request.
"""
print("\n" + "=" * 60)
print("MULTI-TURN CONVERSATION EXAMPLE")
print("=" * 60)
print("\nNote: This example uses thread_id for context. Server must support thread-based state.\n")
# First turn
print("User: My name is Alice\n")
response1 = await client.get_response("My name is Alice")
print(f"Assistant: {response1.text}")
thread_id = response1.additional_properties.get("thread_id")
print(f"\n[Thread: {thread_id}]")
# Second turn - using same thread
print("\nUser: What's my name?\n")
response2 = await client.get_response("What's my name?", metadata={"thread_id": thread_id})
print(f"Assistant: {response2.text}")
# Check if context was maintained
if "alice" not in response2.text.lower():
print("\n[Note: Server may not maintain thread context - consider using ChatAgent for history management]")
# Third turn
print("\nUser: Can you also tell me what 10 * 5 is?\n")
response3 = await client.get_response(
"Can you also tell me what 10 * 5 is?", metadata={"thread_id": thread_id}, tools=[calculate]
)
print(f"Assistant: {response3.text}")
async def main():
"""Run all examples."""
# Get server URL from environment or use default
server_url = os.environ.get("AGUI_SERVER_URL", "http://127.0.0.1:5100/")
print("=" * 60)
print("AG-UI Chat Client Advanced Examples")
print("=" * 60)
print(f"\nServer: {server_url}")
print("\nThese examples demonstrate various AGUIChatClient features:")
print(" 1. Streaming responses")
print(" 2. Non-streaming responses")
print(" 3. Tool/function calling")
print(" 4. Multi-turn conversations")
try:
async with AGUIChatClient(endpoint=server_url) as client:
# Run examples in sequence
thread_id = await streaming_example(client)
thread_id = await non_streaming_example(client, thread_id)
await tool_example(client, thread_id)
# Separate conversation with new thread
await conversation_example(client)
print("\n" + "=" * 60)
print("All examples completed successfully!")
print("=" * 60)
except ConnectionError as e:
print(f"\n\033[91mConnection Error: {e}\033[0m")
print("\nMake sure an AG-UI server is running at the specified endpoint.")
except Exception as e:
print(f"\n\033[91mError: {e}\033[0m")
import traceback
traceback.print_exc()
if __name__ == "__main__":
asyncio.run(main())
@@ -0,0 +1,186 @@
# Copyright (c) Microsoft. All rights reserved.
"""Example showing ChatAgent with AGUIChatClient for hybrid tool execution.
This demonstrates the HYBRID pattern matching .NET AGUIClient implementation:
1. AgentThread Pattern (like .NET):
- Create thread with agent.get_new_thread()
- Pass thread to agent.run_stream() on each turn
- Thread automatically maintains conversation history via message_store
2. Hybrid Tool Execution:
- AGUIChatClient has @use_function_invocation decorator
- Client-side tools (get_weather) can execute locally when server requests them
- Server may also have its own tools that execute server-side
- Both work together: server LLM decides which tool to call, decorator handles client execution
This matches .NET pattern: thread maintains state, tools execute on appropriate side.
"""
import asyncio
import logging
import os
from agent_framework import ChatAgent, FunctionCallContent, FunctionResultContent, TextContent, ai_function
from agent_framework_ag_ui import AGUIChatClient
# Enable debug logging
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
@ai_function(description="Get the current weather for a location.")
def get_weather(location: str) -> str:
"""Get the current weather for a location.
Args:
location: The city or location name
"""
print(f"[CLIENT] get_weather tool called with location: {location}")
weather_data = {
"seattle": "Rainy, 55°F",
"san francisco": "Foggy, 62°F",
"new york": "Sunny, 68°F",
"london": "Cloudy, 52°F",
}
result = weather_data.get(location.lower(), f"Weather data not available for {location}")
print(f"[CLIENT] get_weather returning: {result}")
return result
async def main():
"""Demonstrate ChatAgent + AGUIChatClient hybrid tool execution.
This matches the .NET pattern from Program.cs where:
- AIAgent agent = chatClient.CreateAIAgent(tools: [...])
- AgentThread thread = agent.GetNewThread()
- RunStreamingAsync(messages, thread)
Python equivalent:
- agent = ChatAgent(chat_client=AGUIChatClient(...), tools=[...])
- thread = agent.get_new_thread() # Creates thread with message_store
- agent.run_stream(message, thread=thread) # Thread accumulates history
"""
server_url = os.environ.get("AGUI_SERVER_URL", "http://127.0.0.1:5100/")
print("=" * 70)
print("ChatAgent + AGUIChatClient: Hybrid Tool Execution")
print("=" * 70)
print(f"\nServer: {server_url}")
print("\nThis example demonstrates:")
print(" 1. AgentThread maintains conversation state (like .NET)")
print(" 2. Client-side tools execute locally via @use_function_invocation")
print(" 3. Server may have additional tools that execute server-side")
print(" 4. HYBRID: Client and server tools work together simultaneously\n")
try:
# Create remote client in async context manager
async with AGUIChatClient(endpoint=server_url) as remote_client:
# Wrap in ChatAgent for conversation history management
agent = ChatAgent(
name="remote_assistant",
instructions="You are a helpful assistant. Remember user information across the conversation.",
chat_client=remote_client,
tools=[get_weather],
)
# Create a thread to maintain conversation state (like .NET AgentThread)
thread = agent.get_new_thread()
print("=" * 70)
print("CONVERSATION WITH HISTORY")
print("=" * 70)
# Turn 1: Introduce
print("\nUser: My name is Alice and I live in Seattle\n")
async for chunk in agent.run_stream("My name is Alice and I live in Seattle", thread=thread):
if chunk.text:
print(chunk.text, end="", flush=True)
print("\n")
# Turn 2: Ask about name (tests history)
print("User: What's my name?\n")
async for chunk in agent.run_stream("What's my name?", thread=thread):
if chunk.text:
print(chunk.text, end="", flush=True)
print("\n")
# Turn 3: Ask about location (tests history)
print("User: Where do I live?\n")
async for chunk in agent.run_stream("Where do I live?", thread=thread):
if chunk.text:
print(chunk.text, end="", flush=True)
print("\n")
# Turn 4: Test client-side tool (get_weather is client-side)
print("User: What's the weather forecast for today in Seattle?\n")
async for chunk in agent.run_stream("What's the weather forecast for today in Seattle?", thread=thread):
if chunk.text:
print(chunk.text, end="", flush=True)
print("\n")
# Turn 5: Test server-side tool (get_time_zone is server-side only)
print("User: What time zone is Seattle in?\n")
async for chunk in agent.run_stream("What time zone is Seattle in?", thread=thread):
if chunk.text:
print(chunk.text, end="", flush=True)
print("\n")
# Show thread state
if thread.message_store:
def _preview_for_message(m) -> str:
# Prefer plain text when present
if getattr(m, "text", ""):
t = m.text
return (t[:60] + "...") if len(t) > 60 else t
# Build from contents when no direct text
parts: list[str] = []
for c in getattr(m, "contents", []) or []:
if isinstance(c, FunctionCallContent):
args = c.arguments
if isinstance(args, dict):
try:
import json as _json
args_str = _json.dumps(args)
except Exception:
args_str = str(args)
else:
args_str = str(args or "{}")
parts.append(f"tool_call {c.name} {args_str}")
elif isinstance(c, FunctionResultContent):
parts.append(f"tool_result[{c.call_id}]: {str(c.result)[:40]}")
elif isinstance(c, TextContent):
if c.text:
parts.append(c.text)
else:
typename = getattr(c, "type", c.__class__.__name__)
parts.append(f"<{typename}>")
preview = " | ".join(parts) if parts else ""
return (preview[:60] + "...") if len(preview) > 60 else preview
messages = await thread.message_store.list_messages()
print(f"\n[THREAD STATE] {len(messages)} messages in thread's message_store")
for i, msg in enumerate(messages[-6:], 1): # Show last 6
role = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
text_preview = _preview_for_message(msg)
print(f" {i}. [{role}]: {text_preview}")
except ConnectionError as e:
print(f"\n\033[91mConnection Error: {e}\033[0m")
print("\nMake sure an AG-UI server is running at the specified endpoint.")
except Exception as e:
print(f"\n\033[91mError: {e}\033[0m")
import traceback
traceback.print_exc()
if __name__ == "__main__":
asyncio.run(main())
@@ -1,18 +1,26 @@
# Copyright (c) Microsoft. All rights reserved.
"""AG-UI server example."""
"""AG-UI server example with server-side tools."""
import logging
import os
from agent_framework import ChatAgent
from agent_framework import ChatAgent, ai_function
from agent_framework.ag_ui import add_agent_framework_fastapi_endpoint
from agent_framework.azure import AzureOpenAIChatClient
from dotenv import load_dotenv
from fastapi import FastAPI
from agent_framework_ag_ui import add_agent_framework_fastapi_endpoint
load_dotenv()
# Enable debug logging
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
# Read required configuration
endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
deployment_name = os.environ.get("AZURE_OPENAI_CHAT_DEPLOYMENT_NAME")
@@ -22,14 +30,43 @@ if not endpoint:
if not deployment_name:
raise ValueError("AZURE_OPENAI_CHAT_DEPLOYMENT_NAME environment variable is required")
# Create the AI agent
# Server-side tool (executes on server)
@ai_function(description="Get the time zone for a location.")
def get_time_zone(location: str) -> str:
"""Get the time zone for a location.
Args:
location: The city or location name
"""
print(f"[SERVER] get_time_zone tool called with location: {location}")
timezone_data = {
"seattle": "Pacific Time (UTC-8)",
"san francisco": "Pacific Time (UTC-8)",
"new york": "Eastern Time (UTC-5)",
"london": "Greenwich Mean Time (UTC+0)",
}
result = timezone_data.get(location.lower(), f"Time zone data not available for {location}")
print(f"[SERVER] get_time_zone returning: {result}")
return result
# Create the AI agent with ONLY server-side tools
# IMPORTANT: Do NOT include tools that the client provides!
# In this example:
# - get_time_zone: SERVER-ONLY tool (only server has this)
# - get_weather: CLIENT-ONLY tool (client provides this, server should NOT include it)
# The client will send get_weather tool metadata so the LLM knows about it,
# and @use_function_invocation on AGUIChatClient will execute it client-side.
# This matches the .NET AG-UI hybrid execution pattern.
agent = ChatAgent(
name="AGUIAssistant",
instructions="You are a helpful assistant.",
instructions="You are a helpful assistant. Use get_weather for weather and get_time_zone for time zones.",
chat_client=AzureOpenAIChatClient(
endpoint=endpoint,
deployment_name=deployment_name,
),
tools=[get_time_zone], # ONLY server-side tools
)
# Create FastAPI app
@@ -41,4 +78,4 @@ add_agent_framework_fastapi_endpoint(app, agent, "/")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=5100)
uvicorn.run(app, host="127.0.0.1", port=5100, log_level="debug", access_log=True)
+317
View File
@@ -0,0 +1,317 @@
"""Tests for AGUIChatClient."""
import json
from agent_framework import ChatMessage, ChatOptions, FunctionCallContent, Role, ai_function
from agent_framework_ag_ui._client import AGUIChatClient, ServerFunctionCallContent
class TestAGUIChatClient:
"""Test suite for AGUIChatClient."""
async def test_client_initialization(self) -> None:
"""Test client initialization."""
client = AGUIChatClient(endpoint="http://localhost:8888/")
assert client._http_service is not None
assert client._http_service.endpoint.startswith("http://localhost:8888")
async def test_client_context_manager(self) -> None:
"""Test client as async context manager."""
async with AGUIChatClient(endpoint="http://localhost:8888/") as client:
assert client is not None
async def test_extract_state_from_messages_no_state(self) -> None:
"""Test state extraction when no state is present."""
client = AGUIChatClient(endpoint="http://localhost:8888/")
messages = [
ChatMessage(role="user", text="Hello"),
ChatMessage(role="assistant", text="Hi there"),
]
result_messages, state = client._extract_state_from_messages(messages)
assert result_messages == messages
assert state is None
async def test_extract_state_from_messages_with_state(self) -> None:
"""Test state extraction from last message."""
import base64
client = AGUIChatClient(endpoint="http://localhost:8888/")
state_data = {"key": "value", "count": 42}
state_json = json.dumps(state_data)
state_b64 = base64.b64encode(state_json.encode("utf-8")).decode("utf-8")
from agent_framework import DataContent
messages = [
ChatMessage(role="user", text="Hello"),
ChatMessage(
role="user",
contents=[DataContent(uri=f"data:application/json;base64,{state_b64}")],
),
]
result_messages, state = client._extract_state_from_messages(messages)
assert len(result_messages) == 1
assert result_messages[0].text == "Hello"
assert state == state_data
async def test_extract_state_invalid_json(self) -> None:
"""Test state extraction with invalid JSON."""
import base64
client = AGUIChatClient(endpoint="http://localhost:8888/")
invalid_json = "not valid json"
state_b64 = base64.b64encode(invalid_json.encode("utf-8")).decode("utf-8")
from agent_framework import DataContent
messages = [
ChatMessage(
role="user",
contents=[DataContent(uri=f"data:application/json;base64,{state_b64}")],
),
]
result_messages, state = client._extract_state_from_messages(messages)
assert result_messages == messages
assert state is None
async def test_convert_messages_to_agui_format(self) -> None:
"""Test message conversion to AG-UI format."""
client = AGUIChatClient(endpoint="http://localhost:8888/")
messages = [
ChatMessage(role=Role.USER, text="What is the weather?"),
ChatMessage(role=Role.ASSISTANT, text="Let me check.", message_id="msg_123"),
]
agui_messages = client._convert_messages_to_agui_format(messages)
assert len(agui_messages) == 2
assert agui_messages[0]["role"] == "user"
assert agui_messages[0]["content"] == "What is the weather?"
assert agui_messages[1]["role"] == "assistant"
assert agui_messages[1]["content"] == "Let me check."
assert agui_messages[1]["id"] == "msg_123"
async def test_get_thread_id_from_metadata(self) -> None:
"""Test thread ID extraction from metadata."""
client = AGUIChatClient(endpoint="http://localhost:8888/")
chat_options = ChatOptions(metadata={"thread_id": "existing_thread_123"})
thread_id = client._get_thread_id(chat_options)
assert thread_id == "existing_thread_123"
async def test_get_thread_id_generation(self) -> None:
"""Test automatic thread ID generation."""
client = AGUIChatClient(endpoint="http://localhost:8888/")
chat_options = ChatOptions()
thread_id = client._get_thread_id(chat_options)
assert thread_id.startswith("thread_")
assert len(thread_id) > 7
async def test_get_streaming_response(self, monkeypatch) -> None:
"""Test streaming response method."""
mock_events = [
{"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"},
{"type": "TEXT_MESSAGE_CONTENT", "messageId": "msg_1", "delta": "Hello"},
{"type": "TEXT_MESSAGE_CONTENT", "messageId": "msg_1", "delta": " world"},
{"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"},
]
async def mock_post_run(*args, **kwargs):
for event in mock_events:
yield event
client = AGUIChatClient(endpoint="http://localhost:8888/")
monkeypatch.setattr(client._http_service, "post_run", mock_post_run)
messages = [ChatMessage(role="user", text="Test message")]
chat_options = ChatOptions()
updates = []
async for update in client._inner_get_streaming_response(messages=messages, chat_options=chat_options):
updates.append(update)
assert len(updates) == 4
assert updates[0].additional_properties["thread_id"] == "thread_1"
assert updates[1].contents[0].text == "Hello"
assert updates[2].contents[0].text == " world"
async def test_get_response_non_streaming(self, monkeypatch) -> None:
"""Test non-streaming response method."""
mock_events = [
{"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"},
{"type": "TEXT_MESSAGE_CONTENT", "messageId": "msg_1", "delta": "Complete response"},
{"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"},
]
async def mock_post_run(*args, **kwargs):
for event in mock_events:
yield event
client = AGUIChatClient(endpoint="http://localhost:8888/")
monkeypatch.setattr(client._http_service, "post_run", mock_post_run)
messages = [ChatMessage(role="user", text="Test message")]
chat_options = ChatOptions()
response = await client._inner_get_response(messages=messages, chat_options=chat_options)
assert response is not None
assert len(response.messages) > 0
assert "Complete response" in response.text
async def test_tool_handling(self, monkeypatch) -> None:
"""Test that client tool metadata is sent to server.
Client tool metadata (name, description, schema) is sent to server for planning.
When server requests a client function, @use_function_invocation decorator
intercepts and executes it locally. This matches .NET AG-UI implementation.
"""
from agent_framework import ai_function
@ai_function
def test_tool(param: str) -> str:
"""Test tool."""
return "result"
mock_events = [
{"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"},
{"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"},
]
async def mock_post_run(*args, **kwargs):
# Client tool metadata should be sent to server
tools = kwargs.get("tools")
assert tools is not None
assert len(tools) == 1
assert tools[0]["name"] == "test_tool"
assert tools[0]["description"] == "Test tool."
assert "parameters" in tools[0]
for event in mock_events:
yield event
client = AGUIChatClient(endpoint="http://localhost:8888/")
monkeypatch.setattr(client._http_service, "post_run", mock_post_run)
messages = [ChatMessage(role="user", text="Test with tools")]
chat_options = ChatOptions(tools=[test_tool])
response = await client._inner_get_response(messages=messages, chat_options=chat_options)
assert response is not None
async def test_server_tool_calls_unwrapped_after_invocation(self, monkeypatch) -> None:
"""Ensure server-side tool calls are exposed as FunctionCallContent after processing."""
mock_events = [
{"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"},
{"type": "TOOL_CALL_START", "toolCallId": "call_1", "toolName": "get_time_zone"},
{"type": "TOOL_CALL_ARGS", "toolCallId": "call_1", "delta": '{"location": "Seattle"}'},
{"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"},
]
async def mock_post_run(*args, **kwargs):
for event in mock_events:
yield event
client = AGUIChatClient(endpoint="http://localhost:8888/")
monkeypatch.setattr(client._http_service, "post_run", mock_post_run)
messages = [ChatMessage(role="user", text="Test server tool execution")]
chat_options = ChatOptions()
updates = []
async for update in client.get_streaming_response(messages, chat_options=chat_options):
updates.append(update)
function_calls = [
content for update in updates for content in update.contents if isinstance(content, FunctionCallContent)
]
assert function_calls
assert function_calls[0].name == "get_time_zone"
assert not any(
isinstance(content, ServerFunctionCallContent) for update in updates for content in update.contents
)
async def test_server_tool_calls_not_executed_locally(self, monkeypatch) -> None:
"""Server tools should not trigger local function invocation even when client tools exist."""
@ai_function
def client_tool() -> str:
"""Client tool stub."""
return "client"
mock_events = [
{"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"},
{"type": "TOOL_CALL_START", "toolCallId": "call_1", "toolName": "get_time_zone"},
{"type": "TOOL_CALL_ARGS", "toolCallId": "call_1", "delta": '{"location": "Seattle"}'},
{"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"},
]
async def mock_post_run(*args, **kwargs):
for event in mock_events:
yield event
async def fake_auto_invoke(*args, **kwargs):
function_call = kwargs.get("function_call_content") or args[0]
raise AssertionError(f"Unexpected local execution of server tool: {getattr(function_call, 'name', '?')}")
monkeypatch.setattr("agent_framework._tools._auto_invoke_function", fake_auto_invoke)
client = AGUIChatClient(endpoint="http://localhost:8888/")
monkeypatch.setattr(client._http_service, "post_run", mock_post_run)
messages = [ChatMessage(role="user", text="Test server tool execution")]
chat_options = ChatOptions(tool_choice="auto", tools=[client_tool])
async for _ in client.get_streaming_response(messages, chat_options=chat_options):
pass
async def test_state_transmission(self, monkeypatch) -> None:
"""Test state is properly transmitted to server."""
import base64
state_data = {"user_id": "123", "session": "abc"}
state_json = json.dumps(state_data)
state_b64 = base64.b64encode(state_json.encode("utf-8")).decode("utf-8")
from agent_framework import DataContent
messages = [
ChatMessage(role="user", text="Hello"),
ChatMessage(
role="user",
contents=[DataContent(uri=f"data:application/json;base64,{state_b64}")],
),
]
mock_events = [
{"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"},
{"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"},
]
async def mock_post_run(*args, **kwargs):
assert kwargs.get("state") == state_data
for event in mock_events:
yield event
client = AGUIChatClient(endpoint="http://localhost:8888/")
monkeypatch.setattr(client._http_service, "post_run", mock_post_run)
chat_options = ChatOptions()
response = await client._inner_get_response(messages=messages, chat_options=chat_options)
assert response is not None
@@ -0,0 +1,287 @@
"""Tests for AG-UI event converter."""
from agent_framework import FinishReason, Role
from agent_framework_ag_ui._event_converters import AGUIEventConverter
class TestAGUIEventConverter:
"""Test suite for AGUIEventConverter."""
def test_run_started_event(self) -> None:
"""Test conversion of RUN_STARTED event."""
converter = AGUIEventConverter()
event = {
"type": "RUN_STARTED",
"threadId": "thread_123",
"runId": "run_456",
}
update = converter.convert_event(event)
assert update is not None
assert update.role == Role.ASSISTANT
assert update.additional_properties["thread_id"] == "thread_123"
assert update.additional_properties["run_id"] == "run_456"
assert converter.thread_id == "thread_123"
assert converter.run_id == "run_456"
def test_text_message_start_event(self) -> None:
"""Test conversion of TEXT_MESSAGE_START event."""
converter = AGUIEventConverter()
event = {
"type": "TEXT_MESSAGE_START",
"messageId": "msg_789",
}
update = converter.convert_event(event)
assert update is not None
assert update.role == Role.ASSISTANT
assert update.message_id == "msg_789"
assert converter.current_message_id == "msg_789"
def test_text_message_content_event(self) -> None:
"""Test conversion of TEXT_MESSAGE_CONTENT event."""
converter = AGUIEventConverter()
event = {
"type": "TEXT_MESSAGE_CONTENT",
"messageId": "msg_1",
"delta": "Hello",
}
update = converter.convert_event(event)
assert update is not None
assert update.role == Role.ASSISTANT
assert update.message_id == "msg_1"
assert len(update.contents) == 1
assert update.contents[0].text == "Hello"
def test_text_message_streaming(self) -> None:
"""Test streaming text across multiple TEXT_MESSAGE_CONTENT events."""
converter = AGUIEventConverter()
events = [
{"type": "TEXT_MESSAGE_CONTENT", "messageId": "msg_1", "delta": "Hello"},
{"type": "TEXT_MESSAGE_CONTENT", "messageId": "msg_1", "delta": " world"},
{"type": "TEXT_MESSAGE_CONTENT", "messageId": "msg_1", "delta": "!"},
]
updates = [converter.convert_event(event) for event in events]
assert all(update is not None for update in updates)
assert all(update.message_id == "msg_1" for update in updates)
assert updates[0].contents[0].text == "Hello"
assert updates[1].contents[0].text == " world"
assert updates[2].contents[0].text == "!"
def test_text_message_end_event(self) -> None:
"""Test conversion of TEXT_MESSAGE_END event."""
converter = AGUIEventConverter()
event = {
"type": "TEXT_MESSAGE_END",
"messageId": "msg_1",
}
update = converter.convert_event(event)
assert update is None
def test_tool_call_start_event(self) -> None:
"""Test conversion of TOOL_CALL_START event."""
converter = AGUIEventConverter()
event = {
"type": "TOOL_CALL_START",
"toolCallId": "call_123",
"toolName": "get_weather",
}
update = converter.convert_event(event)
assert update is not None
assert update.role == Role.ASSISTANT
assert len(update.contents) == 1
assert update.contents[0].call_id == "call_123"
assert update.contents[0].name == "get_weather"
assert update.contents[0].arguments == ""
assert converter.current_tool_call_id == "call_123"
assert converter.current_tool_name == "get_weather"
def test_tool_call_start_with_tool_call_name(self) -> None:
"""Ensure TOOL_CALL_START with toolCallName still sets the tool name."""
converter = AGUIEventConverter()
event = {
"type": "TOOL_CALL_START",
"toolCallId": "call_abc",
"toolCallName": "get_weather",
}
update = converter.convert_event(event)
assert update is not None
assert update.contents[0].name == "get_weather"
assert converter.current_tool_name == "get_weather"
def test_tool_call_start_with_tool_call_name_snake_case(self) -> None:
"""Support tool_call_name snake_case field for backwards compatibility."""
converter = AGUIEventConverter()
event = {
"type": "TOOL_CALL_START",
"toolCallId": "call_snake",
"tool_call_name": "get_weather",
}
update = converter.convert_event(event)
assert update is not None
assert update.contents[0].name == "get_weather"
assert converter.current_tool_name == "get_weather"
def test_tool_call_args_streaming(self) -> None:
"""Test streaming tool arguments across multiple TOOL_CALL_ARGS events."""
converter = AGUIEventConverter()
converter.current_tool_call_id = "call_123"
converter.current_tool_name = "search"
events = [
{"type": "TOOL_CALL_ARGS", "delta": '{"query": "'},
{"type": "TOOL_CALL_ARGS", "delta": 'latest news"}'},
]
updates = [converter.convert_event(event) for event in events]
assert all(update is not None for update in updates)
assert updates[0].contents[0].arguments == '{"query": "'
assert updates[1].contents[0].arguments == 'latest news"}'
assert converter.accumulated_tool_args == '{"query": "latest news"}'
def test_tool_call_end_event(self) -> None:
"""Test conversion of TOOL_CALL_END event."""
converter = AGUIEventConverter()
converter.accumulated_tool_args = '{"location": "Seattle"}'
event = {
"type": "TOOL_CALL_END",
"toolCallId": "call_123",
}
update = converter.convert_event(event)
assert update is None
assert converter.accumulated_tool_args == ""
def test_tool_call_result_event(self) -> None:
"""Test conversion of TOOL_CALL_RESULT event."""
converter = AGUIEventConverter()
event = {
"type": "TOOL_CALL_RESULT",
"toolCallId": "call_123",
"result": {"temperature": 22, "condition": "sunny"},
}
update = converter.convert_event(event)
assert update is not None
assert update.role == Role.TOOL
assert len(update.contents) == 1
assert update.contents[0].call_id == "call_123"
assert update.contents[0].result == {"temperature": 22, "condition": "sunny"}
def test_run_finished_event(self) -> None:
"""Test conversion of RUN_FINISHED event."""
converter = AGUIEventConverter()
converter.thread_id = "thread_123"
converter.run_id = "run_456"
event = {
"type": "RUN_FINISHED",
"threadId": "thread_123",
"runId": "run_456",
}
update = converter.convert_event(event)
assert update is not None
assert update.role == Role.ASSISTANT
assert update.finish_reason == FinishReason.STOP
assert update.additional_properties["thread_id"] == "thread_123"
assert update.additional_properties["run_id"] == "run_456"
def test_run_error_event(self) -> None:
"""Test conversion of RUN_ERROR event."""
converter = AGUIEventConverter()
converter.thread_id = "thread_123"
converter.run_id = "run_456"
event = {
"type": "RUN_ERROR",
"message": "Connection timeout",
}
update = converter.convert_event(event)
assert update is not None
assert update.role == Role.ASSISTANT
assert update.finish_reason == FinishReason.CONTENT_FILTER
assert len(update.contents) == 1
assert update.contents[0].message == "Connection timeout"
assert update.contents[0].error_code == "RUN_ERROR"
def test_unknown_event_type(self) -> None:
"""Test handling of unknown event types."""
converter = AGUIEventConverter()
event = {
"type": "UNKNOWN_EVENT",
"data": "some data",
}
update = converter.convert_event(event)
assert update is None
def test_full_conversation_flow(self) -> None:
"""Test complete conversation flow with multiple event types."""
converter = AGUIEventConverter()
events = [
{"type": "RUN_STARTED", "threadId": "thread_1", "runId": "run_1"},
{"type": "TEXT_MESSAGE_START", "messageId": "msg_1"},
{"type": "TEXT_MESSAGE_CONTENT", "messageId": "msg_1", "delta": "I'll check"},
{"type": "TEXT_MESSAGE_CONTENT", "messageId": "msg_1", "delta": " the weather."},
{"type": "TEXT_MESSAGE_END", "messageId": "msg_1"},
{"type": "TOOL_CALL_START", "toolCallId": "call_1", "toolName": "get_weather"},
{"type": "TOOL_CALL_ARGS", "delta": '{"location": "Seattle"}'},
{"type": "TOOL_CALL_END", "toolCallId": "call_1"},
{"type": "TOOL_CALL_RESULT", "toolCallId": "call_1", "result": "Sunny, 72°F"},
{"type": "TEXT_MESSAGE_START", "messageId": "msg_2"},
{"type": "TEXT_MESSAGE_CONTENT", "messageId": "msg_2", "delta": "It's sunny!"},
{"type": "TEXT_MESSAGE_END", "messageId": "msg_2"},
{"type": "RUN_FINISHED", "threadId": "thread_1", "runId": "run_1"},
]
updates = [converter.convert_event(event) for event in events]
non_none_updates = [u for u in updates if u is not None]
assert len(non_none_updates) == 10
assert converter.thread_id == "thread_1"
assert converter.run_id == "run_1"
def test_multiple_tool_calls(self) -> None:
"""Test handling multiple tool calls in sequence."""
converter = AGUIEventConverter()
events = [
{"type": "TOOL_CALL_START", "toolCallId": "call_1", "toolName": "search"},
{"type": "TOOL_CALL_ARGS", "delta": '{"query": "weather"}'},
{"type": "TOOL_CALL_END", "toolCallId": "call_1"},
{"type": "TOOL_CALL_START", "toolCallId": "call_2", "toolName": "fetch"},
{"type": "TOOL_CALL_ARGS", "delta": '{"url": "http://api.weather.com"}'},
{"type": "TOOL_CALL_END", "toolCallId": "call_2"},
]
updates = [converter.convert_event(event) for event in events]
non_none_updates = [u for u in updates if u is not None]
assert len(non_none_updates) == 4
assert non_none_updates[0].contents[0].name == "search"
assert non_none_updates[2].contents[0].name == "fetch"
@@ -0,0 +1,238 @@
# Copyright (c) Microsoft. All rights reserved.
"""Tests for AGUIHttpService."""
import json
from unittest.mock import AsyncMock, Mock
import httpx
import pytest
from agent_framework_ag_ui._http_service import AGUIHttpService
@pytest.fixture
def mock_http_client():
"""Create a mock httpx.AsyncClient."""
client = AsyncMock(spec=httpx.AsyncClient)
return client
@pytest.fixture
def sample_events():
"""Sample AG-UI events for testing."""
return [
{"type": "RUN_STARTED", "threadId": "thread_123", "runId": "run_456"},
{"type": "TEXT_MESSAGE_START", "messageId": "msg_1", "role": "assistant"},
{"type": "TEXT_MESSAGE_CONTENT", "messageId": "msg_1", "delta": "Hello"},
{"type": "TEXT_MESSAGE_CONTENT", "messageId": "msg_1", "delta": " world"},
{"type": "TEXT_MESSAGE_END", "messageId": "msg_1"},
{"type": "RUN_FINISHED", "threadId": "thread_123", "runId": "run_456"},
]
def create_sse_response(events: list[dict]) -> str:
"""Create SSE formatted response from events."""
lines = []
for event in events:
lines.append(f"data: {json.dumps(event)}\n")
return "\n".join(lines)
async def test_http_service_initialization():
"""Test AGUIHttpService initialization."""
# Test with default client
service = AGUIHttpService("http://localhost:8888/")
assert service.endpoint == "http://localhost:8888"
assert service._owns_client is True
assert isinstance(service.http_client, httpx.AsyncClient)
await service.close()
# Test with custom client
custom_client = httpx.AsyncClient()
service = AGUIHttpService("http://localhost:8888/", http_client=custom_client)
assert service._owns_client is False
assert service.http_client is custom_client
# Shouldn't close the custom client
await service.close()
await custom_client.aclose()
async def test_http_service_strips_trailing_slash():
"""Test that endpoint trailing slash is stripped."""
service = AGUIHttpService("http://localhost:8888/")
assert service.endpoint == "http://localhost:8888"
await service.close()
async def test_post_run_successful_streaming(mock_http_client, sample_events):
"""Test successful streaming of events."""
# Create async generator for lines
async def mock_aiter_lines():
sse_data = create_sse_response(sample_events)
for line in sse_data.split("\n"):
if line:
yield line
# Create mock response
mock_response = AsyncMock()
mock_response.status_code = 200
# aiter_lines is called as a method, so it should return a new generator each time
mock_response.aiter_lines = mock_aiter_lines
# Setup mock streaming context manager
mock_stream_context = AsyncMock()
mock_stream_context.__aenter__.return_value = mock_response
mock_stream_context.__aexit__.return_value = None
mock_http_client.stream.return_value = mock_stream_context
service = AGUIHttpService("http://localhost:8888/", http_client=mock_http_client)
events = []
async for event in service.post_run(
thread_id="thread_123", run_id="run_456", messages=[{"role": "user", "content": "Hello"}]
):
events.append(event)
assert len(events) == len(sample_events)
assert events[0]["type"] == "RUN_STARTED"
assert events[-1]["type"] == "RUN_FINISHED"
# Verify request was made correctly
mock_http_client.stream.assert_called_once()
call_args = mock_http_client.stream.call_args
assert call_args.args[0] == "POST"
assert call_args.args[1] == "http://localhost:8888"
assert call_args.kwargs["headers"] == {"Accept": "text/event-stream"}
async def test_post_run_with_state_and_tools(mock_http_client):
"""Test posting run with state and tools."""
async def mock_aiter_lines():
return
yield # Make it an async generator
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.aiter_lines = mock_aiter_lines
mock_stream_context = AsyncMock()
mock_stream_context.__aenter__.return_value = mock_response
mock_stream_context.__aexit__.return_value = None
mock_http_client.stream.return_value = mock_stream_context
service = AGUIHttpService("http://localhost:8888/", http_client=mock_http_client)
state = {"user_context": {"name": "Alice"}}
tools = [{"type": "function", "function": {"name": "test_tool"}}]
async for _ in service.post_run(thread_id="thread_123", run_id="run_456", messages=[], state=state, tools=tools):
pass
# Verify state and tools were included in request
call_args = mock_http_client.stream.call_args
request_data = call_args.kwargs["json"]
assert request_data["state"] == state
assert request_data["tools"] == tools
async def test_post_run_http_error(mock_http_client):
"""Test handling of HTTP errors."""
mock_response = Mock()
mock_response.status_code = 500
mock_response.text = "Internal Server Error"
def raise_http_error():
raise httpx.HTTPStatusError("Server error", request=Mock(), response=mock_response)
mock_response_async = AsyncMock()
mock_response_async.raise_for_status = raise_http_error
mock_stream_context = AsyncMock()
mock_stream_context.__aenter__.return_value = mock_response_async
mock_stream_context.__aexit__.return_value = None
mock_http_client.stream.return_value = mock_stream_context
service = AGUIHttpService("http://localhost:8888/", http_client=mock_http_client)
with pytest.raises(httpx.HTTPStatusError):
async for _ in service.post_run(thread_id="thread_123", run_id="run_456", messages=[]):
pass
async def test_post_run_invalid_json(mock_http_client):
"""Test handling of invalid JSON in SSE stream."""
invalid_sse = "data: {invalid json}\n\ndata: " + json.dumps({"type": "RUN_FINISHED"}) + "\n"
async def mock_aiter_lines():
for line in invalid_sse.split("\n"):
if line:
yield line
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.aiter_lines = mock_aiter_lines
mock_stream_context = AsyncMock()
mock_stream_context.__aenter__.return_value = mock_response
mock_stream_context.__aexit__.return_value = None
mock_http_client.stream.return_value = mock_stream_context
service = AGUIHttpService("http://localhost:8888/", http_client=mock_http_client)
events = []
async for event in service.post_run(thread_id="thread_123", run_id="run_456", messages=[]):
events.append(event)
# Should skip invalid JSON and continue with valid events
assert len(events) == 1
assert events[0]["type"] == "RUN_FINISHED"
async def test_context_manager():
"""Test context manager functionality."""
async with AGUIHttpService("http://localhost:8888/") as service:
assert service.http_client is not None
assert service._owns_client is True
# Client should be closed after exiting context
async def test_context_manager_with_external_client():
"""Test context manager doesn't close external client."""
external_client = httpx.AsyncClient()
async with AGUIHttpService("http://localhost:8888/", http_client=external_client) as service:
assert service.http_client is external_client
assert service._owns_client is False
# External client should still be open
# (caller's responsibility to close)
await external_client.aclose()
async def test_post_run_empty_response(mock_http_client):
"""Test handling of empty response stream."""
async def mock_aiter_lines():
return
yield # Make it an async generator
mock_response = AsyncMock()
mock_response.status_code = 200
mock_response.aiter_lines = mock_aiter_lines
mock_stream_context = AsyncMock()
mock_stream_context.__aenter__.return_value = mock_response
mock_stream_context.__aexit__.return_value = None
mock_http_client.stream.return_value = mock_stream_context
service = AGUIHttpService("http://localhost:8888/", http_client=mock_http_client)
events = []
async for event in service.post_run(thread_id="thread_123", run_id="run_456", messages=[]):
events.append(event)
assert len(events) == 0
@@ -63,10 +63,9 @@ def test_agui_tool_result_to_agent_framework():
assert isinstance(message.contents[0], TextContent)
assert message.contents[0].text == '{"accepted": true, "steps": []}'
assert hasattr(message, "metadata")
assert message.metadata is not None
assert message.metadata.get("is_tool_result") is True
assert message.metadata.get("tool_call_id") == "call_123"
assert message.additional_properties is not None
assert message.additional_properties.get("is_tool_result") is True
assert message.additional_properties.get("tool_call_id") == "call_123"
def test_agui_multiple_messages_to_agent_framework():
@@ -159,6 +158,36 @@ def test_agui_message_without_id():
assert messages[0].message_id is None
def test_agui_with_tool_calls_to_agent_framework():
"""Assistant message with tool_calls is converted to FunctionCallContent."""
agui_msg = {
"role": "assistant",
"content": "Calling tool",
"tool_calls": [
{
"id": "call-123",
"type": "function",
"function": {"name": "get_weather", "arguments": {"location": "Seattle"}},
}
],
"id": "msg-789",
}
messages = agui_messages_to_agent_framework([agui_msg])
assert len(messages) == 1
msg = messages[0]
assert msg.role == Role.ASSISTANT
assert msg.message_id == "msg-789"
# First content is text, second is the function call
assert isinstance(msg.contents[0], TextContent)
assert msg.contents[0].text == "Calling tool"
assert isinstance(msg.contents[1], FunctionCallContent)
assert msg.contents[1].call_id == "call-123"
assert msg.contents[1].name == "get_weather"
assert msg.contents[1].arguments == {"location": "Seattle"}
def test_agent_framework_to_agui_with_tool_calls():
"""Test converting Agent Framework message with tool calls to AG-UI."""
msg = ChatMessage(
@@ -198,13 +227,15 @@ def test_agent_framework_to_agui_multiple_text_contents():
def test_agent_framework_to_agui_no_message_id():
"""Test message without message_id."""
"""Test message without message_id - should auto-generate ID."""
msg = ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")])
messages = agent_framework_messages_to_agui([msg])
assert len(messages) == 1
assert "id" not in messages[0]
assert "id" in messages[0] # ID should be auto-generated
assert messages[0]["id"] # ID should not be empty
assert len(messages[0]["id"]) > 0 # ID should be a valid string
def test_agent_framework_to_agui_system_role():
@@ -0,0 +1,82 @@
"""Tests for AG-UI orchestrators."""
from collections.abc import AsyncGenerator
from types import SimpleNamespace
from typing import Any
from agent_framework import AgentRunResponseUpdate, TextContent, ai_function
from agent_framework._tools import FunctionInvocationConfiguration
from agent_framework_ag_ui._agent import AgentConfig
from agent_framework_ag_ui._orchestrators import DefaultOrchestrator, ExecutionContext
@ai_function
def server_tool() -> str:
"""Server-executable tool."""
return "server"
class DummyAgent:
"""Minimal agent stub to capture run_stream parameters."""
def __init__(self) -> None:
self.chat_options = SimpleNamespace(tools=[server_tool], response_format=None)
self.tools = [server_tool]
self.chat_client = SimpleNamespace(
function_invocation_configuration=FunctionInvocationConfiguration(),
)
self.seen_tools: list[Any] | None = None
async def run_stream(
self,
messages: list[Any],
*,
thread: Any,
tools: list[Any] | None = None,
) -> AsyncGenerator[AgentRunResponseUpdate, None]:
self.seen_tools = tools
yield AgentRunResponseUpdate(contents=[TextContent(text="ok")], role="assistant")
async def test_default_orchestrator_merges_client_tools() -> None:
"""Client tool declarations are merged with server tools before running agent."""
agent = DummyAgent()
orchestrator = DefaultOrchestrator()
input_data = {
"messages": [
{
"role": "user",
"content": [{"type": "input_text", "text": "Hello"}],
}
],
"tools": [
{
"name": "get_weather",
"description": "Client weather lookup.",
"parameters": {
"type": "object",
"properties": {"location": {"type": "string"}},
"required": ["location"],
},
}
],
}
context = ExecutionContext(
input_data=input_data,
agent=agent,
config=AgentConfig(),
)
events = []
async for event in orchestrator.run(context):
events.append(event)
assert agent.seen_tools is not None
tool_names = [getattr(tool, "name", "?") for tool in agent.seen_tools]
assert "server_tool" in tool_names
assert "get_weather" in tool_names
assert agent.chat_client.function_invocation_configuration.additional_tools
+106
View File
@@ -197,3 +197,109 @@ def test_make_json_safe_fallback():
result = make_json_safe(obj)
# Objects with __dict__ return their __dict__ dict
assert isinstance(result, dict)
def test_convert_tools_to_agui_format_with_ai_function():
"""Test converting AIFunction to AG-UI format."""
from agent_framework import ai_function
from agent_framework_ag_ui._utils import convert_tools_to_agui_format
@ai_function
def test_func(param: str, count: int = 5) -> str:
"""Test function."""
return f"{param} {count}"
result = convert_tools_to_agui_format([test_func])
assert result is not None
assert len(result) == 1
assert result[0]["name"] == "test_func"
assert result[0]["description"] == "Test function."
assert "parameters" in result[0]
assert "properties" in result[0]["parameters"]
def test_convert_tools_to_agui_format_with_callable():
"""Test converting plain callable to AG-UI format."""
from agent_framework_ag_ui._utils import convert_tools_to_agui_format
def plain_func(x: int) -> int:
"""A plain function."""
return x * 2
result = convert_tools_to_agui_format([plain_func])
assert result is not None
assert len(result) == 1
assert result[0]["name"] == "plain_func"
assert result[0]["description"] == "A plain function."
assert "parameters" in result[0]
def test_convert_tools_to_agui_format_with_dict():
"""Test converting dict tool to AG-UI format."""
from agent_framework_ag_ui._utils import convert_tools_to_agui_format
tool_dict = {
"name": "custom_tool",
"description": "Custom tool",
"parameters": {"type": "object"},
}
result = convert_tools_to_agui_format([tool_dict])
assert result is not None
assert len(result) == 1
assert result[0] == tool_dict
def test_convert_tools_to_agui_format_with_none():
"""Test converting None tools."""
from agent_framework_ag_ui._utils import convert_tools_to_agui_format
result = convert_tools_to_agui_format(None)
assert result is None
def test_convert_tools_to_agui_format_with_single_tool():
"""Test converting single tool (not in list)."""
from agent_framework import ai_function
from agent_framework_ag_ui._utils import convert_tools_to_agui_format
@ai_function
def single_tool(arg: str) -> str:
"""Single tool."""
return arg
result = convert_tools_to_agui_format(single_tool)
assert result is not None
assert len(result) == 1
assert result[0]["name"] == "single_tool"
def test_convert_tools_to_agui_format_with_multiple_tools():
"""Test converting multiple tools."""
from agent_framework import ai_function
from agent_framework_ag_ui._utils import convert_tools_to_agui_format
@ai_function
def tool1(x: int) -> int:
"""Tool 1."""
return x
@ai_function
def tool2(y: str) -> str:
"""Tool 2."""
return y
result = convert_tools_to_agui_format([tool1, tool2])
assert result is not None
assert len(result) == 2
assert result[0]["name"] == "tool1"
assert result[1]["name"] == "tool2"
@@ -0,0 +1,35 @@
# Copyright (c) Microsoft. All rights reserved.
import importlib
from typing import Any
PACKAGE_NAME = "agent_framework_ag_ui"
PACKAGE_EXTRA = "ag-ui"
_IMPORTS = [
"__version__",
"AgentFrameworkAgent",
"add_agent_framework_fastapi_endpoint",
"AGUIChatClient",
"AGUIEventConverter",
"AGUIHttpService",
"ConfirmationStrategy",
"DefaultConfirmationStrategy",
"TaskPlannerConfirmationStrategy",
"RecipeConfirmationStrategy",
"DocumentWriterConfirmationStrategy",
]
def __getattr__(name: str) -> Any:
if name in _IMPORTS:
try:
return getattr(importlib.import_module(PACKAGE_NAME), name)
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
f"The '{PACKAGE_EXTRA}' extra is not installed, please do `pip install agent-framework-{PACKAGE_EXTRA}`"
) from exc
raise AttributeError(f"Module {PACKAGE_NAME} has no attribute {name}.")
def __dir__() -> list[str]:
return _IMPORTS
@@ -0,0 +1,29 @@
# Copyright (c) Microsoft. All rights reserved.
from agent_framework_ag_ui import (
AgentFrameworkAgent,
AGUIChatClient,
AGUIEventConverter,
AGUIHttpService,
ConfirmationStrategy,
DefaultConfirmationStrategy,
DocumentWriterConfirmationStrategy,
RecipeConfirmationStrategy,
TaskPlannerConfirmationStrategy,
__version__,
add_agent_framework_fastapi_endpoint,
)
__all__ = [
"AGUIChatClient",
"AGUIEventConverter",
"AGUIHttpService",
"AgentFrameworkAgent",
"ConfirmationStrategy",
"DefaultConfirmationStrategy",
"DocumentWriterConfirmationStrategy",
"RecipeConfirmationStrategy",
"TaskPlannerConfirmationStrategy",
"__version__",
"add_agent_framework_fastapi_endpoint",
]
+4 -3
View File
@@ -42,13 +42,14 @@ dependencies = [
[project.optional-dependencies]
all = [
"agent-framework-a2a",
"agent-framework-ag-ui",
"agent-framework-anthropic",
"agent-framework-azure-ai",
"agent-framework-copilotstudio",
"agent-framework-mem0",
"agent-framework-redis",
"agent-framework-devui",
"agent-framework-mem0",
"agent-framework-purview",
"agent-framework-anthropic",
"agent-framework-redis",
]
[tool.uv]
+3433 -3415
View File
File diff suppressed because it is too large Load Diff