mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Fix Python pyright package scoping and typing remediation (#4426)
* Fix Python pyright package scoping and typing remediation Implements issue #4407 by removing the root pyright include, adding package-level pyright includes, and resolving pyright/mypy typing issues across Python packages. Also cleans unnecessary casts and applies line-level, rule-specific ignores where external libraries are too dynamic. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Reduce pyright cost in handoff cloning Simplify cloned_options construction in HandoffAgentExecutor to avoid expensive TypedDict narrowing/inference in _handoff.py, which was causing pyright to spend a long time in orchestrations. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fix types * Fix lint and type-check regressions Resolve current Python package check failures across lint, pyright, and mypy after recent code changes, including purview/declarative pyright issues and multiple ruff simplification findings. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fixed hooks * Stabilize package tests and test tasks Resolve cross-package non-integration test failures, simplify streaming type flow, harden locale/culture handling, and standardize package test poe tasks to exclude integration tests where applicable. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * lots of small fixes * Fix current Python test regressions Address current failing unit tests in azure-ai, bedrock, and azure-cosmos while keeping Bedrock parsing logic inline (no new static helper methods). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * small fixes * small fixes * removed pydantic from json * final updates * fix core * fix tests * fix obser --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
4a043c6c66
commit
55ddd841b7
@@ -73,7 +73,7 @@ def register_cleanup(entity: Any, *hooks: Callable[[], Any]) -> None:
|
||||
)
|
||||
|
||||
|
||||
def _get_registered_cleanup_hooks(entity: Any) -> list[Callable[[], Any]]:
|
||||
def _get_registered_cleanup_hooks(entity: Any) -> list[Callable[[], Any]]: # type: ignore[reportUnusedFunction]
|
||||
"""Get cleanup hooks registered for an entity (internal use).
|
||||
|
||||
Args:
|
||||
@@ -193,7 +193,7 @@ def serve(
|
||||
if entities:
|
||||
logger.info(f"Registering {len(entities)} in-memory entities")
|
||||
# Store entities for later registration during server startup
|
||||
server._pending_entities = entities
|
||||
server.set_pending_entities(entities)
|
||||
|
||||
app = server.get_app()
|
||||
|
||||
|
||||
@@ -11,12 +11,14 @@ from __future__ import annotations
|
||||
import time
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import MutableSequence
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from agent_framework import AgentSession, Message
|
||||
from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage, WorkflowCheckpoint
|
||||
from openai.types.conversations import Conversation, ConversationDeletedResource
|
||||
from openai.types.conversations.conversation_item import ConversationItem
|
||||
from openai.types.conversations.message import Content as OpenAIContent
|
||||
from openai.types.conversations.message import Message as OpenAIMessage
|
||||
from openai.types.conversations.text_content import TextContent
|
||||
from openai.types.responses import (
|
||||
@@ -300,12 +302,17 @@ class InMemoryConversationStore(ConversationStore):
|
||||
stored_messages: list[Message] = conv_data["messages"]
|
||||
|
||||
# Convert items to Messages and add to storage
|
||||
chat_messages = []
|
||||
chat_messages: list[Message] = []
|
||||
for item in items:
|
||||
# Simple conversion - assume text content for now
|
||||
role = item.get("role", "user")
|
||||
content = item.get("content", [])
|
||||
text = content[0].get("text", "") if content else ""
|
||||
first_content = cast(
|
||||
dict[str, Any],
|
||||
content[0] if content and isinstance(content, list) and isinstance(content[0], dict) else {},
|
||||
)
|
||||
text_obj = first_content.get("text", "")
|
||||
text = text_obj if isinstance(text_obj, str) else str(text_obj)
|
||||
|
||||
chat_msg = Message(role=role, text=text) # type: ignore[arg-type]
|
||||
chat_messages.append(chat_msg)
|
||||
@@ -318,23 +325,18 @@ class InMemoryConversationStore(ConversationStore):
|
||||
for msg in chat_messages:
|
||||
item_id = f"item_{uuid.uuid4().hex}"
|
||||
|
||||
# Extract role - handle both string and enum
|
||||
role_str = msg.role if hasattr(msg.role, "value") else str(msg.role)
|
||||
role = cast(MessageRole, role_str) # Safe: Agent Framework roles match OpenAI roles
|
||||
|
||||
# Convert Message contents to OpenAI TextContent format
|
||||
message_content = []
|
||||
message_content: MutableSequence[OpenAIContent] = []
|
||||
for content_item in msg.contents:
|
||||
if content_item.type == "text":
|
||||
# Extract text from TextContent object
|
||||
text_value = getattr(content_item, "text", "")
|
||||
message_content.append(TextContent(type="text", text=text_value))
|
||||
message_content.append(TextContent(type="text", text=content_item.text or ""))
|
||||
|
||||
# Create Message object (concrete type from ConversationItem union)
|
||||
message = OpenAIMessage(
|
||||
id=item_id,
|
||||
type="message", # Required discriminator for union
|
||||
role=role,
|
||||
role=cast(MessageRole, msg.role), # Safe: Agent Framework roles match OpenAI roles,
|
||||
content=message_content,
|
||||
status="completed", # Required field
|
||||
)
|
||||
@@ -383,8 +385,8 @@ class InMemoryConversationStore(ConversationStore):
|
||||
# A single Message may produce multiple ConversationItems
|
||||
# (e.g., a message with both text and a function call)
|
||||
message_contents: list[TextContent | ResponseInputImage | ResponseInputFile] = []
|
||||
function_calls = []
|
||||
function_results = []
|
||||
function_calls: list[ResponseFunctionToolCallItem] = []
|
||||
function_results: list[ResponseFunctionToolCallOutputItem] = []
|
||||
|
||||
for content in msg.contents:
|
||||
content_type = getattr(content, "type", None)
|
||||
@@ -628,7 +630,7 @@ class InMemoryConversationStore(ConversationStore):
|
||||
|
||||
async def list_conversations_by_metadata(self, metadata_filter: dict[str, str]) -> list[Conversation]:
|
||||
"""Filter conversations by metadata (e.g., agent_id)."""
|
||||
results = []
|
||||
results: list[Conversation] = []
|
||||
for conv_data in self._conversations.values():
|
||||
conv_meta = conv_data.get("metadata", {}).copy() # Copy to avoid mutating original
|
||||
|
||||
@@ -704,7 +706,8 @@ class CheckpointConversationManager:
|
||||
ValueError: If conversation not found
|
||||
"""
|
||||
# Access internal conversations dict (we know it's InMemoryConversationStore)
|
||||
conv_data = self._store._conversations.get(conversation_id)
|
||||
conversations_dict = cast(dict[str, dict[str, Any]], getattr(self._store, "_conversations", {}))
|
||||
conv_data = conversations_dict.get(conversation_id)
|
||||
if not conv_data:
|
||||
raise ValueError(f"Conversation {conversation_id} not found")
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from .models._discovery_models import Deployment, DeploymentConfig, DeploymentEvent
|
||||
@@ -175,7 +176,7 @@ class DeploymentManager:
|
||||
|
||||
# Check required resource providers are registered
|
||||
required_providers = ["Microsoft.App", "Microsoft.ContainerRegistry", "Microsoft.OperationalInsights"]
|
||||
unregistered_providers = []
|
||||
unregistered_providers: list[str] = []
|
||||
|
||||
# Get list of registered providers
|
||||
provider_check = await asyncio.create_subprocess_exec(
|
||||
@@ -195,7 +196,12 @@ class DeploymentManager:
|
||||
import json
|
||||
|
||||
try:
|
||||
registered = json.loads(stdout.decode())
|
||||
registered_raw = json.loads(stdout.decode())
|
||||
registered: list[str] = []
|
||||
if isinstance(registered_raw, list):
|
||||
for item_obj in cast(list[object], registered_raw):
|
||||
if isinstance(item_obj, str):
|
||||
registered.append(item_obj)
|
||||
for provider in required_providers:
|
||||
if provider not in registered:
|
||||
unregistered_providers.append(provider)
|
||||
@@ -385,7 +391,7 @@ CMD ["devui", "/app/entity", "--mode", "{config.ui_mode}", "--host", "0.0.0.0",
|
||||
)
|
||||
|
||||
# Stream output line by line
|
||||
output_lines = []
|
||||
output_lines: list[str] = []
|
||||
try:
|
||||
if not process.stdout:
|
||||
raise ValueError("Failed to capture process output")
|
||||
@@ -473,8 +479,11 @@ CMD ["devui", "/app/entity", "--mode", "{config.ui_mode}", "--host", "0.0.0.0",
|
||||
for url in urls:
|
||||
# Strip common trailing punctuation to ensure clean URL parsing
|
||||
url_clean = url.rstrip(".,;:!?'\")}]")
|
||||
host = urlparse(url_clean).hostname
|
||||
if host and (host == "azurecontainerapps.io" or host.endswith(".azurecontainerapps.io")):
|
||||
parsed_url = urlparse(str(url_clean))
|
||||
host = parsed_url.hostname
|
||||
if isinstance(host, str) and (
|
||||
host == "azurecontainerapps.io" or host.endswith(".azurecontainerapps.io")
|
||||
):
|
||||
await event_queue.put(
|
||||
DeploymentEvent(type="deploy.progress", message="Deployment URL generated!")
|
||||
)
|
||||
|
||||
@@ -11,7 +11,7 @@ import logging
|
||||
import sys
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
@@ -141,7 +141,7 @@ class EntityDiscovery:
|
||||
self._loaded_objects[entity_id] = entity_obj
|
||||
|
||||
# Check module-level registry for cleanup hooks
|
||||
from . import _get_registered_cleanup_hooks
|
||||
from . import _get_registered_cleanup_hooks # type: ignore[reportPrivateUsage]
|
||||
|
||||
registered_hooks = _get_registered_cleanup_hooks(entity_obj)
|
||||
if registered_hooks:
|
||||
@@ -299,7 +299,7 @@ class EntityDiscovery:
|
||||
self._loaded_objects[entity_id] = entity_object
|
||||
|
||||
# Check module-level registry for cleanup hooks
|
||||
from . import _get_registered_cleanup_hooks
|
||||
from . import _get_registered_cleanup_hooks # type: ignore[reportPrivateUsage]
|
||||
|
||||
registered_hooks = _get_registered_cleanup_hooks(entity_object)
|
||||
if registered_hooks:
|
||||
@@ -379,6 +379,8 @@ class EntityDiscovery:
|
||||
deployment_supported = True
|
||||
deployment_reason = "Ready for deployment (pending path verification)"
|
||||
|
||||
class_name = type(entity_object).__name__
|
||||
|
||||
# Create EntityInfo with Agent Framework specifics
|
||||
return EntityInfo(
|
||||
id=entity_id,
|
||||
@@ -400,9 +402,7 @@ class EntityDiscovery:
|
||||
deployment_reason=deployment_reason,
|
||||
metadata={
|
||||
"source": "agent_framework_object",
|
||||
"class_name": entity_object.__class__.__name__
|
||||
if hasattr(entity_object, "__class__")
|
||||
else str(type(entity_object)),
|
||||
"class_name": class_name,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -854,7 +854,7 @@ class EntityDiscovery:
|
||||
"module_path": module_path,
|
||||
"entity_type": obj_type,
|
||||
"source": source,
|
||||
"class_name": obj.__class__.__name__ if hasattr(obj, "__class__") else str(type(obj)),
|
||||
"class_name": type(obj).__name__,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -874,47 +874,63 @@ class EntityDiscovery:
|
||||
Returns:
|
||||
List of tool/executor names
|
||||
"""
|
||||
tools = []
|
||||
tools: list[str] = []
|
||||
|
||||
try:
|
||||
if obj_type == "agent":
|
||||
# For agents, check default_options.get("tools")
|
||||
chat_options = getattr(obj, "default_options", None)
|
||||
chat_options_tools = None
|
||||
if chat_options:
|
||||
chat_options_tools = chat_options.get("tools")
|
||||
chat_options_tools: object | None = None
|
||||
if isinstance(chat_options, dict):
|
||||
chat_options_dict = cast(dict[str, Any], chat_options)
|
||||
chat_options_tools = chat_options_dict.get("tools")
|
||||
|
||||
if chat_options_tools:
|
||||
for tool in chat_options_tools:
|
||||
if hasattr(tool, "__name__"):
|
||||
tools.append(tool.__name__)
|
||||
elif hasattr(tool, "name"):
|
||||
tools.append(tool.name)
|
||||
if chat_options_tools is not None:
|
||||
tool_iterable: list[object] = (
|
||||
cast(list[object], chat_options_tools)
|
||||
if isinstance(chat_options_tools, list)
|
||||
else [chat_options_tools]
|
||||
)
|
||||
for tool_obj in tool_iterable:
|
||||
tool_name = getattr(tool_obj, "__name__", None)
|
||||
if isinstance(tool_name, str):
|
||||
tools.append(tool_name)
|
||||
continue
|
||||
|
||||
named_tool = getattr(tool_obj, "name", None)
|
||||
if isinstance(named_tool, str):
|
||||
tools.append(named_tool)
|
||||
else:
|
||||
tools.append(str(tool))
|
||||
tools.append(str(tool_obj))
|
||||
else:
|
||||
# Fallback to direct tools attribute
|
||||
agent_tools = getattr(obj, "tools", None)
|
||||
if agent_tools:
|
||||
for tool in agent_tools:
|
||||
if hasattr(tool, "__name__"):
|
||||
tools.append(tool.__name__)
|
||||
elif hasattr(tool, "name"):
|
||||
tools.append(tool.name)
|
||||
if isinstance(agent_tools, list):
|
||||
for tool_obj in cast(list[object], agent_tools):
|
||||
tool_name = getattr(tool_obj, "__name__", None)
|
||||
if isinstance(tool_name, str):
|
||||
tools.append(tool_name)
|
||||
continue
|
||||
|
||||
named_tool = getattr(tool_obj, "name", None)
|
||||
if isinstance(named_tool, str):
|
||||
tools.append(named_tool)
|
||||
else:
|
||||
tools.append(str(tool))
|
||||
tools.append(str(tool_obj))
|
||||
|
||||
elif obj_type == "workflow":
|
||||
# For workflows, extract executor names
|
||||
if hasattr(obj, "get_executors_list"):
|
||||
executor_objects = obj.get_executors_list()
|
||||
tools = [getattr(ex, "id", str(ex)) for ex in executor_objects]
|
||||
if isinstance(executor_objects, list):
|
||||
for executor_obj in cast(list[object], executor_objects):
|
||||
tools.append(str(getattr(executor_obj, "id", executor_obj)))
|
||||
elif hasattr(obj, "executors"):
|
||||
executors = obj.executors
|
||||
if isinstance(executors, list):
|
||||
tools = [getattr(ex, "id", str(ex)) for ex in executors]
|
||||
for executor_obj in cast(list[object], executors):
|
||||
tools.append(str(getattr(executor_obj, "id", executor_obj)))
|
||||
elif isinstance(executors, dict):
|
||||
tools = list(executors.keys())
|
||||
executors_dict = cast(dict[str, Any], executors)
|
||||
for key_obj in executors_dict:
|
||||
tools.append(str(key_obj))
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error extracting tools from {obj_type} {type(obj)}: {e}")
|
||||
|
||||
@@ -7,7 +7,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from agent_framework import Content, SupportsAgentRun, Workflow
|
||||
|
||||
@@ -24,7 +24,8 @@ logger = logging.getLogger(__name__)
|
||||
def _get_event_type(event: Any) -> str | None:
|
||||
"""Safely get the type of an event, handling both objects and dicts."""
|
||||
if isinstance(event, dict):
|
||||
return event.get("type")
|
||||
event_type = cast(dict[str, Any], event).get("type")
|
||||
return event_type if isinstance(event_type, str) else None
|
||||
return getattr(event, "type", None)
|
||||
|
||||
|
||||
@@ -71,7 +72,8 @@ class AgentFrameworkExecutor:
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
|
||||
# Only set up if no provider exists yet
|
||||
if not hasattr(trace, "_TRACER_PROVIDER") or trace._TRACER_PROVIDER is None:
|
||||
current_provider = trace.get_tracer_provider()
|
||||
if current_provider.__class__.__name__ == "ProxyTracerProvider":
|
||||
resource = Resource.create({
|
||||
"service.name": "agent-framework-server",
|
||||
"service.version": "1.0.0",
|
||||
@@ -94,21 +96,29 @@ class AgentFrameworkExecutor:
|
||||
|
||||
# Configure if instrumentation is enabled (via enable_instrumentation() or env var)
|
||||
if OBSERVABILITY_SETTINGS.ENABLED:
|
||||
# Only configure providers if not already executed
|
||||
if not OBSERVABILITY_SETTINGS._executed_setup:
|
||||
# Call configure_otel_providers to set up exporters.
|
||||
# If OTEL_EXPORTER_OTLP_ENDPOINT is set, exporters will be created automatically.
|
||||
# If not set, no exporters are created (no console spam), but DevUI's
|
||||
# TracerProvider from _setup_instrumentation_provider() remains active for local capture.
|
||||
configure_otel_providers(enable_sensitive_data=OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED)
|
||||
logger.info("Enabled Agent Framework observability")
|
||||
else:
|
||||
logger.debug("Agent Framework observability already configured")
|
||||
# Call configure_otel_providers to set up exporters.
|
||||
# If OTEL_EXPORTER_OTLP_ENDPOINT is set, exporters will be created automatically.
|
||||
# If not set, no exporters are created (no console spam), but DevUI's
|
||||
# TracerProvider from _setup_instrumentation_provider() remains active for local capture.
|
||||
configure_otel_providers(enable_sensitive_data=OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED)
|
||||
logger.info("Enabled Agent Framework observability")
|
||||
else:
|
||||
logger.debug("Instrumentation not enabled, skipping observability setup")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to enable Agent Framework observability: {e}")
|
||||
|
||||
def _get_request_conversation_id(self, request: AgentFrameworkRequest) -> str | None:
|
||||
"""Read conversation id using public request fields."""
|
||||
if isinstance(request.conversation, str):
|
||||
return request.conversation
|
||||
|
||||
if isinstance(request.conversation, dict):
|
||||
conversation_id = request.conversation.get("id")
|
||||
if isinstance(conversation_id, str):
|
||||
return conversation_id
|
||||
|
||||
return None
|
||||
|
||||
async def _ensure_mcp_connections(self, agent: Any) -> None:
|
||||
"""Ensure MCP tool connections are healthy before agent execution.
|
||||
|
||||
@@ -317,7 +327,7 @@ class AgentFrameworkExecutor:
|
||||
|
||||
# Get session from conversation parameter (OpenAI standard!)
|
||||
session = None
|
||||
conversation_id = request._get_conversation_id()
|
||||
conversation_id = self._get_request_conversation_id(request)
|
||||
if conversation_id:
|
||||
session = self.conversation_store.get_session(conversation_id)
|
||||
if session:
|
||||
@@ -344,7 +354,7 @@ class AgentFrameworkExecutor:
|
||||
if session:
|
||||
run_kwargs["session"] = session
|
||||
|
||||
stream = agent.run(user_message, **run_kwargs)
|
||||
stream = cast(Any, agent.run(user_message, **run_kwargs))
|
||||
async for update in stream:
|
||||
for trace_event in trace_collector.get_pending_events():
|
||||
yield trace_event
|
||||
@@ -388,7 +398,7 @@ class AgentFrameworkExecutor:
|
||||
entity_id = request.get_entity_id() or "unknown"
|
||||
|
||||
# Get or create session conversation for checkpoint storage
|
||||
conversation_id = request._get_conversation_id()
|
||||
conversation_id = self._get_request_conversation_id(request)
|
||||
if not conversation_id:
|
||||
# Create default session if not provided
|
||||
import time
|
||||
@@ -463,11 +473,14 @@ class AgentFrameworkExecutor:
|
||||
logger.info(f"Resuming workflow with HIL responses for {len(hil_responses)} request(s)")
|
||||
|
||||
# Unwrap primitive responses if they're wrapped in {response: value} format
|
||||
unwrapped_responses = {}
|
||||
unwrapped_responses: dict[str, Any] = {}
|
||||
for request_id, response_value in hil_responses.items():
|
||||
if isinstance(response_value, dict) and "response" in response_value:
|
||||
response_value = response_value["response"]
|
||||
unwrapped_responses[request_id] = response_value
|
||||
normalized_response: Any = response_value
|
||||
if isinstance(response_value, dict):
|
||||
response_dict = cast(dict[str, Any], response_value)
|
||||
if "response" in response_dict:
|
||||
normalized_response = response_dict["response"]
|
||||
unwrapped_responses[request_id] = normalized_response
|
||||
|
||||
hil_responses = unwrapped_responses
|
||||
|
||||
@@ -568,7 +581,8 @@ class AgentFrameworkExecutor:
|
||||
|
||||
# Handle OpenAI ResponseInputParam (List[ResponseInputItemParam])
|
||||
if isinstance(input_data, list):
|
||||
return self._convert_openai_input_to_chat_message(input_data, Message, Role)
|
||||
input_items: Any = cast(Any, input_data)
|
||||
return self._convert_openai_input_to_chat_message(input_items, Message, Role)
|
||||
|
||||
# Fallback for other formats
|
||||
return self._extract_user_message_fallback(input_data)
|
||||
@@ -593,27 +607,31 @@ class AgentFrameworkExecutor:
|
||||
for item in input_items:
|
||||
# Handle dict format (from JSON)
|
||||
if isinstance(item, dict):
|
||||
item_type = item.get("type")
|
||||
item_dict = cast(dict[str, Any], item)
|
||||
item_type = item_dict.get("type")
|
||||
if item_type == "message":
|
||||
# Extract content from OpenAI message
|
||||
message_content = item.get("content", [])
|
||||
message_content = item_dict.get("content", [])
|
||||
|
||||
# Handle both string content and list content
|
||||
if isinstance(message_content, str):
|
||||
contents.append(Content.from_text(text=message_content))
|
||||
elif isinstance(message_content, list):
|
||||
for content_item in message_content:
|
||||
message_content_items: Any = cast(Any, message_content)
|
||||
for content_item in message_content_items:
|
||||
# Handle dict content items
|
||||
if isinstance(content_item, dict):
|
||||
content_type = content_item.get("type")
|
||||
content_dict = cast(dict[str, Any], content_item)
|
||||
content_type = content_dict.get("type")
|
||||
|
||||
if content_type == "input_text":
|
||||
text = content_item.get("text", "")
|
||||
contents.append(Content.from_text(text=text))
|
||||
text = content_dict.get("text", "")
|
||||
if isinstance(text, str):
|
||||
contents.append(Content.from_text(text=text))
|
||||
|
||||
elif content_type == "input_image":
|
||||
image_url = content_item.get("image_url", "")
|
||||
if image_url:
|
||||
image_url = content_dict.get("image_url", "")
|
||||
if isinstance(image_url, str) and image_url:
|
||||
# Extract media type from data URI if possible
|
||||
# Parse media type from data URL, fallback to image/png
|
||||
if image_url.startswith("data:"):
|
||||
@@ -631,9 +649,12 @@ class AgentFrameworkExecutor:
|
||||
|
||||
elif content_type == "input_file":
|
||||
# Handle file input
|
||||
file_data = content_item.get("file_data")
|
||||
file_url = content_item.get("file_url")
|
||||
filename = content_item.get("filename", "")
|
||||
file_data = content_dict.get("file_data")
|
||||
file_url = content_dict.get("file_url")
|
||||
filename = content_dict.get("filename", "")
|
||||
|
||||
if not isinstance(filename, str):
|
||||
filename = ""
|
||||
|
||||
# Determine media type from filename
|
||||
media_type = "application/octet-stream" # default
|
||||
@@ -656,8 +677,10 @@ class AgentFrameworkExecutor:
|
||||
|
||||
# Use file_data or file_url
|
||||
# Include filename in additional_properties for OpenAI/Azure file handling
|
||||
additional_props = {"filename": filename} if filename else None
|
||||
if file_data:
|
||||
additional_props: dict[str, Any] | None = (
|
||||
{"filename": filename} if filename else None
|
||||
)
|
||||
if isinstance(file_data, str) and file_data:
|
||||
# Assume file_data is base64, create data URI
|
||||
data_uri = f"data:{media_type};base64,{file_data}"
|
||||
contents.append(
|
||||
@@ -667,7 +690,7 @@ class AgentFrameworkExecutor:
|
||||
additional_properties=additional_props,
|
||||
)
|
||||
)
|
||||
elif file_url:
|
||||
elif isinstance(file_url, str) and file_url:
|
||||
contents.append(
|
||||
Content.from_uri(
|
||||
uri=file_url,
|
||||
@@ -679,15 +702,35 @@ class AgentFrameworkExecutor:
|
||||
elif content_type == "function_approval_response":
|
||||
# Handle function approval response (DevUI extension)
|
||||
try:
|
||||
request_id = content_item.get("request_id", "")
|
||||
approved = content_item.get("approved", False)
|
||||
function_call_data = content_item.get("function_call", {})
|
||||
request_id = content_dict.get("request_id", "")
|
||||
approved = content_dict.get("approved", False)
|
||||
function_call_data = content_dict.get("function_call", {})
|
||||
|
||||
if not isinstance(request_id, str):
|
||||
request_id = ""
|
||||
if not isinstance(approved, bool):
|
||||
approved = False
|
||||
if not isinstance(function_call_data, dict):
|
||||
function_call_data = {}
|
||||
|
||||
function_call_data_dict = cast(dict[str, Any], function_call_data)
|
||||
|
||||
function_call_id = function_call_data_dict.get("id", "")
|
||||
function_call_name = function_call_data_dict.get("name", "")
|
||||
function_call_args = function_call_data_dict.get("arguments", {})
|
||||
|
||||
if not isinstance(function_call_id, str):
|
||||
function_call_id = ""
|
||||
if not isinstance(function_call_name, str):
|
||||
function_call_name = ""
|
||||
if not isinstance(function_call_args, dict):
|
||||
function_call_args = {}
|
||||
|
||||
# Create FunctionCallContent from the function_call data
|
||||
function_call = Content.from_function_call(
|
||||
call_id=function_call_data.get("id", ""),
|
||||
name=function_call_data.get("name", ""),
|
||||
arguments=function_call_data.get("arguments", {}),
|
||||
call_id=function_call_id,
|
||||
name=function_call_name,
|
||||
arguments=cast(dict[str, Any], function_call_args),
|
||||
)
|
||||
|
||||
# Create FunctionApprovalResponseContent with correct signature
|
||||
@@ -739,12 +782,14 @@ class AgentFrameworkExecutor:
|
||||
if isinstance(input_data, str):
|
||||
return input_data
|
||||
if isinstance(input_data, dict):
|
||||
typed_input_data = cast(dict[str, Any], input_data)
|
||||
# Try common field names
|
||||
for field in ["message", "text", "input", "content", "query"]:
|
||||
if field in input_data:
|
||||
return str(input_data[field])
|
||||
if field in typed_input_data:
|
||||
value = typed_input_data[field]
|
||||
return value if isinstance(value, str) else str(value)
|
||||
# Fallback to JSON string
|
||||
return json.dumps(input_data)
|
||||
return json.dumps(typed_input_data)
|
||||
return str(input_data)
|
||||
|
||||
def _is_openai_multimodal_format(self, input_data: Any) -> bool:
|
||||
@@ -758,8 +803,12 @@ class AgentFrameworkExecutor:
|
||||
"""
|
||||
if not isinstance(input_data, list) or not input_data:
|
||||
return False
|
||||
first_item = input_data[0]
|
||||
return isinstance(first_item, dict) and first_item.get("type") == "message"
|
||||
input_data_items: Any = cast(Any, input_data)
|
||||
first_item = input_data_items[0]
|
||||
if not isinstance(first_item, dict):
|
||||
return False
|
||||
first_type = cast(dict[str, Any], first_item).get("type")
|
||||
return isinstance(first_type, str) and first_type == "message"
|
||||
|
||||
async def _parse_workflow_input(self, workflow: Any, raw_input: Any) -> Any:
|
||||
"""Parse input based on workflow's expected input type.
|
||||
@@ -775,7 +824,7 @@ class AgentFrameworkExecutor:
|
||||
# Handle JSON string input (from frontend api.ts JSON.stringify)
|
||||
if isinstance(raw_input, str):
|
||||
try:
|
||||
parsed = json.loads(raw_input)
|
||||
parsed: Any = json.loads(raw_input)
|
||||
raw_input = parsed
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# Plain text string, continue with string handling
|
||||
@@ -789,14 +838,14 @@ class AgentFrameworkExecutor:
|
||||
|
||||
# Handle structured input (dict)
|
||||
if isinstance(raw_input, dict):
|
||||
return self._parse_structured_workflow_input(workflow, raw_input)
|
||||
return self._parse_structured_workflow_input(workflow, cast(dict[str, Any], raw_input))
|
||||
|
||||
# Handle string input
|
||||
return self._parse_raw_workflow_input(workflow, str(raw_input))
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error parsing workflow input: {e}")
|
||||
return raw_input
|
||||
return cast(Any, raw_input)
|
||||
|
||||
def _get_start_executor_message_types(self, workflow: Any) -> tuple[Any | None, list[Any]]:
|
||||
"""Return start executor and its declared input types."""
|
||||
@@ -823,7 +872,8 @@ class AgentFrameworkExecutor:
|
||||
try:
|
||||
handlers = start_executor._handlers
|
||||
if isinstance(handlers, dict):
|
||||
message_types = list(handlers.keys())
|
||||
handlers_dict: Any = cast(Any, handlers)
|
||||
message_types = list(handlers_dict.keys())
|
||||
except Exception as exc: # pragma: no cover - defensive logging path
|
||||
logger.debug(f"Failed to read executor handlers: {exc}")
|
||||
|
||||
@@ -847,7 +897,8 @@ class AgentFrameworkExecutor:
|
||||
parsed = json.loads(input_data)
|
||||
# Only use parsed value if it's a list (ResponseInputParam format expected for HIL)
|
||||
if isinstance(parsed, list):
|
||||
input_data = parsed
|
||||
parsed_list: Any = cast(Any, parsed)
|
||||
input_data = parsed_list
|
||||
else:
|
||||
# Parsed to dict, string, or primitive - not HIL response format
|
||||
return None
|
||||
@@ -864,19 +915,32 @@ class AgentFrameworkExecutor:
|
||||
if not isinstance(input_data, list):
|
||||
return None
|
||||
|
||||
for item in input_data:
|
||||
if isinstance(item, dict) and item.get("type") == "message":
|
||||
message_content = item.get("content", [])
|
||||
input_items: Any = cast(Any, input_data)
|
||||
for item in input_items:
|
||||
if isinstance(item, dict):
|
||||
item_dict = cast(dict[str, Any], item)
|
||||
if item_dict.get("type") != "message":
|
||||
continue
|
||||
message_content = item_dict.get("content", [])
|
||||
|
||||
if isinstance(message_content, list):
|
||||
for content_item in message_content:
|
||||
message_content_items: Any = cast(Any, message_content)
|
||||
for content_item in message_content_items:
|
||||
if isinstance(content_item, dict):
|
||||
content_type = content_item.get("type")
|
||||
content_dict = cast(dict[str, Any], content_item)
|
||||
content_type = content_dict.get("type")
|
||||
|
||||
if content_type == "workflow_hil_response":
|
||||
# Extract responses dict
|
||||
# dict.get() returns Any, so we explicitly type it
|
||||
responses: dict[str, Any] = content_item.get("responses", {}) # type: ignore[assignment]
|
||||
responses_raw = content_dict.get("responses", {})
|
||||
if not isinstance(responses_raw, dict):
|
||||
continue
|
||||
|
||||
responses_dict: Any = cast(Any, responses_raw)
|
||||
responses = {
|
||||
str(response_key): response_value
|
||||
for response_key, response_value in responses_dict.items()
|
||||
}
|
||||
logger.info(f"Found workflow HIL responses: {list(responses.keys())}")
|
||||
return responses
|
||||
|
||||
@@ -1000,11 +1064,12 @@ class AgentFrameworkExecutor:
|
||||
return
|
||||
|
||||
# Find the source executor in the workflow
|
||||
if not hasattr(workflow, "executors") or not isinstance(workflow.executors, dict):
|
||||
executors = getattr(workflow, "executors", None)
|
||||
if not isinstance(executors, dict):
|
||||
logger.debug("Workflow doesn't have executors dict")
|
||||
return
|
||||
|
||||
source_executor = workflow.executors.get(source_executor_id)
|
||||
source_executor = cast(dict[str, Any], executors).get(source_executor_id)
|
||||
if not source_executor:
|
||||
logger.debug(f"Could not find executor '{source_executor_id}' in workflow")
|
||||
return
|
||||
|
||||
@@ -11,7 +11,7 @@ import uuid
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Any, Union
|
||||
from typing import Any, Union, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from agent_framework import Content, Message
|
||||
@@ -61,6 +61,17 @@ EventType = Union[
|
||||
]
|
||||
|
||||
|
||||
def _to_str_dict(value: Any) -> dict[str, Any] | None:
|
||||
"""Cast arbitrary dict-like payload to a string-keyed dictionary."""
|
||||
if not isinstance(value, dict):
|
||||
return None
|
||||
return cast(dict[str, Any], value)
|
||||
|
||||
|
||||
def _stringify_name(value: Any) -> str:
|
||||
return value if isinstance(value, str) else str(value)
|
||||
|
||||
|
||||
def _serialize_content_recursive(value: Any) -> Any:
|
||||
"""Recursively serialize Agent Framework Content objects to JSON-compatible values.
|
||||
|
||||
@@ -88,16 +99,21 @@ def _serialize_content_recursive(value: Any) -> Any:
|
||||
|
||||
# Handle dictionaries - recursively process values
|
||||
if isinstance(value, dict):
|
||||
return {key: _serialize_content_recursive(val) for key, val in value.items()}
|
||||
value_dict = cast(dict[str, Any], value)
|
||||
return {str(key): _serialize_content_recursive(val) for key, val in value_dict.items()}
|
||||
|
||||
# Handle lists and tuples - recursively process elements
|
||||
if isinstance(value, (list, tuple)):
|
||||
serialized = [_serialize_content_recursive(item) for item in value]
|
||||
sequence_items: Any = cast(Any, value)
|
||||
serialized: list[Any] = [_serialize_content_recursive(item) for item in sequence_items]
|
||||
# For single-item lists containing text Content, extract just the text
|
||||
# This handles the MCP case where result = [Content.from_text(text="Hello")]
|
||||
# and we want output = "Hello" not output = '[{"type": "text", "text": "Hello"}]'
|
||||
if len(serialized) == 1 and isinstance(serialized[0], dict) and serialized[0].get("type") == "text":
|
||||
return serialized[0].get("text", "")
|
||||
if len(serialized) == 1:
|
||||
first_item = _to_str_dict(serialized[0])
|
||||
if first_item and first_item.get("type") == "text":
|
||||
text_value = first_item.get("text", "")
|
||||
return text_value if isinstance(text_value, str) else str(text_value)
|
||||
return serialized
|
||||
|
||||
# For other objects with model_dump(), try that
|
||||
@@ -156,8 +172,10 @@ class MessageMapper:
|
||||
context = self._get_or_create_context(request)
|
||||
|
||||
# Handle error events
|
||||
if isinstance(raw_event, dict) and raw_event.get("type") == "error":
|
||||
return [await self._create_error_event(raw_event.get("message", "Unknown error"), context)]
|
||||
raw_event_dict = _to_str_dict(raw_event)
|
||||
if raw_event_dict and raw_event_dict.get("type") == "error":
|
||||
message = raw_event_dict.get("message", "Unknown error")
|
||||
return [await self._create_error_event(_stringify_name(message), context)]
|
||||
|
||||
# Handle ResponseTraceEvent objects from our trace collector
|
||||
from .models import ResponseTraceEvent
|
||||
@@ -185,15 +203,12 @@ class MessageMapper:
|
||||
# Handle WorkflowEvent with type='output' or 'data' wrapping AgentResponseUpdate
|
||||
# This must be checked BEFORE generic WorkflowEvent check
|
||||
# Note: AgentExecutor uses type='output' for streaming updates
|
||||
if (
|
||||
isinstance(raw_event, WorkflowEvent)
|
||||
and raw_event.type in ("output", "data")
|
||||
and raw_event.data
|
||||
and isinstance(raw_event.data, AgentResponseUpdate)
|
||||
):
|
||||
# Preserve executor_id in context for proper output routing
|
||||
context["current_executor_id"] = raw_event.executor_id
|
||||
return await self._convert_agent_update(raw_event.data, context)
|
||||
if isinstance(raw_event, WorkflowEvent) and raw_event.type in ("output", "data"):
|
||||
event_data = getattr(cast(Any, raw_event), "data", None)
|
||||
if isinstance(event_data, AgentResponseUpdate):
|
||||
# Preserve executor_id in context for proper output routing
|
||||
context["current_executor_id"] = getattr(cast(Any, raw_event), "executor_id", None)
|
||||
return await self._convert_agent_update(event_data, context)
|
||||
|
||||
# Handle complete agent response (AgentResponse) - for non-streaming agent execution
|
||||
if isinstance(raw_event, AgentResponse):
|
||||
@@ -210,10 +225,11 @@ class MessageMapper:
|
||||
except ImportError as e:
|
||||
logger.warning(f"Could not import Agent Framework types: {e}")
|
||||
# Fallback to attribute-based detection
|
||||
if hasattr(raw_event, "contents"):
|
||||
return await self._convert_agent_update(raw_event, context)
|
||||
if hasattr(raw_event, "__class__") and "Event" in raw_event.__class__.__name__:
|
||||
return await self._convert_workflow_event(raw_event, context)
|
||||
candidate_event = cast(Any, raw_event)
|
||||
if hasattr(candidate_event, "contents"):
|
||||
return await self._convert_agent_update(candidate_event, context)
|
||||
if "Event" in type(candidate_event).__name__:
|
||||
return await self._convert_workflow_event(candidate_event, context)
|
||||
|
||||
# Unknown event type
|
||||
return [await self._create_unknown_event(raw_event, context)]
|
||||
@@ -256,32 +272,36 @@ class MessageMapper:
|
||||
item = getattr(event, "item", None)
|
||||
if item:
|
||||
# Handle both object and dict formats
|
||||
item_type = item.get("type") if isinstance(item, dict) else getattr(item, "type", None)
|
||||
item_dict = _to_str_dict(item)
|
||||
item_type = item_dict.get("type") if item_dict is not None else getattr(item, "type", None)
|
||||
|
||||
# Track function calls to accumulate their arguments
|
||||
if item_type == "function_call":
|
||||
# Handle both object and dict formats
|
||||
if isinstance(item, dict):
|
||||
call_id = item.get("call_id") or item.get("id")
|
||||
if call_id:
|
||||
item_dict = _to_str_dict(item)
|
||||
if item_dict is not None:
|
||||
call_id_value = item_dict.get("call_id") or item_dict.get("id")
|
||||
if call_id_value:
|
||||
call_id = str(call_id_value)
|
||||
function_calls[call_id] = {
|
||||
"id": item.get("id", call_id),
|
||||
"id": str(item_dict.get("id", call_id)),
|
||||
"call_id": call_id,
|
||||
"name": item.get("name", ""),
|
||||
"arguments": item.get("arguments", ""),
|
||||
"name": _stringify_name(item_dict.get("name", "")),
|
||||
"arguments": _stringify_name(item_dict.get("arguments", "")),
|
||||
"type": "function_call",
|
||||
"status": item.get("status", "completed"),
|
||||
"status": _stringify_name(item_dict.get("status", "completed")),
|
||||
}
|
||||
else:
|
||||
call_id = getattr(item, "call_id", None) or getattr(item, "id", None)
|
||||
if call_id:
|
||||
call_id_value = getattr(item, "call_id", None) or getattr(item, "id", None)
|
||||
if call_id_value:
|
||||
call_id = str(call_id_value)
|
||||
function_calls[call_id] = {
|
||||
"id": getattr(item, "id", call_id),
|
||||
"id": str(getattr(item, "id", call_id)),
|
||||
"call_id": call_id,
|
||||
"name": getattr(item, "name", ""),
|
||||
"arguments": getattr(item, "arguments", ""),
|
||||
"name": _stringify_name(getattr(item, "name", "")),
|
||||
"arguments": _stringify_name(getattr(item, "arguments", "")),
|
||||
"type": "function_call",
|
||||
"status": getattr(item, "status", "completed"),
|
||||
"status": _stringify_name(getattr(item, "status", "completed")),
|
||||
}
|
||||
|
||||
# Other output items (message, etc.) - track for later
|
||||
@@ -299,8 +319,9 @@ class MessageMapper:
|
||||
|
||||
# Handle function result complete events
|
||||
elif event_type == "response.function_result.complete":
|
||||
call_id = getattr(event, "call_id", None)
|
||||
if call_id:
|
||||
call_id_value = getattr(event, "call_id", None)
|
||||
if call_id_value:
|
||||
call_id = str(call_id_value)
|
||||
function_results[call_id] = {
|
||||
"type": "function_call_output",
|
||||
"call_id": call_id,
|
||||
@@ -322,7 +343,7 @@ class MessageMapper:
|
||||
|
||||
# Build final text message from accumulated deltas
|
||||
# Combine all text parts (usually there's just one message)
|
||||
all_text_parts = []
|
||||
all_text_parts: list[str] = []
|
||||
for _item_id, parts in text_parts_by_message.items():
|
||||
all_text_parts.extend(parts)
|
||||
|
||||
@@ -493,14 +514,14 @@ class MessageMapper:
|
||||
return value.value
|
||||
|
||||
# Handle lists/tuples/sets - recursively serialize elements
|
||||
if isinstance(value, (list, tuple)):
|
||||
return [self._serialize_value(item) for item in value]
|
||||
if isinstance(value, set):
|
||||
return [self._serialize_value(item) for item in value]
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
value_items: Any = cast(Any, value)
|
||||
return [self._serialize_value(item) for item in value_items]
|
||||
|
||||
# Handle dicts - recursively serialize values
|
||||
if isinstance(value, dict):
|
||||
return {k: self._serialize_value(v) for k, v in value.items()}
|
||||
value_dict = cast(dict[str, Any], value)
|
||||
return {str(k): self._serialize_value(v) for k, v in value_dict.items()}
|
||||
|
||||
# Handle SerializationMixin (like Message) - call to_dict()
|
||||
if hasattr(value, "to_dict") and callable(getattr(value, "to_dict", None)):
|
||||
@@ -551,14 +572,15 @@ class MessageMapper:
|
||||
|
||||
# Handle dict first (most common)
|
||||
if isinstance(request_data, dict):
|
||||
return {k: self._serialize_value(v) for k, v in request_data.items()}
|
||||
request_dict = cast(dict[str, Any], request_data)
|
||||
return {str(k): self._serialize_value(v) for k, v in request_dict.items()}
|
||||
|
||||
# Handle dataclasses with nested SerializationMixin objects
|
||||
# We can't use asdict() directly because it doesn't handle Message
|
||||
if is_dataclass(request_data) and not isinstance(request_data, type):
|
||||
try:
|
||||
# Manually serialize each field to handle nested SerializationMixin
|
||||
result = {}
|
||||
result: dict[str, Any] = {}
|
||||
for field in fields(request_data):
|
||||
field_value = getattr(request_data, field.name)
|
||||
result[field.name] = self._serialize_value(field_value)
|
||||
@@ -900,8 +922,9 @@ class MessageMapper:
|
||||
text = str(output_data)
|
||||
elif isinstance(output_data, list):
|
||||
# Handle list of Message objects (from Magentic yield_output([final_answer]))
|
||||
text_parts = []
|
||||
for item in output_data:
|
||||
text_parts: list[str] = []
|
||||
output_items_list: Any = cast(Any, output_data)
|
||||
for item in output_items_list:
|
||||
if isinstance(item, Message):
|
||||
item_text = getattr(item, "text", None)
|
||||
if item_text:
|
||||
@@ -912,17 +935,17 @@ class MessageMapper:
|
||||
text_parts.append(item)
|
||||
else:
|
||||
try:
|
||||
text_parts.append(json.dumps(item, indent=2))
|
||||
text_parts.append(json.dumps(self._serialize_value(item), indent=2))
|
||||
except (TypeError, ValueError):
|
||||
text_parts.append(str(item))
|
||||
text = "\n".join(text_parts) if text_parts else str(output_data)
|
||||
text = "\n".join(text_parts) if text_parts else str(cast(Any, output_data))
|
||||
elif isinstance(output_data, str):
|
||||
# String output
|
||||
text = output_data
|
||||
else:
|
||||
# Object/dict → JSON string
|
||||
try:
|
||||
text = json.dumps(output_data, indent=2)
|
||||
text = json.dumps(self._serialize_value(output_data), indent=2)
|
||||
except (TypeError, ValueError):
|
||||
# Fallback to string representation if not JSON serializable
|
||||
text = str(output_data)
|
||||
@@ -1420,10 +1443,10 @@ class MessageMapper:
|
||||
None - no event emitted (usage goes in final Response.usage)
|
||||
"""
|
||||
# Extract usage from UsageContent.usage_details (UsageDetails object)
|
||||
details = content.usage_details or {}
|
||||
total_tokens = details.get("total_token_count", 0)
|
||||
prompt_tokens = details.get("input_token_count", 0)
|
||||
completion_tokens = details.get("output_token_count", 0)
|
||||
details = _to_str_dict(getattr(content, "usage_details", None)) or {}
|
||||
total_tokens = int(details.get("total_token_count", 0) or 0)
|
||||
prompt_tokens = int(details.get("input_token_count", 0) or 0)
|
||||
completion_tokens = int(details.get("output_token_count", 0) or 0)
|
||||
|
||||
# Accumulate for final Response.usage
|
||||
request_id = context.get("request_id", "default")
|
||||
|
||||
@@ -22,6 +22,26 @@ from ..models import AgentFrameworkRequest, OpenAIResponse
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _extract_error_details(body: Any) -> tuple[str | None, str | None, str | None]:
|
||||
"""Extract typed OpenAI error fields from error body payload."""
|
||||
if not isinstance(body, dict):
|
||||
return None, None, None
|
||||
|
||||
error_dict: dict[str, Any] = body.get("error") # type: ignore[assignment, reportUnknownVariableType]
|
||||
if not isinstance(error_dict, dict):
|
||||
return None, None, None
|
||||
|
||||
message = error_dict.get("message")
|
||||
error_type = error_dict.get("type")
|
||||
code = error_dict.get("code")
|
||||
|
||||
return (
|
||||
message if isinstance(message, str) else None,
|
||||
error_type if isinstance(error_type, str) else None,
|
||||
code if isinstance(code, str) else None,
|
||||
)
|
||||
|
||||
|
||||
class OpenAIExecutor:
|
||||
"""Executor for OpenAI Responses API - mirrors AgentFrameworkExecutor interface.
|
||||
|
||||
@@ -138,68 +158,64 @@ class OpenAIExecutor:
|
||||
except AuthenticationError as e:
|
||||
# 401 - Invalid API key or authentication issue
|
||||
logger.error(f"OpenAI authentication error: {e}", exc_info=True)
|
||||
error_body = e.body if hasattr(e, "body") else {}
|
||||
error_data = error_body.get("error", {}) if isinstance(error_body, dict) else {}
|
||||
message, error_type, code = _extract_error_details(e.body if hasattr(e, "body") else None)
|
||||
yield {
|
||||
"type": "response.failed",
|
||||
"response": {
|
||||
"id": f"resp_{os.urandom(16).hex()}",
|
||||
"status": "failed",
|
||||
"error": {
|
||||
"message": error_data.get("message", str(e)),
|
||||
"type": error_data.get("type", "authentication_error"),
|
||||
"code": error_data.get("code", "invalid_api_key"),
|
||||
"message": message or str(e),
|
||||
"type": error_type or "authentication_error",
|
||||
"code": code or "invalid_api_key",
|
||||
},
|
||||
},
|
||||
}
|
||||
except PermissionDeniedError as e:
|
||||
# 403 - Permission denied
|
||||
logger.error(f"OpenAI permission denied: {e}", exc_info=True)
|
||||
error_body = e.body if hasattr(e, "body") else {}
|
||||
error_data = error_body.get("error", {}) if isinstance(error_body, dict) else {}
|
||||
message, error_type, code = _extract_error_details(e.body if hasattr(e, "body") else None)
|
||||
yield {
|
||||
"type": "response.failed",
|
||||
"response": {
|
||||
"id": f"resp_{os.urandom(16).hex()}",
|
||||
"status": "failed",
|
||||
"error": {
|
||||
"message": error_data.get("message", str(e)),
|
||||
"type": error_data.get("type", "permission_denied"),
|
||||
"code": error_data.get("code", "insufficient_permissions"),
|
||||
"message": message or str(e),
|
||||
"type": error_type or "permission_denied",
|
||||
"code": code or "insufficient_permissions",
|
||||
},
|
||||
},
|
||||
}
|
||||
except RateLimitError as e:
|
||||
# 429 - Rate limit exceeded
|
||||
logger.error(f"OpenAI rate limit exceeded: {e}", exc_info=True)
|
||||
error_body = e.body if hasattr(e, "body") else {}
|
||||
error_data = error_body.get("error", {}) if isinstance(error_body, dict) else {}
|
||||
message, error_type, code = _extract_error_details(e.body if hasattr(e, "body") else None)
|
||||
yield {
|
||||
"type": "response.failed",
|
||||
"response": {
|
||||
"id": f"resp_{os.urandom(16).hex()}",
|
||||
"status": "failed",
|
||||
"error": {
|
||||
"message": error_data.get("message", str(e)),
|
||||
"type": error_data.get("type", "rate_limit_error"),
|
||||
"code": error_data.get("code", "rate_limit_exceeded"),
|
||||
"message": message or str(e),
|
||||
"type": error_type or "rate_limit_error",
|
||||
"code": code or "rate_limit_exceeded",
|
||||
},
|
||||
},
|
||||
}
|
||||
except APIStatusError as e:
|
||||
# Other OpenAI API errors
|
||||
logger.error(f"OpenAI API error: {e}", exc_info=True)
|
||||
error_body = e.body if hasattr(e, "body") else {}
|
||||
error_data = error_body.get("error", {}) if isinstance(error_body, dict) else {}
|
||||
message, error_type, code = _extract_error_details(e.body if hasattr(e, "body") else None)
|
||||
yield {
|
||||
"type": "response.failed",
|
||||
"response": {
|
||||
"id": f"resp_{os.urandom(16).hex()}",
|
||||
"status": "failed",
|
||||
"error": {
|
||||
"message": error_data.get("message", str(e)),
|
||||
"type": error_data.get("type", "api_error"),
|
||||
"code": error_data.get("code", "unknown_error"),
|
||||
"message": message or str(e),
|
||||
"type": error_type or "api_error",
|
||||
"code": code or "unknown_error",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -31,6 +31,29 @@ from .models._discovery_models import Deployment, DeploymentConfig, DiscoveryRes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _extract_error_details(body: object) -> tuple[str | None, str | None, str | None]:
|
||||
"""Extract typed OpenAI-style error payload fields."""
|
||||
if not isinstance(body, dict):
|
||||
return None, None, None
|
||||
|
||||
body_dict = cast(dict[str, object], body)
|
||||
error_obj = body_dict.get("error")
|
||||
if not isinstance(error_obj, dict):
|
||||
return None, None, None
|
||||
|
||||
error_dict = cast(dict[str, object], error_obj)
|
||||
message = error_dict.get("message")
|
||||
error_type = error_dict.get("type")
|
||||
code = error_dict.get("code")
|
||||
|
||||
return (
|
||||
message if isinstance(message, str) else None,
|
||||
error_type if isinstance(error_type, str) else None,
|
||||
code if isinstance(code, str) else None,
|
||||
)
|
||||
|
||||
|
||||
# Get package version
|
||||
try:
|
||||
__version__ = importlib.metadata.version("agent-framework-devui")
|
||||
@@ -83,6 +106,10 @@ class DevServer:
|
||||
self._pending_entities: list[Any] | None = None
|
||||
self._running_tasks: dict[str, asyncio.Task[Any]] = {} # Track running response tasks for cancellation
|
||||
|
||||
def set_pending_entities(self, entities: list[Any]) -> None:
|
||||
"""Set in-memory entities to register on startup."""
|
||||
self._pending_entities = entities
|
||||
|
||||
def _is_dev_mode(self) -> bool:
|
||||
"""Check if running in developer mode.
|
||||
|
||||
@@ -378,6 +405,8 @@ class DevServer:
|
||||
# Token valid, proceed
|
||||
return await call_next(request)
|
||||
|
||||
_ = auth_middleware
|
||||
|
||||
self._register_routes(app)
|
||||
self._mount_ui(app)
|
||||
|
||||
@@ -452,7 +481,7 @@ class DevServer:
|
||||
if entity_info.type == "workflow" and entity_obj:
|
||||
# Entity object already loaded by load_entity() above
|
||||
# Get workflow structure
|
||||
workflow_dump = None
|
||||
workflow_dump: dict[str, Any] | str | None = None
|
||||
if hasattr(entity_obj, "to_dict") and callable(getattr(entity_obj, "to_dict", None)):
|
||||
try:
|
||||
workflow_dump = entity_obj.to_dict() # type: ignore[attr-defined]
|
||||
@@ -475,7 +504,11 @@ class DevServer:
|
||||
except Exception:
|
||||
workflow_dump = raw_dump
|
||||
else:
|
||||
workflow_dump = parsed_dump if isinstance(parsed_dump, dict) else raw_dump
|
||||
if isinstance(parsed_dump, dict):
|
||||
parsed_dump_dict = cast(dict[str, Any], parsed_dump)
|
||||
workflow_dump = {str(k): v for k, v in parsed_dump_dict.items()}
|
||||
else:
|
||||
workflow_dump = raw_dump
|
||||
else:
|
||||
workflow_dump = raw_dump
|
||||
elif hasattr(entity_obj, "__dict__"):
|
||||
@@ -838,34 +871,31 @@ class DevServer:
|
||||
except AuthenticationError as e:
|
||||
# 401 - Invalid API key or authentication issue
|
||||
logger.error(f"OpenAI authentication error creating conversation: {e}")
|
||||
error_body = e.body if hasattr(e, "body") else {}
|
||||
error_data = error_body.get("error", {}) if isinstance(error_body, dict) else {}
|
||||
message, error_type, code = _extract_error_details(e.body if hasattr(e, "body") else None)
|
||||
error = OpenAIError.create(
|
||||
message=error_data.get("message", str(e)),
|
||||
type=error_data.get("type", "authentication_error"),
|
||||
code=error_data.get("code", "invalid_api_key"),
|
||||
message=message or str(e),
|
||||
type=error_type or "authentication_error",
|
||||
code=code or "invalid_api_key",
|
||||
)
|
||||
return JSONResponse(status_code=401, content=error.to_dict())
|
||||
except PermissionDeniedError as e:
|
||||
# 403 - Permission denied
|
||||
logger.error(f"OpenAI permission denied creating conversation: {e}")
|
||||
error_body = e.body if hasattr(e, "body") else {}
|
||||
error_data = error_body.get("error", {}) if isinstance(error_body, dict) else {}
|
||||
message, error_type, code = _extract_error_details(e.body if hasattr(e, "body") else None)
|
||||
error = OpenAIError.create(
|
||||
message=error_data.get("message", str(e)),
|
||||
type=error_data.get("type", "permission_denied"),
|
||||
code=error_data.get("code", "insufficient_permissions"),
|
||||
message=message or str(e),
|
||||
type=error_type or "permission_denied",
|
||||
code=code or "insufficient_permissions",
|
||||
)
|
||||
return JSONResponse(status_code=403, content=error.to_dict())
|
||||
except APIStatusError as e:
|
||||
# Other OpenAI API errors (rate limit, etc.)
|
||||
logger.error(f"OpenAI API error creating conversation: {e}")
|
||||
error_body = e.body if hasattr(e, "body") else {}
|
||||
error_data = error_body.get("error", {}) if isinstance(error_body, dict) else {}
|
||||
message, error_type, code = _extract_error_details(e.body if hasattr(e, "body") else None)
|
||||
error = OpenAIError.create(
|
||||
message=error_data.get("message", str(e)),
|
||||
type=error_data.get("type", "api_error"),
|
||||
code=error_data.get("code", "unknown_error"),
|
||||
message=message or str(e),
|
||||
type=error_type or "api_error",
|
||||
code=code or "unknown_error",
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=e.status_code if hasattr(e, "status_code") else 500, content=error.to_dict()
|
||||
@@ -902,7 +932,7 @@ class DevServer:
|
||||
executor = await self._ensure_executor()
|
||||
|
||||
# Build filter criteria
|
||||
filters = {}
|
||||
filters: dict[str, str] = {}
|
||||
if agent_id:
|
||||
filters["agent_id"] = agent_id
|
||||
if entity_id:
|
||||
@@ -997,15 +1027,16 @@ class DevServer:
|
||||
conversation_id, limit=limit, after=after, order=order
|
||||
)
|
||||
# Handle both Pydantic models and dicts (some stores return raw dicts)
|
||||
serialized_items = []
|
||||
serialized_items: list[dict[str, Any]] = []
|
||||
for item in items:
|
||||
if hasattr(item, "model_dump"):
|
||||
serialized_items.append(item.model_dump())
|
||||
elif isinstance(item, dict):
|
||||
serialized_items.append(item)
|
||||
item_dict = cast(dict[str, Any], item)
|
||||
serialized_items.append({str(k): v for k, v in item_dict.items()})
|
||||
else:
|
||||
logger.warning(f"Unexpected item type: {type(item)}, converting to dict")
|
||||
serialized_items.append(dict(item))
|
||||
serialized_items.append({str(k): v for k, v in dict(item).items()})
|
||||
|
||||
# Get stored traces for context inspection (DevUI extension)
|
||||
traces = executor.conversation_store.get_traces(conversation_id)
|
||||
@@ -1038,9 +1069,14 @@ class DevServer:
|
||||
if not item:
|
||||
raise HTTPException(status_code=404, detail="Item not found")
|
||||
# Handle both Pydantic models and dicts
|
||||
result: dict[str, Any] = (
|
||||
item.model_dump() if hasattr(item, "model_dump") else cast(dict[str, Any], item)
|
||||
)
|
||||
result: dict[str, Any]
|
||||
if hasattr(item, "model_dump"):
|
||||
result = item.model_dump()
|
||||
elif isinstance(item, dict):
|
||||
item_dict = cast(dict[str, Any], item)
|
||||
result = {str(k): v for k, v in item_dict.items()}
|
||||
else:
|
||||
result = {"value": item}
|
||||
return result
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -1085,16 +1121,42 @@ class DevServer:
|
||||
# Checkpoints are exposed as conversation items with type="checkpoint"
|
||||
# ============================================================================
|
||||
|
||||
registered_route_handlers = (
|
||||
health_check,
|
||||
get_meta,
|
||||
discover_entities,
|
||||
get_entity_info,
|
||||
reload_entity,
|
||||
create_deployment,
|
||||
list_deployments,
|
||||
get_deployment,
|
||||
delete_deployment,
|
||||
deploy_entity,
|
||||
create_response,
|
||||
cancel_response,
|
||||
create_conversation,
|
||||
list_conversations,
|
||||
retrieve_conversation,
|
||||
update_conversation,
|
||||
delete_conversation,
|
||||
create_conversation_items,
|
||||
list_conversation_items,
|
||||
retrieve_conversation_item,
|
||||
delete_conversation_item,
|
||||
)
|
||||
_ = registered_route_handlers
|
||||
|
||||
async def _stream_execution(
|
||||
self, executor: AgentFrameworkExecutor, request: AgentFrameworkRequest
|
||||
) -> AsyncGenerator[str]:
|
||||
"""Stream execution directly through executor."""
|
||||
try:
|
||||
# Collect events for final response.completed event
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
|
||||
# Get conversation_id for trace storage
|
||||
conversation_id = request._get_conversation_id()
|
||||
conversation_getter = getattr(request, "_get_conversation_id", None)
|
||||
conversation_id = conversation_getter() if callable(conversation_getter) else None
|
||||
|
||||
# Stream all events
|
||||
async for event in executor.execute_streaming(request):
|
||||
@@ -1104,7 +1166,7 @@ class DevServer:
|
||||
if conversation_id and hasattr(event, "type") and event.type == "response.trace.completed":
|
||||
try:
|
||||
trace_data = event.data if hasattr(event, "data") else None
|
||||
if trace_data:
|
||||
if trace_data and isinstance(conversation_id, str):
|
||||
executor.conversation_store.add_trace(conversation_id, trace_data)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to store trace event: {e}")
|
||||
@@ -1136,8 +1198,9 @@ class DevServer:
|
||||
# We need to increment from that
|
||||
last_seq = 0
|
||||
for event in reversed(events):
|
||||
if hasattr(event, "sequence_number") and event.sequence_number is not None:
|
||||
last_seq = event.sequence_number
|
||||
sequence_number = getattr(event, "sequence_number", None)
|
||||
if isinstance(sequence_number, int):
|
||||
last_seq = sequence_number
|
||||
break
|
||||
|
||||
completed_event = ResponseCompletedEvent(
|
||||
|
||||
@@ -5,13 +5,37 @@
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict, cast
|
||||
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Type aliases for better readability
|
||||
SessionData = dict[str, Any]
|
||||
RequestRecord = dict[str, Any]
|
||||
|
||||
class RequestRecord(TypedDict):
|
||||
"""Tracked execution request data."""
|
||||
|
||||
id: str
|
||||
timestamp: datetime
|
||||
entity_id: str
|
||||
executor: str
|
||||
input: Any
|
||||
model_id: str
|
||||
stream: bool
|
||||
execution_time: NotRequired[float]
|
||||
status: NotRequired[str]
|
||||
|
||||
|
||||
class SessionData(TypedDict):
|
||||
"""Stored session state."""
|
||||
|
||||
id: str
|
||||
created_at: datetime
|
||||
requests: list[RequestRecord]
|
||||
context: dict[str, Any]
|
||||
active: bool
|
||||
|
||||
|
||||
SessionSummary = dict[str, Any]
|
||||
|
||||
|
||||
@@ -95,7 +119,7 @@ class SessionManager:
|
||||
"stream": True,
|
||||
}
|
||||
session["requests"].append(request_record)
|
||||
return str(request_record["id"])
|
||||
return request_record["id"]
|
||||
|
||||
def update_request_record(self, session_id: str, request_id: str, updates: dict[str, Any]) -> None:
|
||||
"""Update a request record in a session.
|
||||
@@ -111,7 +135,8 @@ class SessionManager:
|
||||
|
||||
for request in session["requests"]:
|
||||
if request["id"] == request_id:
|
||||
request.update(updates)
|
||||
request_data = cast(dict[str, Any], request)
|
||||
request_data.update(updates)
|
||||
break
|
||||
|
||||
def get_session_history(self, session_id: str) -> SessionSummary | None:
|
||||
@@ -138,7 +163,7 @@ class SessionManager:
|
||||
"timestamp": req["timestamp"].isoformat(),
|
||||
"entity_id": req["entity_id"],
|
||||
"executor": req["executor"],
|
||||
"model": req["model"],
|
||||
"model": req["model_id"],
|
||||
"input_length": len(str(req["input"])) if req["input"] else 0,
|
||||
"execution_time": req.get("execution_time"),
|
||||
"status": req.get("status", "unknown"),
|
||||
@@ -153,7 +178,7 @@ class SessionManager:
|
||||
Returns:
|
||||
List of active session summaries
|
||||
"""
|
||||
active_sessions = []
|
||||
active_sessions: list[SessionSummary] = []
|
||||
|
||||
for session_id, session in self.sessions.items():
|
||||
if session["active"]:
|
||||
@@ -178,7 +203,7 @@ class SessionManager:
|
||||
"""
|
||||
cutoff_time = datetime.now().timestamp() - (max_age_hours * 3600)
|
||||
|
||||
sessions_to_remove = []
|
||||
sessions_to_remove: list[str] = []
|
||||
for session_id, session in self.sessions.items():
|
||||
if session["created_at"].timestamp() < cutoff_time:
|
||||
sessions_to_remove.append(session_id)
|
||||
|
||||
@@ -7,12 +7,20 @@ import json
|
||||
import logging
|
||||
from dataclasses import fields, is_dataclass
|
||||
from types import UnionType
|
||||
from typing import Any, Union, get_args, get_origin, get_type_hints
|
||||
from typing import Any, Union, cast, get_args, get_origin, get_type_hints
|
||||
|
||||
from agent_framework import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _string_key_dict(value: object) -> dict[str, Any] | None:
|
||||
"""Cast value to a dict."""
|
||||
if not isinstance(value, dict):
|
||||
return None
|
||||
return cast(dict[str, Any], value)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Agent Metadata Extraction
|
||||
# ============================================================================
|
||||
@@ -39,18 +47,21 @@ def extract_agent_metadata(entity_object: Any) -> dict[str, Any]:
|
||||
# Try to get instructions
|
||||
if hasattr(entity_object, "default_options"):
|
||||
chat_opts = entity_object.default_options
|
||||
if isinstance(chat_opts, dict):
|
||||
if "instructions" in chat_opts:
|
||||
metadata["instructions"] = chat_opts.get("instructions")
|
||||
chat_opts_dict = _string_key_dict(chat_opts)
|
||||
if chat_opts_dict is not None:
|
||||
if "instructions" in chat_opts_dict:
|
||||
metadata["instructions"] = chat_opts_dict.get("instructions")
|
||||
elif hasattr(chat_opts, "instructions"):
|
||||
metadata["instructions"] = chat_opts.instructions
|
||||
|
||||
# Try to get model - check both default_options and client
|
||||
if hasattr(entity_object, "default_options"):
|
||||
chat_opts = entity_object.default_options
|
||||
if isinstance(chat_opts, dict):
|
||||
if chat_opts.get("model_id"):
|
||||
metadata["model"] = chat_opts.get("model_id")
|
||||
chat_opts_dict = _string_key_dict(chat_opts)
|
||||
if chat_opts_dict is not None:
|
||||
model_id = chat_opts_dict.get("model_id")
|
||||
if model_id:
|
||||
metadata["model"] = model_id
|
||||
elif hasattr(chat_opts, "model_id") and chat_opts.model_id:
|
||||
metadata["model"] = chat_opts.model_id
|
||||
if metadata["model"] is None and hasattr(entity_object, "client") and hasattr(entity_object.client, "model_id"):
|
||||
@@ -112,7 +123,7 @@ def extract_executor_message_types(executor: Any) -> list[Any]:
|
||||
try:
|
||||
handlers = executor._handlers
|
||||
if isinstance(handlers, dict):
|
||||
message_types = list(handlers.keys())
|
||||
message_types = list(handlers.keys()) # type: ignore[arg-type] # pyright: ignore[reportUnknownArgumentType]
|
||||
except Exception as exc: # pragma: no cover - defensive logging path
|
||||
logger.debug(f"Failed to read executor handlers: {exc}")
|
||||
|
||||
@@ -366,11 +377,10 @@ def extract_response_type_from_executor(executor: Any, request_type: type) -> ty
|
||||
_, second_param_type = param_items[1] if len(param_items) > 1 else (None, None)
|
||||
|
||||
# Check if first param matches request_type
|
||||
first_matches_request = first_param_type == request_type or (
|
||||
hasattr(first_param_type, "__name__")
|
||||
and hasattr(request_type, "__name__")
|
||||
and first_param_type.__name__ == request_type.__name__
|
||||
)
|
||||
first_matches_request = first_param_type == request_type
|
||||
if not first_matches_request and isinstance(first_param_type, type):
|
||||
request_type_name = request_type.__name__
|
||||
first_matches_request = first_param_type.__name__ == request_type_name
|
||||
|
||||
# Verify we have a matching request type and valid response type (must be a type class)
|
||||
if first_matches_request and second_param_type is not None and isinstance(second_param_type, type):
|
||||
@@ -432,7 +442,7 @@ def generate_input_schema(input_type: type) -> dict[str, Any]:
|
||||
return generate_schema_from_dataclass(input_type)
|
||||
|
||||
# 5. Fallback to string
|
||||
type_name = getattr(input_type, "__name__", str(input_type))
|
||||
type_name = input_type.__name__ if isinstance(input_type, type) else str(cast(Any, input_type))
|
||||
return {"type": "string", "description": f"Input type: {type_name}"}
|
||||
|
||||
|
||||
@@ -466,8 +476,9 @@ def parse_input_for_type(input_data: Any, target_type: type) -> Any:
|
||||
return _parse_string_input(input_data, target_type)
|
||||
|
||||
# Handle dict input
|
||||
if isinstance(input_data, dict):
|
||||
return _parse_dict_input(input_data, target_type)
|
||||
parsed_dict = _string_key_dict(input_data)
|
||||
if parsed_dict is not None:
|
||||
return _parse_dict_input(parsed_dict, target_type)
|
||||
|
||||
# Fallback: return original
|
||||
return input_data
|
||||
|
||||
@@ -2,8 +2,11 @@
|
||||
|
||||
"""Discovery API models for entity information."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
from collections.abc import Callable
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
@@ -57,7 +60,7 @@ class EntityInfo(BaseModel):
|
||||
class DiscoveryResponse(BaseModel):
|
||||
"""Response model for entity discovery."""
|
||||
|
||||
entities: list[EntityInfo] = Field(default_factory=list)
|
||||
entities: list[EntityInfo] = Field(default_factory=cast(Callable[..., list[EntityInfo]], list))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
|
||||
@@ -94,7 +94,7 @@ include = "../../shared_tasks.toml"
|
||||
|
||||
[tool.poe.tasks]
|
||||
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_devui"
|
||||
test = "pytest --cov=agent_framework_devui --cov-report=term-missing:skip-covered tests"
|
||||
test = "pytest -m \"not integration\" --cov=agent_framework_devui --cov-report=term-missing:skip-covered tests"
|
||||
|
||||
[build-system]
|
||||
requires = ["flit-core >= 3.11,<4.0"]
|
||||
|
||||
Reference in New Issue
Block a user