Merge branch 'main' into feature/python-foundry-hosted-agent-vnext

This commit is contained in:
Tao Chen
2026-04-15 20:59:51 -07:00
Unverified
23 changed files with 1502 additions and 55 deletions
+2
View File
@@ -63,6 +63,8 @@ agent_framework/
- **`SessionContext`** - Context object for session-scoped data during agent runs
- **`ContextProvider`** - Base class for context providers (RAG, memory systems)
- **`HistoryProvider`** - Base class for conversation history storage
- **`InMemoryHistoryProvider`** - Built-in session-state history provider for local runs
- **`FileHistoryProvider`** - JSON Lines file-backed history provider storing one file per session with one message record per line
### Skills (`_skills.py`)
@@ -103,6 +103,7 @@ from ._middleware import (
from ._sessions import (
AgentSession,
ContextProvider,
FileHistoryProvider,
HistoryProvider,
InMemoryHistoryProvider,
SessionContext,
@@ -318,6 +319,7 @@ __all__ = [
"FanInEdgeGroup",
"FanOutEdgeGroup",
"FileCheckpointStorage",
"FileHistoryProvider",
"FinalT",
"FinishReason",
"FinishReasonLiteral",
@@ -47,6 +47,7 @@ class ExperimentalFeature(str, Enum):
"""
EVALS = "EVALS"
FILE_HISTORY = "FILE_HISTORY"
SKILLS = "SKILLS"
@@ -8,16 +8,24 @@ This module provides the core types for the context provider pipeline:
- HistoryProvider: Base class for history storage providers
- AgentSession: Lightweight session state container
- InMemoryHistoryProvider: Built-in in-memory history provider
- FileHistoryProvider: Built-in JSON Lines file history provider
"""
from __future__ import annotations
import asyncio
import copy
import json
import threading
import uuid
import weakref
from abc import abstractmethod
from base64 import urlsafe_b64encode
from collections.abc import Awaitable, Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, ClassVar, TypeGuard, cast
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias, TypeGuard, cast
from ._feature_stage import ExperimentalFeature, experimental
from ._middleware import ChatContext, ChatMiddleware
from ._types import AgentResponse, ChatResponse, Message, ResponseStream
from .exceptions import ChatClientInvalidResponseException
@@ -30,6 +38,17 @@ if TYPE_CHECKING:
# Registry of known types for state deserialization
_STATE_TYPE_REGISTRY: dict[str, type] = {}
JsonDumps: TypeAlias = Callable[[Any], str | bytes]
JsonLoads: TypeAlias = Callable[[str | bytes], Any]
def _default_json_dumps(value: Any) -> str:
return json.dumps(value, ensure_ascii=False)
def _default_json_loads(value: str | bytes) -> Any:
return json.loads(value)
def _is_middleware_sequence(
middleware: MiddlewareTypes | Sequence[MiddlewareTypes],
@@ -837,3 +856,247 @@ class InMemoryHistoryProvider(HistoryProvider):
return
existing = state.get("messages", [])
state["messages"] = [*existing, *messages]
@experimental(feature_id=ExperimentalFeature.FILE_HISTORY)
class FileHistoryProvider(HistoryProvider):
"""File-backed history provider that stores one JSON Lines file per session.
Each persisted message is written as a single JSON object per line. The
provider does not serialize full session snapshots into the file. By default
it uses the standard library ``json`` module, but callers can inject
alternative ``dumps`` and ``loads`` callables compatible with the JSON
Lines format.
Security posture:
Persisted history is stored as plaintext JSONL on the local filesystem.
Treat ``storage_path`` as trusted application storage, not as a secret
store. Encoded fallback filenames and resolved-path validation help
prevent path traversal via ``session_id``, but they do not encrypt file
contents or provide cross-process / cross-host locking. Use OS-level
file permissions, trusted directories, and carefully review what agent
or tool output is allowed to be persisted.
"""
DEFAULT_SOURCE_ID: ClassVar[str] = "file_history"
DEFAULT_SESSION_FILE_STEM: ClassVar[str] = "default"
FILE_EXTENSION: ClassVar[str] = ".jsonl"
_FILE_LOCK_STRIPE_COUNT: ClassVar[int] = 64
_ENCODED_SESSION_PREFIX: ClassVar[str] = "~session-"
_FILE_WRITE_LOCKS: ClassVar[tuple[threading.Lock, ...]] = tuple(
threading.Lock() for _ in range(_FILE_LOCK_STRIPE_COUNT)
)
_WINDOWS_RESERVED_FILE_STEMS: ClassVar[frozenset[str]] = frozenset({
"CON",
"PRN",
"AUX",
"NUL",
"COM1",
"COM2",
"COM3",
"COM4",
"COM5",
"COM6",
"COM7",
"COM8",
"COM9",
"LPT1",
"LPT2",
"LPT3",
"LPT4",
"LPT5",
"LPT6",
"LPT7",
"LPT8",
"LPT9",
})
def __init__(
self,
storage_path: str | Path,
*,
source_id: str = DEFAULT_SOURCE_ID,
load_messages: bool = True,
store_inputs: bool = True,
store_context_messages: bool = False,
store_context_from: set[str] | None = None,
store_outputs: bool = True,
skip_excluded: bool = False,
dumps: JsonDumps | None = None,
loads: JsonLoads | None = None,
) -> None:
"""Initialize the file history provider.
Args:
storage_path: Directory path where session history files will be stored.
Keyword Args:
source_id: Unique identifier for this provider instance.
load_messages: Whether to load messages before invocation.
store_inputs: Whether to store input messages.
store_context_messages: Whether to store context from other providers.
store_context_from: If set, only store context from these source_ids.
store_outputs: Whether to store response messages.
skip_excluded: When True, ``get_messages`` omits messages whose
``additional_properties["_excluded"]`` is truthy.
dumps: Callable that serializes a message payload dict to JSON text
or UTF-8 bytes. The returned JSON must fit on a single line.
loads: Callable that deserializes JSON text or bytes back to a
message payload dict.
"""
super().__init__(
source_id=source_id,
load_messages=load_messages,
store_inputs=store_inputs,
store_context_messages=store_context_messages,
store_context_from=store_context_from,
store_outputs=store_outputs,
)
self.storage_path = Path(storage_path)
self.storage_path.mkdir(parents=True, exist_ok=True)
self._storage_root = self.storage_path.resolve()
self.skip_excluded = skip_excluded
self.dumps = dumps or _default_json_dumps
self.loads = loads or _default_json_loads
self._async_write_locks_by_loop: weakref.WeakKeyDictionary[
asyncio.AbstractEventLoop,
tuple[asyncio.Lock, ...],
] = weakref.WeakKeyDictionary()
async def get_messages(
self,
session_id: str | None,
*,
state: dict[str, Any] | None = None,
**kwargs: Any,
) -> list[Message]:
"""Retrieve messages from the session's JSON Lines file."""
del state, kwargs
file_path = self._session_file_path(session_id)
async_lock = self._session_async_write_lock(file_path)
thread_lock = self._session_write_lock(file_path)
def _read_messages() -> list[Message]:
with thread_lock:
if not file_path.exists():
return []
messages: list[Message] = []
with file_path.open(encoding="utf-8") as file_handle:
for line_number, line in enumerate(file_handle, start=1):
serialized = line.strip()
if not serialized:
continue
try:
payload = self.loads(serialized)
except (TypeError, ValueError) as exc:
raise ValueError(
f"Failed to deserialize history line {line_number} from '{file_path}'."
) from exc
if not isinstance(payload, Mapping):
raise ValueError(
f"History line {line_number} in '{file_path}' did not deserialize to a mapping."
)
try:
message = Message.from_dict(dict(cast(Mapping[str, Any], payload)))
except ValueError as exc:
raise ValueError(
f"History line {line_number} in '{file_path}' is not a valid Message payload."
) from exc
messages.append(message)
return messages
async with async_lock:
messages = await asyncio.to_thread(_read_messages)
if self.skip_excluded:
messages = [m for m in messages if not m.additional_properties.get("_excluded", False)]
return messages
async def save_messages(
self,
session_id: str | None,
messages: Sequence[Message],
*,
state: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
"""Append messages to the session's JSON Lines file."""
del state, kwargs
if not messages:
return
file_path = self._session_file_path(session_id)
async_lock = self._session_async_write_lock(file_path)
file_lock = self._session_write_lock(file_path)
def _append_messages() -> None:
with file_lock, file_path.open("a", encoding="utf-8") as file_handle:
for message in messages:
file_handle.write(f"{self._serialize_message(message)}\n")
async with async_lock:
await asyncio.to_thread(_append_messages)
def _serialize_message(self, message: Message) -> str:
"""Serialize a message payload to a single JSON Lines record."""
serialized = self.dumps(message.to_dict())
if isinstance(serialized, bytes):
serialized_text = serialized.decode("utf-8")
elif isinstance(serialized, str):
serialized_text = serialized
else:
raise TypeError("FileHistoryProvider.dumps must return str or bytes.")
if "\n" in serialized_text or "\r" in serialized_text:
raise ValueError("FileHistoryProvider.dumps must return single-line JSON for JSON Lines storage.")
return serialized_text
def _session_file_path(self, session_id: str | None) -> Path:
"""Resolve the on-disk history file path for a session."""
file_path = (self._storage_root / f"{self._session_file_stem(session_id)}{self.FILE_EXTENSION}").resolve()
if not file_path.is_relative_to(self._storage_root):
raise ValueError(f"Session history path escaped storage directory: {session_id!r}")
return file_path
def _session_file_stem(self, session_id: str | None) -> str:
"""Return the filename stem for a session."""
raw_session_id = session_id or self.DEFAULT_SESSION_FILE_STEM
if self._is_literal_session_file_stem_safe(raw_session_id):
return raw_session_id
encoded_session_id = urlsafe_b64encode(raw_session_id.encode("utf-8")).decode("ascii").rstrip("=")
return f"{self._ENCODED_SESSION_PREFIX}{encoded_session_id or self.DEFAULT_SESSION_FILE_STEM}"
def _session_async_write_lock(self, file_path: Path) -> asyncio.Lock:
"""Return the event-loop-local async lock for a session history file."""
loop = asyncio.get_running_loop()
locks = self._async_write_locks_by_loop.get(loop)
if locks is None:
locks = tuple(asyncio.Lock() for _ in range(self._FILE_LOCK_STRIPE_COUNT))
self._async_write_locks_by_loop[loop] = locks
return locks[self._lock_index(file_path)]
@classmethod
def _session_write_lock(cls, file_path: Path) -> threading.Lock:
"""Return the process-local thread lock for a session history file."""
return cls._FILE_WRITE_LOCKS[cls._lock_index(file_path)]
@classmethod
def _lock_index(cls, file_path: Path) -> int:
"""Map a session history file to a bounded lock stripe."""
return hash(file_path) % cls._FILE_LOCK_STRIPE_COUNT
@classmethod
def _is_literal_session_file_stem_safe(cls, session_id: str) -> bool:
"""Return whether the session ID can be used directly as a filename stem."""
if (
not session_id
or session_id.startswith(".")
or session_id.endswith((" ", "."))
or session_id.upper() in cls._WINDOWS_RESERVED_FILE_STEMS
):
return False
if any(ord(character) < 32 for character in session_id):
return False
return all(character.isalnum() or character in "._-" for character in session_id)
@@ -244,10 +244,10 @@ class FileCheckpointStorage:
is serialized using pickle and embedded as base64-encoded strings within the JSON. This allows
for human-readable checkpoint files while preserving the ability to store complex Python objects.
By default, checkpoint deserialization is restricted to a built-in set of safe
Python types (primitives, datetime, uuid, ...) and all ``agent_framework``
internal types. To allow additional application-specific types, pass them via
the ``allowed_checkpoint_types`` parameter using ``"module:qualname"`` format.
By default, checkpoint deserialization is restricted to a built-in set of safe Python types
(primitives, datetime, uuid, ...), all ``agent_framework`` internal types, and OpenAI SDK types
(``openai.types``). To allow additional application-specific types, pass them via the
``allowed_checkpoint_types`` parameter using ``"module:qualname"`` format.
Example::
@@ -10,9 +10,9 @@ This hybrid approach provides:
When ``allowed_types`` is supplied to :func:`decode_checkpoint_value`, a
``RestrictedUnpickler`` is used that limits which classes may be instantiated
during deserialization. The default built-in safe set covers common Python
value types (primitives, datetime, uuid, ...) and all ``agent_framework``
internal types. Callers can extend the set by passing additional
``"module:qualname"`` strings.
value types (primitives, datetime, uuid, ...), all ``agent_framework`` internal
types, and all ``openai.types`` types. Callers can extend the set by passing
additional ``"module:qualname"`` strings.
"""
from __future__ import annotations
@@ -37,6 +37,9 @@ _JSON_NATIVE_TYPES = (str, int, float, bool, type(None))
# Module prefix for framework-internal types that are always allowed
_FRAMEWORK_MODULE_PREFIX = "agent_framework."
# Module prefix for OpenAI SDK types that are always allowed
_OPENAI_MODULE_PREFIX = "openai.types."
# Built-in types considered safe for checkpoint deserialization.
# Each entry is a ``module:qualname`` string matching the format produced by
# :func:`_type_to_key`. These are the classes for which pickle's
@@ -84,8 +87,9 @@ class _RestrictedUnpickler(pickle.Unpickler): # noqa: S301
"""Unpickler that restricts which classes may be instantiated.
Only classes whose ``module:qualname`` key appears in the combined allow
set (built-in safe types + framework types + caller-specified extras) are
permitted. All other classes raise :class:`pickle.UnpicklingError`.
set (built-in safe types + framework types + OpenAI SDK types +
caller-specified extras) are permitted. All other classes raise
:class:`pickle.UnpicklingError`.
"""
def __init__(self, data: bytes, allowed_types: frozenset[str]) -> None:
@@ -99,6 +103,7 @@ class _RestrictedUnpickler(pickle.Unpickler): # noqa: S301
type_key in _BUILTIN_ALLOWED_TYPE_KEYS
or type_key in self._allowed_types
or module.startswith(_FRAMEWORK_MODULE_PREFIX)
or module.startswith(_OPENAI_MODULE_PREFIX)
):
return super().find_class(module, name) # type: ignore[no-any-return] # nosec
@@ -11,11 +11,11 @@ import logging
import types
import uuid
from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, Sequence
from typing import Any, Literal, overload
from typing import TYPE_CHECKING, Any, Literal, overload
from .._sessions import ContextProvider
from .._types import ResponseStream
from ..observability import OtelAttr, capture_exception, create_workflow_span
from ._agent import WorkflowAgent
from ._checkpoint import CheckpointStorage
from ._const import DEFAULT_MAX_ITERATIONS, GLOBAL_KWARGS_KEY, WORKFLOW_RUN_KWARGS_KEY
from ._edge import (
@@ -35,6 +35,9 @@ from ._runner_context import RunnerContext
from ._state import State
from ._typing_utils import is_instance_of, try_coerce_to_type
if TYPE_CHECKING:
from ._agent import WorkflowAgent
logger = logging.getLogger(__name__)
@@ -910,7 +913,14 @@ class Workflow(DictConvertible):
return list(output_types)
def as_agent(self, name: str | None = None) -> WorkflowAgent:
def as_agent(
self,
name: str | None = None,
*,
description: str | None = None,
context_providers: Sequence[ContextProvider] | None = None,
**kwargs: Any,
) -> WorkflowAgent:
"""Create a WorkflowAgent that wraps this workflow.
The returned agent converts standard agent inputs (strings, Message, or lists of these)
@@ -924,7 +934,10 @@ class Workflow(DictConvertible):
initialization will fail with a ValueError.
Args:
name: Optional name for the agent. If None, a default name will be generated.
name: Optional name for the agent. Defaults to workflow name.
description: Optional description of the agent. Defaults to workflow description.
context_providers: Optional sequence of context providers for the agent.
**kwargs: Additional keyword arguments passed to BaseAgent.
Returns:
A WorkflowAgent instance that wraps this workflow.
@@ -935,4 +948,10 @@ class Workflow(DictConvertible):
# Import here to avoid circular imports
from ._agent import WorkflowAgent
return WorkflowAgent(workflow=self, name=name)
return WorkflowAgent(
workflow=self,
name=name if name is not None else self.name,
description=description if description is not None else self.description,
context_providers=context_providers,
**kwargs,
)
@@ -1,7 +1,12 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import json
import threading
import time
from collections.abc import Awaitable, Callable, Sequence
from pathlib import Path
from typing import Any
import pytest
@@ -10,6 +15,8 @@ from agent_framework import (
AgentSession,
ChatContext,
ContextProvider,
ExperimentalFeature,
FileHistoryProvider,
HistoryProvider,
InMemoryHistoryProvider,
Message,
@@ -505,3 +512,217 @@ class TestInMemoryHistoryProvider:
ctx = SessionContext(session_id="s1", input_messages=[])
ctx.extend_messages("custom-source", [Message(role="user", contents=["test"])])
assert "custom-source" in ctx.context_messages
class TestFileHistoryProvider:
def test_is_marked_experimental(self) -> None:
assert FileHistoryProvider.__feature_stage__ == "experimental"
assert FileHistoryProvider.__feature_id__ == ExperimentalFeature.FILE_HISTORY.value
assert FileHistoryProvider.__doc__ is not None
assert ".. warning:: Experimental" in FileHistoryProvider.__doc__
async def test_stores_and_loads_messages(self, tmp_path: Path) -> None:
from agent_framework import AgentResponse
provider = FileHistoryProvider(tmp_path)
session = AgentSession(session_id="s1")
input_message = Message(role="user", contents=["hello"])
response_message = Message(role="assistant", contents=["hi there"])
first_context = SessionContext(session_id=session.session_id, input_messages=[input_message])
await provider.before_run( # type: ignore[arg-type]
agent=None,
session=session,
context=first_context,
state={},
)
first_context._response = AgentResponse(messages=[response_message])
await provider.after_run( # type: ignore[arg-type]
agent=None,
session=session,
context=first_context,
state={},
)
session_file = provider._session_file_path(session.session_id)
assert session_file.name == "s1.jsonl"
assert session_file.exists()
raw_lines = (await asyncio.to_thread(session_file.read_text, encoding="utf-8")).splitlines()
assert len(raw_lines) == 2
payloads = [json.loads(line) for line in raw_lines]
assert all(payload["type"] == "message" for payload in payloads)
assert all("session_id" not in payload for payload in payloads)
second_context = SessionContext(
session_id=session.session_id, input_messages=[Message(role="user", contents=["again"])]
)
await provider.before_run( # type: ignore[arg-type]
agent=None,
session=session,
context=second_context,
state={},
)
loaded = second_context.context_messages.get(provider.source_id, [])
assert len(loaded) == 2
assert loaded[0].text == "hello"
assert loaded[1].text == "hi there"
def test_creates_storage_directory(self, tmp_path: Path) -> None:
nested_path = tmp_path / "nested" / "history"
provider = FileHistoryProvider(nested_path)
assert provider.storage_path == nested_path
assert nested_path.exists()
assert nested_path.is_dir()
async def test_uses_encoded_filename_for_unsafe_session_id(self, tmp_path: Path) -> None:
provider = FileHistoryProvider(tmp_path)
unsafe_session_id = "../unsafe/session"
await provider.save_messages(unsafe_session_id, [Message(role="user", contents=["hello"])])
session_file = provider._session_file_path(unsafe_session_id)
assert session_file.parent == provider.storage_path
assert session_file.name.startswith("~session-")
assert session_file.suffix == ".jsonl"
assert session_file.exists()
jsonl_files = await asyncio.to_thread(
lambda: sorted(path.name for path in provider.storage_path.glob("*.jsonl"))
)
assert jsonl_files == [session_file.name]
async def test_allows_custom_serializers_returning_bytes(self, tmp_path: Path) -> None:
calls: list[str] = []
def dumps(payload: object) -> bytes:
calls.append("dumps")
return json.dumps(payload).encode("utf-8")
def loads(payload: str | bytes) -> object:
calls.append("loads")
if isinstance(payload, bytes):
payload = payload.decode("utf-8")
return json.loads(payload)
provider = FileHistoryProvider(tmp_path, dumps=dumps, loads=loads)
await provider.save_messages("custom-serializer", [Message(role="user", contents=["hello"])])
loaded = await provider.get_messages("custom-serializer")
assert calls == ["dumps", "loads"]
assert len(loaded) == 1
assert loaded[0].text == "hello"
async def test_invalid_jsonl_line_raises(self, tmp_path: Path) -> None:
provider = FileHistoryProvider(tmp_path)
await asyncio.to_thread(provider._session_file_path("broken").write_text, "{not-json}\n", encoding="utf-8")
with pytest.raises(ValueError, match="Failed to deserialize history line 1"):
await provider.get_messages("broken")
async def test_missing_session_file_returns_empty_messages(self, tmp_path: Path) -> None:
provider = FileHistoryProvider(tmp_path)
loaded = await provider.get_messages("missing")
assert loaded == []
async def test_none_session_id_uses_default_jsonl_file(self, tmp_path: Path) -> None:
provider = FileHistoryProvider(tmp_path)
await provider.save_messages(None, [Message(role="user", contents=["hello"])])
session_file = provider._session_file_path(None)
assert session_file.name == "default.jsonl"
loaded = await provider.get_messages(None)
assert [message.text for message in loaded] == ["hello"]
async def test_non_mapping_jsonl_line_raises(self, tmp_path: Path) -> None:
provider = FileHistoryProvider(tmp_path)
await asyncio.to_thread(provider._session_file_path("non-mapping").write_text, "[1, 2, 3]\n", encoding="utf-8")
with pytest.raises(ValueError, match="did not deserialize to a mapping"):
await provider.get_messages("non-mapping")
async def test_skip_excluded_omits_excluded_messages(self, tmp_path: Path) -> None:
provider = FileHistoryProvider(tmp_path, skip_excluded=True)
await provider.save_messages(
"skip-excluded",
[
Message(role="user", contents=["keep"]),
Message(role="assistant", contents=["skip"], additional_properties={"_excluded": True}),
],
)
loaded = await provider.get_messages("skip-excluded")
assert [message.text for message in loaded] == ["keep"]
async def test_serializer_must_return_single_line_json(self, tmp_path: Path) -> None:
def dumps(payload: object) -> str:
return json.dumps(payload, indent=2)
provider = FileHistoryProvider(tmp_path, dumps=dumps)
with pytest.raises(ValueError, match="single-line JSON"):
await provider.save_messages("pretty-json", [Message(role="user", contents=["hello"])])
async def test_concurrent_writes_for_same_session_are_locked(
self,
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
provider = FileHistoryProvider(tmp_path)
session_id = "shared-session"
file_path = provider._session_file_path(session_id)
real_open = Path.open
write_started = threading.Event()
active_writes = 0
overlap_detected = False
class _TrackingFile:
def __init__(self, wrapped: Any) -> None:
self._wrapped = wrapped
def __enter__(self) -> "_TrackingFile":
self._wrapped.__enter__()
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self._wrapped.__exit__(exc_type, exc_val, exc_tb)
def write(self, data: str) -> int:
nonlocal active_writes, overlap_detected
write_started.set()
active_writes += 1
overlap_detected = overlap_detected or active_writes > 1
try:
time.sleep(0.05)
return int(self._wrapped.write(data))
finally:
active_writes -= 1
def __getattr__(self, name: str) -> Any:
return getattr(self._wrapped, name)
def tracked_open(path: Path, *args: Any, **kwargs: Any) -> Any:
handle = real_open(path, *args, **kwargs)
if path == file_path and args and args[0] == "a":
return _TrackingFile(handle)
return handle
monkeypatch.setattr(Path, "open", tracked_open)
first_save = asyncio.create_task(provider.save_messages(session_id, [Message(role="user", contents=["first"])]))
started = await asyncio.to_thread(write_started.wait, 1.0)
assert started
second_save = asyncio.create_task(
provider.save_messages(session_id, [Message(role="assistant", contents=["second"])])
)
await asyncio.gather(first_save, second_save)
assert not overlap_detected
loaded = await provider.get_messages(session_id)
assert [message.text for message in loaded] == ["first", "second"]
@@ -216,3 +216,50 @@ def test_restricted_unpickler_raises_pickle_error():
unpickler = _RestrictedUnpickler(pickled, frozenset())
with pytest.raises(pickle.UnpicklingError, match="deserialization blocked"):
unpickler.load()
def test_restricted_decode_allows_openai_types():
"""OpenAI SDK types are always allowed during restricted deserialization."""
from openai.types.chat.chat_completion import ChatCompletion, Choice
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from openai.types.completion_usage import CompletionUsage
completion = ChatCompletion(
id="chatcmpl-test",
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(role="assistant", content="hello"),
)
],
created=1700000000,
model="gpt-4",
object="chat.completion",
usage=CompletionUsage(completion_tokens=1, prompt_tokens=1, total_tokens=2),
)
encoded = encode_checkpoint_value(completion)
decoded = decode_checkpoint_value(encoded, allowed_types=frozenset())
assert isinstance(decoded, ChatCompletion)
assert decoded.id == "chatcmpl-test"
assert decoded.choices[0].message.content == "hello"
def test_restricted_decode_allows_openai_response_types():
"""OpenAI Responses API types are always allowed during restricted deserialization."""
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails, ResponseUsage
usage = ResponseUsage(
input_tokens=10,
output_tokens=20,
total_tokens=30,
input_tokens_details=InputTokensDetails(cached_tokens=0),
output_tokens_details=OutputTokensDetails(reasoning_tokens=0),
)
encoded = encode_checkpoint_value(usage)
decoded = decode_checkpoint_value(encoded, allowed_types=frozenset())
assert isinstance(decoded, ResponseUsage)
assert decoded.input_tokens == 10
assert decoded.output_tokens == 20
@@ -313,6 +313,37 @@ class TestWorkflowAgent:
assert isinstance(agent_no_name, WorkflowAgent)
assert agent_no_name.workflow is workflow
def test_workflow_as_agent_with_description_and_context_providers(self) -> None:
"""Test that Workflow.as_agent() forwards description and context_providers."""
executor = SimpleExecutor(id="executor1", response_text="Response")
workflow = WorkflowBuilder(start_executor=executor).build()
history_provider = InMemoryHistoryProvider()
agent = workflow.as_agent(
name="MyAgent",
description="A test agent",
context_providers=[history_provider],
)
assert isinstance(agent, WorkflowAgent)
assert agent.name == "MyAgent"
assert agent.description == "A test agent"
assert history_provider in agent.context_providers
def test_workflow_as_agent_defaults_name_and_description_from_workflow(self) -> None:
"""Test that as_agent() defaults name and description to the workflow's own values."""
executor = SimpleExecutor(id="executor1", response_text="Response")
workflow = WorkflowBuilder(
start_executor=executor,
name="my-workflow",
description="Workflow description",
).build()
agent = workflow.as_agent()
assert agent.name == "my-workflow"
assert agent.description == "Workflow description"
def test_workflow_as_agent_cannot_handle_agent_inputs(self) -> None:
"""Test that Workflow.as_agent() raises an error if the start executor cannot handle agent inputs."""