mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Fix type compatibility check (#1753)
* Fix type compatibility check * Address comments
This commit is contained in:
committed by
GitHub
Unverified
parent
00a78d7bc6
commit
1fbdcf8268
@@ -29,6 +29,7 @@ from ._events import (
|
||||
WorkflowEvent,
|
||||
)
|
||||
from ._message_utils import normalize_messages_input
|
||||
from ._typing_utils import is_type_compatible
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._workflow import Workflow
|
||||
@@ -93,7 +94,7 @@ class WorkflowAgent(BaseAgent):
|
||||
except KeyError as exc: # Defensive: workflow lacks a configured entry point
|
||||
raise ValueError("Workflow's start executor is not defined.") from exc
|
||||
|
||||
if list[ChatMessage] not in start_executor.input_types:
|
||||
if not any(is_type_compatible(list[ChatMessage], input_type) for input_type in start_executor.input_types):
|
||||
raise ValueError("Workflow's start executor cannot handle list[ChatMessage]")
|
||||
|
||||
super().__init__(id=id, name=name, description=description, **kwargs)
|
||||
|
||||
@@ -169,3 +169,128 @@ def is_instance_of(data: Any, target_type: type | UnionType | Any) -> bool:
|
||||
|
||||
# Fallback: if we reach here, we assume data is an instance of the target_type
|
||||
return isinstance(data, target_type)
|
||||
|
||||
|
||||
def is_type_compatible(source_type: type | UnionType | Any, target_type: type | UnionType | Any) -> bool:
|
||||
"""Check if source_type is compatible with target_type.
|
||||
|
||||
A type is compatible if values of source_type can be assigned to variables of target_type.
|
||||
For example:
|
||||
- list[ChatMessage] is compatible with list[str | ChatMessage]
|
||||
- str is compatible with str | int
|
||||
- int is compatible with Any
|
||||
|
||||
Args:
|
||||
source_type: The type being assigned from
|
||||
target_type: The type being assigned to
|
||||
|
||||
Returns:
|
||||
bool: True if source_type is compatible with target_type, False otherwise
|
||||
"""
|
||||
# Case 0: target_type is Any - always compatible
|
||||
if target_type is Any:
|
||||
return True
|
||||
|
||||
# Case 1: exact type match
|
||||
if source_type == target_type:
|
||||
return True
|
||||
|
||||
source_origin = get_origin(source_type)
|
||||
source_args = get_args(source_type)
|
||||
target_origin = get_origin(target_type)
|
||||
target_args = get_args(target_type)
|
||||
|
||||
# Case 2: target is Union/Optional - source is compatible if it matches any target member
|
||||
if target_origin is Union or target_origin is UnionType:
|
||||
# Special case: if source is also a Union, check that each source member
|
||||
# is compatible with at least one target member
|
||||
if source_origin is Union or source_origin is UnionType:
|
||||
return all(
|
||||
any(is_type_compatible(source_arg, target_arg) for target_arg in target_args)
|
||||
for source_arg in source_args
|
||||
)
|
||||
# If source is not a Union, check if it's compatible with any target member
|
||||
return any(is_type_compatible(source_type, arg) for arg in target_args)
|
||||
|
||||
# Case 3: source is Union (and target is not Union) - each source member must be compatible with target
|
||||
if source_origin is Union or source_origin is UnionType:
|
||||
return all(is_type_compatible(arg, target_type) for arg in source_args)
|
||||
|
||||
# Case 4: both are non-generic types
|
||||
if source_origin is None and target_origin is None:
|
||||
# Only call issubclass if both are actual types, not UnionType or Any
|
||||
if isinstance(source_type, type) and isinstance(target_type, type):
|
||||
try:
|
||||
return issubclass(source_type, target_type)
|
||||
except TypeError:
|
||||
# Handle cases where issubclass doesn't work (e.g., with special forms)
|
||||
return False
|
||||
return source_type == target_type
|
||||
|
||||
# Case 5: different container types are not compatible
|
||||
if source_origin != target_origin:
|
||||
return False
|
||||
|
||||
# Case 6: same container type - check generic arguments
|
||||
if source_origin in [list, set]:
|
||||
if not source_args and not target_args:
|
||||
return True # Both are untyped
|
||||
if not source_args or not target_args:
|
||||
return True # One is untyped - assume compatible
|
||||
# For collections, source element type must be compatible with target element type
|
||||
return is_type_compatible(source_args[0], target_args[0])
|
||||
|
||||
# Case 7: tuple compatibility
|
||||
if source_origin is tuple:
|
||||
if not source_args and not target_args:
|
||||
return True # Both are untyped tuples
|
||||
if not source_args or not target_args:
|
||||
return True # One is untyped - assume compatible
|
||||
|
||||
# Handle Tuple[T, ...] (variable length)
|
||||
if len(source_args) == 2 and source_args[1] is Ellipsis:
|
||||
if len(target_args) == 2 and target_args[1] is Ellipsis:
|
||||
return is_type_compatible(source_args[0], target_args[0])
|
||||
return False # Variable length can't be assigned to fixed length
|
||||
|
||||
if len(target_args) == 2 and target_args[1] is Ellipsis:
|
||||
# Fixed length can be assigned to variable length if element types are compatible
|
||||
return all(is_type_compatible(source_arg, target_args[0]) for source_arg in source_args)
|
||||
|
||||
# Fixed length tuples must have same length and compatible element types
|
||||
if len(source_args) != len(target_args):
|
||||
return False
|
||||
return all(is_type_compatible(s_arg, t_arg) for s_arg, t_arg in zip(source_args, target_args, strict=False))
|
||||
|
||||
# Case 8: dict compatibility
|
||||
if source_origin is dict:
|
||||
if not source_args and not target_args:
|
||||
return True # Both are untyped dicts
|
||||
if not source_args or not target_args:
|
||||
return True # One is untyped - assume compatible
|
||||
if len(source_args) != 2 or len(target_args) != 2:
|
||||
return False # Malformed dict types
|
||||
# Both key and value types must be compatible
|
||||
return is_type_compatible(source_args[0], target_args[0]) and is_type_compatible(source_args[1], target_args[1])
|
||||
|
||||
# Case 9: custom generic classes - check if origins are the same and args are compatible
|
||||
if source_origin and target_origin and source_origin == target_origin:
|
||||
if not source_args and not target_args:
|
||||
return True # Both are untyped generics
|
||||
if not source_args or not target_args:
|
||||
return True # One is untyped - assume compatible
|
||||
if len(source_args) != len(target_args):
|
||||
return False # Different number of type parameters
|
||||
return all(is_type_compatible(s_arg, t_arg) for s_arg, t_arg in zip(source_args, target_args, strict=False))
|
||||
|
||||
# Case 10: fallback - check if source is subclass of target (for non-generic types)
|
||||
if source_origin is None and target_origin is None:
|
||||
try:
|
||||
# Only call issubclass if both are actual types, not UnionType or Any
|
||||
if isinstance(source_type, type) and isinstance(target_type, type):
|
||||
return issubclass(source_type, target_type)
|
||||
return source_type == target_type
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from types import UnionType
|
||||
from typing import Any, Union, get_args, get_origin
|
||||
from typing import Any
|
||||
|
||||
from ._edge import Edge, EdgeGroup, FanInEdgeGroup
|
||||
from ._executor import Executor
|
||||
from ._request_info_executor import RequestInfoExecutor
|
||||
from ._typing_utils import is_type_compatible
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -308,11 +307,11 @@ class WorkflowGraphValidator:
|
||||
for target_type in target_input_types:
|
||||
if isinstance(edge_group, FanInEdgeGroup):
|
||||
# If the edge is part of an edge group, the target expects a list of data types
|
||||
if self._is_type_compatible(list[source_type], target_type): # type: ignore[valid-type]
|
||||
if is_type_compatible(list[source_type], target_type): # type: ignore[valid-type]
|
||||
compatible = True
|
||||
compatible_pairs.append((list[source_type], target_type)) # type: ignore[valid-type]
|
||||
else:
|
||||
if self._is_type_compatible(source_type, target_type):
|
||||
if is_type_compatible(source_type, target_type):
|
||||
compatible = True
|
||||
compatible_pairs.append((source_type, target_type))
|
||||
|
||||
@@ -527,53 +526,6 @@ class WorkflowGraphValidator:
|
||||
|
||||
# endregion
|
||||
|
||||
# region Type Compatibility Utilities
|
||||
@staticmethod
|
||||
def _is_type_compatible(source_type: type[Any], target_type: type[Any]) -> bool:
|
||||
"""Check if source_type is compatible with target_type."""
|
||||
# Handle Any type
|
||||
if source_type is Any or target_type is Any:
|
||||
return True
|
||||
|
||||
# Handle exact match
|
||||
if source_type == target_type:
|
||||
return True
|
||||
|
||||
# Handle inheritance
|
||||
try:
|
||||
if inspect.isclass(source_type) and inspect.isclass(target_type):
|
||||
return issubclass(source_type, target_type)
|
||||
except TypeError:
|
||||
# Handle generic types that can't be used with issubclass
|
||||
pass
|
||||
|
||||
# Handle Union types
|
||||
source_origin = get_origin(source_type)
|
||||
target_origin = get_origin(target_type)
|
||||
|
||||
if target_origin in (Union, UnionType):
|
||||
target_args = get_args(target_type)
|
||||
return any(WorkflowGraphValidator._is_type_compatible(source_type, arg) for arg in target_args)
|
||||
|
||||
if source_origin in (Union, UnionType):
|
||||
source_args = get_args(source_type)
|
||||
return all(WorkflowGraphValidator._is_type_compatible(arg, target_type) for arg in source_args)
|
||||
|
||||
# Handle generic types
|
||||
if source_origin is not None and target_origin is not None and source_origin == target_origin:
|
||||
source_args = get_args(source_type)
|
||||
target_args = get_args(target_type)
|
||||
if len(source_args) == len(target_args):
|
||||
return all(
|
||||
WorkflowGraphValidator._is_type_compatible(s_arg, t_arg)
|
||||
for s_arg, t_arg in zip(source_args, target_args, strict=True)
|
||||
)
|
||||
|
||||
# No other special compatibility cases
|
||||
return False
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from dataclasses import dataclass
|
||||
from typing import Any, Generic, TypeVar, Union
|
||||
|
||||
from agent_framework._workflows import RequestInfoMessage, RequestResponse
|
||||
from agent_framework._workflows._typing_utils import is_instance_of
|
||||
from agent_framework._workflows._typing_utils import is_instance_of, is_type_compatible
|
||||
|
||||
|
||||
def test_basic_types() -> None:
|
||||
@@ -133,3 +133,86 @@ def test_edge_cases() -> None:
|
||||
assert is_instance_of({}, dict[str, int]) # Empty dict should be valid
|
||||
assert is_instance_of(None, int | None) # Optional type with None
|
||||
assert not is_instance_of(5, str | None) # Optional type without matching type
|
||||
|
||||
|
||||
def test_type_compatibility_basic() -> None:
|
||||
"""Test basic type compatibility scenarios."""
|
||||
# Exact type match
|
||||
assert is_type_compatible(str, str)
|
||||
assert is_type_compatible(int, int)
|
||||
|
||||
# Any compatibility
|
||||
assert is_type_compatible(str, Any)
|
||||
assert is_type_compatible(list[int], Any)
|
||||
|
||||
# Subclass compatibility
|
||||
class Animal:
|
||||
pass
|
||||
|
||||
class Dog(Animal):
|
||||
pass
|
||||
|
||||
assert is_type_compatible(Dog, Animal)
|
||||
assert not is_type_compatible(Animal, Dog)
|
||||
|
||||
|
||||
def test_type_compatibility_unions() -> None:
|
||||
"""Test type compatibility with Union types."""
|
||||
# Source matches target union member
|
||||
assert is_type_compatible(str, Union[str, int])
|
||||
assert is_type_compatible(int, Union[str, int])
|
||||
assert not is_type_compatible(float, Union[str, int])
|
||||
|
||||
# Source union - all members must be compatible with target
|
||||
assert is_type_compatible(Union[str, int], Union[str, int, float])
|
||||
assert not is_type_compatible(Union[str, int, bytes], Union[str, int])
|
||||
|
||||
|
||||
def test_type_compatibility_collections() -> None:
|
||||
"""Test type compatibility with collection types."""
|
||||
|
||||
# List compatibility - key use case
|
||||
@dataclass
|
||||
class ChatMessage:
|
||||
text: str
|
||||
|
||||
assert is_type_compatible(list[ChatMessage], list[Union[str, ChatMessage]])
|
||||
assert is_type_compatible(list[str], list[Union[str, ChatMessage]])
|
||||
assert not is_type_compatible(list[Union[str, ChatMessage]], list[ChatMessage])
|
||||
|
||||
# Dict compatibility
|
||||
assert is_type_compatible(dict[str, int], dict[str, Union[int, float]])
|
||||
assert not is_type_compatible(dict[str, Union[int, float]], dict[str, int])
|
||||
|
||||
# Set compatibility
|
||||
assert is_type_compatible(set[str], set[Union[str, int]])
|
||||
assert not is_type_compatible(set[Union[str, int]], set[str])
|
||||
|
||||
|
||||
def test_type_compatibility_tuples() -> None:
|
||||
"""Test type compatibility with tuple types."""
|
||||
# Fixed length tuples
|
||||
assert is_type_compatible(tuple[str, int], tuple[Union[str, bytes], Union[int, float]])
|
||||
assert not is_type_compatible(tuple[str, int], tuple[str, int, bool]) # Different lengths
|
||||
|
||||
# Variable length tuples
|
||||
assert is_type_compatible(tuple[str, ...], tuple[Union[str, bytes], ...])
|
||||
assert is_type_compatible(tuple[str, int, bool], tuple[Union[str, int, bool], ...])
|
||||
assert not is_type_compatible(tuple[str, ...], tuple[str, int]) # Variable to fixed
|
||||
|
||||
|
||||
def test_type_compatibility_complex() -> None:
|
||||
"""Test complex nested type compatibility."""
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
content: str
|
||||
|
||||
# Complex nested structure
|
||||
source = list[dict[str, Message]]
|
||||
target = list[dict[Union[str, bytes], Union[str, Message]]]
|
||||
assert is_type_compatible(source, target)
|
||||
|
||||
# Incompatible nested structure
|
||||
incompatible_target = list[dict[Union[str, bytes], int]]
|
||||
assert not is_type_compatible(source, incompatible_target)
|
||||
|
||||
@@ -519,7 +519,9 @@ class StateTrackingExecutor(Executor):
|
||||
"""An executor that tracks state in shared state to test context reset behavior."""
|
||||
|
||||
@handler
|
||||
async def handle_message(self, message: StateTrackingMessage, ctx: WorkflowContext[Any, list[Any]]) -> None:
|
||||
async def handle_message(
|
||||
self, message: StateTrackingMessage, ctx: WorkflowContext[StateTrackingMessage, list[str]]
|
||||
) -> None:
|
||||
"""Handle the message and track it in shared state."""
|
||||
# Get existing messages from shared state
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user