mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: ChatKit sample fixes (#2174)
* sample fixes * Update thread naming
This commit is contained in:
committed by
GitHub
Unverified
parent
15d0bda8a2
commit
a75590eb9b
@@ -16,11 +16,43 @@ from random import randint
|
||||
from typing import Annotated, Any
|
||||
|
||||
import uvicorn
|
||||
|
||||
# Agent Framework imports
|
||||
from agent_framework import AgentRunResponseUpdate, ChatAgent, ChatMessage, FunctionResultContent, Role
|
||||
from agent_framework.azure import AzureOpenAIChatClient
|
||||
|
||||
# Agent Framework ChatKit integration
|
||||
from agent_framework_chatkit import ThreadItemConverter, stream_agent_response
|
||||
|
||||
# Local imports
|
||||
from attachment_store import FileBasedAttachmentStore
|
||||
from azure.identity import AzureCliCredential
|
||||
|
||||
# ChatKit imports
|
||||
from chatkit.actions import Action
|
||||
from chatkit.server import ChatKitServer
|
||||
from chatkit.store import StoreItemType, default_generate_id
|
||||
from chatkit.types import (
|
||||
ThreadItem,
|
||||
ThreadItemDoneEvent,
|
||||
ThreadMetadata,
|
||||
ThreadStreamEvent,
|
||||
UserMessageItem,
|
||||
WidgetItem,
|
||||
)
|
||||
from chatkit.widgets import WidgetRoot
|
||||
from fastapi import FastAPI, File, Request, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse, JSONResponse, Response, StreamingResponse
|
||||
from pydantic import Field
|
||||
from store import SQLiteStore
|
||||
from weather_widget import (
|
||||
WeatherData,
|
||||
city_selector_copy_text,
|
||||
render_city_selector_widget,
|
||||
render_weather_widget,
|
||||
weather_widget_copy_text,
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# Configuration Constants
|
||||
@@ -56,37 +88,6 @@ logging.basicConfig(
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Agent Framework imports
|
||||
from agent_framework import AgentRunResponseUpdate, ChatAgent, ChatMessage, FunctionResultContent, Role
|
||||
from agent_framework.azure import AzureOpenAIChatClient
|
||||
|
||||
# Agent Framework ChatKit integration
|
||||
from agent_framework_chatkit import ThreadItemConverter, stream_agent_response
|
||||
|
||||
# Local imports
|
||||
from attachment_store import FileBasedAttachmentStore
|
||||
|
||||
# ChatKit imports
|
||||
from chatkit.actions import Action
|
||||
from chatkit.server import ChatKitServer
|
||||
from chatkit.store import StoreItemType, default_generate_id
|
||||
from chatkit.types import (
|
||||
ThreadItemDoneEvent,
|
||||
ThreadMetadata,
|
||||
ThreadStreamEvent,
|
||||
UserMessageItem,
|
||||
WidgetItem,
|
||||
)
|
||||
from chatkit.widgets import WidgetRoot
|
||||
from store import SQLiteStore
|
||||
from weather_widget import (
|
||||
WeatherData,
|
||||
city_selector_copy_text,
|
||||
render_city_selector_widget,
|
||||
render_weather_widget,
|
||||
weather_widget_copy_text,
|
||||
)
|
||||
|
||||
|
||||
class WeatherResponse(str):
|
||||
"""A string response that also carries WeatherData for widget creation."""
|
||||
@@ -238,6 +239,81 @@ class WeatherChatKitServer(ChatKitServer[dict[str, Any]]):
|
||||
"""
|
||||
return await attachment_store.read_attachment_bytes(attachment_id)
|
||||
|
||||
async def _update_thread_title(
|
||||
self, thread: ThreadMetadata, thread_items: list[ThreadItem], context: dict[str, Any]
|
||||
) -> None:
|
||||
"""Update thread title using LLM to generate a concise summary.
|
||||
|
||||
Args:
|
||||
thread: The thread metadata to update.
|
||||
thread_items: All items in the thread.
|
||||
context: The context dictionary.
|
||||
"""
|
||||
logger.info(f"Attempting to update thread title for thread: {thread.id}")
|
||||
|
||||
if not thread_items:
|
||||
logger.debug("No thread items available for title generation")
|
||||
return
|
||||
|
||||
# Collect user messages to understand the conversation topic
|
||||
user_messages: list[str] = []
|
||||
for item in thread_items:
|
||||
if isinstance(item, UserMessageItem) and item.content:
|
||||
for content_part in item.content:
|
||||
if hasattr(content_part, "text") and isinstance(content_part.text, str):
|
||||
user_messages.append(content_part.text)
|
||||
break
|
||||
|
||||
if not user_messages:
|
||||
logger.debug("No user messages found for title generation")
|
||||
return
|
||||
|
||||
logger.debug(f"Found {len(user_messages)} user message(s) for title generation")
|
||||
|
||||
try:
|
||||
# Use the agent's chat client to generate a concise title
|
||||
# Combine first few messages to capture the conversation topic
|
||||
conversation_context = "\n".join(user_messages[:3])
|
||||
|
||||
title_prompt = [
|
||||
ChatMessage(
|
||||
role=Role.USER,
|
||||
text=(
|
||||
f"Generate a very short, concise title (max 40 characters) for a conversation "
|
||||
f"that starts with:\n\n{conversation_context}\n\n"
|
||||
"Respond with ONLY the title, nothing else."
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
# Use the chat client directly for a quick, lightweight call
|
||||
response = await self.weather_agent.chat_client.get_response(
|
||||
messages=title_prompt,
|
||||
temperature=0.3,
|
||||
max_tokens=20,
|
||||
)
|
||||
|
||||
if response.messages and response.messages[-1].text:
|
||||
title = response.messages[-1].text.strip().strip('"').strip("'")
|
||||
# Ensure it's not too long
|
||||
if len(title) > 50:
|
||||
title = title[:47] + "..."
|
||||
|
||||
thread.title = title
|
||||
await self.store.save_thread(thread, context)
|
||||
logger.info(f"Updated thread {thread.id} title to: {title}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to generate thread title, using fallback: {e}")
|
||||
# Fallback to simple truncation
|
||||
first_message: str = user_messages[0]
|
||||
title: str = first_message[:50].strip()
|
||||
if len(first_message) > 50:
|
||||
title += "..."
|
||||
thread.title = title
|
||||
await self.store.save_thread(thread, context)
|
||||
logger.info(f"Updated thread {thread.id} title to (fallback): {title}")
|
||||
|
||||
async def respond(
|
||||
self,
|
||||
thread: ThreadMetadata,
|
||||
@@ -263,8 +339,19 @@ class WeatherChatKitServer(ChatKitServer[dict[str, Any]]):
|
||||
weather_data: WeatherData | None = None
|
||||
show_city_selector = False
|
||||
|
||||
# Convert ChatKit user message to Agent Framework ChatMessage using ThreadItemConverter
|
||||
agent_messages = await self.converter.to_agent_input(input_user_message)
|
||||
# Load full thread history from the store
|
||||
thread_items_page = await self.store.load_thread_items(
|
||||
thread_id=thread.id,
|
||||
after=None,
|
||||
limit=1000,
|
||||
order="asc",
|
||||
context=context,
|
||||
)
|
||||
thread_items = thread_items_page.data
|
||||
|
||||
# Convert ALL thread items to Agent Framework ChatMessages using ThreadItemConverter
|
||||
# This ensures the agent has the full conversation context
|
||||
agent_messages = await self.converter.to_agent_input(thread_items)
|
||||
|
||||
if not agent_messages:
|
||||
logger.warning("No messages after conversion")
|
||||
@@ -330,6 +417,10 @@ class WeatherChatKitServer(ChatKitServer[dict[str, Any]]):
|
||||
yield widget_event
|
||||
logger.debug("City selector widget streamed successfully")
|
||||
|
||||
# Update thread title based on first user message if not already set
|
||||
if not thread.title or thread.title == "New thread":
|
||||
await self._update_thread_title(thread, thread_items, context)
|
||||
|
||||
logger.info(f"Completed processing message for thread: {thread.id}")
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -8,7 +8,7 @@ cloud storage like S3, Azure Blob Storage, or Google Cloud Storage.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from chatkit.store import AttachmentStore
|
||||
from chatkit.types import Attachment, AttachmentCreateParams, FileAttachment, ImageAttachment
|
||||
@@ -51,7 +51,7 @@ class FileBasedAttachmentStore(AttachmentStore[dict[str, Any]]):
|
||||
self.uploads_dir = Path(uploads_dir)
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.data_store = data_store
|
||||
|
||||
|
||||
# Create uploads directory if it doesn't exist
|
||||
self.uploads_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -65,9 +65,7 @@ class FileBasedAttachmentStore(AttachmentStore[dict[str, Any]]):
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
|
||||
async def create_attachment(
|
||||
self, input: AttachmentCreateParams, context: dict[str, Any]
|
||||
) -> Attachment:
|
||||
async def create_attachment(self, input: AttachmentCreateParams, context: dict[str, Any]) -> Attachment:
|
||||
"""Create an attachment with upload URL for two-phase upload.
|
||||
|
||||
This creates the attachment metadata and returns upload URLs that
|
||||
@@ -75,7 +73,7 @@ class FileBasedAttachmentStore(AttachmentStore[dict[str, Any]]):
|
||||
"""
|
||||
# Generate unique ID for this attachment
|
||||
attachment_id = self.generate_attachment_id(input.mime_type, context)
|
||||
|
||||
|
||||
# Generate upload URL that points to our FastAPI upload endpoint
|
||||
upload_url = f"{self.base_url}/upload/{attachment_id}"
|
||||
|
||||
@@ -83,7 +81,7 @@ class FileBasedAttachmentStore(AttachmentStore[dict[str, Any]]):
|
||||
if input.mime_type.startswith("image/"):
|
||||
# For images, also provide a preview URL
|
||||
preview_url = f"{self.base_url}/preview/{attachment_id}"
|
||||
|
||||
|
||||
attachment = ImageAttachment(
|
||||
id=attachment_id,
|
||||
type="image",
|
||||
@@ -117,5 +115,5 @@ class FileBasedAttachmentStore(AttachmentStore[dict[str, Any]]):
|
||||
file_path = self.get_file_path(attachment_id)
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"Attachment {attachment_id} not found on disk")
|
||||
|
||||
|
||||
return file_path.read_bytes()
|
||||
|
||||
@@ -10,7 +10,7 @@ import sqlite3
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from chatkit.store import Store, NotFoundError
|
||||
from chatkit.store import NotFoundError, Store
|
||||
from chatkit.types import (
|
||||
Attachment,
|
||||
Page,
|
||||
@@ -22,16 +22,19 @@ from pydantic import BaseModel
|
||||
|
||||
class ThreadData(BaseModel):
|
||||
"""Model for serializing thread data to SQLite."""
|
||||
|
||||
thread: ThreadMetadata
|
||||
|
||||
|
||||
class ItemData(BaseModel):
|
||||
"""Model for serializing thread item data to SQLite."""
|
||||
|
||||
item: ThreadItem
|
||||
|
||||
|
||||
class AttachmentData(BaseModel):
|
||||
"""Model for serializing attachment data to SQLite."""
|
||||
|
||||
attachment: Attachment
|
||||
|
||||
|
||||
@@ -185,19 +188,13 @@ class SQLiteStore(Store[dict[str, Any]]):
|
||||
params.append(limit + 1)
|
||||
|
||||
items_cursor = conn.execute(query, params).fetchall()
|
||||
items = [
|
||||
ItemData.model_validate_json(row[0]).item for row in items_cursor
|
||||
]
|
||||
items = [ItemData.model_validate_json(row[0]).item for row in items_cursor]
|
||||
|
||||
has_more = len(items) > limit
|
||||
if has_more:
|
||||
items = items[:limit]
|
||||
|
||||
return Page[ThreadItem](
|
||||
data=items,
|
||||
has_more=has_more,
|
||||
after=items[-1].id if items else None
|
||||
)
|
||||
return Page[ThreadItem](data=items, has_more=has_more, after=items[-1].id if items else None)
|
||||
|
||||
async def save_attachment(self, attachment: Attachment, context: dict[str, Any]) -> None:
|
||||
user_id = context.get("user_id", "demo_user")
|
||||
@@ -270,23 +267,15 @@ class SQLiteStore(Store[dict[str, Any]]):
|
||||
params.append(limit + 1)
|
||||
|
||||
threads_cursor = conn.execute(query, params).fetchall()
|
||||
threads = [
|
||||
ThreadData.model_validate_json(row[0]).thread for row in threads_cursor
|
||||
]
|
||||
threads = [ThreadData.model_validate_json(row[0]).thread for row in threads_cursor]
|
||||
|
||||
has_more = len(threads) > limit
|
||||
if has_more:
|
||||
threads = threads[:limit]
|
||||
|
||||
return Page[ThreadMetadata](
|
||||
data=threads,
|
||||
has_more=has_more,
|
||||
after=threads[-1].id if threads else None
|
||||
)
|
||||
return Page[ThreadMetadata](data=threads, has_more=has_more, after=threads[-1].id if threads else None)
|
||||
|
||||
async def add_thread_item(
|
||||
self, thread_id: str, item: ThreadItem, context: dict[str, Any]
|
||||
) -> None:
|
||||
async def add_thread_item(self, thread_id: str, item: ThreadItem, context: dict[str, Any]) -> None:
|
||||
user_id = context.get("user_id", "demo_user")
|
||||
|
||||
with self._create_connection() as conn:
|
||||
@@ -348,9 +337,7 @@ class SQLiteStore(Store[dict[str, Any]]):
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
async def delete_thread_item(
|
||||
self, thread_id: str, item_id: str, context: dict[str, Any]
|
||||
) -> None:
|
||||
async def delete_thread_item(self, thread_id: str, item_id: str, context: dict[str, Any]) -> None:
|
||||
user_id = context.get("user_id", "demo_user")
|
||||
|
||||
with self._create_connection() as conn:
|
||||
|
||||
@@ -29,7 +29,6 @@ POPULAR_CITIES = [
|
||||
CITY_VALUE_TO_NAME = {city["value"]: city["label"] for city in POPULAR_CITIES}
|
||||
|
||||
|
||||
|
||||
def _sun_svg() -> str:
|
||||
"""Generate SVG for sunny weather icon."""
|
||||
color = WEATHER_ICON_COLOR
|
||||
|
||||
Reference in New Issue
Block a user