Python: ChatKit sample fixes (#2174)

* sample fixes

* Update thread naming
This commit is contained in:
Evan Mattson
2025-11-14 08:02:31 +09:00
committed by GitHub
Unverified
parent 15d0bda8a2
commit a75590eb9b
5 changed files with 151 additions and 67 deletions
+124 -33
View File
@@ -16,11 +16,43 @@ from random import randint
from typing import Annotated, Any
import uvicorn
# Agent Framework imports
from agent_framework import AgentRunResponseUpdate, ChatAgent, ChatMessage, FunctionResultContent, Role
from agent_framework.azure import AzureOpenAIChatClient
# Agent Framework ChatKit integration
from agent_framework_chatkit import ThreadItemConverter, stream_agent_response
# Local imports
from attachment_store import FileBasedAttachmentStore
from azure.identity import AzureCliCredential
# ChatKit imports
from chatkit.actions import Action
from chatkit.server import ChatKitServer
from chatkit.store import StoreItemType, default_generate_id
from chatkit.types import (
ThreadItem,
ThreadItemDoneEvent,
ThreadMetadata,
ThreadStreamEvent,
UserMessageItem,
WidgetItem,
)
from chatkit.widgets import WidgetRoot
from fastapi import FastAPI, File, Request, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse, Response, StreamingResponse
from pydantic import Field
from store import SQLiteStore
from weather_widget import (
WeatherData,
city_selector_copy_text,
render_city_selector_widget,
render_weather_widget,
weather_widget_copy_text,
)
# ============================================================================
# Configuration Constants
@@ -56,37 +88,6 @@ logging.basicConfig(
)
logger = logging.getLogger(__name__)
# Agent Framework imports
from agent_framework import AgentRunResponseUpdate, ChatAgent, ChatMessage, FunctionResultContent, Role
from agent_framework.azure import AzureOpenAIChatClient
# Agent Framework ChatKit integration
from agent_framework_chatkit import ThreadItemConverter, stream_agent_response
# Local imports
from attachment_store import FileBasedAttachmentStore
# ChatKit imports
from chatkit.actions import Action
from chatkit.server import ChatKitServer
from chatkit.store import StoreItemType, default_generate_id
from chatkit.types import (
ThreadItemDoneEvent,
ThreadMetadata,
ThreadStreamEvent,
UserMessageItem,
WidgetItem,
)
from chatkit.widgets import WidgetRoot
from store import SQLiteStore
from weather_widget import (
WeatherData,
city_selector_copy_text,
render_city_selector_widget,
render_weather_widget,
weather_widget_copy_text,
)
class WeatherResponse(str):
"""A string response that also carries WeatherData for widget creation."""
@@ -238,6 +239,81 @@ class WeatherChatKitServer(ChatKitServer[dict[str, Any]]):
"""
return await attachment_store.read_attachment_bytes(attachment_id)
async def _update_thread_title(
self, thread: ThreadMetadata, thread_items: list[ThreadItem], context: dict[str, Any]
) -> None:
"""Update thread title using LLM to generate a concise summary.
Args:
thread: The thread metadata to update.
thread_items: All items in the thread.
context: The context dictionary.
"""
logger.info(f"Attempting to update thread title for thread: {thread.id}")
if not thread_items:
logger.debug("No thread items available for title generation")
return
# Collect user messages to understand the conversation topic
user_messages: list[str] = []
for item in thread_items:
if isinstance(item, UserMessageItem) and item.content:
for content_part in item.content:
if hasattr(content_part, "text") and isinstance(content_part.text, str):
user_messages.append(content_part.text)
break
if not user_messages:
logger.debug("No user messages found for title generation")
return
logger.debug(f"Found {len(user_messages)} user message(s) for title generation")
try:
# Use the agent's chat client to generate a concise title
# Combine first few messages to capture the conversation topic
conversation_context = "\n".join(user_messages[:3])
title_prompt = [
ChatMessage(
role=Role.USER,
text=(
f"Generate a very short, concise title (max 40 characters) for a conversation "
f"that starts with:\n\n{conversation_context}\n\n"
"Respond with ONLY the title, nothing else."
),
)
]
# Use the chat client directly for a quick, lightweight call
response = await self.weather_agent.chat_client.get_response(
messages=title_prompt,
temperature=0.3,
max_tokens=20,
)
if response.messages and response.messages[-1].text:
title = response.messages[-1].text.strip().strip('"').strip("'")
# Ensure it's not too long
if len(title) > 50:
title = title[:47] + "..."
thread.title = title
await self.store.save_thread(thread, context)
logger.info(f"Updated thread {thread.id} title to: {title}")
except Exception as e:
logger.warning(f"Failed to generate thread title, using fallback: {e}")
# Fallback to simple truncation
first_message: str = user_messages[0]
title: str = first_message[:50].strip()
if len(first_message) > 50:
title += "..."
thread.title = title
await self.store.save_thread(thread, context)
logger.info(f"Updated thread {thread.id} title to (fallback): {title}")
async def respond(
self,
thread: ThreadMetadata,
@@ -263,8 +339,19 @@ class WeatherChatKitServer(ChatKitServer[dict[str, Any]]):
weather_data: WeatherData | None = None
show_city_selector = False
# Convert ChatKit user message to Agent Framework ChatMessage using ThreadItemConverter
agent_messages = await self.converter.to_agent_input(input_user_message)
# Load full thread history from the store
thread_items_page = await self.store.load_thread_items(
thread_id=thread.id,
after=None,
limit=1000,
order="asc",
context=context,
)
thread_items = thread_items_page.data
# Convert ALL thread items to Agent Framework ChatMessages using ThreadItemConverter
# This ensures the agent has the full conversation context
agent_messages = await self.converter.to_agent_input(thread_items)
if not agent_messages:
logger.warning("No messages after conversion")
@@ -330,6 +417,10 @@ class WeatherChatKitServer(ChatKitServer[dict[str, Any]]):
yield widget_event
logger.debug("City selector widget streamed successfully")
# Update thread title based on first user message if not already set
if not thread.title or thread.title == "New thread":
await self._update_thread_title(thread, thread_items, context)
logger.info(f"Completed processing message for thread: {thread.id}")
except Exception as e:
@@ -8,7 +8,7 @@ cloud storage like S3, Azure Blob Storage, or Google Cloud Storage.
"""
from pathlib import Path
from typing import Any, TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from chatkit.store import AttachmentStore
from chatkit.types import Attachment, AttachmentCreateParams, FileAttachment, ImageAttachment
@@ -51,7 +51,7 @@ class FileBasedAttachmentStore(AttachmentStore[dict[str, Any]]):
self.uploads_dir = Path(uploads_dir)
self.base_url = base_url.rstrip("/")
self.data_store = data_store
# Create uploads directory if it doesn't exist
self.uploads_dir.mkdir(parents=True, exist_ok=True)
@@ -65,9 +65,7 @@ class FileBasedAttachmentStore(AttachmentStore[dict[str, Any]]):
if file_path.exists():
file_path.unlink()
async def create_attachment(
self, input: AttachmentCreateParams, context: dict[str, Any]
) -> Attachment:
async def create_attachment(self, input: AttachmentCreateParams, context: dict[str, Any]) -> Attachment:
"""Create an attachment with upload URL for two-phase upload.
This creates the attachment metadata and returns upload URLs that
@@ -75,7 +73,7 @@ class FileBasedAttachmentStore(AttachmentStore[dict[str, Any]]):
"""
# Generate unique ID for this attachment
attachment_id = self.generate_attachment_id(input.mime_type, context)
# Generate upload URL that points to our FastAPI upload endpoint
upload_url = f"{self.base_url}/upload/{attachment_id}"
@@ -83,7 +81,7 @@ class FileBasedAttachmentStore(AttachmentStore[dict[str, Any]]):
if input.mime_type.startswith("image/"):
# For images, also provide a preview URL
preview_url = f"{self.base_url}/preview/{attachment_id}"
attachment = ImageAttachment(
id=attachment_id,
type="image",
@@ -117,5 +115,5 @@ class FileBasedAttachmentStore(AttachmentStore[dict[str, Any]]):
file_path = self.get_file_path(attachment_id)
if not file_path.exists():
raise FileNotFoundError(f"Attachment {attachment_id} not found on disk")
return file_path.read_bytes()
@@ -10,7 +10,7 @@ import sqlite3
import uuid
from typing import Any
from chatkit.store import Store, NotFoundError
from chatkit.store import NotFoundError, Store
from chatkit.types import (
Attachment,
Page,
@@ -22,16 +22,19 @@ from pydantic import BaseModel
class ThreadData(BaseModel):
"""Model for serializing thread data to SQLite."""
thread: ThreadMetadata
class ItemData(BaseModel):
"""Model for serializing thread item data to SQLite."""
item: ThreadItem
class AttachmentData(BaseModel):
"""Model for serializing attachment data to SQLite."""
attachment: Attachment
@@ -185,19 +188,13 @@ class SQLiteStore(Store[dict[str, Any]]):
params.append(limit + 1)
items_cursor = conn.execute(query, params).fetchall()
items = [
ItemData.model_validate_json(row[0]).item for row in items_cursor
]
items = [ItemData.model_validate_json(row[0]).item for row in items_cursor]
has_more = len(items) > limit
if has_more:
items = items[:limit]
return Page[ThreadItem](
data=items,
has_more=has_more,
after=items[-1].id if items else None
)
return Page[ThreadItem](data=items, has_more=has_more, after=items[-1].id if items else None)
async def save_attachment(self, attachment: Attachment, context: dict[str, Any]) -> None:
user_id = context.get("user_id", "demo_user")
@@ -270,23 +267,15 @@ class SQLiteStore(Store[dict[str, Any]]):
params.append(limit + 1)
threads_cursor = conn.execute(query, params).fetchall()
threads = [
ThreadData.model_validate_json(row[0]).thread for row in threads_cursor
]
threads = [ThreadData.model_validate_json(row[0]).thread for row in threads_cursor]
has_more = len(threads) > limit
if has_more:
threads = threads[:limit]
return Page[ThreadMetadata](
data=threads,
has_more=has_more,
after=threads[-1].id if threads else None
)
return Page[ThreadMetadata](data=threads, has_more=has_more, after=threads[-1].id if threads else None)
async def add_thread_item(
self, thread_id: str, item: ThreadItem, context: dict[str, Any]
) -> None:
async def add_thread_item(self, thread_id: str, item: ThreadItem, context: dict[str, Any]) -> None:
user_id = context.get("user_id", "demo_user")
with self._create_connection() as conn:
@@ -348,9 +337,7 @@ class SQLiteStore(Store[dict[str, Any]]):
)
conn.commit()
async def delete_thread_item(
self, thread_id: str, item_id: str, context: dict[str, Any]
) -> None:
async def delete_thread_item(self, thread_id: str, item_id: str, context: dict[str, Any]) -> None:
user_id = context.get("user_id", "demo_user")
with self._create_connection() as conn:
@@ -29,7 +29,6 @@ POPULAR_CITIES = [
CITY_VALUE_TO_NAME = {city["value"]: city["label"] for city in POPULAR_CITIES}
def _sun_svg() -> str:
"""Generate SVG for sunny weather icon."""
color = WEATHER_ICON_COLOR