Python: [BREAKING] Simplify API: ChatAgent -> Agent, ChatMessage -> Message (#3747)

* [BREAKING] Rename ChatAgent -> Agent, ChatMessage -> Message, ChatClientProtocol -> SupportsChatGetResponse

Simplify the public API by removing redundant 'Chat' prefix from core types:
- ChatAgent -> Agent
- RawChatAgent -> RawAgent
- ChatMessage -> Message
- ChatClientProtocol -> SupportsChatGetResponse

Also renamed internal WorkflowMessage (was Message in _runner_context) to avoid collision.

No backward compatibility aliases - this is a clean breaking change.

* [BREAKING] Rename Agent chat_client parameter to client

* Fix rebase issues: WorkflowMessage references and broken markdown links

* Fix formatting and lint issues from code quality checks

* Fix import ordering in workflow sample files

* fixed rebase

* Fix test failures: use WorkflowMessage and A2AMessage after ChatMessage→Message rename

- Replace Message(data=..., source_id=...) with WorkflowMessage(...) in workflow tests
- Fix isinstance check in A2A agent to use A2AMessage instead of Message
- Fix import in test_workflow_observability.py (Message→WorkflowMessage)

* Fix lint, fmt, and sample errors after ChatMessage→Message rename

- Auto-fix 70+ ruff lint issues across samples (ChatMessage→Message refs)
- Fix HostedVectorStoreContent→Content.from_hosted_vector_store in file search sample
- Fix _normalize_messages→normalize_messages in custom agent sample
- Fix context.terminate→raise MiddlewareTermination in middleware samples
- Fix with_update_hook→with_transform_hook in override middleware sample
- Add TOptions_co import back to custom_chat_client sample
- Add noqa for FastAPI File() default in chatkit sample
- Fix B023 loop variable capture in weather agent sample

* fix: update Agent constructor calls from chat_client to client in declaration-only tool tests

* fix: add register_cleanup to devui lazy-loading proxy and type stub

* fixed tests and updated new pieces

* fix agui typevar

* fix merge errors

* fix merge conflicts

* fiux merge

* Remove unused links

---------

Co-authored-by: Evan Mattson <evan.mattson@microsoft.com>
This commit is contained in:
Eduard van Valkenburg
2026-02-11 00:04:32 +01:00
committed by GitHub
Unverified
parent a4c9e43afb
commit 0521f5bed8
418 changed files with 5385 additions and 5389 deletions
@@ -26,13 +26,13 @@ Example:
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from agent_framework import ChatAgent, HostedCodeInterpreterTool, HostedWebSearchTool
from agent_framework import Agent, HostedCodeInterpreterTool, HostedWebSearchTool
from agent_framework.azure import AzureAIAgentClient
from azure.identity.aio import AzureCliCredential
@asynccontextmanager
async def create_gaia_agent() -> AsyncIterator[ChatAgent]:
async def create_gaia_agent() -> AsyncIterator[Agent]:
"""Create an Azure AI agent configured for GAIA benchmark tasks.
The agent is configured with:
@@ -40,7 +40,7 @@ async def create_gaia_agent() -> AsyncIterator[ChatAgent]:
- Code Interpreter tool for calculations and data analysis
Yields:
ChatAgent: A configured agent ready to run GAIA tasks.
Agent: A configured agent ready to run GAIA tasks.
Example:
async with create_gaia_agent() as agent:
@@ -25,12 +25,12 @@ Example:
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from agent_framework import ChatAgent, HostedCodeInterpreterTool, HostedWebSearchTool
from agent_framework import Agent, HostedCodeInterpreterTool, HostedWebSearchTool
from agent_framework.openai import OpenAIResponsesClient
@asynccontextmanager
async def create_gaia_agent() -> AsyncIterator[ChatAgent]:
async def create_gaia_agent() -> AsyncIterator[Agent]:
"""Create an OpenAI agent configured for GAIA benchmark tasks.
Uses OpenAI Responses API for enhanced capabilities.
@@ -40,16 +40,16 @@ async def create_gaia_agent() -> AsyncIterator[ChatAgent]:
- Code Interpreter tool for calculations and data analysis
Yields:
ChatAgent: A configured agent ready to run GAIA tasks.
Agent: A configured agent ready to run GAIA tasks.
Example:
async with create_gaia_agent() as agent:
result = await agent.run("What is the capital of France?")
print(result.text)
"""
chat_client = OpenAIResponsesClient()
client = OpenAIResponsesClient()
async with chat_client.as_agent(
async with client.as_agent(
name="GaiaAgent",
instructions="Solve tasks to your best ability. Use Web Search to find "
"information and Code Interpreter to perform calculations and data analysis.",
+2 -2
View File
@@ -49,8 +49,8 @@ async def math_agent(task: TaskType, llm: LLM) -> float:
"""A function that solves a math problem and returns the evaluation score."""
async with (
MCPStdioTool(name="calculator", command="uvx", args=["mcp-server-calculator"]) as mcp_server,
ChatAgent(
chat_client=OpenAIChatClient(
Agent(
client=OpenAIChatClient(
model_id=llm.model,
api_key="your-api-key",
base_url=llm.endpoint,
@@ -20,7 +20,7 @@ import string
from typing import TypedDict, cast
import sympy # type: ignore[import-untyped,reportMissingImports]
from agent_framework import AgentResponse, ChatAgent, MCPStdioTool
from agent_framework import Agent, AgentResponse, MCPStdioTool
from agent_framework.lab.lightning import AgentFrameworkTracer
from agent_framework.openai import OpenAIChatClient
from agentlightning import LLM, Dataset, Trainer, rollout
@@ -166,8 +166,8 @@ async def math_agent(task: MathProblem, llm: LLM) -> float:
# MCPStdioTool provides calculator functionality via MCP protocol
async with (
MCPStdioTool(name="calculator", command="uvx", args=["mcp-server-calculator"]) as mcp_server,
ChatAgent(
chat_client=OpenAIChatClient(
Agent(
client=OpenAIChatClient(
model_id=llm.model, # This is the model being trained
api_key=os.getenv("OPENAI_API_KEY") or "dummy", # Can be dummy when connecting to training LLM
base_url=llm.endpoint, # vLLM server endpoint provided by agent-lightning
@@ -9,7 +9,7 @@ import pytest
agentlightning = pytest.importorskip("agentlightning")
from agent_framework import AgentExecutor, AgentResponse, ChatAgent, WorkflowBuilder, Workflow
from agent_framework import AgentExecutor, AgentResponse, Agent, WorkflowBuilder, Workflow
from agent_framework_lab_lightning import AgentFrameworkTracer
from agent_framework.openai import OpenAIChatClient
from agentlightning import TracerTraceToTriplet
@@ -80,14 +80,14 @@ def workflow_two_agents():
),
):
# Create the two agents
analyzer_agent = ChatAgent(
chat_client=first_chat_client,
analyzer_agent = Agent(
client=first_chat_client,
name="DataAnalyzer",
instructions="You are a data analyst. Analyze the given data and provide insights.",
)
advisor_agent = ChatAgent(
chat_client=second_chat_client,
advisor_agent = Agent(
client=second_chat_client,
name="InvestmentAdvisor",
instructions="You are an investment advisor. Based on analysis results, provide recommendations.",
)
+5 -5
View File
@@ -138,21 +138,21 @@ export OPENAI_BASE_URL="https://your-custom-endpoint.com/v1"
```python
from agent_framework.lab.tau2 import TaskRunner
from agent_framework import ChatAgent
from agent_framework import Agent
class CustomTaskRunner(TaskRunner):
def assistant_agent(self, assistant_chat_client):
# Override to customize the assistant agent
return ChatAgent(
chat_client=assistant_chat_client,
return Agent(
client=assistant_chat_client,
instructions="Your custom system prompt here",
# Add custom tools, temperature, etc.
)
def user_simulator(self, user_chat_client, task):
# Override to customize the user simulator
return ChatAgent(
chat_client=user_chat_client,
return Agent(
client=user_chat_client,
instructions="Custom user simulator prompt",
)
```
@@ -2,7 +2,7 @@
from typing import Any
from agent_framework._types import ChatMessage, Content
from agent_framework._types import Content, Message
from loguru import logger
@@ -11,7 +11,7 @@ def _get_role_value(role: Any) -> str:
return role.value if hasattr(role, "value") else str(role)
def flip_messages(messages: list[ChatMessage]) -> list[ChatMessage]:
def flip_messages(messages: list[Message]) -> list[Message]:
"""Flip message roles between assistant and user for role-playing scenarios.
Used in agent simulations where the assistant's messages become user inputs
@@ -30,7 +30,7 @@ def flip_messages(messages: list[ChatMessage]) -> list[ChatMessage]:
# Flip assistant to user
contents = filter_out_function_calls(msg.contents)
if contents:
flipped_msg = ChatMessage(
flipped_msg = Message(
role="user",
# The function calls will cause 400 when role is user
contents=contents,
@@ -40,7 +40,7 @@ def flip_messages(messages: list[ChatMessage]) -> list[ChatMessage]:
flipped_messages.append(flipped_msg)
elif role_value == "user":
# Flip user to assistant
flipped_msg = ChatMessage(
flipped_msg = Message(
role="assistant", contents=msg.contents, author_name=msg.author_name, message_id=msg.message_id
)
flipped_messages.append(flipped_msg)
@@ -53,7 +53,7 @@ def flip_messages(messages: list[ChatMessage]) -> list[ChatMessage]:
return flipped_messages
def log_messages(messages: list[ChatMessage]) -> None:
def log_messages(messages: list[Message]) -> None:
"""Log messages with colored output based on role and content type.
Provides visual debugging by color-coding different message roles and
@@ -5,7 +5,7 @@ from collections.abc import Sequence
from typing import Any
import tiktoken
from agent_framework import ChatMessage, ChatMessageStore
from agent_framework import ChatMessageStore, Message
from loguru import logger
@@ -19,7 +19,7 @@ class SlidingWindowChatMessageStore(ChatMessageStore):
def __init__(
self,
messages: Sequence[ChatMessage] | None = None,
messages: Sequence[Message] | None = None,
max_tokens: int = 3800,
system_message: str | None = None,
tool_definitions: Any | None = None,
@@ -32,17 +32,17 @@ class SlidingWindowChatMessageStore(ChatMessageStore):
# An estimation based on a commonly used vocab table
self.encoding = tiktoken.get_encoding("o200k_base")
async def add_messages(self, messages: Sequence[ChatMessage]) -> None:
async def add_messages(self, messages: Sequence[Message]) -> None:
await super().add_messages(messages)
self.truncated_messages = self.messages.copy()
self.truncate_messages()
async def list_messages(self) -> list[ChatMessage]:
async def list_messages(self) -> list[Message]:
"""Get the current list of messages, which may be truncated."""
return self.truncated_messages
async def list_all_messages(self) -> list[ChatMessage]:
async def list_all_messages(self) -> list[Message]:
"""Get all messages from the store including the truncated ones."""
return self.messages
@@ -7,17 +7,19 @@ from typing import Any
import numpy as np
from agent_framework._tools import FunctionTool
from agent_framework._types import ChatMessage
from agent_framework._types import Message
from loguru import logger
from pydantic import BaseModel
from tau2.data_model.message import ( # type: ignore[import-untyped]
AssistantMessage,
Message,
SystemMessage,
ToolCall,
ToolMessage,
UserMessage,
)
from tau2.data_model.message import (
Message as Tau2Message,
)
from tau2.data_model.tasks import EnvFunctionCall, InitializationData # type: ignore[import-untyped]
from tau2.environment.environment import Environment # type: ignore[import-untyped]
from tau2.environment.tool import Tool # type: ignore[import-untyped]
@@ -45,7 +47,7 @@ def convert_tau2_tool_to_function_tool(tau2_tool: Tool) -> FunctionTool[Any, Any
)
def convert_agent_framework_messages_to_tau2_messages(messages: list[ChatMessage]) -> list[Message]:
def convert_agent_framework_messages_to_tau2_messages(messages: list[Message]) -> list[Tau2Message]:
"""Convert agent framework ChatMessages to tau2 Message objects.
Handles role mapping, text extraction, function calls, and function results.
@@ -119,13 +121,13 @@ def patch_env_set_state() -> None:
self: Any,
initialization_data: InitializationData | None,
initialization_actions: list[EnvFunctionCall] | None,
message_history: list[Message],
message_history: list[Tau2Message],
) -> None:
if self.solo_mode and any(isinstance(message, UserMessage) for message in message_history):
raise ValueError("User messages are not allowed in solo mode")
def get_actions_from_messages(
messages: list[Message],
messages: list[Tau2Message],
) -> list[tuple[ToolCall, ToolMessage]]:
"""Get the actions from the messages."""
messages = deepcopy(messages)[::-1]
@@ -6,14 +6,14 @@ import uuid
from typing import cast
from agent_framework import (
Agent,
AgentExecutor,
AgentExecutorRequest,
AgentExecutorResponse,
AgentResponse,
ChatAgent,
ChatClientProtocol,
ChatMessage,
FunctionExecutor,
Message,
SupportsChatGetResponse,
Workflow,
WorkflowBuilder,
WorkflowContext,
@@ -67,10 +67,10 @@ class TaskRunner:
# State tracking
step_count: int
full_conversation: list[ChatMessage]
full_conversation: list[Message]
termination_reason: TerminationReason | None
full_reward_info: RewardInfo | None
_final_user_message: list[ChatMessage] | None
_final_user_message: list[Message] | None
_assistant_executor: AgentExecutor | None
_user_executor: AgentExecutor | None
@@ -159,7 +159,7 @@ class TaskRunner:
"""Check if user wants to stop the conversation."""
return STOP in text or TRANSFER in text or OUT_OF_SCOPE in text
def assistant_agent(self, assistant_chat_client: ChatClientProtocol) -> ChatAgent:
def assistant_agent(self, assistant_chat_client: SupportsChatGetResponse) -> Agent:
"""Create an assistant agent.
Users can override this method to provide a custom assistant agent.
@@ -196,8 +196,8 @@ class TaskRunner:
# - Access to all domain tools (booking, cancellation, etc.)
# - Sliding window memory to handle long conversations within token limits
# - Temperature-controlled response generation
return ChatAgent(
chat_client=assistant_chat_client,
return Agent(
client=assistant_chat_client,
instructions=assistant_system_prompt,
tools=tools,
temperature=self.assistant_sampling_temperature,
@@ -208,7 +208,7 @@ class TaskRunner:
),
)
def user_simulator(self, user_simuator_chat_client: ChatClientProtocol, task: Task) -> ChatAgent:
def user_simulator(self, user_simuator_chat_client: SupportsChatGetResponse, task: Task) -> Agent:
"""Create a user simulator agent.
Users can override this method to provide a custom user simulator agent.
@@ -230,8 +230,8 @@ class TaskRunner:
{task.user_scenario.instructions}
</scenario>"""
return ChatAgent(
chat_client=user_simuator_chat_client,
return Agent(
client=user_simuator_chat_client,
instructions=user_sim_system_prompt,
temperature=0.0,
# No sliding window for user simulator to maintain full conversation context
@@ -268,7 +268,7 @@ class TaskRunner:
target_id=USER_SIMULATOR_ID if is_from_agent else ASSISTANT_AGENT_ID,
)
def build_conversation_workflow(self, assistant_agent: ChatAgent, user_simulator_agent: ChatAgent) -> Workflow:
def build_conversation_workflow(self, assistant_agent: Agent, user_simulator_agent: Agent) -> Workflow:
"""Build the conversation workflow.
Users can override this method to provide a custom conversation workflow.
@@ -304,9 +304,9 @@ class TaskRunner:
async def run(
self,
task: Task,
assistant_chat_client: ChatClientProtocol,
user_simulator_chat_client: ChatClientProtocol,
) -> list[ChatMessage]:
assistant_chat_client: SupportsChatGetResponse,
user_simulator_chat_client: SupportsChatGetResponse,
) -> list[Message]:
"""Run a tau2 task using workflow-based agent orchestration.
This method orchestrates a complex multi-agent simulation:
@@ -323,7 +323,7 @@ class TaskRunner:
user_simulator_chat_client: LLM client for the user simulator
Returns:
Complete conversation history as ChatMessage list for evaluation
Complete conversation history as Message list for evaluation
"""
logger.info(f"Starting workflow agent for task {task.id}: {task.description.purpose}") # type: ignore[unused-ignore]
logger.info(f"Assistant chat client: {assistant_chat_client}")
@@ -340,11 +340,11 @@ class TaskRunner:
# Matches tau2's expected conversation start pattern
logger.info(f"Starting workflow with hardcoded greeting: '{DEFAULT_FIRST_AGENT_MESSAGE}'")
first_message = ChatMessage(role="assistant", text=DEFAULT_FIRST_AGENT_MESSAGE)
first_message = Message(role="assistant", text=DEFAULT_FIRST_AGENT_MESSAGE)
initial_greeting = AgentExecutorResponse(
executor_id=ASSISTANT_AGENT_ID,
agent_response=AgentResponse(messages=[first_message]),
full_conversation=[ChatMessage(role="assistant", text=DEFAULT_FIRST_AGENT_MESSAGE)],
full_conversation=[Message(role="assistant", text=DEFAULT_FIRST_AGENT_MESSAGE)],
)
# STEP 4: Execute the workflow and collect results
@@ -371,7 +371,7 @@ class TaskRunner:
return full_conversation
def evaluate(
self, task_input: Task, conversation: list[ChatMessage], termination_reason: TerminationReason | None
self, task_input: Task, conversation: list[Message], termination_reason: TerminationReason | None
) -> float:
"""Evaluate agent performance using tau2's comprehensive evaluation system.
@@ -2,14 +2,14 @@
from unittest.mock import patch
from agent_framework._types import ChatMessage, Content
from agent_framework._types import Content, Message
from agent_framework_lab_tau2._message_utils import flip_messages, log_messages
def test_flip_messages_user_to_assistant():
"""Test flipping user message to assistant."""
messages = [
ChatMessage(
Message(
role="user",
contents=[Content.from_text(text="Hello assistant")],
author_name="User1",
@@ -29,7 +29,7 @@ def test_flip_messages_user_to_assistant():
def test_flip_messages_assistant_to_user():
"""Test flipping assistant message to user."""
messages = [
ChatMessage(
Message(
role="assistant",
contents=[Content.from_text(text="Hello user")],
author_name="Assistant1",
@@ -51,7 +51,7 @@ def test_flip_messages_assistant_with_function_calls_filtered():
function_call = Content.from_function_call(call_id="call_123", name="test_function", arguments={"param": "value"})
messages = [
ChatMessage(
Message(
role="assistant",
contents=[
Content.from_text(text="I'll call a function"),
@@ -78,7 +78,7 @@ def test_flip_messages_assistant_with_only_function_calls_skipped():
function_call = Content.from_function_call(call_id="call_456", name="another_function", arguments={"key": "value"})
messages = [
ChatMessage(role="assistant", contents=[function_call], message_id="msg_004") # Only function call, no text
Message(role="assistant", contents=[function_call], message_id="msg_004") # Only function call, no text
]
flipped = flip_messages(messages)
@@ -91,7 +91,7 @@ def test_flip_messages_tool_messages_skipped():
"""Test that tool messages are skipped."""
function_result = Content.from_function_result(call_id="call_789", result={"success": True})
messages = [ChatMessage(role="tool", contents=[function_result])]
messages = [Message(role="tool", contents=[function_result])]
flipped = flip_messages(messages)
@@ -101,9 +101,7 @@ def test_flip_messages_tool_messages_skipped():
def test_flip_messages_system_messages_preserved():
"""Test that system messages are preserved as-is."""
messages = [
ChatMessage(role="system", contents=[Content.from_text(text="System instruction")], message_id="sys_001")
]
messages = [Message(role="system", contents=[Content.from_text(text="System instruction")], message_id="sys_001")]
flipped = flip_messages(messages)
@@ -120,11 +118,11 @@ def test_flip_messages_mixed_conversation():
function_result = Content.from_function_result(call_id="call_mixed", result="function result")
messages = [
ChatMessage(role="system", contents=[Content.from_text(text="System prompt")]),
ChatMessage(role="user", contents=[Content.from_text(text="User question")]),
ChatMessage(role="assistant", contents=[Content.from_text(text="Assistant response"), function_call]),
ChatMessage(role="tool", contents=[function_result]),
ChatMessage(role="assistant", contents=[Content.from_text(text="Final response")]),
Message(role="system", contents=[Content.from_text(text="System prompt")]),
Message(role="user", contents=[Content.from_text(text="User question")]),
Message(role="assistant", contents=[Content.from_text(text="Assistant response"), function_call]),
Message(role="tool", contents=[function_result]),
Message(role="assistant", contents=[Content.from_text(text="Final response")]),
]
flipped = flip_messages(messages)
@@ -159,7 +157,7 @@ def test_flip_messages_empty_list():
def test_flip_messages_preserves_metadata():
"""Test that message metadata is preserved during flipping."""
messages = [
ChatMessage(
Message(
role="user",
contents=[Content.from_text(text="Test message")],
author_name="TestUser",
@@ -178,8 +176,8 @@ def test_flip_messages_preserves_metadata():
def test_log_messages_text_content(mock_logger):
"""Test logging messages with text content."""
messages = [
ChatMessage(role="user", contents=[Content.from_text(text="Hello")]),
ChatMessage(role="assistant", contents=[Content.from_text(text="Hi there!")]),
Message(role="user", contents=[Content.from_text(text="Hello")]),
Message(role="assistant", contents=[Content.from_text(text="Hi there!")]),
]
log_messages(messages)
@@ -193,7 +191,7 @@ def test_log_messages_function_call(mock_logger):
"""Test logging messages with function calls."""
function_call = Content.from_function_call(call_id="call_log", name="log_function", arguments={"param": "value"})
messages = [ChatMessage(role="assistant", contents=[function_call])]
messages = [Message(role="assistant", contents=[function_call])]
log_messages(messages)
@@ -209,7 +207,7 @@ def test_log_messages_function_result(mock_logger):
"""Test logging messages with function results."""
function_result = Content.from_function_result(call_id="call_result", result="success")
messages = [ChatMessage(role="tool", contents=[function_result])]
messages = [Message(role="tool", contents=[function_result])]
log_messages(messages)
@@ -223,10 +221,10 @@ def test_log_messages_function_result(mock_logger):
def test_log_messages_different_roles(mock_logger):
"""Test logging messages with different roles get different colors."""
messages = [
ChatMessage(role="system", contents=[Content.from_text(text="System")]),
ChatMessage(role="user", contents=[Content.from_text(text="User")]),
ChatMessage(role="assistant", contents=[Content.from_text(text="Assistant")]),
ChatMessage(role="tool", contents=[Content.from_text(text="Tool")]),
Message(role="system", contents=[Content.from_text(text="System")]),
Message(role="user", contents=[Content.from_text(text="User")]),
Message(role="assistant", contents=[Content.from_text(text="Assistant")]),
Message(role="tool", contents=[Content.from_text(text="Tool")]),
]
log_messages(messages)
@@ -250,7 +248,7 @@ def test_log_messages_different_roles(mock_logger):
@patch("agent_framework_lab_tau2._message_utils.logger")
def test_log_messages_escapes_html(mock_logger):
"""Test that HTML-like characters are properly escaped in log output."""
messages = [ChatMessage(role="user", contents=[Content.from_text(text="Message with <tag> content")])]
messages = [Message(role="user", contents=[Content.from_text(text="Message with <tag> content")])]
log_messages(messages)
@@ -266,7 +264,7 @@ def test_log_messages_mixed_content_types(mock_logger):
function_call = Content.from_function_call(call_id="mixed_call", name="mixed_function", arguments={"key": "value"})
messages = [
ChatMessage(
Message(
role="assistant",
contents=[Content.from_text(text="I'll call a function"), function_call, Content.from_text(text="Done!")],
)
@@ -4,7 +4,7 @@
from unittest.mock import patch
from agent_framework._types import ChatMessage, Content
from agent_framework._types import Content, Message
from agent_framework_lab_tau2._sliding_window import SlidingWindowChatMessageStore
@@ -36,8 +36,8 @@ def test_initialization_with_parameters():
def test_initialization_with_messages():
"""Test initializing with existing messages."""
messages = [
ChatMessage(role="user", contents=[Content.from_text(text="Hello")]),
ChatMessage(role="assistant", contents=[Content.from_text(text="Hi there!")]),
Message(role="user", contents=[Content.from_text(text="Hello")]),
Message(role="assistant", contents=[Content.from_text(text="Hi there!")]),
]
sliding_window = SlidingWindowChatMessageStore(messages=messages, max_tokens=1000)
@@ -51,8 +51,8 @@ async def test_add_messages_simple():
sliding_window = SlidingWindowChatMessageStore(max_tokens=10000) # Large limit
new_messages = [
ChatMessage(role="user", contents=[Content.from_text(text="What's the weather?")]),
ChatMessage(role="assistant", contents=[Content.from_text(text="I can help with that.")]),
Message(role="user", contents=[Content.from_text(text="What's the weather?")]),
Message(role="assistant", contents=[Content.from_text(text="I can help with that.")]),
]
await sliding_window.add_messages(new_messages)
@@ -69,7 +69,7 @@ async def test_list_all_messages_vs_list_messages():
# Add many messages to trigger truncation
messages = [
ChatMessage(role="user", contents=[Content.from_text(text=f"Message {i} with some content")]) for i in range(10)
Message(role="user", contents=[Content.from_text(text=f"Message {i} with some content")]) for i in range(10)
]
await sliding_window.add_messages(messages)
@@ -87,7 +87,7 @@ async def test_list_all_messages_vs_list_messages():
def test_get_token_count_basic():
"""Test basic token counting."""
sliding_window = SlidingWindowChatMessageStore(max_tokens=1000)
sliding_window.truncated_messages = [ChatMessage(role="user", contents=[Content.from_text(text="Hello")])]
sliding_window.truncated_messages = [Message(role="user", contents=[Content.from_text(text="Hello")])]
token_count = sliding_window.get_token_count()
@@ -104,7 +104,7 @@ def test_get_token_count_with_system_message():
token_count_empty = sliding_window.get_token_count()
# Add a message
sliding_window.truncated_messages = [ChatMessage(role="user", contents=[Content.from_text(text="Hello")])]
sliding_window.truncated_messages = [Message(role="user", contents=[Content.from_text(text="Hello")])]
token_count_with_message = sliding_window.get_token_count()
# With message should be more tokens
@@ -117,7 +117,7 @@ def test_get_token_count_function_call():
function_call = Content.from_function_call(call_id="call_123", name="test_function", arguments={"param": "value"})
sliding_window = SlidingWindowChatMessageStore(max_tokens=1000)
sliding_window.truncated_messages = [ChatMessage(role="assistant", contents=[function_call])]
sliding_window.truncated_messages = [Message(role="assistant", contents=[function_call])]
token_count = sliding_window.get_token_count()
assert token_count > 0
@@ -128,7 +128,7 @@ def test_get_token_count_function_result():
function_result = Content.from_function_result(call_id="call_123", result={"success": True, "data": "result"})
sliding_window = SlidingWindowChatMessageStore(max_tokens=1000)
sliding_window.truncated_messages = [ChatMessage(role="tool", contents=[function_result])]
sliding_window.truncated_messages = [Message(role="tool", contents=[function_result])]
token_count = sliding_window.get_token_count()
assert token_count > 0
@@ -141,17 +141,17 @@ def test_truncate_messages_removes_old_messages(mock_logger):
# Create messages that will exceed the limit
messages = [
ChatMessage(
Message(
role="user",
contents=[Content.from_text(text="This is a very long message that should exceed the token limit")],
),
ChatMessage(
Message(
role="assistant",
contents=[
Content.from_text(text="This is another very long message that should also exceed the token limit")
],
),
ChatMessage(role="user", contents=[Content.from_text(text="Short msg")]),
Message(role="user", contents=[Content.from_text(text="Short msg")]),
]
sliding_window.truncated_messages = messages.copy()
@@ -170,10 +170,8 @@ def test_truncate_messages_removes_leading_tool_messages(mock_logger):
sliding_window = SlidingWindowChatMessageStore(max_tokens=10000) # Large limit
# Create messages starting with tool message
tool_message = ChatMessage(
role="tool", contents=[Content.from_function_result(call_id="call_123", result="result")]
)
user_message = ChatMessage(role="user", contents=[Content.from_text(text="Hello")])
tool_message = Message(role="tool", contents=[Content.from_function_result(call_id="call_123", result="result")])
user_message = Message(role="user", contents=[Content.from_text(text="Hello")])
sliding_window.truncated_messages = [tool_message, user_message]
sliding_window.truncate_messages()
@@ -231,13 +229,13 @@ async def test_real_world_scenario():
# Simulate a conversation
conversation = [
ChatMessage(role="user", contents=[Content.from_text(text="Hello, how are you?")]),
ChatMessage(
Message(role="user", contents=[Content.from_text(text="Hello, how are you?")]),
Message(
role="assistant",
contents=[Content.from_text(text="I'm doing well, thank you! How can I help you today?")],
),
ChatMessage(role="user", contents=[Content.from_text(text="Can you tell me about the weather?")]),
ChatMessage(
Message(role="user", contents=[Content.from_text(text="Can you tell me about the weather?")]),
Message(
role="assistant",
contents=[
Content.from_text(
@@ -246,8 +244,8 @@ async def test_real_world_scenario():
)
],
),
ChatMessage(role="user", contents=[Content.from_text(text="What about telling me a joke instead?")]),
ChatMessage(
Message(role="user", contents=[Content.from_text(text="What about telling me a joke instead?")]),
Message(
role="assistant",
contents=[
Content.from_text(text="Sure! Why don't scientists trust atoms? Because they make up everything!")
@@ -6,7 +6,7 @@ import urllib.request
from pathlib import Path
import pytest
from agent_framework import ChatMessage, Content, FunctionTool
from agent_framework import Content, FunctionTool, Message
from agent_framework_lab_tau2._tau2_utils import (
convert_agent_framework_messages_to_tau2_messages,
convert_tau2_tool_to_function_tool,
@@ -91,7 +91,7 @@ def test_convert_tau2_tool_to_function_tool_multiple_tools(tau2_airline_environm
def test_convert_agent_framework_messages_to_tau2_messages_system():
"""Test converting system message."""
messages = [ChatMessage(role="system", contents=[Content.from_text(text="System instruction")])]
messages = [Message(role="system", contents=[Content.from_text(text="System instruction")])]
tau2_messages = convert_agent_framework_messages_to_tau2_messages(messages)
@@ -103,7 +103,7 @@ def test_convert_agent_framework_messages_to_tau2_messages_system():
def test_convert_agent_framework_messages_to_tau2_messages_user():
"""Test converting user message."""
messages = [ChatMessage(role="user", contents=[Content.from_text(text="Hello assistant")])]
messages = [Message(role="user", contents=[Content.from_text(text="Hello assistant")])]
tau2_messages = convert_agent_framework_messages_to_tau2_messages(messages)
@@ -116,7 +116,7 @@ def test_convert_agent_framework_messages_to_tau2_messages_user():
def test_convert_agent_framework_messages_to_tau2_messages_assistant():
"""Test converting assistant message."""
messages = [ChatMessage(role="assistant", contents=[Content.from_text(text="Hello user")])]
messages = [Message(role="assistant", contents=[Content.from_text(text="Hello user")])]
tau2_messages = convert_agent_framework_messages_to_tau2_messages(messages)
@@ -131,7 +131,7 @@ def test_convert_agent_framework_messages_to_tau2_messages_with_function_call():
"""Test converting message with function call."""
function_call = Content.from_function_call(call_id="call_123", name="test_function", arguments={"param": "value"})
messages = [ChatMessage(role="assistant", contents=[Content.from_text(text="I'll call a function"), function_call])]
messages = [Message(role="assistant", contents=[Content.from_text(text="I'll call a function"), function_call])]
tau2_messages = convert_agent_framework_messages_to_tau2_messages(messages)
@@ -153,7 +153,7 @@ def test_convert_agent_framework_messages_to_tau2_messages_with_function_result(
"""Test converting message with function result."""
function_result = Content.from_function_result(call_id="call_123", result={"success": True, "data": "result data"})
messages = [ChatMessage(role="tool", contents=[function_result])]
messages = [Message(role="tool", contents=[function_result])]
tau2_messages = convert_agent_framework_messages_to_tau2_messages(messages)
@@ -173,7 +173,7 @@ def test_convert_agent_framework_messages_to_tau2_messages_with_error():
call_id="call_456", result="Error occurred", exception=Exception("Test error")
)
messages = [ChatMessage(role="tool", contents=[function_result])]
messages = [Message(role="tool", contents=[function_result])]
tau2_messages = convert_agent_framework_messages_to_tau2_messages(messages)
@@ -185,7 +185,7 @@ def test_convert_agent_framework_messages_to_tau2_messages_with_error():
def test_convert_agent_framework_messages_to_tau2_messages_multiple_text_contents():
"""Test converting message with multiple text contents."""
messages = [
ChatMessage(role="user", contents=[Content.from_text(text="First part"), Content.from_text(text="Second part")])
Message(role="user", contents=[Content.from_text(text="First part"), Content.from_text(text="Second part")])
]
tau2_messages = convert_agent_framework_messages_to_tau2_messages(messages)
@@ -202,11 +202,11 @@ def test_convert_agent_framework_messages_to_tau2_messages_complex_scenario():
function_result = Content.from_function_result(call_id="call_789", result={"output": "tool result"})
messages = [
ChatMessage(role="system", contents=[Content.from_text(text="System prompt")]),
ChatMessage(role="user", contents=[Content.from_text(text="User request")]),
ChatMessage(role="assistant", contents=[Content.from_text(text="I'll help you"), function_call]),
ChatMessage(role="tool", contents=[function_result]),
ChatMessage(role="assistant", contents=[Content.from_text(text="Based on the result...")]),
Message(role="system", contents=[Content.from_text(text="System prompt")]),
Message(role="user", contents=[Content.from_text(text="User request")]),
Message(role="assistant", contents=[Content.from_text(text="I'll help you"), function_call]),
Message(role="tool", contents=[function_result]),
Message(role="assistant", contents=[Content.from_text(text="Based on the result...")]),
]
tau2_messages = convert_agent_framework_messages_to_tau2_messages(messages)