Python: Fix Python pyright package scoping and typing remediation (#4426)

* Fix Python pyright package scoping and typing remediation

Implements issue #4407 by removing the root pyright include, adding package-level pyright includes, and resolving pyright/mypy typing issues across Python packages. Also cleans unnecessary casts and applies line-level, rule-specific ignores where external libraries are too dynamic.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Reduce pyright cost in handoff cloning

Simplify cloned_options construction in HandoffAgentExecutor to avoid expensive TypedDict narrowing/inference in _handoff.py, which was causing pyright to spend a long time in orchestrations.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* fix types

* Fix lint and type-check regressions

Resolve current Python package check failures across lint, pyright, and mypy after recent code changes, including purview/declarative pyright issues and multiple ruff simplification findings.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* fixed hooks

* Stabilize package tests and test tasks

Resolve cross-package non-integration test failures, simplify streaming type flow, harden locale/culture handling, and standardize package test poe tasks to exclude integration tests where applicable.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* lots of small fixes

* Fix current Python test regressions

Address current failing unit tests in azure-ai, bedrock, and azure-cosmos while keeping Bedrock parsing logic inline (no new static helper methods).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* small fixes

* small fixes

* removed pydantic from json

* final updates

* fix core

* fix tests

* fix obser

---------

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Eduard van Valkenburg
2026-03-05 16:32:24 +01:00
committed by GitHub
Unverified
parent 4a043c6c66
commit 55ddd841b7
122 changed files with 2328 additions and 2407 deletions
@@ -73,7 +73,7 @@ def register_cleanup(entity: Any, *hooks: Callable[[], Any]) -> None:
)
def _get_registered_cleanup_hooks(entity: Any) -> list[Callable[[], Any]]:
def _get_registered_cleanup_hooks(entity: Any) -> list[Callable[[], Any]]: # type: ignore[reportUnusedFunction]
"""Get cleanup hooks registered for an entity (internal use).
Args:
@@ -193,7 +193,7 @@ def serve(
if entities:
logger.info(f"Registering {len(entities)} in-memory entities")
# Store entities for later registration during server startup
server._pending_entities = entities
server.set_pending_entities(entities)
app = server.get_app()
@@ -11,12 +11,14 @@ from __future__ import annotations
import time
import uuid
from abc import ABC, abstractmethod
from collections.abc import MutableSequence
from typing import Any, Literal, cast
from agent_framework import AgentSession, Message
from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage, WorkflowCheckpoint
from openai.types.conversations import Conversation, ConversationDeletedResource
from openai.types.conversations.conversation_item import ConversationItem
from openai.types.conversations.message import Content as OpenAIContent
from openai.types.conversations.message import Message as OpenAIMessage
from openai.types.conversations.text_content import TextContent
from openai.types.responses import (
@@ -300,12 +302,17 @@ class InMemoryConversationStore(ConversationStore):
stored_messages: list[Message] = conv_data["messages"]
# Convert items to Messages and add to storage
chat_messages = []
chat_messages: list[Message] = []
for item in items:
# Simple conversion - assume text content for now
role = item.get("role", "user")
content = item.get("content", [])
text = content[0].get("text", "") if content else ""
first_content = cast(
dict[str, Any],
content[0] if content and isinstance(content, list) and isinstance(content[0], dict) else {},
)
text_obj = first_content.get("text", "")
text = text_obj if isinstance(text_obj, str) else str(text_obj)
chat_msg = Message(role=role, text=text) # type: ignore[arg-type]
chat_messages.append(chat_msg)
@@ -318,23 +325,18 @@ class InMemoryConversationStore(ConversationStore):
for msg in chat_messages:
item_id = f"item_{uuid.uuid4().hex}"
# Extract role - handle both string and enum
role_str = msg.role if hasattr(msg.role, "value") else str(msg.role)
role = cast(MessageRole, role_str) # Safe: Agent Framework roles match OpenAI roles
# Convert Message contents to OpenAI TextContent format
message_content = []
message_content: MutableSequence[OpenAIContent] = []
for content_item in msg.contents:
if content_item.type == "text":
# Extract text from TextContent object
text_value = getattr(content_item, "text", "")
message_content.append(TextContent(type="text", text=text_value))
message_content.append(TextContent(type="text", text=content_item.text or ""))
# Create Message object (concrete type from ConversationItem union)
message = OpenAIMessage(
id=item_id,
type="message", # Required discriminator for union
role=role,
role=cast(MessageRole, msg.role), # Safe: Agent Framework roles match OpenAI roles,
content=message_content,
status="completed", # Required field
)
@@ -383,8 +385,8 @@ class InMemoryConversationStore(ConversationStore):
# A single Message may produce multiple ConversationItems
# (e.g., a message with both text and a function call)
message_contents: list[TextContent | ResponseInputImage | ResponseInputFile] = []
function_calls = []
function_results = []
function_calls: list[ResponseFunctionToolCallItem] = []
function_results: list[ResponseFunctionToolCallOutputItem] = []
for content in msg.contents:
content_type = getattr(content, "type", None)
@@ -628,7 +630,7 @@ class InMemoryConversationStore(ConversationStore):
async def list_conversations_by_metadata(self, metadata_filter: dict[str, str]) -> list[Conversation]:
"""Filter conversations by metadata (e.g., agent_id)."""
results = []
results: list[Conversation] = []
for conv_data in self._conversations.values():
conv_meta = conv_data.get("metadata", {}).copy() # Copy to avoid mutating original
@@ -704,7 +706,8 @@ class CheckpointConversationManager:
ValueError: If conversation not found
"""
# Access internal conversations dict (we know it's InMemoryConversationStore)
conv_data = self._store._conversations.get(conversation_id)
conversations_dict = cast(dict[str, dict[str, Any]], getattr(self._store, "_conversations", {}))
conv_data = conversations_dict.get(conversation_id)
if not conv_data:
raise ValueError(f"Conversation {conversation_id} not found")
@@ -10,6 +10,7 @@ import uuid
from collections.abc import AsyncGenerator
from datetime import datetime, timezone
from pathlib import Path
from typing import cast
from urllib.parse import urlparse
from .models._discovery_models import Deployment, DeploymentConfig, DeploymentEvent
@@ -175,7 +176,7 @@ class DeploymentManager:
# Check required resource providers are registered
required_providers = ["Microsoft.App", "Microsoft.ContainerRegistry", "Microsoft.OperationalInsights"]
unregistered_providers = []
unregistered_providers: list[str] = []
# Get list of registered providers
provider_check = await asyncio.create_subprocess_exec(
@@ -195,7 +196,12 @@ class DeploymentManager:
import json
try:
registered = json.loads(stdout.decode())
registered_raw = json.loads(stdout.decode())
registered: list[str] = []
if isinstance(registered_raw, list):
for item_obj in cast(list[object], registered_raw):
if isinstance(item_obj, str):
registered.append(item_obj)
for provider in required_providers:
if provider not in registered:
unregistered_providers.append(provider)
@@ -385,7 +391,7 @@ CMD ["devui", "/app/entity", "--mode", "{config.ui_mode}", "--host", "0.0.0.0",
)
# Stream output line by line
output_lines = []
output_lines: list[str] = []
try:
if not process.stdout:
raise ValueError("Failed to capture process output")
@@ -473,8 +479,11 @@ CMD ["devui", "/app/entity", "--mode", "{config.ui_mode}", "--host", "0.0.0.0",
for url in urls:
# Strip common trailing punctuation to ensure clean URL parsing
url_clean = url.rstrip(".,;:!?'\")}]")
host = urlparse(url_clean).hostname
if host and (host == "azurecontainerapps.io" or host.endswith(".azurecontainerapps.io")):
parsed_url = urlparse(str(url_clean))
host = parsed_url.hostname
if isinstance(host, str) and (
host == "azurecontainerapps.io" or host.endswith(".azurecontainerapps.io")
):
await event_queue.put(
DeploymentEvent(type="deploy.progress", message="Deployment URL generated!")
)
@@ -11,7 +11,7 @@ import logging
import sys
import uuid
from pathlib import Path
from typing import Any
from typing import Any, cast
from dotenv import load_dotenv
@@ -141,7 +141,7 @@ class EntityDiscovery:
self._loaded_objects[entity_id] = entity_obj
# Check module-level registry for cleanup hooks
from . import _get_registered_cleanup_hooks
from . import _get_registered_cleanup_hooks # type: ignore[reportPrivateUsage]
registered_hooks = _get_registered_cleanup_hooks(entity_obj)
if registered_hooks:
@@ -299,7 +299,7 @@ class EntityDiscovery:
self._loaded_objects[entity_id] = entity_object
# Check module-level registry for cleanup hooks
from . import _get_registered_cleanup_hooks
from . import _get_registered_cleanup_hooks # type: ignore[reportPrivateUsage]
registered_hooks = _get_registered_cleanup_hooks(entity_object)
if registered_hooks:
@@ -379,6 +379,8 @@ class EntityDiscovery:
deployment_supported = True
deployment_reason = "Ready for deployment (pending path verification)"
class_name = type(entity_object).__name__
# Create EntityInfo with Agent Framework specifics
return EntityInfo(
id=entity_id,
@@ -400,9 +402,7 @@ class EntityDiscovery:
deployment_reason=deployment_reason,
metadata={
"source": "agent_framework_object",
"class_name": entity_object.__class__.__name__
if hasattr(entity_object, "__class__")
else str(type(entity_object)),
"class_name": class_name,
},
)
@@ -854,7 +854,7 @@ class EntityDiscovery:
"module_path": module_path,
"entity_type": obj_type,
"source": source,
"class_name": obj.__class__.__name__ if hasattr(obj, "__class__") else str(type(obj)),
"class_name": type(obj).__name__,
},
)
@@ -874,47 +874,63 @@ class EntityDiscovery:
Returns:
List of tool/executor names
"""
tools = []
tools: list[str] = []
try:
if obj_type == "agent":
# For agents, check default_options.get("tools")
chat_options = getattr(obj, "default_options", None)
chat_options_tools = None
if chat_options:
chat_options_tools = chat_options.get("tools")
chat_options_tools: object | None = None
if isinstance(chat_options, dict):
chat_options_dict = cast(dict[str, Any], chat_options)
chat_options_tools = chat_options_dict.get("tools")
if chat_options_tools:
for tool in chat_options_tools:
if hasattr(tool, "__name__"):
tools.append(tool.__name__)
elif hasattr(tool, "name"):
tools.append(tool.name)
if chat_options_tools is not None:
tool_iterable: list[object] = (
cast(list[object], chat_options_tools)
if isinstance(chat_options_tools, list)
else [chat_options_tools]
)
for tool_obj in tool_iterable:
tool_name = getattr(tool_obj, "__name__", None)
if isinstance(tool_name, str):
tools.append(tool_name)
continue
named_tool = getattr(tool_obj, "name", None)
if isinstance(named_tool, str):
tools.append(named_tool)
else:
tools.append(str(tool))
tools.append(str(tool_obj))
else:
# Fallback to direct tools attribute
agent_tools = getattr(obj, "tools", None)
if agent_tools:
for tool in agent_tools:
if hasattr(tool, "__name__"):
tools.append(tool.__name__)
elif hasattr(tool, "name"):
tools.append(tool.name)
if isinstance(agent_tools, list):
for tool_obj in cast(list[object], agent_tools):
tool_name = getattr(tool_obj, "__name__", None)
if isinstance(tool_name, str):
tools.append(tool_name)
continue
named_tool = getattr(tool_obj, "name", None)
if isinstance(named_tool, str):
tools.append(named_tool)
else:
tools.append(str(tool))
tools.append(str(tool_obj))
elif obj_type == "workflow":
# For workflows, extract executor names
if hasattr(obj, "get_executors_list"):
executor_objects = obj.get_executors_list()
tools = [getattr(ex, "id", str(ex)) for ex in executor_objects]
if isinstance(executor_objects, list):
for executor_obj in cast(list[object], executor_objects):
tools.append(str(getattr(executor_obj, "id", executor_obj)))
elif hasattr(obj, "executors"):
executors = obj.executors
if isinstance(executors, list):
tools = [getattr(ex, "id", str(ex)) for ex in executors]
for executor_obj in cast(list[object], executors):
tools.append(str(getattr(executor_obj, "id", executor_obj)))
elif isinstance(executors, dict):
tools = list(executors.keys())
executors_dict = cast(dict[str, Any], executors)
for key_obj in executors_dict:
tools.append(str(key_obj))
except Exception as e:
logger.debug(f"Error extracting tools from {obj_type} {type(obj)}: {e}")
@@ -7,7 +7,7 @@ from __future__ import annotations
import json
import logging
from collections.abc import AsyncGenerator
from typing import Any
from typing import Any, cast
from agent_framework import Content, SupportsAgentRun, Workflow
@@ -24,7 +24,8 @@ logger = logging.getLogger(__name__)
def _get_event_type(event: Any) -> str | None:
"""Safely get the type of an event, handling both objects and dicts."""
if isinstance(event, dict):
return event.get("type")
event_type = cast(dict[str, Any], event).get("type")
return event_type if isinstance(event_type, str) else None
return getattr(event, "type", None)
@@ -71,7 +72,8 @@ class AgentFrameworkExecutor:
from opentelemetry.sdk.trace import TracerProvider
# Only set up if no provider exists yet
if not hasattr(trace, "_TRACER_PROVIDER") or trace._TRACER_PROVIDER is None:
current_provider = trace.get_tracer_provider()
if current_provider.__class__.__name__ == "ProxyTracerProvider":
resource = Resource.create({
"service.name": "agent-framework-server",
"service.version": "1.0.0",
@@ -94,21 +96,29 @@ class AgentFrameworkExecutor:
# Configure if instrumentation is enabled (via enable_instrumentation() or env var)
if OBSERVABILITY_SETTINGS.ENABLED:
# Only configure providers if not already executed
if not OBSERVABILITY_SETTINGS._executed_setup:
# Call configure_otel_providers to set up exporters.
# If OTEL_EXPORTER_OTLP_ENDPOINT is set, exporters will be created automatically.
# If not set, no exporters are created (no console spam), but DevUI's
# TracerProvider from _setup_instrumentation_provider() remains active for local capture.
configure_otel_providers(enable_sensitive_data=OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED)
logger.info("Enabled Agent Framework observability")
else:
logger.debug("Agent Framework observability already configured")
# Call configure_otel_providers to set up exporters.
# If OTEL_EXPORTER_OTLP_ENDPOINT is set, exporters will be created automatically.
# If not set, no exporters are created (no console spam), but DevUI's
# TracerProvider from _setup_instrumentation_provider() remains active for local capture.
configure_otel_providers(enable_sensitive_data=OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED)
logger.info("Enabled Agent Framework observability")
else:
logger.debug("Instrumentation not enabled, skipping observability setup")
except Exception as e:
logger.warning(f"Failed to enable Agent Framework observability: {e}")
def _get_request_conversation_id(self, request: AgentFrameworkRequest) -> str | None:
"""Read conversation id using public request fields."""
if isinstance(request.conversation, str):
return request.conversation
if isinstance(request.conversation, dict):
conversation_id = request.conversation.get("id")
if isinstance(conversation_id, str):
return conversation_id
return None
async def _ensure_mcp_connections(self, agent: Any) -> None:
"""Ensure MCP tool connections are healthy before agent execution.
@@ -317,7 +327,7 @@ class AgentFrameworkExecutor:
# Get session from conversation parameter (OpenAI standard!)
session = None
conversation_id = request._get_conversation_id()
conversation_id = self._get_request_conversation_id(request)
if conversation_id:
session = self.conversation_store.get_session(conversation_id)
if session:
@@ -344,7 +354,7 @@ class AgentFrameworkExecutor:
if session:
run_kwargs["session"] = session
stream = agent.run(user_message, **run_kwargs)
stream = cast(Any, agent.run(user_message, **run_kwargs))
async for update in stream:
for trace_event in trace_collector.get_pending_events():
yield trace_event
@@ -388,7 +398,7 @@ class AgentFrameworkExecutor:
entity_id = request.get_entity_id() or "unknown"
# Get or create session conversation for checkpoint storage
conversation_id = request._get_conversation_id()
conversation_id = self._get_request_conversation_id(request)
if not conversation_id:
# Create default session if not provided
import time
@@ -463,11 +473,14 @@ class AgentFrameworkExecutor:
logger.info(f"Resuming workflow with HIL responses for {len(hil_responses)} request(s)")
# Unwrap primitive responses if they're wrapped in {response: value} format
unwrapped_responses = {}
unwrapped_responses: dict[str, Any] = {}
for request_id, response_value in hil_responses.items():
if isinstance(response_value, dict) and "response" in response_value:
response_value = response_value["response"]
unwrapped_responses[request_id] = response_value
normalized_response: Any = response_value
if isinstance(response_value, dict):
response_dict = cast(dict[str, Any], response_value)
if "response" in response_dict:
normalized_response = response_dict["response"]
unwrapped_responses[request_id] = normalized_response
hil_responses = unwrapped_responses
@@ -568,7 +581,8 @@ class AgentFrameworkExecutor:
# Handle OpenAI ResponseInputParam (List[ResponseInputItemParam])
if isinstance(input_data, list):
return self._convert_openai_input_to_chat_message(input_data, Message, Role)
input_items: Any = cast(Any, input_data)
return self._convert_openai_input_to_chat_message(input_items, Message, Role)
# Fallback for other formats
return self._extract_user_message_fallback(input_data)
@@ -593,27 +607,31 @@ class AgentFrameworkExecutor:
for item in input_items:
# Handle dict format (from JSON)
if isinstance(item, dict):
item_type = item.get("type")
item_dict = cast(dict[str, Any], item)
item_type = item_dict.get("type")
if item_type == "message":
# Extract content from OpenAI message
message_content = item.get("content", [])
message_content = item_dict.get("content", [])
# Handle both string content and list content
if isinstance(message_content, str):
contents.append(Content.from_text(text=message_content))
elif isinstance(message_content, list):
for content_item in message_content:
message_content_items: Any = cast(Any, message_content)
for content_item in message_content_items:
# Handle dict content items
if isinstance(content_item, dict):
content_type = content_item.get("type")
content_dict = cast(dict[str, Any], content_item)
content_type = content_dict.get("type")
if content_type == "input_text":
text = content_item.get("text", "")
contents.append(Content.from_text(text=text))
text = content_dict.get("text", "")
if isinstance(text, str):
contents.append(Content.from_text(text=text))
elif content_type == "input_image":
image_url = content_item.get("image_url", "")
if image_url:
image_url = content_dict.get("image_url", "")
if isinstance(image_url, str) and image_url:
# Extract media type from data URI if possible
# Parse media type from data URL, fallback to image/png
if image_url.startswith("data:"):
@@ -631,9 +649,12 @@ class AgentFrameworkExecutor:
elif content_type == "input_file":
# Handle file input
file_data = content_item.get("file_data")
file_url = content_item.get("file_url")
filename = content_item.get("filename", "")
file_data = content_dict.get("file_data")
file_url = content_dict.get("file_url")
filename = content_dict.get("filename", "")
if not isinstance(filename, str):
filename = ""
# Determine media type from filename
media_type = "application/octet-stream" # default
@@ -656,8 +677,10 @@ class AgentFrameworkExecutor:
# Use file_data or file_url
# Include filename in additional_properties for OpenAI/Azure file handling
additional_props = {"filename": filename} if filename else None
if file_data:
additional_props: dict[str, Any] | None = (
{"filename": filename} if filename else None
)
if isinstance(file_data, str) and file_data:
# Assume file_data is base64, create data URI
data_uri = f"data:{media_type};base64,{file_data}"
contents.append(
@@ -667,7 +690,7 @@ class AgentFrameworkExecutor:
additional_properties=additional_props,
)
)
elif file_url:
elif isinstance(file_url, str) and file_url:
contents.append(
Content.from_uri(
uri=file_url,
@@ -679,15 +702,35 @@ class AgentFrameworkExecutor:
elif content_type == "function_approval_response":
# Handle function approval response (DevUI extension)
try:
request_id = content_item.get("request_id", "")
approved = content_item.get("approved", False)
function_call_data = content_item.get("function_call", {})
request_id = content_dict.get("request_id", "")
approved = content_dict.get("approved", False)
function_call_data = content_dict.get("function_call", {})
if not isinstance(request_id, str):
request_id = ""
if not isinstance(approved, bool):
approved = False
if not isinstance(function_call_data, dict):
function_call_data = {}
function_call_data_dict = cast(dict[str, Any], function_call_data)
function_call_id = function_call_data_dict.get("id", "")
function_call_name = function_call_data_dict.get("name", "")
function_call_args = function_call_data_dict.get("arguments", {})
if not isinstance(function_call_id, str):
function_call_id = ""
if not isinstance(function_call_name, str):
function_call_name = ""
if not isinstance(function_call_args, dict):
function_call_args = {}
# Create FunctionCallContent from the function_call data
function_call = Content.from_function_call(
call_id=function_call_data.get("id", ""),
name=function_call_data.get("name", ""),
arguments=function_call_data.get("arguments", {}),
call_id=function_call_id,
name=function_call_name,
arguments=cast(dict[str, Any], function_call_args),
)
# Create FunctionApprovalResponseContent with correct signature
@@ -739,12 +782,14 @@ class AgentFrameworkExecutor:
if isinstance(input_data, str):
return input_data
if isinstance(input_data, dict):
typed_input_data = cast(dict[str, Any], input_data)
# Try common field names
for field in ["message", "text", "input", "content", "query"]:
if field in input_data:
return str(input_data[field])
if field in typed_input_data:
value = typed_input_data[field]
return value if isinstance(value, str) else str(value)
# Fallback to JSON string
return json.dumps(input_data)
return json.dumps(typed_input_data)
return str(input_data)
def _is_openai_multimodal_format(self, input_data: Any) -> bool:
@@ -758,8 +803,12 @@ class AgentFrameworkExecutor:
"""
if not isinstance(input_data, list) or not input_data:
return False
first_item = input_data[0]
return isinstance(first_item, dict) and first_item.get("type") == "message"
input_data_items: Any = cast(Any, input_data)
first_item = input_data_items[0]
if not isinstance(first_item, dict):
return False
first_type = cast(dict[str, Any], first_item).get("type")
return isinstance(first_type, str) and first_type == "message"
async def _parse_workflow_input(self, workflow: Any, raw_input: Any) -> Any:
"""Parse input based on workflow's expected input type.
@@ -775,7 +824,7 @@ class AgentFrameworkExecutor:
# Handle JSON string input (from frontend api.ts JSON.stringify)
if isinstance(raw_input, str):
try:
parsed = json.loads(raw_input)
parsed: Any = json.loads(raw_input)
raw_input = parsed
except (json.JSONDecodeError, TypeError):
# Plain text string, continue with string handling
@@ -789,14 +838,14 @@ class AgentFrameworkExecutor:
# Handle structured input (dict)
if isinstance(raw_input, dict):
return self._parse_structured_workflow_input(workflow, raw_input)
return self._parse_structured_workflow_input(workflow, cast(dict[str, Any], raw_input))
# Handle string input
return self._parse_raw_workflow_input(workflow, str(raw_input))
except Exception as e:
logger.warning(f"Error parsing workflow input: {e}")
return raw_input
return cast(Any, raw_input)
def _get_start_executor_message_types(self, workflow: Any) -> tuple[Any | None, list[Any]]:
"""Return start executor and its declared input types."""
@@ -823,7 +872,8 @@ class AgentFrameworkExecutor:
try:
handlers = start_executor._handlers
if isinstance(handlers, dict):
message_types = list(handlers.keys())
handlers_dict: Any = cast(Any, handlers)
message_types = list(handlers_dict.keys())
except Exception as exc: # pragma: no cover - defensive logging path
logger.debug(f"Failed to read executor handlers: {exc}")
@@ -847,7 +897,8 @@ class AgentFrameworkExecutor:
parsed = json.loads(input_data)
# Only use parsed value if it's a list (ResponseInputParam format expected for HIL)
if isinstance(parsed, list):
input_data = parsed
parsed_list: Any = cast(Any, parsed)
input_data = parsed_list
else:
# Parsed to dict, string, or primitive - not HIL response format
return None
@@ -864,19 +915,32 @@ class AgentFrameworkExecutor:
if not isinstance(input_data, list):
return None
for item in input_data:
if isinstance(item, dict) and item.get("type") == "message":
message_content = item.get("content", [])
input_items: Any = cast(Any, input_data)
for item in input_items:
if isinstance(item, dict):
item_dict = cast(dict[str, Any], item)
if item_dict.get("type") != "message":
continue
message_content = item_dict.get("content", [])
if isinstance(message_content, list):
for content_item in message_content:
message_content_items: Any = cast(Any, message_content)
for content_item in message_content_items:
if isinstance(content_item, dict):
content_type = content_item.get("type")
content_dict = cast(dict[str, Any], content_item)
content_type = content_dict.get("type")
if content_type == "workflow_hil_response":
# Extract responses dict
# dict.get() returns Any, so we explicitly type it
responses: dict[str, Any] = content_item.get("responses", {}) # type: ignore[assignment]
responses_raw = content_dict.get("responses", {})
if not isinstance(responses_raw, dict):
continue
responses_dict: Any = cast(Any, responses_raw)
responses = {
str(response_key): response_value
for response_key, response_value in responses_dict.items()
}
logger.info(f"Found workflow HIL responses: {list(responses.keys())}")
return responses
@@ -1000,11 +1064,12 @@ class AgentFrameworkExecutor:
return
# Find the source executor in the workflow
if not hasattr(workflow, "executors") or not isinstance(workflow.executors, dict):
executors = getattr(workflow, "executors", None)
if not isinstance(executors, dict):
logger.debug("Workflow doesn't have executors dict")
return
source_executor = workflow.executors.get(source_executor_id)
source_executor = cast(dict[str, Any], executors).get(source_executor_id)
if not source_executor:
logger.debug(f"Could not find executor '{source_executor_id}' in workflow")
return
@@ -11,7 +11,7 @@ import uuid
from collections import OrderedDict
from collections.abc import Sequence
from datetime import datetime
from typing import Any, Union
from typing import Any, Union, cast
from uuid import uuid4
from agent_framework import Content, Message
@@ -61,6 +61,17 @@ EventType = Union[
]
def _to_str_dict(value: Any) -> dict[str, Any] | None:
"""Cast arbitrary dict-like payload to a string-keyed dictionary."""
if not isinstance(value, dict):
return None
return cast(dict[str, Any], value)
def _stringify_name(value: Any) -> str:
return value if isinstance(value, str) else str(value)
def _serialize_content_recursive(value: Any) -> Any:
"""Recursively serialize Agent Framework Content objects to JSON-compatible values.
@@ -88,16 +99,21 @@ def _serialize_content_recursive(value: Any) -> Any:
# Handle dictionaries - recursively process values
if isinstance(value, dict):
return {key: _serialize_content_recursive(val) for key, val in value.items()}
value_dict = cast(dict[str, Any], value)
return {str(key): _serialize_content_recursive(val) for key, val in value_dict.items()}
# Handle lists and tuples - recursively process elements
if isinstance(value, (list, tuple)):
serialized = [_serialize_content_recursive(item) for item in value]
sequence_items: Any = cast(Any, value)
serialized: list[Any] = [_serialize_content_recursive(item) for item in sequence_items]
# For single-item lists containing text Content, extract just the text
# This handles the MCP case where result = [Content.from_text(text="Hello")]
# and we want output = "Hello" not output = '[{"type": "text", "text": "Hello"}]'
if len(serialized) == 1 and isinstance(serialized[0], dict) and serialized[0].get("type") == "text":
return serialized[0].get("text", "")
if len(serialized) == 1:
first_item = _to_str_dict(serialized[0])
if first_item and first_item.get("type") == "text":
text_value = first_item.get("text", "")
return text_value if isinstance(text_value, str) else str(text_value)
return serialized
# For other objects with model_dump(), try that
@@ -156,8 +172,10 @@ class MessageMapper:
context = self._get_or_create_context(request)
# Handle error events
if isinstance(raw_event, dict) and raw_event.get("type") == "error":
return [await self._create_error_event(raw_event.get("message", "Unknown error"), context)]
raw_event_dict = _to_str_dict(raw_event)
if raw_event_dict and raw_event_dict.get("type") == "error":
message = raw_event_dict.get("message", "Unknown error")
return [await self._create_error_event(_stringify_name(message), context)]
# Handle ResponseTraceEvent objects from our trace collector
from .models import ResponseTraceEvent
@@ -185,15 +203,12 @@ class MessageMapper:
# Handle WorkflowEvent with type='output' or 'data' wrapping AgentResponseUpdate
# This must be checked BEFORE generic WorkflowEvent check
# Note: AgentExecutor uses type='output' for streaming updates
if (
isinstance(raw_event, WorkflowEvent)
and raw_event.type in ("output", "data")
and raw_event.data
and isinstance(raw_event.data, AgentResponseUpdate)
):
# Preserve executor_id in context for proper output routing
context["current_executor_id"] = raw_event.executor_id
return await self._convert_agent_update(raw_event.data, context)
if isinstance(raw_event, WorkflowEvent) and raw_event.type in ("output", "data"):
event_data = getattr(cast(Any, raw_event), "data", None)
if isinstance(event_data, AgentResponseUpdate):
# Preserve executor_id in context for proper output routing
context["current_executor_id"] = getattr(cast(Any, raw_event), "executor_id", None)
return await self._convert_agent_update(event_data, context)
# Handle complete agent response (AgentResponse) - for non-streaming agent execution
if isinstance(raw_event, AgentResponse):
@@ -210,10 +225,11 @@ class MessageMapper:
except ImportError as e:
logger.warning(f"Could not import Agent Framework types: {e}")
# Fallback to attribute-based detection
if hasattr(raw_event, "contents"):
return await self._convert_agent_update(raw_event, context)
if hasattr(raw_event, "__class__") and "Event" in raw_event.__class__.__name__:
return await self._convert_workflow_event(raw_event, context)
candidate_event = cast(Any, raw_event)
if hasattr(candidate_event, "contents"):
return await self._convert_agent_update(candidate_event, context)
if "Event" in type(candidate_event).__name__:
return await self._convert_workflow_event(candidate_event, context)
# Unknown event type
return [await self._create_unknown_event(raw_event, context)]
@@ -256,32 +272,36 @@ class MessageMapper:
item = getattr(event, "item", None)
if item:
# Handle both object and dict formats
item_type = item.get("type") if isinstance(item, dict) else getattr(item, "type", None)
item_dict = _to_str_dict(item)
item_type = item_dict.get("type") if item_dict is not None else getattr(item, "type", None)
# Track function calls to accumulate their arguments
if item_type == "function_call":
# Handle both object and dict formats
if isinstance(item, dict):
call_id = item.get("call_id") or item.get("id")
if call_id:
item_dict = _to_str_dict(item)
if item_dict is not None:
call_id_value = item_dict.get("call_id") or item_dict.get("id")
if call_id_value:
call_id = str(call_id_value)
function_calls[call_id] = {
"id": item.get("id", call_id),
"id": str(item_dict.get("id", call_id)),
"call_id": call_id,
"name": item.get("name", ""),
"arguments": item.get("arguments", ""),
"name": _stringify_name(item_dict.get("name", "")),
"arguments": _stringify_name(item_dict.get("arguments", "")),
"type": "function_call",
"status": item.get("status", "completed"),
"status": _stringify_name(item_dict.get("status", "completed")),
}
else:
call_id = getattr(item, "call_id", None) or getattr(item, "id", None)
if call_id:
call_id_value = getattr(item, "call_id", None) or getattr(item, "id", None)
if call_id_value:
call_id = str(call_id_value)
function_calls[call_id] = {
"id": getattr(item, "id", call_id),
"id": str(getattr(item, "id", call_id)),
"call_id": call_id,
"name": getattr(item, "name", ""),
"arguments": getattr(item, "arguments", ""),
"name": _stringify_name(getattr(item, "name", "")),
"arguments": _stringify_name(getattr(item, "arguments", "")),
"type": "function_call",
"status": getattr(item, "status", "completed"),
"status": _stringify_name(getattr(item, "status", "completed")),
}
# Other output items (message, etc.) - track for later
@@ -299,8 +319,9 @@ class MessageMapper:
# Handle function result complete events
elif event_type == "response.function_result.complete":
call_id = getattr(event, "call_id", None)
if call_id:
call_id_value = getattr(event, "call_id", None)
if call_id_value:
call_id = str(call_id_value)
function_results[call_id] = {
"type": "function_call_output",
"call_id": call_id,
@@ -322,7 +343,7 @@ class MessageMapper:
# Build final text message from accumulated deltas
# Combine all text parts (usually there's just one message)
all_text_parts = []
all_text_parts: list[str] = []
for _item_id, parts in text_parts_by_message.items():
all_text_parts.extend(parts)
@@ -493,14 +514,14 @@ class MessageMapper:
return value.value
# Handle lists/tuples/sets - recursively serialize elements
if isinstance(value, (list, tuple)):
return [self._serialize_value(item) for item in value]
if isinstance(value, set):
return [self._serialize_value(item) for item in value]
if isinstance(value, (list, tuple, set)):
value_items: Any = cast(Any, value)
return [self._serialize_value(item) for item in value_items]
# Handle dicts - recursively serialize values
if isinstance(value, dict):
return {k: self._serialize_value(v) for k, v in value.items()}
value_dict = cast(dict[str, Any], value)
return {str(k): self._serialize_value(v) for k, v in value_dict.items()}
# Handle SerializationMixin (like Message) - call to_dict()
if hasattr(value, "to_dict") and callable(getattr(value, "to_dict", None)):
@@ -551,14 +572,15 @@ class MessageMapper:
# Handle dict first (most common)
if isinstance(request_data, dict):
return {k: self._serialize_value(v) for k, v in request_data.items()}
request_dict = cast(dict[str, Any], request_data)
return {str(k): self._serialize_value(v) for k, v in request_dict.items()}
# Handle dataclasses with nested SerializationMixin objects
# We can't use asdict() directly because it doesn't handle Message
if is_dataclass(request_data) and not isinstance(request_data, type):
try:
# Manually serialize each field to handle nested SerializationMixin
result = {}
result: dict[str, Any] = {}
for field in fields(request_data):
field_value = getattr(request_data, field.name)
result[field.name] = self._serialize_value(field_value)
@@ -900,8 +922,9 @@ class MessageMapper:
text = str(output_data)
elif isinstance(output_data, list):
# Handle list of Message objects (from Magentic yield_output([final_answer]))
text_parts = []
for item in output_data:
text_parts: list[str] = []
output_items_list: Any = cast(Any, output_data)
for item in output_items_list:
if isinstance(item, Message):
item_text = getattr(item, "text", None)
if item_text:
@@ -912,17 +935,17 @@ class MessageMapper:
text_parts.append(item)
else:
try:
text_parts.append(json.dumps(item, indent=2))
text_parts.append(json.dumps(self._serialize_value(item), indent=2))
except (TypeError, ValueError):
text_parts.append(str(item))
text = "\n".join(text_parts) if text_parts else str(output_data)
text = "\n".join(text_parts) if text_parts else str(cast(Any, output_data))
elif isinstance(output_data, str):
# String output
text = output_data
else:
# Object/dict → JSON string
try:
text = json.dumps(output_data, indent=2)
text = json.dumps(self._serialize_value(output_data), indent=2)
except (TypeError, ValueError):
# Fallback to string representation if not JSON serializable
text = str(output_data)
@@ -1420,10 +1443,10 @@ class MessageMapper:
None - no event emitted (usage goes in final Response.usage)
"""
# Extract usage from UsageContent.usage_details (UsageDetails object)
details = content.usage_details or {}
total_tokens = details.get("total_token_count", 0)
prompt_tokens = details.get("input_token_count", 0)
completion_tokens = details.get("output_token_count", 0)
details = _to_str_dict(getattr(content, "usage_details", None)) or {}
total_tokens = int(details.get("total_token_count", 0) or 0)
prompt_tokens = int(details.get("input_token_count", 0) or 0)
completion_tokens = int(details.get("output_token_count", 0) or 0)
# Accumulate for final Response.usage
request_id = context.get("request_id", "default")
@@ -22,6 +22,26 @@ from ..models import AgentFrameworkRequest, OpenAIResponse
logger = logging.getLogger(__name__)
def _extract_error_details(body: Any) -> tuple[str | None, str | None, str | None]:
"""Extract typed OpenAI error fields from error body payload."""
if not isinstance(body, dict):
return None, None, None
error_dict: dict[str, Any] = body.get("error") # type: ignore[assignment, reportUnknownVariableType]
if not isinstance(error_dict, dict):
return None, None, None
message = error_dict.get("message")
error_type = error_dict.get("type")
code = error_dict.get("code")
return (
message if isinstance(message, str) else None,
error_type if isinstance(error_type, str) else None,
code if isinstance(code, str) else None,
)
class OpenAIExecutor:
"""Executor for OpenAI Responses API - mirrors AgentFrameworkExecutor interface.
@@ -138,68 +158,64 @@ class OpenAIExecutor:
except AuthenticationError as e:
# 401 - Invalid API key or authentication issue
logger.error(f"OpenAI authentication error: {e}", exc_info=True)
error_body = e.body if hasattr(e, "body") else {}
error_data = error_body.get("error", {}) if isinstance(error_body, dict) else {}
message, error_type, code = _extract_error_details(e.body if hasattr(e, "body") else None)
yield {
"type": "response.failed",
"response": {
"id": f"resp_{os.urandom(16).hex()}",
"status": "failed",
"error": {
"message": error_data.get("message", str(e)),
"type": error_data.get("type", "authentication_error"),
"code": error_data.get("code", "invalid_api_key"),
"message": message or str(e),
"type": error_type or "authentication_error",
"code": code or "invalid_api_key",
},
},
}
except PermissionDeniedError as e:
# 403 - Permission denied
logger.error(f"OpenAI permission denied: {e}", exc_info=True)
error_body = e.body if hasattr(e, "body") else {}
error_data = error_body.get("error", {}) if isinstance(error_body, dict) else {}
message, error_type, code = _extract_error_details(e.body if hasattr(e, "body") else None)
yield {
"type": "response.failed",
"response": {
"id": f"resp_{os.urandom(16).hex()}",
"status": "failed",
"error": {
"message": error_data.get("message", str(e)),
"type": error_data.get("type", "permission_denied"),
"code": error_data.get("code", "insufficient_permissions"),
"message": message or str(e),
"type": error_type or "permission_denied",
"code": code or "insufficient_permissions",
},
},
}
except RateLimitError as e:
# 429 - Rate limit exceeded
logger.error(f"OpenAI rate limit exceeded: {e}", exc_info=True)
error_body = e.body if hasattr(e, "body") else {}
error_data = error_body.get("error", {}) if isinstance(error_body, dict) else {}
message, error_type, code = _extract_error_details(e.body if hasattr(e, "body") else None)
yield {
"type": "response.failed",
"response": {
"id": f"resp_{os.urandom(16).hex()}",
"status": "failed",
"error": {
"message": error_data.get("message", str(e)),
"type": error_data.get("type", "rate_limit_error"),
"code": error_data.get("code", "rate_limit_exceeded"),
"message": message or str(e),
"type": error_type or "rate_limit_error",
"code": code or "rate_limit_exceeded",
},
},
}
except APIStatusError as e:
# Other OpenAI API errors
logger.error(f"OpenAI API error: {e}", exc_info=True)
error_body = e.body if hasattr(e, "body") else {}
error_data = error_body.get("error", {}) if isinstance(error_body, dict) else {}
message, error_type, code = _extract_error_details(e.body if hasattr(e, "body") else None)
yield {
"type": "response.failed",
"response": {
"id": f"resp_{os.urandom(16).hex()}",
"status": "failed",
"error": {
"message": error_data.get("message", str(e)),
"type": error_data.get("type", "api_error"),
"code": error_data.get("code", "unknown_error"),
"message": message or str(e),
"type": error_type or "api_error",
"code": code or "unknown_error",
},
},
}
@@ -31,6 +31,29 @@ from .models._discovery_models import Deployment, DeploymentConfig, DiscoveryRes
logger = logging.getLogger(__name__)
def _extract_error_details(body: object) -> tuple[str | None, str | None, str | None]:
"""Extract typed OpenAI-style error payload fields."""
if not isinstance(body, dict):
return None, None, None
body_dict = cast(dict[str, object], body)
error_obj = body_dict.get("error")
if not isinstance(error_obj, dict):
return None, None, None
error_dict = cast(dict[str, object], error_obj)
message = error_dict.get("message")
error_type = error_dict.get("type")
code = error_dict.get("code")
return (
message if isinstance(message, str) else None,
error_type if isinstance(error_type, str) else None,
code if isinstance(code, str) else None,
)
# Get package version
try:
__version__ = importlib.metadata.version("agent-framework-devui")
@@ -83,6 +106,10 @@ class DevServer:
self._pending_entities: list[Any] | None = None
self._running_tasks: dict[str, asyncio.Task[Any]] = {} # Track running response tasks for cancellation
def set_pending_entities(self, entities: list[Any]) -> None:
"""Set in-memory entities to register on startup."""
self._pending_entities = entities
def _is_dev_mode(self) -> bool:
"""Check if running in developer mode.
@@ -378,6 +405,8 @@ class DevServer:
# Token valid, proceed
return await call_next(request)
_ = auth_middleware
self._register_routes(app)
self._mount_ui(app)
@@ -452,7 +481,7 @@ class DevServer:
if entity_info.type == "workflow" and entity_obj:
# Entity object already loaded by load_entity() above
# Get workflow structure
workflow_dump = None
workflow_dump: dict[str, Any] | str | None = None
if hasattr(entity_obj, "to_dict") and callable(getattr(entity_obj, "to_dict", None)):
try:
workflow_dump = entity_obj.to_dict() # type: ignore[attr-defined]
@@ -475,7 +504,11 @@ class DevServer:
except Exception:
workflow_dump = raw_dump
else:
workflow_dump = parsed_dump if isinstance(parsed_dump, dict) else raw_dump
if isinstance(parsed_dump, dict):
parsed_dump_dict = cast(dict[str, Any], parsed_dump)
workflow_dump = {str(k): v for k, v in parsed_dump_dict.items()}
else:
workflow_dump = raw_dump
else:
workflow_dump = raw_dump
elif hasattr(entity_obj, "__dict__"):
@@ -838,34 +871,31 @@ class DevServer:
except AuthenticationError as e:
# 401 - Invalid API key or authentication issue
logger.error(f"OpenAI authentication error creating conversation: {e}")
error_body = e.body if hasattr(e, "body") else {}
error_data = error_body.get("error", {}) if isinstance(error_body, dict) else {}
message, error_type, code = _extract_error_details(e.body if hasattr(e, "body") else None)
error = OpenAIError.create(
message=error_data.get("message", str(e)),
type=error_data.get("type", "authentication_error"),
code=error_data.get("code", "invalid_api_key"),
message=message or str(e),
type=error_type or "authentication_error",
code=code or "invalid_api_key",
)
return JSONResponse(status_code=401, content=error.to_dict())
except PermissionDeniedError as e:
# 403 - Permission denied
logger.error(f"OpenAI permission denied creating conversation: {e}")
error_body = e.body if hasattr(e, "body") else {}
error_data = error_body.get("error", {}) if isinstance(error_body, dict) else {}
message, error_type, code = _extract_error_details(e.body if hasattr(e, "body") else None)
error = OpenAIError.create(
message=error_data.get("message", str(e)),
type=error_data.get("type", "permission_denied"),
code=error_data.get("code", "insufficient_permissions"),
message=message or str(e),
type=error_type or "permission_denied",
code=code or "insufficient_permissions",
)
return JSONResponse(status_code=403, content=error.to_dict())
except APIStatusError as e:
# Other OpenAI API errors (rate limit, etc.)
logger.error(f"OpenAI API error creating conversation: {e}")
error_body = e.body if hasattr(e, "body") else {}
error_data = error_body.get("error", {}) if isinstance(error_body, dict) else {}
message, error_type, code = _extract_error_details(e.body if hasattr(e, "body") else None)
error = OpenAIError.create(
message=error_data.get("message", str(e)),
type=error_data.get("type", "api_error"),
code=error_data.get("code", "unknown_error"),
message=message or str(e),
type=error_type or "api_error",
code=code or "unknown_error",
)
return JSONResponse(
status_code=e.status_code if hasattr(e, "status_code") else 500, content=error.to_dict()
@@ -902,7 +932,7 @@ class DevServer:
executor = await self._ensure_executor()
# Build filter criteria
filters = {}
filters: dict[str, str] = {}
if agent_id:
filters["agent_id"] = agent_id
if entity_id:
@@ -997,15 +1027,16 @@ class DevServer:
conversation_id, limit=limit, after=after, order=order
)
# Handle both Pydantic models and dicts (some stores return raw dicts)
serialized_items = []
serialized_items: list[dict[str, Any]] = []
for item in items:
if hasattr(item, "model_dump"):
serialized_items.append(item.model_dump())
elif isinstance(item, dict):
serialized_items.append(item)
item_dict = cast(dict[str, Any], item)
serialized_items.append({str(k): v for k, v in item_dict.items()})
else:
logger.warning(f"Unexpected item type: {type(item)}, converting to dict")
serialized_items.append(dict(item))
serialized_items.append({str(k): v for k, v in dict(item).items()})
# Get stored traces for context inspection (DevUI extension)
traces = executor.conversation_store.get_traces(conversation_id)
@@ -1038,9 +1069,14 @@ class DevServer:
if not item:
raise HTTPException(status_code=404, detail="Item not found")
# Handle both Pydantic models and dicts
result: dict[str, Any] = (
item.model_dump() if hasattr(item, "model_dump") else cast(dict[str, Any], item)
)
result: dict[str, Any]
if hasattr(item, "model_dump"):
result = item.model_dump()
elif isinstance(item, dict):
item_dict = cast(dict[str, Any], item)
result = {str(k): v for k, v in item_dict.items()}
else:
result = {"value": item}
return result
except HTTPException:
raise
@@ -1085,16 +1121,42 @@ class DevServer:
# Checkpoints are exposed as conversation items with type="checkpoint"
# ============================================================================
registered_route_handlers = (
health_check,
get_meta,
discover_entities,
get_entity_info,
reload_entity,
create_deployment,
list_deployments,
get_deployment,
delete_deployment,
deploy_entity,
create_response,
cancel_response,
create_conversation,
list_conversations,
retrieve_conversation,
update_conversation,
delete_conversation,
create_conversation_items,
list_conversation_items,
retrieve_conversation_item,
delete_conversation_item,
)
_ = registered_route_handlers
async def _stream_execution(
self, executor: AgentFrameworkExecutor, request: AgentFrameworkRequest
) -> AsyncGenerator[str]:
"""Stream execution directly through executor."""
try:
# Collect events for final response.completed event
events = []
events: list[Any] = []
# Get conversation_id for trace storage
conversation_id = request._get_conversation_id()
conversation_getter = getattr(request, "_get_conversation_id", None)
conversation_id = conversation_getter() if callable(conversation_getter) else None
# Stream all events
async for event in executor.execute_streaming(request):
@@ -1104,7 +1166,7 @@ class DevServer:
if conversation_id and hasattr(event, "type") and event.type == "response.trace.completed":
try:
trace_data = event.data if hasattr(event, "data") else None
if trace_data:
if trace_data and isinstance(conversation_id, str):
executor.conversation_store.add_trace(conversation_id, trace_data)
except Exception as e:
logger.debug(f"Failed to store trace event: {e}")
@@ -1136,8 +1198,9 @@ class DevServer:
# We need to increment from that
last_seq = 0
for event in reversed(events):
if hasattr(event, "sequence_number") and event.sequence_number is not None:
last_seq = event.sequence_number
sequence_number = getattr(event, "sequence_number", None)
if isinstance(sequence_number, int):
last_seq = sequence_number
break
completed_event = ResponseCompletedEvent(
@@ -5,13 +5,37 @@
import logging
import uuid
from datetime import datetime
from typing import Any
from typing import Any, TypedDict, cast
from typing_extensions import NotRequired
logger = logging.getLogger(__name__)
# Type aliases for better readability
SessionData = dict[str, Any]
RequestRecord = dict[str, Any]
class RequestRecord(TypedDict):
"""Tracked execution request data."""
id: str
timestamp: datetime
entity_id: str
executor: str
input: Any
model_id: str
stream: bool
execution_time: NotRequired[float]
status: NotRequired[str]
class SessionData(TypedDict):
"""Stored session state."""
id: str
created_at: datetime
requests: list[RequestRecord]
context: dict[str, Any]
active: bool
SessionSummary = dict[str, Any]
@@ -95,7 +119,7 @@ class SessionManager:
"stream": True,
}
session["requests"].append(request_record)
return str(request_record["id"])
return request_record["id"]
def update_request_record(self, session_id: str, request_id: str, updates: dict[str, Any]) -> None:
"""Update a request record in a session.
@@ -111,7 +135,8 @@ class SessionManager:
for request in session["requests"]:
if request["id"] == request_id:
request.update(updates)
request_data = cast(dict[str, Any], request)
request_data.update(updates)
break
def get_session_history(self, session_id: str) -> SessionSummary | None:
@@ -138,7 +163,7 @@ class SessionManager:
"timestamp": req["timestamp"].isoformat(),
"entity_id": req["entity_id"],
"executor": req["executor"],
"model": req["model"],
"model": req["model_id"],
"input_length": len(str(req["input"])) if req["input"] else 0,
"execution_time": req.get("execution_time"),
"status": req.get("status", "unknown"),
@@ -153,7 +178,7 @@ class SessionManager:
Returns:
List of active session summaries
"""
active_sessions = []
active_sessions: list[SessionSummary] = []
for session_id, session in self.sessions.items():
if session["active"]:
@@ -178,7 +203,7 @@ class SessionManager:
"""
cutoff_time = datetime.now().timestamp() - (max_age_hours * 3600)
sessions_to_remove = []
sessions_to_remove: list[str] = []
for session_id, session in self.sessions.items():
if session["created_at"].timestamp() < cutoff_time:
sessions_to_remove.append(session_id)
@@ -7,12 +7,20 @@ import json
import logging
from dataclasses import fields, is_dataclass
from types import UnionType
from typing import Any, Union, get_args, get_origin, get_type_hints
from typing import Any, Union, cast, get_args, get_origin, get_type_hints
from agent_framework import Message
logger = logging.getLogger(__name__)
def _string_key_dict(value: object) -> dict[str, Any] | None:
"""Cast value to a dict."""
if not isinstance(value, dict):
return None
return cast(dict[str, Any], value)
# ============================================================================
# Agent Metadata Extraction
# ============================================================================
@@ -39,18 +47,21 @@ def extract_agent_metadata(entity_object: Any) -> dict[str, Any]:
# Try to get instructions
if hasattr(entity_object, "default_options"):
chat_opts = entity_object.default_options
if isinstance(chat_opts, dict):
if "instructions" in chat_opts:
metadata["instructions"] = chat_opts.get("instructions")
chat_opts_dict = _string_key_dict(chat_opts)
if chat_opts_dict is not None:
if "instructions" in chat_opts_dict:
metadata["instructions"] = chat_opts_dict.get("instructions")
elif hasattr(chat_opts, "instructions"):
metadata["instructions"] = chat_opts.instructions
# Try to get model - check both default_options and client
if hasattr(entity_object, "default_options"):
chat_opts = entity_object.default_options
if isinstance(chat_opts, dict):
if chat_opts.get("model_id"):
metadata["model"] = chat_opts.get("model_id")
chat_opts_dict = _string_key_dict(chat_opts)
if chat_opts_dict is not None:
model_id = chat_opts_dict.get("model_id")
if model_id:
metadata["model"] = model_id
elif hasattr(chat_opts, "model_id") and chat_opts.model_id:
metadata["model"] = chat_opts.model_id
if metadata["model"] is None and hasattr(entity_object, "client") and hasattr(entity_object.client, "model_id"):
@@ -112,7 +123,7 @@ def extract_executor_message_types(executor: Any) -> list[Any]:
try:
handlers = executor._handlers
if isinstance(handlers, dict):
message_types = list(handlers.keys())
message_types = list(handlers.keys()) # type: ignore[arg-type] # pyright: ignore[reportUnknownArgumentType]
except Exception as exc: # pragma: no cover - defensive logging path
logger.debug(f"Failed to read executor handlers: {exc}")
@@ -366,11 +377,10 @@ def extract_response_type_from_executor(executor: Any, request_type: type) -> ty
_, second_param_type = param_items[1] if len(param_items) > 1 else (None, None)
# Check if first param matches request_type
first_matches_request = first_param_type == request_type or (
hasattr(first_param_type, "__name__")
and hasattr(request_type, "__name__")
and first_param_type.__name__ == request_type.__name__
)
first_matches_request = first_param_type == request_type
if not first_matches_request and isinstance(first_param_type, type):
request_type_name = request_type.__name__
first_matches_request = first_param_type.__name__ == request_type_name
# Verify we have a matching request type and valid response type (must be a type class)
if first_matches_request and second_param_type is not None and isinstance(second_param_type, type):
@@ -432,7 +442,7 @@ def generate_input_schema(input_type: type) -> dict[str, Any]:
return generate_schema_from_dataclass(input_type)
# 5. Fallback to string
type_name = getattr(input_type, "__name__", str(input_type))
type_name = input_type.__name__ if isinstance(input_type, type) else str(cast(Any, input_type))
return {"type": "string", "description": f"Input type: {type_name}"}
@@ -466,8 +476,9 @@ def parse_input_for_type(input_data: Any, target_type: type) -> Any:
return _parse_string_input(input_data, target_type)
# Handle dict input
if isinstance(input_data, dict):
return _parse_dict_input(input_data, target_type)
parsed_dict = _string_key_dict(input_data)
if parsed_dict is not None:
return _parse_dict_input(parsed_dict, target_type)
# Fallback: return original
return input_data
@@ -2,8 +2,11 @@
"""Discovery API models for entity information."""
from __future__ import annotations
import re
from typing import Any
from typing import Any, cast
from collections.abc import Callable
from pydantic import BaseModel, Field, field_validator
@@ -57,7 +60,7 @@ class EntityInfo(BaseModel):
class DiscoveryResponse(BaseModel):
"""Response model for entity discovery."""
entities: list[EntityInfo] = Field(default_factory=list)
entities: list[EntityInfo] = Field(default_factory=cast(Callable[..., list[EntityInfo]], list))
# ============================================================================
+1 -1
View File
@@ -94,7 +94,7 @@ include = "../../shared_tasks.toml"
[tool.poe.tasks]
mypy = "mypy --config-file $POE_ROOT/pyproject.toml agent_framework_devui"
test = "pytest --cov=agent_framework_devui --cov-report=term-missing:skip-covered tests"
test = "pytest -m \"not integration\" --cov=agent_framework_devui --cov-report=term-missing:skip-covered tests"
[build-system]
requires = ["flit-core >= 3.11,<4.0"]