Python: DevUI fixes : Add multimodal input support for workflows and refactor chat input (#2593)

* show app version in devui .NET: Python: Improved Versioning for DevUI
Fixes #2059

* feat: Add multimodal input support for workflows and refactor chat input

This PR adds support for multimodal content (images, files) in workflow
inputs and refactors the chat input into a reusable component.

## Multimodal Workflow Support
- Add `isChatMessageSchema()` to detect ChatMessage input schemas
- Update `RunWorkflowButton` to use `ChatMessageInput` for ChatMessage workflows
- Wrap multimodal content in OpenAI message format for backend processing
- Add `_is_openai_multimodal_format()` to detect OpenAI ResponseInputParam
- Update `_parse_workflow_input()` to route multimodal input through
  existing `_convert_input_to_chat_message()` converter

## Reusable ChatMessageInput Component
- Extract chat input logic from agent-view into `ChatMessageInput` component
- Support file upload, drag & drop, paste handling, and attachments
- Add `useDragDrop` hook for parent-level drag handling with full-area
  drop zones
- Refactor agent-view to use the new shared component

## Other Improvements
- Add `isStreaming` prop to executor nodes for animation control
- Clean up unused imports and state variables in agent-view
- Add tests for multimodal workflow input handling

Fixes workflow input not receiving images when using AgentExecutor nodes.

* add self loop edge, fix #2470

* fix test
This commit is contained in:
Victor Dibia
2025-12-03 12:15:51 -08:00
committed by GitHub
Unverified
parent 6835161f2d
commit 411ee7a60f
38 changed files with 3488 additions and 1493 deletions
@@ -134,9 +134,12 @@ class ConversationStore(ABC):
pass
@abstractmethod
def get_item(self, conversation_id: str, item_id: str) -> ConversationItem | None:
async def get_item(self, conversation_id: str, item_id: str) -> ConversationItem | None:
"""Get a specific conversation item by ID.
Supports checkpoint items - will load full checkpoint state from storage.
For checkpoints, the full state is included in metadata.full_checkpoint.
Args:
conversation_id: Conversation ID
item_id: Item ID
@@ -162,7 +165,7 @@ class ConversationStore(ABC):
pass
@abstractmethod
def list_conversations_by_metadata(self, metadata_filter: dict[str, str]) -> list[Conversation]:
async def list_conversations_by_metadata(self, metadata_filter: dict[str, str]) -> list[Conversation]:
"""Filter conversations by metadata (e.g., agent_id).
Args:
@@ -444,7 +447,15 @@ class InMemoryConversationStore(ConversationStore):
# Get all checkpoints for this conversation
checkpoints = await checkpoint_storage.list_checkpoints()
for checkpoint in checkpoints:
# Create a conversation item for each checkpoint
# Create a conversation item for each checkpoint with summary metadata
# Full checkpoint state is NOT included here (too large for list view)
# Use get_item() to retrieve full checkpoint details
# Calculate approximate size of checkpoint
import json
checkpoint_json = json.dumps(checkpoint.to_dict())
checkpoint_size = len(checkpoint_json.encode("utf-8"))
checkpoint_item = {
"id": f"checkpoint_{checkpoint.checkpoint_id}",
"type": "checkpoint",
@@ -452,6 +463,15 @@ class InMemoryConversationStore(ConversationStore):
"workflow_id": checkpoint.workflow_id,
"timestamp": checkpoint.timestamp,
"status": "completed",
"metadata": {
# Summary metrics for list view
"iteration_count": checkpoint.iteration_count,
"pending_hil_count": len(checkpoint.pending_request_info_events),
"has_pending_hil": len(checkpoint.pending_request_info_events) > 0,
"message_count": sum(len(msgs) for msgs in checkpoint.messages.values()),
"size_bytes": checkpoint_size,
"version": checkpoint.version,
},
}
items.append(cast(ConversationItem, checkpoint_item))
@@ -472,24 +492,91 @@ class InMemoryConversationStore(ConversationStore):
return paginated_items, has_more
def get_item(self, conversation_id: str, item_id: str) -> ConversationItem | None:
"""Get a specific conversation item by ID."""
# Use the item index for O(1) lookup
async def get_item(self, conversation_id: str, item_id: str) -> ConversationItem | None:
"""Get a specific conversation item by ID.
Supports checkpoint items - will load full checkpoint state from storage.
For checkpoints, the full state is included in metadata.full_checkpoint.
"""
# First check item index for messages, function calls, etc. (O(1) lookup)
conv_items = self._item_index.get(conversation_id, {})
return conv_items.get(item_id)
item = conv_items.get(item_id)
if item:
return item
# If not found and ID is a checkpoint, load from checkpoint storage
if item_id.startswith("checkpoint_"):
checkpoint_id = item_id[len("checkpoint_") :] # Remove "checkpoint_" prefix
conv_data = self._conversations.get(conversation_id)
if not conv_data:
return None
checkpoint_storage = conv_data.get("checkpoint_storage")
if not checkpoint_storage:
return None
# Load full checkpoint from storage
checkpoint = await checkpoint_storage.load_checkpoint(checkpoint_id)
if not checkpoint:
return None
# Calculate size of checkpoint
import json
checkpoint_json = json.dumps(checkpoint.to_dict())
checkpoint_size = len(checkpoint_json.encode("utf-8"))
# Build checkpoint item with FULL state in metadata
checkpoint_item = {
"id": item_id,
"type": "checkpoint",
"checkpoint_id": checkpoint.checkpoint_id,
"workflow_id": checkpoint.workflow_id,
"timestamp": checkpoint.timestamp,
"status": "completed",
"metadata": {
# Summary metrics (same as list view)
"iteration_count": checkpoint.iteration_count,
"pending_hil_count": len(checkpoint.pending_request_info_events),
"has_pending_hil": len(checkpoint.pending_request_info_events) > 0,
"message_count": sum(len(msgs) for msgs in checkpoint.messages.values()),
"size_bytes": checkpoint_size,
"version": checkpoint.version,
# 🔥 FULL checkpoint state (lazy loaded)
"full_checkpoint": checkpoint.to_dict(),
},
}
return cast(ConversationItem, checkpoint_item)
return None
def get_thread(self, conversation_id: str) -> AgentThread | None:
"""Get AgentThread for execution - CRITICAL for agent.run_stream()."""
conv_data = self._conversations.get(conversation_id)
return conv_data["thread"] if conv_data else None
def list_conversations_by_metadata(self, metadata_filter: dict[str, str]) -> list[Conversation]:
async def list_conversations_by_metadata(self, metadata_filter: dict[str, str]) -> list[Conversation]:
"""Filter conversations by metadata (e.g., agent_id)."""
results = []
for conv_data in self._conversations.values():
conv_meta = conv_data.get("metadata", {})
conv_meta = conv_data.get("metadata", {}).copy() # Copy to avoid mutating original
# Check if all filter items match
if all(conv_meta.get(k) == v for k, v in metadata_filter.items()):
# Enrich workflow sessions with checkpoint summary
if conv_meta.get("type") == "workflow_session":
checkpoint_storage = conv_data.get("checkpoint_storage")
if checkpoint_storage:
checkpoints = await checkpoint_storage.list_checkpoints()
latest = checkpoints[0] if checkpoints else None
conv_meta["checkpoint_summary"] = {
"count": len(checkpoints),
"latest_iteration": latest.iteration_count if latest else 0,
"has_pending_hil": len(latest.pending_request_info_events) > 0 if latest else False,
"pending_hil_count": len(latest.pending_request_info_events) if latest else 0,
}
results.append(
Conversation(
id=conv_data["id"],
@@ -498,6 +585,10 @@ class InMemoryConversationStore(ConversationStore):
metadata=conv_meta,
)
)
# Sort by created_at descending (most recent first)
results.sort(key=lambda c: c.created_at, reverse=True)
return results
@@ -722,6 +722,20 @@ class AgentFrameworkExecutor:
return json.dumps(input_data)
return str(input_data)
def _is_openai_multimodal_format(self, input_data: Any) -> bool:
"""Check if input is OpenAI ResponseInputParam format (list with message items).
Args:
input_data: Input data to check
Returns:
True if input is OpenAI multimodal format
"""
if not isinstance(input_data, list) or not input_data:
return False
first_item = input_data[0]
return isinstance(first_item, dict) and first_item.get("type") == "message"
async def _parse_workflow_input(self, workflow: Any, raw_input: Any) -> Any:
"""Parse input based on workflow's expected input type.
@@ -733,9 +747,26 @@ class AgentFrameworkExecutor:
Parsed input appropriate for the workflow
"""
try:
# Handle structured input
# Handle JSON string input (from frontend api.ts JSON.stringify)
if isinstance(raw_input, str):
try:
parsed = json.loads(raw_input)
raw_input = parsed
except (json.JSONDecodeError, TypeError):
# Plain text string, continue with string handling
pass
# Check for OpenAI multimodal format (list with type: "message")
# This handles ChatMessage inputs with images, files, etc.
if self._is_openai_multimodal_format(raw_input):
logger.debug("Detected OpenAI multimodal format, converting to ChatMessage")
return self._convert_input_to_chat_message(raw_input)
# Handle structured input (dict)
if isinstance(raw_input, dict):
return self._parse_structured_workflow_input(workflow, raw_input)
# Handle string input
return self._parse_raw_workflow_input(workflow, str(raw_input))
except Exception as e:
@@ -885,6 +885,10 @@ class MessageMapper:
context[f"exec_item_{executor_id}"] = item_id
context["output_index"] = context.get("output_index", -1) + 1
# Track current executor for routing Magentic agent events
# This allows MagenticAgentDeltaEvent to route to the executor's item
context["current_executor_id"] = executor_id
# Create ExecutorActionItem with proper type
executor_item = ExecutorActionItem(
type="executor_action",
@@ -908,6 +912,10 @@ class MessageMapper:
executor_id = getattr(event, "executor_id", "unknown")
item_id = context.get(f"exec_item_{executor_id}", f"exec_{executor_id}_unknown")
# Clear current executor tracking when executor completes
if context.get("current_executor_id") == executor_id:
context.pop("current_executor_id", None)
# Create ExecutorActionItem with completed status
# ExecutorCompletedEvent uses 'data' field, not 'result'
executor_item = ExecutorActionItem(
@@ -1059,6 +1067,30 @@ class MessageMapper:
text = getattr(event, "text", None)
if text:
# Check if we're inside an executor - route to executor's item
# This prevents duplicate timeline entries (executor + inner agent)
current_executor_id = context.get("current_executor_id")
executor_item_key = f"exec_item_{current_executor_id}" if current_executor_id else None
if executor_item_key and executor_item_key in context:
# Route delta to the executor's item instead of creating a new message item
item_id = context[executor_item_key]
# Emit text delta event routed to the executor's item
return [
ResponseTextDeltaEvent(
type="response.output_text.delta",
output_index=context.get("output_index", 0),
content_index=0,
item_id=item_id,
delta=text,
logprobs=[],
sequence_number=self._next_sequence(context),
)
]
# Fallback: No executor context - create separate message item (original behavior)
# This handles cases where MagenticAgentDeltaEvent is emitted outside an executor
events = []
# Track Magentic agent messages separately from regular messages
@@ -1181,7 +1213,21 @@ class MessageMapper:
agent_id = getattr(event, "agent_id", "unknown_agent")
message = getattr(event, "message", None)
# Track Magentic agent messages
# Check if we're inside an executor - if so, deltas were already routed there
# We don't need to emit a separate message completion event
current_executor_id = context.get("current_executor_id")
executor_item_key = f"exec_item_{current_executor_id}" if current_executor_id else None
if executor_item_key and executor_item_key in context:
# Deltas were routed to executor item - no separate message item to complete
# The executor's output_item.done will mark completion
logger.debug(
f"MagenticAgentMessageEvent from {agent_id} - "
f"deltas routed to executor {current_executor_id}, skipping"
)
return []
# Fallback: Handle case where we created a separate message item (no executor context)
magentic_key = f"magentic_message_{agent_id}"
# Check if we were streaming for this agent
@@ -2,14 +2,17 @@
"""FastAPI server implementation."""
import asyncio
import importlib.metadata
import inspect
import json
import logging
import os
import secrets
import uuid
from collections.abc import AsyncGenerator, Awaitable, Callable
from contextlib import asynccontextmanager
from typing import Any
from typing import Any, cast
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
@@ -26,6 +29,12 @@ from .models._discovery_models import Deployment, DeploymentConfig, DiscoveryRes
logger = logging.getLogger(__name__)
# Get package version
try:
__version__ = importlib.metadata.version("agent-framework-devui")
except importlib.metadata.PackageNotFoundError:
__version__ = "0.0.0" # Fallback for development mode
# No AuthMiddleware class needed - we'll use the decorator pattern instead
@@ -70,6 +79,7 @@ class DevServer:
self.deployment_manager = DeploymentManager()
self._app: FastAPI | None = None
self._pending_entities: list[Any] | None = None
self._running_tasks: dict[str, asyncio.Task[Any]] = {} # Track running response tasks for cancellation
def _is_dev_mode(self) -> bool:
"""Check if running in developer mode.
@@ -293,7 +303,7 @@ class DevServer:
app = FastAPI(
title="Agent Framework Server",
description="OpenAI-compatible API server for Agent Framework and other AI frameworks",
version="1.0.0",
version=__version__,
lifespan=lifespan,
)
@@ -388,8 +398,6 @@ class DevServer:
"""Get server metadata and configuration."""
import os
from . import __version__
# Ensure executors are initialized to check capabilities
openai_executor = await self._ensure_openai_executor()
@@ -731,13 +739,18 @@ class DevServer:
# Execute request
if request.stream:
# Generate response ID for tracking
response_id = f"resp_{uuid.uuid4().hex[:8]}"
logger.info(f"[CANCELLATION] Creating response {response_id} for entity {entity_id}")
return StreamingResponse(
self._stream_execution(executor, request),
self._stream_with_cancellation(executor, request, response_id),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Access-Control-Allow-Origin": "*",
"X-Response-ID": response_id, # Include ID for debugging/tracking
},
)
return await executor.execute_sync(request)
@@ -747,6 +760,30 @@ class DevServer:
error = OpenAIError.create(error_msg)
return JSONResponse(status_code=500, content=error.to_dict())
@app.post("/v1/responses/{response_id}/cancel")
async def cancel_response(response_id: str) -> dict[str, Any]:
"""Cancel a running response execution.
This endpoint allows explicit cancellation of a running stream.
Note: Cancellation also happens automatically when the client disconnects.
"""
logger.info(f"[CANCELLATION] Cancel request received for {response_id}")
if task := self._running_tasks.get(response_id):
if not task.done():
logger.info(f"[CANCELLATION] Cancelling task for {response_id}")
task.cancel()
# Wait briefly for cancellation to propagate
try: # noqa: SIM105
await asyncio.wait_for(task, timeout=0.5)
except (asyncio.CancelledError, asyncio.TimeoutError):
pass
return {"status": "cancelled", "response_id": response_id}
logger.warning(f"[CANCELLATION] Task already completed for {response_id}")
return {"status": "already_completed", "response_id": response_id}
logger.warning(f"[CANCELLATION] No task found for {response_id}")
return {"status": "not_found", "response_id": response_id}
# ========================================
# OpenAI Conversations API (Standard)
# ========================================
@@ -862,7 +899,7 @@ class DevServer:
filters["type"] = type
# Apply filters
conversations = executor.conversation_store.list_conversations_by_metadata(filters)
conversations = await executor.conversation_store.list_conversations_by_metadata(filters)
return {
"object": "list",
@@ -973,13 +1010,19 @@ class DevServer:
@app.get("/v1/conversations/{conversation_id}/items/{item_id}")
async def retrieve_conversation_item(conversation_id: str, item_id: str) -> dict[str, Any]:
"""Get specific conversation item - OpenAI standard."""
"""Get specific conversation item - OpenAI standard.
Supports checkpoint items - returns full checkpoint state in metadata.full_checkpoint.
"""
try:
executor = await self._ensure_executor()
item = executor.conversation_store.get_item(conversation_id, item_id)
item = await executor.conversation_store.get_item(conversation_id, item_id)
if not item:
raise HTTPException(status_code=404, detail="Item not found")
result: dict[str, Any] = item.model_dump()
# Handle both Pydantic models and dicts
result: dict[str, Any] = (
item.model_dump() if hasattr(item, "model_dump") else cast(dict[str, Any], item)
)
return result
except HTTPException:
raise
@@ -1139,6 +1182,100 @@ class DevServer:
}
yield f"data: {json.dumps(error_event)}\n\n"
async def _stream_with_cancellation(
self, executor: AgentFrameworkExecutor, request: AgentFrameworkRequest, response_id: str
) -> AsyncGenerator[str, None]:
"""Stream execution with automatic cancellation on client disconnect.
This wrapper adds cancellation support to the execution stream:
1. Tracks the execution as an asyncio Task
2. Detects client disconnection via GeneratorExit
3. Cancels the task when client disconnects
4. Propagates CancelledError through the execution chain
Args:
executor: Agent Framework executor instance
request: Request to execute
response_id: Unique ID for this response/execution
Yields:
SSE-formatted event strings from the original stream
"""
task = None
async def execution_wrapper() -> AsyncGenerator[str, None]:
"""Inner wrapper to handle the actual execution."""
try:
logger.debug(f"[CANCELLATION] Starting execution for {response_id}")
async for chunk in self._stream_execution(executor, request):
# Check if we're being cancelled
current_task = asyncio.current_task()
if current_task and current_task.cancelled():
logger.info(f"[CANCELLATION] Detected cancellation, breaking stream for {response_id}")
break
yield chunk
except asyncio.CancelledError:
logger.info(f"[CANCELLATION] Execution cancelled via CancelledError for {response_id}")
# Emit cancellation event to client (if still connected)
cancelled_event = {
"type": "response.cancelled",
"response_id": response_id,
"message": "Execution cancelled by user",
}
yield f"data: {json.dumps(cancelled_event)}\n\n"
raise
except Exception as e:
logger.error(f"[CANCELLATION] Error in cancellable execution for {response_id}: {e}")
raise
try:
# Get or create the current task and track it
task = asyncio.current_task()
if task:
self._running_tasks[response_id] = task
logger.debug(f"[CANCELLATION] Tracking task {task.get_name()} for response {response_id}")
else:
logger.warning(f"[CANCELLATION] No current task found to track for {response_id}")
# Stream the execution
async for chunk in execution_wrapper():
yield chunk
logger.debug(f"[CANCELLATION] Stream completed normally for {response_id}")
except GeneratorExit:
# Client disconnected - this is raised when the generator is closed
logger.info(f"[CANCELLATION] Client disconnected, initiating cancellation for {response_id}")
if task and not task.done():
logger.info(f"[CANCELLATION] Cancelling task for disconnected client {response_id}")
task.cancel()
# Give it a moment to cancel gracefully
# Note: We should NOT use asyncio.shield here as it prevents cancellation
try:
await asyncio.wait_for(task, timeout=1.0)
except (asyncio.CancelledError, asyncio.TimeoutError):
logger.debug(f"[CANCELLATION] Task cancelled successfully for {response_id}")
except Exception as e:
logger.warning(f"[CANCELLATION] Error during task cancellation for {response_id}: {e}")
raise # Re-raise GeneratorExit to properly close the generator
except asyncio.CancelledError:
logger.info(f"[CANCELLATION] Stream cancelled for {response_id}")
raise
except Exception as e:
logger.error(f"[CANCELLATION] Unexpected error in stream for {response_id}: {e}")
raise
finally:
# Clean up tracking
if response_id in self._running_tasks:
self._running_tasks.pop(response_id)
logger.debug(f"[CANCELLATION] Cleaned up task tracking for {response_id}")
def _mount_ui(self, app: FastAPI) -> None:
"""Mount the UI as static files."""
from pathlib import Path
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long