mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: DevUI fixes : Add multimodal input support for workflows and refactor chat input (#2593)
* show app version in devui .NET: Python: Improved Versioning for DevUI Fixes #2059 * feat: Add multimodal input support for workflows and refactor chat input This PR adds support for multimodal content (images, files) in workflow inputs and refactors the chat input into a reusable component. ## Multimodal Workflow Support - Add `isChatMessageSchema()` to detect ChatMessage input schemas - Update `RunWorkflowButton` to use `ChatMessageInput` for ChatMessage workflows - Wrap multimodal content in OpenAI message format for backend processing - Add `_is_openai_multimodal_format()` to detect OpenAI ResponseInputParam - Update `_parse_workflow_input()` to route multimodal input through existing `_convert_input_to_chat_message()` converter ## Reusable ChatMessageInput Component - Extract chat input logic from agent-view into `ChatMessageInput` component - Support file upload, drag & drop, paste handling, and attachments - Add `useDragDrop` hook for parent-level drag handling with full-area drop zones - Refactor agent-view to use the new shared component ## Other Improvements - Add `isStreaming` prop to executor nodes for animation control - Clean up unused imports and state variables in agent-view - Add tests for multimodal workflow input handling Fixes workflow input not receiving images when using AgentExecutor nodes. * add self loop edge, fix #2470 * fix test
This commit is contained in:
committed by
GitHub
Unverified
parent
6835161f2d
commit
411ee7a60f
@@ -134,9 +134,12 @@ class ConversationStore(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_item(self, conversation_id: str, item_id: str) -> ConversationItem | None:
|
||||
async def get_item(self, conversation_id: str, item_id: str) -> ConversationItem | None:
|
||||
"""Get a specific conversation item by ID.
|
||||
|
||||
Supports checkpoint items - will load full checkpoint state from storage.
|
||||
For checkpoints, the full state is included in metadata.full_checkpoint.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
item_id: Item ID
|
||||
@@ -162,7 +165,7 @@ class ConversationStore(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def list_conversations_by_metadata(self, metadata_filter: dict[str, str]) -> list[Conversation]:
|
||||
async def list_conversations_by_metadata(self, metadata_filter: dict[str, str]) -> list[Conversation]:
|
||||
"""Filter conversations by metadata (e.g., agent_id).
|
||||
|
||||
Args:
|
||||
@@ -444,7 +447,15 @@ class InMemoryConversationStore(ConversationStore):
|
||||
# Get all checkpoints for this conversation
|
||||
checkpoints = await checkpoint_storage.list_checkpoints()
|
||||
for checkpoint in checkpoints:
|
||||
# Create a conversation item for each checkpoint
|
||||
# Create a conversation item for each checkpoint with summary metadata
|
||||
# Full checkpoint state is NOT included here (too large for list view)
|
||||
# Use get_item() to retrieve full checkpoint details
|
||||
# Calculate approximate size of checkpoint
|
||||
import json
|
||||
|
||||
checkpoint_json = json.dumps(checkpoint.to_dict())
|
||||
checkpoint_size = len(checkpoint_json.encode("utf-8"))
|
||||
|
||||
checkpoint_item = {
|
||||
"id": f"checkpoint_{checkpoint.checkpoint_id}",
|
||||
"type": "checkpoint",
|
||||
@@ -452,6 +463,15 @@ class InMemoryConversationStore(ConversationStore):
|
||||
"workflow_id": checkpoint.workflow_id,
|
||||
"timestamp": checkpoint.timestamp,
|
||||
"status": "completed",
|
||||
"metadata": {
|
||||
# Summary metrics for list view
|
||||
"iteration_count": checkpoint.iteration_count,
|
||||
"pending_hil_count": len(checkpoint.pending_request_info_events),
|
||||
"has_pending_hil": len(checkpoint.pending_request_info_events) > 0,
|
||||
"message_count": sum(len(msgs) for msgs in checkpoint.messages.values()),
|
||||
"size_bytes": checkpoint_size,
|
||||
"version": checkpoint.version,
|
||||
},
|
||||
}
|
||||
items.append(cast(ConversationItem, checkpoint_item))
|
||||
|
||||
@@ -472,24 +492,91 @@ class InMemoryConversationStore(ConversationStore):
|
||||
|
||||
return paginated_items, has_more
|
||||
|
||||
def get_item(self, conversation_id: str, item_id: str) -> ConversationItem | None:
|
||||
"""Get a specific conversation item by ID."""
|
||||
# Use the item index for O(1) lookup
|
||||
async def get_item(self, conversation_id: str, item_id: str) -> ConversationItem | None:
|
||||
"""Get a specific conversation item by ID.
|
||||
|
||||
Supports checkpoint items - will load full checkpoint state from storage.
|
||||
For checkpoints, the full state is included in metadata.full_checkpoint.
|
||||
"""
|
||||
# First check item index for messages, function calls, etc. (O(1) lookup)
|
||||
conv_items = self._item_index.get(conversation_id, {})
|
||||
return conv_items.get(item_id)
|
||||
item = conv_items.get(item_id)
|
||||
if item:
|
||||
return item
|
||||
|
||||
# If not found and ID is a checkpoint, load from checkpoint storage
|
||||
if item_id.startswith("checkpoint_"):
|
||||
checkpoint_id = item_id[len("checkpoint_") :] # Remove "checkpoint_" prefix
|
||||
conv_data = self._conversations.get(conversation_id)
|
||||
if not conv_data:
|
||||
return None
|
||||
|
||||
checkpoint_storage = conv_data.get("checkpoint_storage")
|
||||
if not checkpoint_storage:
|
||||
return None
|
||||
|
||||
# Load full checkpoint from storage
|
||||
checkpoint = await checkpoint_storage.load_checkpoint(checkpoint_id)
|
||||
if not checkpoint:
|
||||
return None
|
||||
|
||||
# Calculate size of checkpoint
|
||||
import json
|
||||
|
||||
checkpoint_json = json.dumps(checkpoint.to_dict())
|
||||
checkpoint_size = len(checkpoint_json.encode("utf-8"))
|
||||
|
||||
# Build checkpoint item with FULL state in metadata
|
||||
checkpoint_item = {
|
||||
"id": item_id,
|
||||
"type": "checkpoint",
|
||||
"checkpoint_id": checkpoint.checkpoint_id,
|
||||
"workflow_id": checkpoint.workflow_id,
|
||||
"timestamp": checkpoint.timestamp,
|
||||
"status": "completed",
|
||||
"metadata": {
|
||||
# Summary metrics (same as list view)
|
||||
"iteration_count": checkpoint.iteration_count,
|
||||
"pending_hil_count": len(checkpoint.pending_request_info_events),
|
||||
"has_pending_hil": len(checkpoint.pending_request_info_events) > 0,
|
||||
"message_count": sum(len(msgs) for msgs in checkpoint.messages.values()),
|
||||
"size_bytes": checkpoint_size,
|
||||
"version": checkpoint.version,
|
||||
# 🔥 FULL checkpoint state (lazy loaded)
|
||||
"full_checkpoint": checkpoint.to_dict(),
|
||||
},
|
||||
}
|
||||
|
||||
return cast(ConversationItem, checkpoint_item)
|
||||
|
||||
return None
|
||||
|
||||
def get_thread(self, conversation_id: str) -> AgentThread | None:
|
||||
"""Get AgentThread for execution - CRITICAL for agent.run_stream()."""
|
||||
conv_data = self._conversations.get(conversation_id)
|
||||
return conv_data["thread"] if conv_data else None
|
||||
|
||||
def list_conversations_by_metadata(self, metadata_filter: dict[str, str]) -> list[Conversation]:
|
||||
async def list_conversations_by_metadata(self, metadata_filter: dict[str, str]) -> list[Conversation]:
|
||||
"""Filter conversations by metadata (e.g., agent_id)."""
|
||||
results = []
|
||||
for conv_data in self._conversations.values():
|
||||
conv_meta = conv_data.get("metadata", {})
|
||||
conv_meta = conv_data.get("metadata", {}).copy() # Copy to avoid mutating original
|
||||
|
||||
# Check if all filter items match
|
||||
if all(conv_meta.get(k) == v for k, v in metadata_filter.items()):
|
||||
# Enrich workflow sessions with checkpoint summary
|
||||
if conv_meta.get("type") == "workflow_session":
|
||||
checkpoint_storage = conv_data.get("checkpoint_storage")
|
||||
if checkpoint_storage:
|
||||
checkpoints = await checkpoint_storage.list_checkpoints()
|
||||
latest = checkpoints[0] if checkpoints else None
|
||||
conv_meta["checkpoint_summary"] = {
|
||||
"count": len(checkpoints),
|
||||
"latest_iteration": latest.iteration_count if latest else 0,
|
||||
"has_pending_hil": len(latest.pending_request_info_events) > 0 if latest else False,
|
||||
"pending_hil_count": len(latest.pending_request_info_events) if latest else 0,
|
||||
}
|
||||
|
||||
results.append(
|
||||
Conversation(
|
||||
id=conv_data["id"],
|
||||
@@ -498,6 +585,10 @@ class InMemoryConversationStore(ConversationStore):
|
||||
metadata=conv_meta,
|
||||
)
|
||||
)
|
||||
|
||||
# Sort by created_at descending (most recent first)
|
||||
results.sort(key=lambda c: c.created_at, reverse=True)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
|
||||
@@ -722,6 +722,20 @@ class AgentFrameworkExecutor:
|
||||
return json.dumps(input_data)
|
||||
return str(input_data)
|
||||
|
||||
def _is_openai_multimodal_format(self, input_data: Any) -> bool:
|
||||
"""Check if input is OpenAI ResponseInputParam format (list with message items).
|
||||
|
||||
Args:
|
||||
input_data: Input data to check
|
||||
|
||||
Returns:
|
||||
True if input is OpenAI multimodal format
|
||||
"""
|
||||
if not isinstance(input_data, list) or not input_data:
|
||||
return False
|
||||
first_item = input_data[0]
|
||||
return isinstance(first_item, dict) and first_item.get("type") == "message"
|
||||
|
||||
async def _parse_workflow_input(self, workflow: Any, raw_input: Any) -> Any:
|
||||
"""Parse input based on workflow's expected input type.
|
||||
|
||||
@@ -733,9 +747,26 @@ class AgentFrameworkExecutor:
|
||||
Parsed input appropriate for the workflow
|
||||
"""
|
||||
try:
|
||||
# Handle structured input
|
||||
# Handle JSON string input (from frontend api.ts JSON.stringify)
|
||||
if isinstance(raw_input, str):
|
||||
try:
|
||||
parsed = json.loads(raw_input)
|
||||
raw_input = parsed
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# Plain text string, continue with string handling
|
||||
pass
|
||||
|
||||
# Check for OpenAI multimodal format (list with type: "message")
|
||||
# This handles ChatMessage inputs with images, files, etc.
|
||||
if self._is_openai_multimodal_format(raw_input):
|
||||
logger.debug("Detected OpenAI multimodal format, converting to ChatMessage")
|
||||
return self._convert_input_to_chat_message(raw_input)
|
||||
|
||||
# Handle structured input (dict)
|
||||
if isinstance(raw_input, dict):
|
||||
return self._parse_structured_workflow_input(workflow, raw_input)
|
||||
|
||||
# Handle string input
|
||||
return self._parse_raw_workflow_input(workflow, str(raw_input))
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -885,6 +885,10 @@ class MessageMapper:
|
||||
context[f"exec_item_{executor_id}"] = item_id
|
||||
context["output_index"] = context.get("output_index", -1) + 1
|
||||
|
||||
# Track current executor for routing Magentic agent events
|
||||
# This allows MagenticAgentDeltaEvent to route to the executor's item
|
||||
context["current_executor_id"] = executor_id
|
||||
|
||||
# Create ExecutorActionItem with proper type
|
||||
executor_item = ExecutorActionItem(
|
||||
type="executor_action",
|
||||
@@ -908,6 +912,10 @@ class MessageMapper:
|
||||
executor_id = getattr(event, "executor_id", "unknown")
|
||||
item_id = context.get(f"exec_item_{executor_id}", f"exec_{executor_id}_unknown")
|
||||
|
||||
# Clear current executor tracking when executor completes
|
||||
if context.get("current_executor_id") == executor_id:
|
||||
context.pop("current_executor_id", None)
|
||||
|
||||
# Create ExecutorActionItem with completed status
|
||||
# ExecutorCompletedEvent uses 'data' field, not 'result'
|
||||
executor_item = ExecutorActionItem(
|
||||
@@ -1059,6 +1067,30 @@ class MessageMapper:
|
||||
text = getattr(event, "text", None)
|
||||
|
||||
if text:
|
||||
# Check if we're inside an executor - route to executor's item
|
||||
# This prevents duplicate timeline entries (executor + inner agent)
|
||||
current_executor_id = context.get("current_executor_id")
|
||||
executor_item_key = f"exec_item_{current_executor_id}" if current_executor_id else None
|
||||
|
||||
if executor_item_key and executor_item_key in context:
|
||||
# Route delta to the executor's item instead of creating a new message item
|
||||
item_id = context[executor_item_key]
|
||||
|
||||
# Emit text delta event routed to the executor's item
|
||||
return [
|
||||
ResponseTextDeltaEvent(
|
||||
type="response.output_text.delta",
|
||||
output_index=context.get("output_index", 0),
|
||||
content_index=0,
|
||||
item_id=item_id,
|
||||
delta=text,
|
||||
logprobs=[],
|
||||
sequence_number=self._next_sequence(context),
|
||||
)
|
||||
]
|
||||
|
||||
# Fallback: No executor context - create separate message item (original behavior)
|
||||
# This handles cases where MagenticAgentDeltaEvent is emitted outside an executor
|
||||
events = []
|
||||
|
||||
# Track Magentic agent messages separately from regular messages
|
||||
@@ -1181,7 +1213,21 @@ class MessageMapper:
|
||||
agent_id = getattr(event, "agent_id", "unknown_agent")
|
||||
message = getattr(event, "message", None)
|
||||
|
||||
# Track Magentic agent messages
|
||||
# Check if we're inside an executor - if so, deltas were already routed there
|
||||
# We don't need to emit a separate message completion event
|
||||
current_executor_id = context.get("current_executor_id")
|
||||
executor_item_key = f"exec_item_{current_executor_id}" if current_executor_id else None
|
||||
|
||||
if executor_item_key and executor_item_key in context:
|
||||
# Deltas were routed to executor item - no separate message item to complete
|
||||
# The executor's output_item.done will mark completion
|
||||
logger.debug(
|
||||
f"MagenticAgentMessageEvent from {agent_id} - "
|
||||
f"deltas routed to executor {current_executor_id}, skipping"
|
||||
)
|
||||
return []
|
||||
|
||||
# Fallback: Handle case where we created a separate message item (no executor context)
|
||||
magentic_key = f"magentic_message_{agent_id}"
|
||||
|
||||
# Check if we were streaming for this agent
|
||||
|
||||
@@ -2,14 +2,17 @@
|
||||
|
||||
"""FastAPI server implementation."""
|
||||
|
||||
import asyncio
|
||||
import importlib.metadata
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
@@ -26,6 +29,12 @@ from .models._discovery_models import Deployment, DeploymentConfig, DiscoveryRes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Get package version
|
||||
try:
|
||||
__version__ = importlib.metadata.version("agent-framework-devui")
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
__version__ = "0.0.0" # Fallback for development mode
|
||||
|
||||
|
||||
# No AuthMiddleware class needed - we'll use the decorator pattern instead
|
||||
|
||||
@@ -70,6 +79,7 @@ class DevServer:
|
||||
self.deployment_manager = DeploymentManager()
|
||||
self._app: FastAPI | None = None
|
||||
self._pending_entities: list[Any] | None = None
|
||||
self._running_tasks: dict[str, asyncio.Task[Any]] = {} # Track running response tasks for cancellation
|
||||
|
||||
def _is_dev_mode(self) -> bool:
|
||||
"""Check if running in developer mode.
|
||||
@@ -293,7 +303,7 @@ class DevServer:
|
||||
app = FastAPI(
|
||||
title="Agent Framework Server",
|
||||
description="OpenAI-compatible API server for Agent Framework and other AI frameworks",
|
||||
version="1.0.0",
|
||||
version=__version__,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
@@ -388,8 +398,6 @@ class DevServer:
|
||||
"""Get server metadata and configuration."""
|
||||
import os
|
||||
|
||||
from . import __version__
|
||||
|
||||
# Ensure executors are initialized to check capabilities
|
||||
openai_executor = await self._ensure_openai_executor()
|
||||
|
||||
@@ -731,13 +739,18 @@ class DevServer:
|
||||
|
||||
# Execute request
|
||||
if request.stream:
|
||||
# Generate response ID for tracking
|
||||
response_id = f"resp_{uuid.uuid4().hex[:8]}"
|
||||
logger.info(f"[CANCELLATION] Creating response {response_id} for entity {entity_id}")
|
||||
|
||||
return StreamingResponse(
|
||||
self._stream_execution(executor, request),
|
||||
self._stream_with_cancellation(executor, request, response_id),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"X-Response-ID": response_id, # Include ID for debugging/tracking
|
||||
},
|
||||
)
|
||||
return await executor.execute_sync(request)
|
||||
@@ -747,6 +760,30 @@ class DevServer:
|
||||
error = OpenAIError.create(error_msg)
|
||||
return JSONResponse(status_code=500, content=error.to_dict())
|
||||
|
||||
@app.post("/v1/responses/{response_id}/cancel")
|
||||
async def cancel_response(response_id: str) -> dict[str, Any]:
|
||||
"""Cancel a running response execution.
|
||||
|
||||
This endpoint allows explicit cancellation of a running stream.
|
||||
Note: Cancellation also happens automatically when the client disconnects.
|
||||
"""
|
||||
logger.info(f"[CANCELLATION] Cancel request received for {response_id}")
|
||||
|
||||
if task := self._running_tasks.get(response_id):
|
||||
if not task.done():
|
||||
logger.info(f"[CANCELLATION] Cancelling task for {response_id}")
|
||||
task.cancel()
|
||||
# Wait briefly for cancellation to propagate
|
||||
try: # noqa: SIM105
|
||||
await asyncio.wait_for(task, timeout=0.5)
|
||||
except (asyncio.CancelledError, asyncio.TimeoutError):
|
||||
pass
|
||||
return {"status": "cancelled", "response_id": response_id}
|
||||
logger.warning(f"[CANCELLATION] Task already completed for {response_id}")
|
||||
return {"status": "already_completed", "response_id": response_id}
|
||||
logger.warning(f"[CANCELLATION] No task found for {response_id}")
|
||||
return {"status": "not_found", "response_id": response_id}
|
||||
|
||||
# ========================================
|
||||
# OpenAI Conversations API (Standard)
|
||||
# ========================================
|
||||
@@ -862,7 +899,7 @@ class DevServer:
|
||||
filters["type"] = type
|
||||
|
||||
# Apply filters
|
||||
conversations = executor.conversation_store.list_conversations_by_metadata(filters)
|
||||
conversations = await executor.conversation_store.list_conversations_by_metadata(filters)
|
||||
|
||||
return {
|
||||
"object": "list",
|
||||
@@ -973,13 +1010,19 @@ class DevServer:
|
||||
|
||||
@app.get("/v1/conversations/{conversation_id}/items/{item_id}")
|
||||
async def retrieve_conversation_item(conversation_id: str, item_id: str) -> dict[str, Any]:
|
||||
"""Get specific conversation item - OpenAI standard."""
|
||||
"""Get specific conversation item - OpenAI standard.
|
||||
|
||||
Supports checkpoint items - returns full checkpoint state in metadata.full_checkpoint.
|
||||
"""
|
||||
try:
|
||||
executor = await self._ensure_executor()
|
||||
item = executor.conversation_store.get_item(conversation_id, item_id)
|
||||
item = await executor.conversation_store.get_item(conversation_id, item_id)
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Item not found")
|
||||
result: dict[str, Any] = item.model_dump()
|
||||
# Handle both Pydantic models and dicts
|
||||
result: dict[str, Any] = (
|
||||
item.model_dump() if hasattr(item, "model_dump") else cast(dict[str, Any], item)
|
||||
)
|
||||
return result
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -1139,6 +1182,100 @@ class DevServer:
|
||||
}
|
||||
yield f"data: {json.dumps(error_event)}\n\n"
|
||||
|
||||
async def _stream_with_cancellation(
|
||||
self, executor: AgentFrameworkExecutor, request: AgentFrameworkRequest, response_id: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream execution with automatic cancellation on client disconnect.
|
||||
|
||||
This wrapper adds cancellation support to the execution stream:
|
||||
1. Tracks the execution as an asyncio Task
|
||||
2. Detects client disconnection via GeneratorExit
|
||||
3. Cancels the task when client disconnects
|
||||
4. Propagates CancelledError through the execution chain
|
||||
|
||||
Args:
|
||||
executor: Agent Framework executor instance
|
||||
request: Request to execute
|
||||
response_id: Unique ID for this response/execution
|
||||
|
||||
Yields:
|
||||
SSE-formatted event strings from the original stream
|
||||
"""
|
||||
task = None
|
||||
|
||||
async def execution_wrapper() -> AsyncGenerator[str, None]:
|
||||
"""Inner wrapper to handle the actual execution."""
|
||||
try:
|
||||
logger.debug(f"[CANCELLATION] Starting execution for {response_id}")
|
||||
|
||||
async for chunk in self._stream_execution(executor, request):
|
||||
# Check if we're being cancelled
|
||||
current_task = asyncio.current_task()
|
||||
if current_task and current_task.cancelled():
|
||||
logger.info(f"[CANCELLATION] Detected cancellation, breaking stream for {response_id}")
|
||||
break
|
||||
yield chunk
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"[CANCELLATION] Execution cancelled via CancelledError for {response_id}")
|
||||
# Emit cancellation event to client (if still connected)
|
||||
cancelled_event = {
|
||||
"type": "response.cancelled",
|
||||
"response_id": response_id,
|
||||
"message": "Execution cancelled by user",
|
||||
}
|
||||
yield f"data: {json.dumps(cancelled_event)}\n\n"
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[CANCELLATION] Error in cancellable execution for {response_id}: {e}")
|
||||
raise
|
||||
|
||||
try:
|
||||
# Get or create the current task and track it
|
||||
task = asyncio.current_task()
|
||||
if task:
|
||||
self._running_tasks[response_id] = task
|
||||
logger.debug(f"[CANCELLATION] Tracking task {task.get_name()} for response {response_id}")
|
||||
else:
|
||||
logger.warning(f"[CANCELLATION] No current task found to track for {response_id}")
|
||||
|
||||
# Stream the execution
|
||||
async for chunk in execution_wrapper():
|
||||
yield chunk
|
||||
|
||||
logger.debug(f"[CANCELLATION] Stream completed normally for {response_id}")
|
||||
|
||||
except GeneratorExit:
|
||||
# Client disconnected - this is raised when the generator is closed
|
||||
logger.info(f"[CANCELLATION] Client disconnected, initiating cancellation for {response_id}")
|
||||
|
||||
if task and not task.done():
|
||||
logger.info(f"[CANCELLATION] Cancelling task for disconnected client {response_id}")
|
||||
task.cancel()
|
||||
# Give it a moment to cancel gracefully
|
||||
# Note: We should NOT use asyncio.shield here as it prevents cancellation
|
||||
try:
|
||||
await asyncio.wait_for(task, timeout=1.0)
|
||||
except (asyncio.CancelledError, asyncio.TimeoutError):
|
||||
logger.debug(f"[CANCELLATION] Task cancelled successfully for {response_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[CANCELLATION] Error during task cancellation for {response_id}: {e}")
|
||||
raise # Re-raise GeneratorExit to properly close the generator
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"[CANCELLATION] Stream cancelled for {response_id}")
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[CANCELLATION] Unexpected error in stream for {response_id}: {e}")
|
||||
raise
|
||||
|
||||
finally:
|
||||
# Clean up tracking
|
||||
if response_id in self._running_tasks:
|
||||
self._running_tasks.pop(response_id)
|
||||
logger.debug(f"[CANCELLATION] Cleaned up task tracking for {response_id}")
|
||||
|
||||
def _mount_ui(self, app: FastAPI) -> None:
|
||||
"""Mount the UI as static files."""
|
||||
from pathlib import Path
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user