mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: [BREAKING] Simplify API: ChatAgent -> Agent, ChatMessage -> Message (#3747)
* [BREAKING] Rename ChatAgent -> Agent, ChatMessage -> Message, ChatClientProtocol -> SupportsChatGetResponse Simplify the public API by removing redundant 'Chat' prefix from core types: - ChatAgent -> Agent - RawChatAgent -> RawAgent - ChatMessage -> Message - ChatClientProtocol -> SupportsChatGetResponse Also renamed internal WorkflowMessage (was Message in _runner_context) to avoid collision. No backward compatibility aliases - this is a clean breaking change. * [BREAKING] Rename Agent chat_client parameter to client * Fix rebase issues: WorkflowMessage references and broken markdown links * Fix formatting and lint issues from code quality checks * Fix import ordering in workflow sample files * fixed rebase * Fix test failures: use WorkflowMessage and A2AMessage after ChatMessage→Message rename - Replace Message(data=..., source_id=...) with WorkflowMessage(...) in workflow tests - Fix isinstance check in A2A agent to use A2AMessage instead of Message - Fix import in test_workflow_observability.py (Message→WorkflowMessage) * Fix lint, fmt, and sample errors after ChatMessage→Message rename - Auto-fix 70+ ruff lint issues across samples (ChatMessage→Message refs) - Fix HostedVectorStoreContent→Content.from_hosted_vector_store in file search sample - Fix _normalize_messages→normalize_messages in custom agent sample - Fix context.terminate→raise MiddlewareTermination in middleware samples - Fix with_update_hook→with_transform_hook in override middleware sample - Add TOptions_co import back to custom_chat_client sample - Add noqa for FastAPI File() default in chatkit sample - Fix B023 loop variable capture in weather agent sample * fix: update Agent constructor calls from chat_client to client in declaration-only tool tests * fix: add register_cleanup to devui lazy-loading proxy and type stub * fixed tests and updated new pieces * fix agui typevar * fix merge errors * fix merge conflicts * fiux merge * Remove unused links --------- Co-authored-by: Evan Mattson <evan.mattson@microsoft.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
a4c9e43afb
commit
0521f5bed8
@@ -26,13 +26,13 @@ Example:
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from agent_framework import ChatAgent, HostedCodeInterpreterTool, HostedWebSearchTool
|
||||
from agent_framework import Agent, HostedCodeInterpreterTool, HostedWebSearchTool
|
||||
from agent_framework.azure import AzureAIAgentClient
|
||||
from azure.identity.aio import AzureCliCredential
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def create_gaia_agent() -> AsyncIterator[ChatAgent]:
|
||||
async def create_gaia_agent() -> AsyncIterator[Agent]:
|
||||
"""Create an Azure AI agent configured for GAIA benchmark tasks.
|
||||
|
||||
The agent is configured with:
|
||||
@@ -40,7 +40,7 @@ async def create_gaia_agent() -> AsyncIterator[ChatAgent]:
|
||||
- Code Interpreter tool for calculations and data analysis
|
||||
|
||||
Yields:
|
||||
ChatAgent: A configured agent ready to run GAIA tasks.
|
||||
Agent: A configured agent ready to run GAIA tasks.
|
||||
|
||||
Example:
|
||||
async with create_gaia_agent() as agent:
|
||||
|
||||
@@ -25,12 +25,12 @@ Example:
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from agent_framework import ChatAgent, HostedCodeInterpreterTool, HostedWebSearchTool
|
||||
from agent_framework import Agent, HostedCodeInterpreterTool, HostedWebSearchTool
|
||||
from agent_framework.openai import OpenAIResponsesClient
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def create_gaia_agent() -> AsyncIterator[ChatAgent]:
|
||||
async def create_gaia_agent() -> AsyncIterator[Agent]:
|
||||
"""Create an OpenAI agent configured for GAIA benchmark tasks.
|
||||
|
||||
Uses OpenAI Responses API for enhanced capabilities.
|
||||
@@ -40,16 +40,16 @@ async def create_gaia_agent() -> AsyncIterator[ChatAgent]:
|
||||
- Code Interpreter tool for calculations and data analysis
|
||||
|
||||
Yields:
|
||||
ChatAgent: A configured agent ready to run GAIA tasks.
|
||||
Agent: A configured agent ready to run GAIA tasks.
|
||||
|
||||
Example:
|
||||
async with create_gaia_agent() as agent:
|
||||
result = await agent.run("What is the capital of France?")
|
||||
print(result.text)
|
||||
"""
|
||||
chat_client = OpenAIResponsesClient()
|
||||
client = OpenAIResponsesClient()
|
||||
|
||||
async with chat_client.as_agent(
|
||||
async with client.as_agent(
|
||||
name="GaiaAgent",
|
||||
instructions="Solve tasks to your best ability. Use Web Search to find "
|
||||
"information and Code Interpreter to perform calculations and data analysis.",
|
||||
|
||||
@@ -49,8 +49,8 @@ async def math_agent(task: TaskType, llm: LLM) -> float:
|
||||
"""A function that solves a math problem and returns the evaluation score."""
|
||||
async with (
|
||||
MCPStdioTool(name="calculator", command="uvx", args=["mcp-server-calculator"]) as mcp_server,
|
||||
ChatAgent(
|
||||
chat_client=OpenAIChatClient(
|
||||
Agent(
|
||||
client=OpenAIChatClient(
|
||||
model_id=llm.model,
|
||||
api_key="your-api-key",
|
||||
base_url=llm.endpoint,
|
||||
|
||||
@@ -20,7 +20,7 @@ import string
|
||||
from typing import TypedDict, cast
|
||||
|
||||
import sympy # type: ignore[import-untyped,reportMissingImports]
|
||||
from agent_framework import AgentResponse, ChatAgent, MCPStdioTool
|
||||
from agent_framework import Agent, AgentResponse, MCPStdioTool
|
||||
from agent_framework.lab.lightning import AgentFrameworkTracer
|
||||
from agent_framework.openai import OpenAIChatClient
|
||||
from agentlightning import LLM, Dataset, Trainer, rollout
|
||||
@@ -166,8 +166,8 @@ async def math_agent(task: MathProblem, llm: LLM) -> float:
|
||||
# MCPStdioTool provides calculator functionality via MCP protocol
|
||||
async with (
|
||||
MCPStdioTool(name="calculator", command="uvx", args=["mcp-server-calculator"]) as mcp_server,
|
||||
ChatAgent(
|
||||
chat_client=OpenAIChatClient(
|
||||
Agent(
|
||||
client=OpenAIChatClient(
|
||||
model_id=llm.model, # This is the model being trained
|
||||
api_key=os.getenv("OPENAI_API_KEY") or "dummy", # Can be dummy when connecting to training LLM
|
||||
base_url=llm.endpoint, # vLLM server endpoint provided by agent-lightning
|
||||
|
||||
@@ -9,7 +9,7 @@ import pytest
|
||||
|
||||
agentlightning = pytest.importorskip("agentlightning")
|
||||
|
||||
from agent_framework import AgentExecutor, AgentResponse, ChatAgent, WorkflowBuilder, Workflow
|
||||
from agent_framework import AgentExecutor, AgentResponse, Agent, WorkflowBuilder, Workflow
|
||||
from agent_framework_lab_lightning import AgentFrameworkTracer
|
||||
from agent_framework.openai import OpenAIChatClient
|
||||
from agentlightning import TracerTraceToTriplet
|
||||
@@ -80,14 +80,14 @@ def workflow_two_agents():
|
||||
),
|
||||
):
|
||||
# Create the two agents
|
||||
analyzer_agent = ChatAgent(
|
||||
chat_client=first_chat_client,
|
||||
analyzer_agent = Agent(
|
||||
client=first_chat_client,
|
||||
name="DataAnalyzer",
|
||||
instructions="You are a data analyst. Analyze the given data and provide insights.",
|
||||
)
|
||||
|
||||
advisor_agent = ChatAgent(
|
||||
chat_client=second_chat_client,
|
||||
advisor_agent = Agent(
|
||||
client=second_chat_client,
|
||||
name="InvestmentAdvisor",
|
||||
instructions="You are an investment advisor. Based on analysis results, provide recommendations.",
|
||||
)
|
||||
|
||||
@@ -138,21 +138,21 @@ export OPENAI_BASE_URL="https://your-custom-endpoint.com/v1"
|
||||
|
||||
```python
|
||||
from agent_framework.lab.tau2 import TaskRunner
|
||||
from agent_framework import ChatAgent
|
||||
from agent_framework import Agent
|
||||
|
||||
class CustomTaskRunner(TaskRunner):
|
||||
def assistant_agent(self, assistant_chat_client):
|
||||
# Override to customize the assistant agent
|
||||
return ChatAgent(
|
||||
chat_client=assistant_chat_client,
|
||||
return Agent(
|
||||
client=assistant_chat_client,
|
||||
instructions="Your custom system prompt here",
|
||||
# Add custom tools, temperature, etc.
|
||||
)
|
||||
|
||||
def user_simulator(self, user_chat_client, task):
|
||||
# Override to customize the user simulator
|
||||
return ChatAgent(
|
||||
chat_client=user_chat_client,
|
||||
return Agent(
|
||||
client=user_chat_client,
|
||||
instructions="Custom user simulator prompt",
|
||||
)
|
||||
```
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from typing import Any
|
||||
|
||||
from agent_framework._types import ChatMessage, Content
|
||||
from agent_framework._types import Content, Message
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ def _get_role_value(role: Any) -> str:
|
||||
return role.value if hasattr(role, "value") else str(role)
|
||||
|
||||
|
||||
def flip_messages(messages: list[ChatMessage]) -> list[ChatMessage]:
|
||||
def flip_messages(messages: list[Message]) -> list[Message]:
|
||||
"""Flip message roles between assistant and user for role-playing scenarios.
|
||||
|
||||
Used in agent simulations where the assistant's messages become user inputs
|
||||
@@ -30,7 +30,7 @@ def flip_messages(messages: list[ChatMessage]) -> list[ChatMessage]:
|
||||
# Flip assistant to user
|
||||
contents = filter_out_function_calls(msg.contents)
|
||||
if contents:
|
||||
flipped_msg = ChatMessage(
|
||||
flipped_msg = Message(
|
||||
role="user",
|
||||
# The function calls will cause 400 when role is user
|
||||
contents=contents,
|
||||
@@ -40,7 +40,7 @@ def flip_messages(messages: list[ChatMessage]) -> list[ChatMessage]:
|
||||
flipped_messages.append(flipped_msg)
|
||||
elif role_value == "user":
|
||||
# Flip user to assistant
|
||||
flipped_msg = ChatMessage(
|
||||
flipped_msg = Message(
|
||||
role="assistant", contents=msg.contents, author_name=msg.author_name, message_id=msg.message_id
|
||||
)
|
||||
flipped_messages.append(flipped_msg)
|
||||
@@ -53,7 +53,7 @@ def flip_messages(messages: list[ChatMessage]) -> list[ChatMessage]:
|
||||
return flipped_messages
|
||||
|
||||
|
||||
def log_messages(messages: list[ChatMessage]) -> None:
|
||||
def log_messages(messages: list[Message]) -> None:
|
||||
"""Log messages with colored output based on role and content type.
|
||||
|
||||
Provides visual debugging by color-coding different message roles and
|
||||
|
||||
@@ -5,7 +5,7 @@ from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import tiktoken
|
||||
from agent_framework import ChatMessage, ChatMessageStore
|
||||
from agent_framework import ChatMessageStore, Message
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ class SlidingWindowChatMessageStore(ChatMessageStore):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
messages: Sequence[ChatMessage] | None = None,
|
||||
messages: Sequence[Message] | None = None,
|
||||
max_tokens: int = 3800,
|
||||
system_message: str | None = None,
|
||||
tool_definitions: Any | None = None,
|
||||
@@ -32,17 +32,17 @@ class SlidingWindowChatMessageStore(ChatMessageStore):
|
||||
# An estimation based on a commonly used vocab table
|
||||
self.encoding = tiktoken.get_encoding("o200k_base")
|
||||
|
||||
async def add_messages(self, messages: Sequence[ChatMessage]) -> None:
|
||||
async def add_messages(self, messages: Sequence[Message]) -> None:
|
||||
await super().add_messages(messages)
|
||||
|
||||
self.truncated_messages = self.messages.copy()
|
||||
self.truncate_messages()
|
||||
|
||||
async def list_messages(self) -> list[ChatMessage]:
|
||||
async def list_messages(self) -> list[Message]:
|
||||
"""Get the current list of messages, which may be truncated."""
|
||||
return self.truncated_messages
|
||||
|
||||
async def list_all_messages(self) -> list[ChatMessage]:
|
||||
async def list_all_messages(self) -> list[Message]:
|
||||
"""Get all messages from the store including the truncated ones."""
|
||||
return self.messages
|
||||
|
||||
|
||||
@@ -7,17 +7,19 @@ from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from agent_framework._tools import FunctionTool
|
||||
from agent_framework._types import ChatMessage
|
||||
from agent_framework._types import Message
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
from tau2.data_model.message import ( # type: ignore[import-untyped]
|
||||
AssistantMessage,
|
||||
Message,
|
||||
SystemMessage,
|
||||
ToolCall,
|
||||
ToolMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from tau2.data_model.message import (
|
||||
Message as Tau2Message,
|
||||
)
|
||||
from tau2.data_model.tasks import EnvFunctionCall, InitializationData # type: ignore[import-untyped]
|
||||
from tau2.environment.environment import Environment # type: ignore[import-untyped]
|
||||
from tau2.environment.tool import Tool # type: ignore[import-untyped]
|
||||
@@ -45,7 +47,7 @@ def convert_tau2_tool_to_function_tool(tau2_tool: Tool) -> FunctionTool[Any, Any
|
||||
)
|
||||
|
||||
|
||||
def convert_agent_framework_messages_to_tau2_messages(messages: list[ChatMessage]) -> list[Message]:
|
||||
def convert_agent_framework_messages_to_tau2_messages(messages: list[Message]) -> list[Tau2Message]:
|
||||
"""Convert agent framework ChatMessages to tau2 Message objects.
|
||||
|
||||
Handles role mapping, text extraction, function calls, and function results.
|
||||
@@ -119,13 +121,13 @@ def patch_env_set_state() -> None:
|
||||
self: Any,
|
||||
initialization_data: InitializationData | None,
|
||||
initialization_actions: list[EnvFunctionCall] | None,
|
||||
message_history: list[Message],
|
||||
message_history: list[Tau2Message],
|
||||
) -> None:
|
||||
if self.solo_mode and any(isinstance(message, UserMessage) for message in message_history):
|
||||
raise ValueError("User messages are not allowed in solo mode")
|
||||
|
||||
def get_actions_from_messages(
|
||||
messages: list[Message],
|
||||
messages: list[Tau2Message],
|
||||
) -> list[tuple[ToolCall, ToolMessage]]:
|
||||
"""Get the actions from the messages."""
|
||||
messages = deepcopy(messages)[::-1]
|
||||
|
||||
@@ -6,14 +6,14 @@ import uuid
|
||||
from typing import cast
|
||||
|
||||
from agent_framework import (
|
||||
Agent,
|
||||
AgentExecutor,
|
||||
AgentExecutorRequest,
|
||||
AgentExecutorResponse,
|
||||
AgentResponse,
|
||||
ChatAgent,
|
||||
ChatClientProtocol,
|
||||
ChatMessage,
|
||||
FunctionExecutor,
|
||||
Message,
|
||||
SupportsChatGetResponse,
|
||||
Workflow,
|
||||
WorkflowBuilder,
|
||||
WorkflowContext,
|
||||
@@ -67,10 +67,10 @@ class TaskRunner:
|
||||
|
||||
# State tracking
|
||||
step_count: int
|
||||
full_conversation: list[ChatMessage]
|
||||
full_conversation: list[Message]
|
||||
termination_reason: TerminationReason | None
|
||||
full_reward_info: RewardInfo | None
|
||||
_final_user_message: list[ChatMessage] | None
|
||||
_final_user_message: list[Message] | None
|
||||
_assistant_executor: AgentExecutor | None
|
||||
_user_executor: AgentExecutor | None
|
||||
|
||||
@@ -159,7 +159,7 @@ class TaskRunner:
|
||||
"""Check if user wants to stop the conversation."""
|
||||
return STOP in text or TRANSFER in text or OUT_OF_SCOPE in text
|
||||
|
||||
def assistant_agent(self, assistant_chat_client: ChatClientProtocol) -> ChatAgent:
|
||||
def assistant_agent(self, assistant_chat_client: SupportsChatGetResponse) -> Agent:
|
||||
"""Create an assistant agent.
|
||||
|
||||
Users can override this method to provide a custom assistant agent.
|
||||
@@ -196,8 +196,8 @@ class TaskRunner:
|
||||
# - Access to all domain tools (booking, cancellation, etc.)
|
||||
# - Sliding window memory to handle long conversations within token limits
|
||||
# - Temperature-controlled response generation
|
||||
return ChatAgent(
|
||||
chat_client=assistant_chat_client,
|
||||
return Agent(
|
||||
client=assistant_chat_client,
|
||||
instructions=assistant_system_prompt,
|
||||
tools=tools,
|
||||
temperature=self.assistant_sampling_temperature,
|
||||
@@ -208,7 +208,7 @@ class TaskRunner:
|
||||
),
|
||||
)
|
||||
|
||||
def user_simulator(self, user_simuator_chat_client: ChatClientProtocol, task: Task) -> ChatAgent:
|
||||
def user_simulator(self, user_simuator_chat_client: SupportsChatGetResponse, task: Task) -> Agent:
|
||||
"""Create a user simulator agent.
|
||||
|
||||
Users can override this method to provide a custom user simulator agent.
|
||||
@@ -230,8 +230,8 @@ class TaskRunner:
|
||||
{task.user_scenario.instructions}
|
||||
</scenario>"""
|
||||
|
||||
return ChatAgent(
|
||||
chat_client=user_simuator_chat_client,
|
||||
return Agent(
|
||||
client=user_simuator_chat_client,
|
||||
instructions=user_sim_system_prompt,
|
||||
temperature=0.0,
|
||||
# No sliding window for user simulator to maintain full conversation context
|
||||
@@ -268,7 +268,7 @@ class TaskRunner:
|
||||
target_id=USER_SIMULATOR_ID if is_from_agent else ASSISTANT_AGENT_ID,
|
||||
)
|
||||
|
||||
def build_conversation_workflow(self, assistant_agent: ChatAgent, user_simulator_agent: ChatAgent) -> Workflow:
|
||||
def build_conversation_workflow(self, assistant_agent: Agent, user_simulator_agent: Agent) -> Workflow:
|
||||
"""Build the conversation workflow.
|
||||
|
||||
Users can override this method to provide a custom conversation workflow.
|
||||
@@ -304,9 +304,9 @@ class TaskRunner:
|
||||
async def run(
|
||||
self,
|
||||
task: Task,
|
||||
assistant_chat_client: ChatClientProtocol,
|
||||
user_simulator_chat_client: ChatClientProtocol,
|
||||
) -> list[ChatMessage]:
|
||||
assistant_chat_client: SupportsChatGetResponse,
|
||||
user_simulator_chat_client: SupportsChatGetResponse,
|
||||
) -> list[Message]:
|
||||
"""Run a tau2 task using workflow-based agent orchestration.
|
||||
|
||||
This method orchestrates a complex multi-agent simulation:
|
||||
@@ -323,7 +323,7 @@ class TaskRunner:
|
||||
user_simulator_chat_client: LLM client for the user simulator
|
||||
|
||||
Returns:
|
||||
Complete conversation history as ChatMessage list for evaluation
|
||||
Complete conversation history as Message list for evaluation
|
||||
"""
|
||||
logger.info(f"Starting workflow agent for task {task.id}: {task.description.purpose}") # type: ignore[unused-ignore]
|
||||
logger.info(f"Assistant chat client: {assistant_chat_client}")
|
||||
@@ -340,11 +340,11 @@ class TaskRunner:
|
||||
# Matches tau2's expected conversation start pattern
|
||||
logger.info(f"Starting workflow with hardcoded greeting: '{DEFAULT_FIRST_AGENT_MESSAGE}'")
|
||||
|
||||
first_message = ChatMessage(role="assistant", text=DEFAULT_FIRST_AGENT_MESSAGE)
|
||||
first_message = Message(role="assistant", text=DEFAULT_FIRST_AGENT_MESSAGE)
|
||||
initial_greeting = AgentExecutorResponse(
|
||||
executor_id=ASSISTANT_AGENT_ID,
|
||||
agent_response=AgentResponse(messages=[first_message]),
|
||||
full_conversation=[ChatMessage(role="assistant", text=DEFAULT_FIRST_AGENT_MESSAGE)],
|
||||
full_conversation=[Message(role="assistant", text=DEFAULT_FIRST_AGENT_MESSAGE)],
|
||||
)
|
||||
|
||||
# STEP 4: Execute the workflow and collect results
|
||||
@@ -371,7 +371,7 @@ class TaskRunner:
|
||||
return full_conversation
|
||||
|
||||
def evaluate(
|
||||
self, task_input: Task, conversation: list[ChatMessage], termination_reason: TerminationReason | None
|
||||
self, task_input: Task, conversation: list[Message], termination_reason: TerminationReason | None
|
||||
) -> float:
|
||||
"""Evaluate agent performance using tau2's comprehensive evaluation system.
|
||||
|
||||
|
||||
@@ -2,14 +2,14 @@
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from agent_framework._types import ChatMessage, Content
|
||||
from agent_framework._types import Content, Message
|
||||
from agent_framework_lab_tau2._message_utils import flip_messages, log_messages
|
||||
|
||||
|
||||
def test_flip_messages_user_to_assistant():
|
||||
"""Test flipping user message to assistant."""
|
||||
messages = [
|
||||
ChatMessage(
|
||||
Message(
|
||||
role="user",
|
||||
contents=[Content.from_text(text="Hello assistant")],
|
||||
author_name="User1",
|
||||
@@ -29,7 +29,7 @@ def test_flip_messages_user_to_assistant():
|
||||
def test_flip_messages_assistant_to_user():
|
||||
"""Test flipping assistant message to user."""
|
||||
messages = [
|
||||
ChatMessage(
|
||||
Message(
|
||||
role="assistant",
|
||||
contents=[Content.from_text(text="Hello user")],
|
||||
author_name="Assistant1",
|
||||
@@ -51,7 +51,7 @@ def test_flip_messages_assistant_with_function_calls_filtered():
|
||||
function_call = Content.from_function_call(call_id="call_123", name="test_function", arguments={"param": "value"})
|
||||
|
||||
messages = [
|
||||
ChatMessage(
|
||||
Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_text(text="I'll call a function"),
|
||||
@@ -78,7 +78,7 @@ def test_flip_messages_assistant_with_only_function_calls_skipped():
|
||||
function_call = Content.from_function_call(call_id="call_456", name="another_function", arguments={"key": "value"})
|
||||
|
||||
messages = [
|
||||
ChatMessage(role="assistant", contents=[function_call], message_id="msg_004") # Only function call, no text
|
||||
Message(role="assistant", contents=[function_call], message_id="msg_004") # Only function call, no text
|
||||
]
|
||||
|
||||
flipped = flip_messages(messages)
|
||||
@@ -91,7 +91,7 @@ def test_flip_messages_tool_messages_skipped():
|
||||
"""Test that tool messages are skipped."""
|
||||
function_result = Content.from_function_result(call_id="call_789", result={"success": True})
|
||||
|
||||
messages = [ChatMessage(role="tool", contents=[function_result])]
|
||||
messages = [Message(role="tool", contents=[function_result])]
|
||||
|
||||
flipped = flip_messages(messages)
|
||||
|
||||
@@ -101,9 +101,7 @@ def test_flip_messages_tool_messages_skipped():
|
||||
|
||||
def test_flip_messages_system_messages_preserved():
|
||||
"""Test that system messages are preserved as-is."""
|
||||
messages = [
|
||||
ChatMessage(role="system", contents=[Content.from_text(text="System instruction")], message_id="sys_001")
|
||||
]
|
||||
messages = [Message(role="system", contents=[Content.from_text(text="System instruction")], message_id="sys_001")]
|
||||
|
||||
flipped = flip_messages(messages)
|
||||
|
||||
@@ -120,11 +118,11 @@ def test_flip_messages_mixed_conversation():
|
||||
function_result = Content.from_function_result(call_id="call_mixed", result="function result")
|
||||
|
||||
messages = [
|
||||
ChatMessage(role="system", contents=[Content.from_text(text="System prompt")]),
|
||||
ChatMessage(role="user", contents=[Content.from_text(text="User question")]),
|
||||
ChatMessage(role="assistant", contents=[Content.from_text(text="Assistant response"), function_call]),
|
||||
ChatMessage(role="tool", contents=[function_result]),
|
||||
ChatMessage(role="assistant", contents=[Content.from_text(text="Final response")]),
|
||||
Message(role="system", contents=[Content.from_text(text="System prompt")]),
|
||||
Message(role="user", contents=[Content.from_text(text="User question")]),
|
||||
Message(role="assistant", contents=[Content.from_text(text="Assistant response"), function_call]),
|
||||
Message(role="tool", contents=[function_result]),
|
||||
Message(role="assistant", contents=[Content.from_text(text="Final response")]),
|
||||
]
|
||||
|
||||
flipped = flip_messages(messages)
|
||||
@@ -159,7 +157,7 @@ def test_flip_messages_empty_list():
|
||||
def test_flip_messages_preserves_metadata():
|
||||
"""Test that message metadata is preserved during flipping."""
|
||||
messages = [
|
||||
ChatMessage(
|
||||
Message(
|
||||
role="user",
|
||||
contents=[Content.from_text(text="Test message")],
|
||||
author_name="TestUser",
|
||||
@@ -178,8 +176,8 @@ def test_flip_messages_preserves_metadata():
|
||||
def test_log_messages_text_content(mock_logger):
|
||||
"""Test logging messages with text content."""
|
||||
messages = [
|
||||
ChatMessage(role="user", contents=[Content.from_text(text="Hello")]),
|
||||
ChatMessage(role="assistant", contents=[Content.from_text(text="Hi there!")]),
|
||||
Message(role="user", contents=[Content.from_text(text="Hello")]),
|
||||
Message(role="assistant", contents=[Content.from_text(text="Hi there!")]),
|
||||
]
|
||||
|
||||
log_messages(messages)
|
||||
@@ -193,7 +191,7 @@ def test_log_messages_function_call(mock_logger):
|
||||
"""Test logging messages with function calls."""
|
||||
function_call = Content.from_function_call(call_id="call_log", name="log_function", arguments={"param": "value"})
|
||||
|
||||
messages = [ChatMessage(role="assistant", contents=[function_call])]
|
||||
messages = [Message(role="assistant", contents=[function_call])]
|
||||
|
||||
log_messages(messages)
|
||||
|
||||
@@ -209,7 +207,7 @@ def test_log_messages_function_result(mock_logger):
|
||||
"""Test logging messages with function results."""
|
||||
function_result = Content.from_function_result(call_id="call_result", result="success")
|
||||
|
||||
messages = [ChatMessage(role="tool", contents=[function_result])]
|
||||
messages = [Message(role="tool", contents=[function_result])]
|
||||
|
||||
log_messages(messages)
|
||||
|
||||
@@ -223,10 +221,10 @@ def test_log_messages_function_result(mock_logger):
|
||||
def test_log_messages_different_roles(mock_logger):
|
||||
"""Test logging messages with different roles get different colors."""
|
||||
messages = [
|
||||
ChatMessage(role="system", contents=[Content.from_text(text="System")]),
|
||||
ChatMessage(role="user", contents=[Content.from_text(text="User")]),
|
||||
ChatMessage(role="assistant", contents=[Content.from_text(text="Assistant")]),
|
||||
ChatMessage(role="tool", contents=[Content.from_text(text="Tool")]),
|
||||
Message(role="system", contents=[Content.from_text(text="System")]),
|
||||
Message(role="user", contents=[Content.from_text(text="User")]),
|
||||
Message(role="assistant", contents=[Content.from_text(text="Assistant")]),
|
||||
Message(role="tool", contents=[Content.from_text(text="Tool")]),
|
||||
]
|
||||
|
||||
log_messages(messages)
|
||||
@@ -250,7 +248,7 @@ def test_log_messages_different_roles(mock_logger):
|
||||
@patch("agent_framework_lab_tau2._message_utils.logger")
|
||||
def test_log_messages_escapes_html(mock_logger):
|
||||
"""Test that HTML-like characters are properly escaped in log output."""
|
||||
messages = [ChatMessage(role="user", contents=[Content.from_text(text="Message with <tag> content")])]
|
||||
messages = [Message(role="user", contents=[Content.from_text(text="Message with <tag> content")])]
|
||||
|
||||
log_messages(messages)
|
||||
|
||||
@@ -266,7 +264,7 @@ def test_log_messages_mixed_content_types(mock_logger):
|
||||
function_call = Content.from_function_call(call_id="mixed_call", name="mixed_function", arguments={"key": "value"})
|
||||
|
||||
messages = [
|
||||
ChatMessage(
|
||||
Message(
|
||||
role="assistant",
|
||||
contents=[Content.from_text(text="I'll call a function"), function_call, Content.from_text(text="Done!")],
|
||||
)
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from agent_framework._types import ChatMessage, Content
|
||||
from agent_framework._types import Content, Message
|
||||
from agent_framework_lab_tau2._sliding_window import SlidingWindowChatMessageStore
|
||||
|
||||
|
||||
@@ -36,8 +36,8 @@ def test_initialization_with_parameters():
|
||||
def test_initialization_with_messages():
|
||||
"""Test initializing with existing messages."""
|
||||
messages = [
|
||||
ChatMessage(role="user", contents=[Content.from_text(text="Hello")]),
|
||||
ChatMessage(role="assistant", contents=[Content.from_text(text="Hi there!")]),
|
||||
Message(role="user", contents=[Content.from_text(text="Hello")]),
|
||||
Message(role="assistant", contents=[Content.from_text(text="Hi there!")]),
|
||||
]
|
||||
|
||||
sliding_window = SlidingWindowChatMessageStore(messages=messages, max_tokens=1000)
|
||||
@@ -51,8 +51,8 @@ async def test_add_messages_simple():
|
||||
sliding_window = SlidingWindowChatMessageStore(max_tokens=10000) # Large limit
|
||||
|
||||
new_messages = [
|
||||
ChatMessage(role="user", contents=[Content.from_text(text="What's the weather?")]),
|
||||
ChatMessage(role="assistant", contents=[Content.from_text(text="I can help with that.")]),
|
||||
Message(role="user", contents=[Content.from_text(text="What's the weather?")]),
|
||||
Message(role="assistant", contents=[Content.from_text(text="I can help with that.")]),
|
||||
]
|
||||
|
||||
await sliding_window.add_messages(new_messages)
|
||||
@@ -69,7 +69,7 @@ async def test_list_all_messages_vs_list_messages():
|
||||
|
||||
# Add many messages to trigger truncation
|
||||
messages = [
|
||||
ChatMessage(role="user", contents=[Content.from_text(text=f"Message {i} with some content")]) for i in range(10)
|
||||
Message(role="user", contents=[Content.from_text(text=f"Message {i} with some content")]) for i in range(10)
|
||||
]
|
||||
|
||||
await sliding_window.add_messages(messages)
|
||||
@@ -87,7 +87,7 @@ async def test_list_all_messages_vs_list_messages():
|
||||
def test_get_token_count_basic():
|
||||
"""Test basic token counting."""
|
||||
sliding_window = SlidingWindowChatMessageStore(max_tokens=1000)
|
||||
sliding_window.truncated_messages = [ChatMessage(role="user", contents=[Content.from_text(text="Hello")])]
|
||||
sliding_window.truncated_messages = [Message(role="user", contents=[Content.from_text(text="Hello")])]
|
||||
|
||||
token_count = sliding_window.get_token_count()
|
||||
|
||||
@@ -104,7 +104,7 @@ def test_get_token_count_with_system_message():
|
||||
token_count_empty = sliding_window.get_token_count()
|
||||
|
||||
# Add a message
|
||||
sliding_window.truncated_messages = [ChatMessage(role="user", contents=[Content.from_text(text="Hello")])]
|
||||
sliding_window.truncated_messages = [Message(role="user", contents=[Content.from_text(text="Hello")])]
|
||||
token_count_with_message = sliding_window.get_token_count()
|
||||
|
||||
# With message should be more tokens
|
||||
@@ -117,7 +117,7 @@ def test_get_token_count_function_call():
|
||||
function_call = Content.from_function_call(call_id="call_123", name="test_function", arguments={"param": "value"})
|
||||
|
||||
sliding_window = SlidingWindowChatMessageStore(max_tokens=1000)
|
||||
sliding_window.truncated_messages = [ChatMessage(role="assistant", contents=[function_call])]
|
||||
sliding_window.truncated_messages = [Message(role="assistant", contents=[function_call])]
|
||||
|
||||
token_count = sliding_window.get_token_count()
|
||||
assert token_count > 0
|
||||
@@ -128,7 +128,7 @@ def test_get_token_count_function_result():
|
||||
function_result = Content.from_function_result(call_id="call_123", result={"success": True, "data": "result"})
|
||||
|
||||
sliding_window = SlidingWindowChatMessageStore(max_tokens=1000)
|
||||
sliding_window.truncated_messages = [ChatMessage(role="tool", contents=[function_result])]
|
||||
sliding_window.truncated_messages = [Message(role="tool", contents=[function_result])]
|
||||
|
||||
token_count = sliding_window.get_token_count()
|
||||
assert token_count > 0
|
||||
@@ -141,17 +141,17 @@ def test_truncate_messages_removes_old_messages(mock_logger):
|
||||
|
||||
# Create messages that will exceed the limit
|
||||
messages = [
|
||||
ChatMessage(
|
||||
Message(
|
||||
role="user",
|
||||
contents=[Content.from_text(text="This is a very long message that should exceed the token limit")],
|
||||
),
|
||||
ChatMessage(
|
||||
Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_text(text="This is another very long message that should also exceed the token limit")
|
||||
],
|
||||
),
|
||||
ChatMessage(role="user", contents=[Content.from_text(text="Short msg")]),
|
||||
Message(role="user", contents=[Content.from_text(text="Short msg")]),
|
||||
]
|
||||
|
||||
sliding_window.truncated_messages = messages.copy()
|
||||
@@ -170,10 +170,8 @@ def test_truncate_messages_removes_leading_tool_messages(mock_logger):
|
||||
sliding_window = SlidingWindowChatMessageStore(max_tokens=10000) # Large limit
|
||||
|
||||
# Create messages starting with tool message
|
||||
tool_message = ChatMessage(
|
||||
role="tool", contents=[Content.from_function_result(call_id="call_123", result="result")]
|
||||
)
|
||||
user_message = ChatMessage(role="user", contents=[Content.from_text(text="Hello")])
|
||||
tool_message = Message(role="tool", contents=[Content.from_function_result(call_id="call_123", result="result")])
|
||||
user_message = Message(role="user", contents=[Content.from_text(text="Hello")])
|
||||
|
||||
sliding_window.truncated_messages = [tool_message, user_message]
|
||||
sliding_window.truncate_messages()
|
||||
@@ -231,13 +229,13 @@ async def test_real_world_scenario():
|
||||
|
||||
# Simulate a conversation
|
||||
conversation = [
|
||||
ChatMessage(role="user", contents=[Content.from_text(text="Hello, how are you?")]),
|
||||
ChatMessage(
|
||||
Message(role="user", contents=[Content.from_text(text="Hello, how are you?")]),
|
||||
Message(
|
||||
role="assistant",
|
||||
contents=[Content.from_text(text="I'm doing well, thank you! How can I help you today?")],
|
||||
),
|
||||
ChatMessage(role="user", contents=[Content.from_text(text="Can you tell me about the weather?")]),
|
||||
ChatMessage(
|
||||
Message(role="user", contents=[Content.from_text(text="Can you tell me about the weather?")]),
|
||||
Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_text(
|
||||
@@ -246,8 +244,8 @@ async def test_real_world_scenario():
|
||||
)
|
||||
],
|
||||
),
|
||||
ChatMessage(role="user", contents=[Content.from_text(text="What about telling me a joke instead?")]),
|
||||
ChatMessage(
|
||||
Message(role="user", contents=[Content.from_text(text="What about telling me a joke instead?")]),
|
||||
Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_text(text="Sure! Why don't scientists trust atoms? Because they make up everything!")
|
||||
|
||||
@@ -6,7 +6,7 @@ import urllib.request
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from agent_framework import ChatMessage, Content, FunctionTool
|
||||
from agent_framework import Content, FunctionTool, Message
|
||||
from agent_framework_lab_tau2._tau2_utils import (
|
||||
convert_agent_framework_messages_to_tau2_messages,
|
||||
convert_tau2_tool_to_function_tool,
|
||||
@@ -91,7 +91,7 @@ def test_convert_tau2_tool_to_function_tool_multiple_tools(tau2_airline_environm
|
||||
|
||||
def test_convert_agent_framework_messages_to_tau2_messages_system():
|
||||
"""Test converting system message."""
|
||||
messages = [ChatMessage(role="system", contents=[Content.from_text(text="System instruction")])]
|
||||
messages = [Message(role="system", contents=[Content.from_text(text="System instruction")])]
|
||||
|
||||
tau2_messages = convert_agent_framework_messages_to_tau2_messages(messages)
|
||||
|
||||
@@ -103,7 +103,7 @@ def test_convert_agent_framework_messages_to_tau2_messages_system():
|
||||
|
||||
def test_convert_agent_framework_messages_to_tau2_messages_user():
|
||||
"""Test converting user message."""
|
||||
messages = [ChatMessage(role="user", contents=[Content.from_text(text="Hello assistant")])]
|
||||
messages = [Message(role="user", contents=[Content.from_text(text="Hello assistant")])]
|
||||
|
||||
tau2_messages = convert_agent_framework_messages_to_tau2_messages(messages)
|
||||
|
||||
@@ -116,7 +116,7 @@ def test_convert_agent_framework_messages_to_tau2_messages_user():
|
||||
|
||||
def test_convert_agent_framework_messages_to_tau2_messages_assistant():
|
||||
"""Test converting assistant message."""
|
||||
messages = [ChatMessage(role="assistant", contents=[Content.from_text(text="Hello user")])]
|
||||
messages = [Message(role="assistant", contents=[Content.from_text(text="Hello user")])]
|
||||
|
||||
tau2_messages = convert_agent_framework_messages_to_tau2_messages(messages)
|
||||
|
||||
@@ -131,7 +131,7 @@ def test_convert_agent_framework_messages_to_tau2_messages_with_function_call():
|
||||
"""Test converting message with function call."""
|
||||
function_call = Content.from_function_call(call_id="call_123", name="test_function", arguments={"param": "value"})
|
||||
|
||||
messages = [ChatMessage(role="assistant", contents=[Content.from_text(text="I'll call a function"), function_call])]
|
||||
messages = [Message(role="assistant", contents=[Content.from_text(text="I'll call a function"), function_call])]
|
||||
|
||||
tau2_messages = convert_agent_framework_messages_to_tau2_messages(messages)
|
||||
|
||||
@@ -153,7 +153,7 @@ def test_convert_agent_framework_messages_to_tau2_messages_with_function_result(
|
||||
"""Test converting message with function result."""
|
||||
function_result = Content.from_function_result(call_id="call_123", result={"success": True, "data": "result data"})
|
||||
|
||||
messages = [ChatMessage(role="tool", contents=[function_result])]
|
||||
messages = [Message(role="tool", contents=[function_result])]
|
||||
|
||||
tau2_messages = convert_agent_framework_messages_to_tau2_messages(messages)
|
||||
|
||||
@@ -173,7 +173,7 @@ def test_convert_agent_framework_messages_to_tau2_messages_with_error():
|
||||
call_id="call_456", result="Error occurred", exception=Exception("Test error")
|
||||
)
|
||||
|
||||
messages = [ChatMessage(role="tool", contents=[function_result])]
|
||||
messages = [Message(role="tool", contents=[function_result])]
|
||||
|
||||
tau2_messages = convert_agent_framework_messages_to_tau2_messages(messages)
|
||||
|
||||
@@ -185,7 +185,7 @@ def test_convert_agent_framework_messages_to_tau2_messages_with_error():
|
||||
def test_convert_agent_framework_messages_to_tau2_messages_multiple_text_contents():
|
||||
"""Test converting message with multiple text contents."""
|
||||
messages = [
|
||||
ChatMessage(role="user", contents=[Content.from_text(text="First part"), Content.from_text(text="Second part")])
|
||||
Message(role="user", contents=[Content.from_text(text="First part"), Content.from_text(text="Second part")])
|
||||
]
|
||||
|
||||
tau2_messages = convert_agent_framework_messages_to_tau2_messages(messages)
|
||||
@@ -202,11 +202,11 @@ def test_convert_agent_framework_messages_to_tau2_messages_complex_scenario():
|
||||
function_result = Content.from_function_result(call_id="call_789", result={"output": "tool result"})
|
||||
|
||||
messages = [
|
||||
ChatMessage(role="system", contents=[Content.from_text(text="System prompt")]),
|
||||
ChatMessage(role="user", contents=[Content.from_text(text="User request")]),
|
||||
ChatMessage(role="assistant", contents=[Content.from_text(text="I'll help you"), function_call]),
|
||||
ChatMessage(role="tool", contents=[function_result]),
|
||||
ChatMessage(role="assistant", contents=[Content.from_text(text="Based on the result...")]),
|
||||
Message(role="system", contents=[Content.from_text(text="System prompt")]),
|
||||
Message(role="user", contents=[Content.from_text(text="User request")]),
|
||||
Message(role="assistant", contents=[Content.from_text(text="I'll help you"), function_call]),
|
||||
Message(role="tool", contents=[function_result]),
|
||||
Message(role="assistant", contents=[Content.from_text(text="Based on the result...")]),
|
||||
]
|
||||
|
||||
tau2_messages = convert_agent_framework_messages_to_tau2_messages(messages)
|
||||
|
||||
Reference in New Issue
Block a user