mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Merge branch 'main' into feature/python-foundry-hosted-agent-vnext
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
@@ -1503,6 +1504,8 @@ async def test_anthropic_client_integration_function_calling() -> None:
|
||||
@skip_if_anthropic_integration_tests_disabled
|
||||
async def test_anthropic_client_integration_hosted_tools() -> None:
|
||||
"""Integration test for hosted tools."""
|
||||
import anthropic
|
||||
|
||||
client = AnthropicClient()
|
||||
|
||||
messages = [Message(role="user", contents=["What tools do you have available?"])]
|
||||
@@ -1515,10 +1518,18 @@ async def test_anthropic_client_integration_hosted_tools() -> None:
|
||||
),
|
||||
]
|
||||
|
||||
response = await client.get_response(
|
||||
messages=messages,
|
||||
options={"tools": tools, "max_tokens": 100},
|
||||
)
|
||||
try:
|
||||
response = await client.get_response(
|
||||
messages=messages,
|
||||
options={"tools": tools, "max_tokens": 100},
|
||||
)
|
||||
except (
|
||||
anthropic.BadRequestError,
|
||||
anthropic.InternalServerError,
|
||||
anthropic.APIConnectionError,
|
||||
anthropic.APITimeoutError,
|
||||
) as e:
|
||||
pytest.skip(f"Upstream MCP server unavailable: {e}")
|
||||
|
||||
assert response is not None
|
||||
assert response.text is not None
|
||||
@@ -1607,7 +1618,8 @@ async def test_anthropic_client_integration_images() -> None:
|
||||
|
||||
assert response is not None
|
||||
assert response.messages[0].text is not None
|
||||
assert "house" in response.messages[0].text.lower()
|
||||
text = response.messages[0].text.lower()
|
||||
assert re.search(r"\b(house|home|building|cottage|mansion|villa)\b", text)
|
||||
|
||||
|
||||
# Response Format Tests
|
||||
|
||||
@@ -63,6 +63,8 @@ agent_framework/
|
||||
- **`SessionContext`** - Context object for session-scoped data during agent runs
|
||||
- **`ContextProvider`** - Base class for context providers (RAG, memory systems)
|
||||
- **`HistoryProvider`** - Base class for conversation history storage
|
||||
- **`InMemoryHistoryProvider`** - Built-in session-state history provider for local runs
|
||||
- **`FileHistoryProvider`** - JSON Lines file-backed history provider storing one file per session with one message record per line
|
||||
|
||||
### Skills (`_skills.py`)
|
||||
|
||||
|
||||
@@ -103,6 +103,7 @@ from ._middleware import (
|
||||
from ._sessions import (
|
||||
AgentSession,
|
||||
ContextProvider,
|
||||
FileHistoryProvider,
|
||||
HistoryProvider,
|
||||
InMemoryHistoryProvider,
|
||||
SessionContext,
|
||||
@@ -318,6 +319,7 @@ __all__ = [
|
||||
"FanInEdgeGroup",
|
||||
"FanOutEdgeGroup",
|
||||
"FileCheckpointStorage",
|
||||
"FileHistoryProvider",
|
||||
"FinalT",
|
||||
"FinishReason",
|
||||
"FinishReasonLiteral",
|
||||
|
||||
@@ -47,6 +47,7 @@ class ExperimentalFeature(str, Enum):
|
||||
"""
|
||||
|
||||
EVALS = "EVALS"
|
||||
FILE_HISTORY = "FILE_HISTORY"
|
||||
SKILLS = "SKILLS"
|
||||
|
||||
|
||||
|
||||
@@ -8,16 +8,24 @@ This module provides the core types for the context provider pipeline:
|
||||
- HistoryProvider: Base class for history storage providers
|
||||
- AgentSession: Lightweight session state container
|
||||
- InMemoryHistoryProvider: Built-in in-memory history provider
|
||||
- FileHistoryProvider: Built-in JSON Lines file history provider
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
import threading
|
||||
import uuid
|
||||
import weakref
|
||||
from abc import abstractmethod
|
||||
from base64 import urlsafe_b64encode
|
||||
from collections.abc import Awaitable, Callable, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, TypeGuard, cast
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias, TypeGuard, cast
|
||||
|
||||
from ._feature_stage import ExperimentalFeature, experimental
|
||||
from ._middleware import ChatContext, ChatMiddleware
|
||||
from ._types import AgentResponse, ChatResponse, Message, ResponseStream
|
||||
from .exceptions import ChatClientInvalidResponseException
|
||||
@@ -30,6 +38,17 @@ if TYPE_CHECKING:
|
||||
# Registry of known types for state deserialization
|
||||
_STATE_TYPE_REGISTRY: dict[str, type] = {}
|
||||
|
||||
JsonDumps: TypeAlias = Callable[[Any], str | bytes]
|
||||
JsonLoads: TypeAlias = Callable[[str | bytes], Any]
|
||||
|
||||
|
||||
def _default_json_dumps(value: Any) -> str:
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
|
||||
|
||||
def _default_json_loads(value: str | bytes) -> Any:
|
||||
return json.loads(value)
|
||||
|
||||
|
||||
def _is_middleware_sequence(
|
||||
middleware: MiddlewareTypes | Sequence[MiddlewareTypes],
|
||||
@@ -837,3 +856,247 @@ class InMemoryHistoryProvider(HistoryProvider):
|
||||
return
|
||||
existing = state.get("messages", [])
|
||||
state["messages"] = [*existing, *messages]
|
||||
|
||||
|
||||
@experimental(feature_id=ExperimentalFeature.FILE_HISTORY)
|
||||
class FileHistoryProvider(HistoryProvider):
|
||||
"""File-backed history provider that stores one JSON Lines file per session.
|
||||
|
||||
Each persisted message is written as a single JSON object per line. The
|
||||
provider does not serialize full session snapshots into the file. By default
|
||||
it uses the standard library ``json`` module, but callers can inject
|
||||
alternative ``dumps`` and ``loads`` callables compatible with the JSON
|
||||
Lines format.
|
||||
|
||||
Security posture:
|
||||
Persisted history is stored as plaintext JSONL on the local filesystem.
|
||||
Treat ``storage_path`` as trusted application storage, not as a secret
|
||||
store. Encoded fallback filenames and resolved-path validation help
|
||||
prevent path traversal via ``session_id``, but they do not encrypt file
|
||||
contents or provide cross-process / cross-host locking. Use OS-level
|
||||
file permissions, trusted directories, and carefully review what agent
|
||||
or tool output is allowed to be persisted.
|
||||
"""
|
||||
|
||||
DEFAULT_SOURCE_ID: ClassVar[str] = "file_history"
|
||||
DEFAULT_SESSION_FILE_STEM: ClassVar[str] = "default"
|
||||
FILE_EXTENSION: ClassVar[str] = ".jsonl"
|
||||
_FILE_LOCK_STRIPE_COUNT: ClassVar[int] = 64
|
||||
_ENCODED_SESSION_PREFIX: ClassVar[str] = "~session-"
|
||||
_FILE_WRITE_LOCKS: ClassVar[tuple[threading.Lock, ...]] = tuple(
|
||||
threading.Lock() for _ in range(_FILE_LOCK_STRIPE_COUNT)
|
||||
)
|
||||
_WINDOWS_RESERVED_FILE_STEMS: ClassVar[frozenset[str]] = frozenset({
|
||||
"CON",
|
||||
"PRN",
|
||||
"AUX",
|
||||
"NUL",
|
||||
"COM1",
|
||||
"COM2",
|
||||
"COM3",
|
||||
"COM4",
|
||||
"COM5",
|
||||
"COM6",
|
||||
"COM7",
|
||||
"COM8",
|
||||
"COM9",
|
||||
"LPT1",
|
||||
"LPT2",
|
||||
"LPT3",
|
||||
"LPT4",
|
||||
"LPT5",
|
||||
"LPT6",
|
||||
"LPT7",
|
||||
"LPT8",
|
||||
"LPT9",
|
||||
})
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage_path: str | Path,
|
||||
*,
|
||||
source_id: str = DEFAULT_SOURCE_ID,
|
||||
load_messages: bool = True,
|
||||
store_inputs: bool = True,
|
||||
store_context_messages: bool = False,
|
||||
store_context_from: set[str] | None = None,
|
||||
store_outputs: bool = True,
|
||||
skip_excluded: bool = False,
|
||||
dumps: JsonDumps | None = None,
|
||||
loads: JsonLoads | None = None,
|
||||
) -> None:
|
||||
"""Initialize the file history provider.
|
||||
|
||||
Args:
|
||||
storage_path: Directory path where session history files will be stored.
|
||||
|
||||
Keyword Args:
|
||||
source_id: Unique identifier for this provider instance.
|
||||
load_messages: Whether to load messages before invocation.
|
||||
store_inputs: Whether to store input messages.
|
||||
store_context_messages: Whether to store context from other providers.
|
||||
store_context_from: If set, only store context from these source_ids.
|
||||
store_outputs: Whether to store response messages.
|
||||
skip_excluded: When True, ``get_messages`` omits messages whose
|
||||
``additional_properties["_excluded"]`` is truthy.
|
||||
dumps: Callable that serializes a message payload dict to JSON text
|
||||
or UTF-8 bytes. The returned JSON must fit on a single line.
|
||||
loads: Callable that deserializes JSON text or bytes back to a
|
||||
message payload dict.
|
||||
"""
|
||||
super().__init__(
|
||||
source_id=source_id,
|
||||
load_messages=load_messages,
|
||||
store_inputs=store_inputs,
|
||||
store_context_messages=store_context_messages,
|
||||
store_context_from=store_context_from,
|
||||
store_outputs=store_outputs,
|
||||
)
|
||||
self.storage_path = Path(storage_path)
|
||||
self.storage_path.mkdir(parents=True, exist_ok=True)
|
||||
self._storage_root = self.storage_path.resolve()
|
||||
self.skip_excluded = skip_excluded
|
||||
self.dumps = dumps or _default_json_dumps
|
||||
self.loads = loads or _default_json_loads
|
||||
self._async_write_locks_by_loop: weakref.WeakKeyDictionary[
|
||||
asyncio.AbstractEventLoop,
|
||||
tuple[asyncio.Lock, ...],
|
||||
] = weakref.WeakKeyDictionary()
|
||||
|
||||
async def get_messages(
|
||||
self,
|
||||
session_id: str | None,
|
||||
*,
|
||||
state: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> list[Message]:
|
||||
"""Retrieve messages from the session's JSON Lines file."""
|
||||
del state, kwargs
|
||||
file_path = self._session_file_path(session_id)
|
||||
async_lock = self._session_async_write_lock(file_path)
|
||||
thread_lock = self._session_write_lock(file_path)
|
||||
|
||||
def _read_messages() -> list[Message]:
|
||||
with thread_lock:
|
||||
if not file_path.exists():
|
||||
return []
|
||||
|
||||
messages: list[Message] = []
|
||||
with file_path.open(encoding="utf-8") as file_handle:
|
||||
for line_number, line in enumerate(file_handle, start=1):
|
||||
serialized = line.strip()
|
||||
if not serialized:
|
||||
continue
|
||||
try:
|
||||
payload = self.loads(serialized)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise ValueError(
|
||||
f"Failed to deserialize history line {line_number} from '{file_path}'."
|
||||
) from exc
|
||||
if not isinstance(payload, Mapping):
|
||||
raise ValueError(
|
||||
f"History line {line_number} in '{file_path}' did not deserialize to a mapping."
|
||||
)
|
||||
|
||||
try:
|
||||
message = Message.from_dict(dict(cast(Mapping[str, Any], payload)))
|
||||
except ValueError as exc:
|
||||
raise ValueError(
|
||||
f"History line {line_number} in '{file_path}' is not a valid Message payload."
|
||||
) from exc
|
||||
messages.append(message)
|
||||
return messages
|
||||
|
||||
async with async_lock:
|
||||
messages = await asyncio.to_thread(_read_messages)
|
||||
if self.skip_excluded:
|
||||
messages = [m for m in messages if not m.additional_properties.get("_excluded", False)]
|
||||
return messages
|
||||
|
||||
async def save_messages(
|
||||
self,
|
||||
session_id: str | None,
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
state: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Append messages to the session's JSON Lines file."""
|
||||
del state, kwargs
|
||||
if not messages:
|
||||
return
|
||||
|
||||
file_path = self._session_file_path(session_id)
|
||||
async_lock = self._session_async_write_lock(file_path)
|
||||
file_lock = self._session_write_lock(file_path)
|
||||
|
||||
def _append_messages() -> None:
|
||||
with file_lock, file_path.open("a", encoding="utf-8") as file_handle:
|
||||
for message in messages:
|
||||
file_handle.write(f"{self._serialize_message(message)}\n")
|
||||
|
||||
async with async_lock:
|
||||
await asyncio.to_thread(_append_messages)
|
||||
|
||||
def _serialize_message(self, message: Message) -> str:
|
||||
"""Serialize a message payload to a single JSON Lines record."""
|
||||
serialized = self.dumps(message.to_dict())
|
||||
if isinstance(serialized, bytes):
|
||||
serialized_text = serialized.decode("utf-8")
|
||||
elif isinstance(serialized, str):
|
||||
serialized_text = serialized
|
||||
else:
|
||||
raise TypeError("FileHistoryProvider.dumps must return str or bytes.")
|
||||
|
||||
if "\n" in serialized_text or "\r" in serialized_text:
|
||||
raise ValueError("FileHistoryProvider.dumps must return single-line JSON for JSON Lines storage.")
|
||||
return serialized_text
|
||||
|
||||
def _session_file_path(self, session_id: str | None) -> Path:
|
||||
"""Resolve the on-disk history file path for a session."""
|
||||
file_path = (self._storage_root / f"{self._session_file_stem(session_id)}{self.FILE_EXTENSION}").resolve()
|
||||
if not file_path.is_relative_to(self._storage_root):
|
||||
raise ValueError(f"Session history path escaped storage directory: {session_id!r}")
|
||||
return file_path
|
||||
|
||||
def _session_file_stem(self, session_id: str | None) -> str:
|
||||
"""Return the filename stem for a session."""
|
||||
raw_session_id = session_id or self.DEFAULT_SESSION_FILE_STEM
|
||||
if self._is_literal_session_file_stem_safe(raw_session_id):
|
||||
return raw_session_id
|
||||
|
||||
encoded_session_id = urlsafe_b64encode(raw_session_id.encode("utf-8")).decode("ascii").rstrip("=")
|
||||
return f"{self._ENCODED_SESSION_PREFIX}{encoded_session_id or self.DEFAULT_SESSION_FILE_STEM}"
|
||||
|
||||
def _session_async_write_lock(self, file_path: Path) -> asyncio.Lock:
|
||||
"""Return the event-loop-local async lock for a session history file."""
|
||||
loop = asyncio.get_running_loop()
|
||||
locks = self._async_write_locks_by_loop.get(loop)
|
||||
if locks is None:
|
||||
locks = tuple(asyncio.Lock() for _ in range(self._FILE_LOCK_STRIPE_COUNT))
|
||||
self._async_write_locks_by_loop[loop] = locks
|
||||
return locks[self._lock_index(file_path)]
|
||||
|
||||
@classmethod
|
||||
def _session_write_lock(cls, file_path: Path) -> threading.Lock:
|
||||
"""Return the process-local thread lock for a session history file."""
|
||||
return cls._FILE_WRITE_LOCKS[cls._lock_index(file_path)]
|
||||
|
||||
@classmethod
|
||||
def _lock_index(cls, file_path: Path) -> int:
|
||||
"""Map a session history file to a bounded lock stripe."""
|
||||
return hash(file_path) % cls._FILE_LOCK_STRIPE_COUNT
|
||||
|
||||
@classmethod
|
||||
def _is_literal_session_file_stem_safe(cls, session_id: str) -> bool:
|
||||
"""Return whether the session ID can be used directly as a filename stem."""
|
||||
if (
|
||||
not session_id
|
||||
or session_id.startswith(".")
|
||||
or session_id.endswith((" ", "."))
|
||||
or session_id.upper() in cls._WINDOWS_RESERVED_FILE_STEMS
|
||||
):
|
||||
return False
|
||||
if any(ord(character) < 32 for character in session_id):
|
||||
return False
|
||||
return all(character.isalnum() or character in "._-" for character in session_id)
|
||||
|
||||
@@ -244,10 +244,10 @@ class FileCheckpointStorage:
|
||||
is serialized using pickle and embedded as base64-encoded strings within the JSON. This allows
|
||||
for human-readable checkpoint files while preserving the ability to store complex Python objects.
|
||||
|
||||
By default, checkpoint deserialization is restricted to a built-in set of safe
|
||||
Python types (primitives, datetime, uuid, ...) and all ``agent_framework``
|
||||
internal types. To allow additional application-specific types, pass them via
|
||||
the ``allowed_checkpoint_types`` parameter using ``"module:qualname"`` format.
|
||||
By default, checkpoint deserialization is restricted to a built-in set of safe Python types
|
||||
(primitives, datetime, uuid, ...), all ``agent_framework`` internal types, and OpenAI SDK types
|
||||
(``openai.types``). To allow additional application-specific types, pass them via the
|
||||
``allowed_checkpoint_types`` parameter using ``"module:qualname"`` format.
|
||||
|
||||
Example::
|
||||
|
||||
|
||||
@@ -10,9 +10,9 @@ This hybrid approach provides:
|
||||
When ``allowed_types`` is supplied to :func:`decode_checkpoint_value`, a
|
||||
``RestrictedUnpickler`` is used that limits which classes may be instantiated
|
||||
during deserialization. The default built-in safe set covers common Python
|
||||
value types (primitives, datetime, uuid, ...) and all ``agent_framework``
|
||||
internal types. Callers can extend the set by passing additional
|
||||
``"module:qualname"`` strings.
|
||||
value types (primitives, datetime, uuid, ...), all ``agent_framework`` internal
|
||||
types, and all ``openai.types`` types. Callers can extend the set by passing
|
||||
additional ``"module:qualname"`` strings.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -37,6 +37,9 @@ _JSON_NATIVE_TYPES = (str, int, float, bool, type(None))
|
||||
# Module prefix for framework-internal types that are always allowed
|
||||
_FRAMEWORK_MODULE_PREFIX = "agent_framework."
|
||||
|
||||
# Module prefix for OpenAI SDK types that are always allowed
|
||||
_OPENAI_MODULE_PREFIX = "openai.types."
|
||||
|
||||
# Built-in types considered safe for checkpoint deserialization.
|
||||
# Each entry is a ``module:qualname`` string matching the format produced by
|
||||
# :func:`_type_to_key`. These are the classes for which pickle's
|
||||
@@ -84,8 +87,9 @@ class _RestrictedUnpickler(pickle.Unpickler): # noqa: S301
|
||||
"""Unpickler that restricts which classes may be instantiated.
|
||||
|
||||
Only classes whose ``module:qualname`` key appears in the combined allow
|
||||
set (built-in safe types + framework types + caller-specified extras) are
|
||||
permitted. All other classes raise :class:`pickle.UnpicklingError`.
|
||||
set (built-in safe types + framework types + OpenAI SDK types +
|
||||
caller-specified extras) are permitted. All other classes raise
|
||||
:class:`pickle.UnpicklingError`.
|
||||
"""
|
||||
|
||||
def __init__(self, data: bytes, allowed_types: frozenset[str]) -> None:
|
||||
@@ -99,6 +103,7 @@ class _RestrictedUnpickler(pickle.Unpickler): # noqa: S301
|
||||
type_key in _BUILTIN_ALLOWED_TYPE_KEYS
|
||||
or type_key in self._allowed_types
|
||||
or module.startswith(_FRAMEWORK_MODULE_PREFIX)
|
||||
or module.startswith(_OPENAI_MODULE_PREFIX)
|
||||
):
|
||||
return super().find_class(module, name) # type: ignore[no-any-return] # nosec
|
||||
|
||||
|
||||
@@ -11,11 +11,11 @@ import logging
|
||||
import types
|
||||
import uuid
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, Sequence
|
||||
from typing import Any, Literal, overload
|
||||
from typing import TYPE_CHECKING, Any, Literal, overload
|
||||
|
||||
from .._sessions import ContextProvider
|
||||
from .._types import ResponseStream
|
||||
from ..observability import OtelAttr, capture_exception, create_workflow_span
|
||||
from ._agent import WorkflowAgent
|
||||
from ._checkpoint import CheckpointStorage
|
||||
from ._const import DEFAULT_MAX_ITERATIONS, GLOBAL_KWARGS_KEY, WORKFLOW_RUN_KWARGS_KEY
|
||||
from ._edge import (
|
||||
@@ -35,6 +35,9 @@ from ._runner_context import RunnerContext
|
||||
from ._state import State
|
||||
from ._typing_utils import is_instance_of, try_coerce_to_type
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._agent import WorkflowAgent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -910,7 +913,14 @@ class Workflow(DictConvertible):
|
||||
|
||||
return list(output_types)
|
||||
|
||||
def as_agent(self, name: str | None = None) -> WorkflowAgent:
|
||||
def as_agent(
|
||||
self,
|
||||
name: str | None = None,
|
||||
*,
|
||||
description: str | None = None,
|
||||
context_providers: Sequence[ContextProvider] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> WorkflowAgent:
|
||||
"""Create a WorkflowAgent that wraps this workflow.
|
||||
|
||||
The returned agent converts standard agent inputs (strings, Message, or lists of these)
|
||||
@@ -924,7 +934,10 @@ class Workflow(DictConvertible):
|
||||
initialization will fail with a ValueError.
|
||||
|
||||
Args:
|
||||
name: Optional name for the agent. If None, a default name will be generated.
|
||||
name: Optional name for the agent. Defaults to workflow name.
|
||||
description: Optional description of the agent. Defaults to workflow description.
|
||||
context_providers: Optional sequence of context providers for the agent.
|
||||
**kwargs: Additional keyword arguments passed to BaseAgent.
|
||||
|
||||
Returns:
|
||||
A WorkflowAgent instance that wraps this workflow.
|
||||
@@ -935,4 +948,10 @@ class Workflow(DictConvertible):
|
||||
# Import here to avoid circular imports
|
||||
from ._agent import WorkflowAgent
|
||||
|
||||
return WorkflowAgent(workflow=self, name=name)
|
||||
return WorkflowAgent(
|
||||
workflow=self,
|
||||
name=name if name is not None else self.name,
|
||||
description=description if description is not None else self.description,
|
||||
context_providers=context_providers,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -10,6 +15,8 @@ from agent_framework import (
|
||||
AgentSession,
|
||||
ChatContext,
|
||||
ContextProvider,
|
||||
ExperimentalFeature,
|
||||
FileHistoryProvider,
|
||||
HistoryProvider,
|
||||
InMemoryHistoryProvider,
|
||||
Message,
|
||||
@@ -505,3 +512,217 @@ class TestInMemoryHistoryProvider:
|
||||
ctx = SessionContext(session_id="s1", input_messages=[])
|
||||
ctx.extend_messages("custom-source", [Message(role="user", contents=["test"])])
|
||||
assert "custom-source" in ctx.context_messages
|
||||
|
||||
|
||||
class TestFileHistoryProvider:
|
||||
def test_is_marked_experimental(self) -> None:
|
||||
assert FileHistoryProvider.__feature_stage__ == "experimental"
|
||||
assert FileHistoryProvider.__feature_id__ == ExperimentalFeature.FILE_HISTORY.value
|
||||
assert FileHistoryProvider.__doc__ is not None
|
||||
assert ".. warning:: Experimental" in FileHistoryProvider.__doc__
|
||||
|
||||
async def test_stores_and_loads_messages(self, tmp_path: Path) -> None:
|
||||
from agent_framework import AgentResponse
|
||||
|
||||
provider = FileHistoryProvider(tmp_path)
|
||||
session = AgentSession(session_id="s1")
|
||||
|
||||
input_message = Message(role="user", contents=["hello"])
|
||||
response_message = Message(role="assistant", contents=["hi there"])
|
||||
first_context = SessionContext(session_id=session.session_id, input_messages=[input_message])
|
||||
|
||||
await provider.before_run( # type: ignore[arg-type]
|
||||
agent=None,
|
||||
session=session,
|
||||
context=first_context,
|
||||
state={},
|
||||
)
|
||||
first_context._response = AgentResponse(messages=[response_message])
|
||||
await provider.after_run( # type: ignore[arg-type]
|
||||
agent=None,
|
||||
session=session,
|
||||
context=first_context,
|
||||
state={},
|
||||
)
|
||||
|
||||
session_file = provider._session_file_path(session.session_id)
|
||||
assert session_file.name == "s1.jsonl"
|
||||
assert session_file.exists()
|
||||
raw_lines = (await asyncio.to_thread(session_file.read_text, encoding="utf-8")).splitlines()
|
||||
assert len(raw_lines) == 2
|
||||
payloads = [json.loads(line) for line in raw_lines]
|
||||
assert all(payload["type"] == "message" for payload in payloads)
|
||||
assert all("session_id" not in payload for payload in payloads)
|
||||
|
||||
second_context = SessionContext(
|
||||
session_id=session.session_id, input_messages=[Message(role="user", contents=["again"])]
|
||||
)
|
||||
await provider.before_run( # type: ignore[arg-type]
|
||||
agent=None,
|
||||
session=session,
|
||||
context=second_context,
|
||||
state={},
|
||||
)
|
||||
loaded = second_context.context_messages.get(provider.source_id, [])
|
||||
assert len(loaded) == 2
|
||||
assert loaded[0].text == "hello"
|
||||
assert loaded[1].text == "hi there"
|
||||
|
||||
def test_creates_storage_directory(self, tmp_path: Path) -> None:
|
||||
nested_path = tmp_path / "nested" / "history"
|
||||
provider = FileHistoryProvider(nested_path)
|
||||
assert provider.storage_path == nested_path
|
||||
assert nested_path.exists()
|
||||
assert nested_path.is_dir()
|
||||
|
||||
async def test_uses_encoded_filename_for_unsafe_session_id(self, tmp_path: Path) -> None:
|
||||
provider = FileHistoryProvider(tmp_path)
|
||||
unsafe_session_id = "../unsafe/session"
|
||||
|
||||
await provider.save_messages(unsafe_session_id, [Message(role="user", contents=["hello"])])
|
||||
|
||||
session_file = provider._session_file_path(unsafe_session_id)
|
||||
assert session_file.parent == provider.storage_path
|
||||
assert session_file.name.startswith("~session-")
|
||||
assert session_file.suffix == ".jsonl"
|
||||
assert session_file.exists()
|
||||
jsonl_files = await asyncio.to_thread(
|
||||
lambda: sorted(path.name for path in provider.storage_path.glob("*.jsonl"))
|
||||
)
|
||||
assert jsonl_files == [session_file.name]
|
||||
|
||||
async def test_allows_custom_serializers_returning_bytes(self, tmp_path: Path) -> None:
|
||||
calls: list[str] = []
|
||||
|
||||
def dumps(payload: object) -> bytes:
|
||||
calls.append("dumps")
|
||||
return json.dumps(payload).encode("utf-8")
|
||||
|
||||
def loads(payload: str | bytes) -> object:
|
||||
calls.append("loads")
|
||||
if isinstance(payload, bytes):
|
||||
payload = payload.decode("utf-8")
|
||||
return json.loads(payload)
|
||||
|
||||
provider = FileHistoryProvider(tmp_path, dumps=dumps, loads=loads)
|
||||
|
||||
await provider.save_messages("custom-serializer", [Message(role="user", contents=["hello"])])
|
||||
loaded = await provider.get_messages("custom-serializer")
|
||||
|
||||
assert calls == ["dumps", "loads"]
|
||||
assert len(loaded) == 1
|
||||
assert loaded[0].text == "hello"
|
||||
|
||||
async def test_invalid_jsonl_line_raises(self, tmp_path: Path) -> None:
|
||||
provider = FileHistoryProvider(tmp_path)
|
||||
await asyncio.to_thread(provider._session_file_path("broken").write_text, "{not-json}\n", encoding="utf-8")
|
||||
|
||||
with pytest.raises(ValueError, match="Failed to deserialize history line 1"):
|
||||
await provider.get_messages("broken")
|
||||
|
||||
async def test_missing_session_file_returns_empty_messages(self, tmp_path: Path) -> None:
|
||||
provider = FileHistoryProvider(tmp_path)
|
||||
|
||||
loaded = await provider.get_messages("missing")
|
||||
|
||||
assert loaded == []
|
||||
|
||||
async def test_none_session_id_uses_default_jsonl_file(self, tmp_path: Path) -> None:
|
||||
provider = FileHistoryProvider(tmp_path)
|
||||
|
||||
await provider.save_messages(None, [Message(role="user", contents=["hello"])])
|
||||
|
||||
session_file = provider._session_file_path(None)
|
||||
assert session_file.name == "default.jsonl"
|
||||
loaded = await provider.get_messages(None)
|
||||
assert [message.text for message in loaded] == ["hello"]
|
||||
|
||||
async def test_non_mapping_jsonl_line_raises(self, tmp_path: Path) -> None:
|
||||
provider = FileHistoryProvider(tmp_path)
|
||||
await asyncio.to_thread(provider._session_file_path("non-mapping").write_text, "[1, 2, 3]\n", encoding="utf-8")
|
||||
|
||||
with pytest.raises(ValueError, match="did not deserialize to a mapping"):
|
||||
await provider.get_messages("non-mapping")
|
||||
|
||||
async def test_skip_excluded_omits_excluded_messages(self, tmp_path: Path) -> None:
|
||||
provider = FileHistoryProvider(tmp_path, skip_excluded=True)
|
||||
|
||||
await provider.save_messages(
|
||||
"skip-excluded",
|
||||
[
|
||||
Message(role="user", contents=["keep"]),
|
||||
Message(role="assistant", contents=["skip"], additional_properties={"_excluded": True}),
|
||||
],
|
||||
)
|
||||
|
||||
loaded = await provider.get_messages("skip-excluded")
|
||||
|
||||
assert [message.text for message in loaded] == ["keep"]
|
||||
|
||||
async def test_serializer_must_return_single_line_json(self, tmp_path: Path) -> None:
|
||||
def dumps(payload: object) -> str:
|
||||
return json.dumps(payload, indent=2)
|
||||
|
||||
provider = FileHistoryProvider(tmp_path, dumps=dumps)
|
||||
|
||||
with pytest.raises(ValueError, match="single-line JSON"):
|
||||
await provider.save_messages("pretty-json", [Message(role="user", contents=["hello"])])
|
||||
|
||||
async def test_concurrent_writes_for_same_session_are_locked(
|
||||
self,
|
||||
tmp_path: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
provider = FileHistoryProvider(tmp_path)
|
||||
session_id = "shared-session"
|
||||
file_path = provider._session_file_path(session_id)
|
||||
real_open = Path.open
|
||||
write_started = threading.Event()
|
||||
active_writes = 0
|
||||
overlap_detected = False
|
||||
|
||||
class _TrackingFile:
|
||||
def __init__(self, wrapped: Any) -> None:
|
||||
self._wrapped = wrapped
|
||||
|
||||
def __enter__(self) -> "_TrackingFile":
|
||||
self._wrapped.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
self._wrapped.__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
def write(self, data: str) -> int:
|
||||
nonlocal active_writes, overlap_detected
|
||||
write_started.set()
|
||||
active_writes += 1
|
||||
overlap_detected = overlap_detected or active_writes > 1
|
||||
try:
|
||||
time.sleep(0.05)
|
||||
return int(self._wrapped.write(data))
|
||||
finally:
|
||||
active_writes -= 1
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return getattr(self._wrapped, name)
|
||||
|
||||
def tracked_open(path: Path, *args: Any, **kwargs: Any) -> Any:
|
||||
handle = real_open(path, *args, **kwargs)
|
||||
if path == file_path and args and args[0] == "a":
|
||||
return _TrackingFile(handle)
|
||||
return handle
|
||||
|
||||
monkeypatch.setattr(Path, "open", tracked_open)
|
||||
|
||||
first_save = asyncio.create_task(provider.save_messages(session_id, [Message(role="user", contents=["first"])]))
|
||||
started = await asyncio.to_thread(write_started.wait, 1.0)
|
||||
assert started
|
||||
|
||||
second_save = asyncio.create_task(
|
||||
provider.save_messages(session_id, [Message(role="assistant", contents=["second"])])
|
||||
)
|
||||
await asyncio.gather(first_save, second_save)
|
||||
|
||||
assert not overlap_detected
|
||||
loaded = await provider.get_messages(session_id)
|
||||
assert [message.text for message in loaded] == ["first", "second"]
|
||||
|
||||
@@ -216,3 +216,50 @@ def test_restricted_unpickler_raises_pickle_error():
|
||||
unpickler = _RestrictedUnpickler(pickled, frozenset())
|
||||
with pytest.raises(pickle.UnpicklingError, match="deserialization blocked"):
|
||||
unpickler.load()
|
||||
|
||||
|
||||
def test_restricted_decode_allows_openai_types():
|
||||
"""OpenAI SDK types are always allowed during restricted deserialization."""
|
||||
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
||||
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
completion = ChatCompletion(
|
||||
id="chatcmpl-test",
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(role="assistant", content="hello"),
|
||||
)
|
||||
],
|
||||
created=1700000000,
|
||||
model="gpt-4",
|
||||
object="chat.completion",
|
||||
usage=CompletionUsage(completion_tokens=1, prompt_tokens=1, total_tokens=2),
|
||||
)
|
||||
encoded = encode_checkpoint_value(completion)
|
||||
decoded = decode_checkpoint_value(encoded, allowed_types=frozenset())
|
||||
|
||||
assert isinstance(decoded, ChatCompletion)
|
||||
assert decoded.id == "chatcmpl-test"
|
||||
assert decoded.choices[0].message.content == "hello"
|
||||
|
||||
|
||||
def test_restricted_decode_allows_openai_response_types():
|
||||
"""OpenAI Responses API types are always allowed during restricted deserialization."""
|
||||
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails, ResponseUsage
|
||||
|
||||
usage = ResponseUsage(
|
||||
input_tokens=10,
|
||||
output_tokens=20,
|
||||
total_tokens=30,
|
||||
input_tokens_details=InputTokensDetails(cached_tokens=0),
|
||||
output_tokens_details=OutputTokensDetails(reasoning_tokens=0),
|
||||
)
|
||||
encoded = encode_checkpoint_value(usage)
|
||||
decoded = decode_checkpoint_value(encoded, allowed_types=frozenset())
|
||||
|
||||
assert isinstance(decoded, ResponseUsage)
|
||||
assert decoded.input_tokens == 10
|
||||
assert decoded.output_tokens == 20
|
||||
|
||||
@@ -313,6 +313,37 @@ class TestWorkflowAgent:
|
||||
assert isinstance(agent_no_name, WorkflowAgent)
|
||||
assert agent_no_name.workflow is workflow
|
||||
|
||||
def test_workflow_as_agent_with_description_and_context_providers(self) -> None:
|
||||
"""Test that Workflow.as_agent() forwards description and context_providers."""
|
||||
executor = SimpleExecutor(id="executor1", response_text="Response")
|
||||
workflow = WorkflowBuilder(start_executor=executor).build()
|
||||
|
||||
history_provider = InMemoryHistoryProvider()
|
||||
agent = workflow.as_agent(
|
||||
name="MyAgent",
|
||||
description="A test agent",
|
||||
context_providers=[history_provider],
|
||||
)
|
||||
|
||||
assert isinstance(agent, WorkflowAgent)
|
||||
assert agent.name == "MyAgent"
|
||||
assert agent.description == "A test agent"
|
||||
assert history_provider in agent.context_providers
|
||||
|
||||
def test_workflow_as_agent_defaults_name_and_description_from_workflow(self) -> None:
|
||||
"""Test that as_agent() defaults name and description to the workflow's own values."""
|
||||
executor = SimpleExecutor(id="executor1", response_text="Response")
|
||||
workflow = WorkflowBuilder(
|
||||
start_executor=executor,
|
||||
name="my-workflow",
|
||||
description="Workflow description",
|
||||
).build()
|
||||
|
||||
agent = workflow.as_agent()
|
||||
|
||||
assert agent.name == "my-workflow"
|
||||
assert agent.description == "Workflow description"
|
||||
|
||||
def test_workflow_as_agent_cannot_handle_agent_inputs(self) -> None:
|
||||
"""Test that Workflow.as_agent() raises an error if the start executor cannot handle agent inputs."""
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ description = "Debug UI for Microsoft Agent Framework with OpenAI-compatible API
|
||||
authors = [{ name = "Microsoft", email = "af-support@microsoft.com"}]
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
version = "1.0.0b260409"
|
||||
version = "1.0.0b260414"
|
||||
license-files = ["LICENSE"]
|
||||
urls.homepage = "https://github.com/microsoft/agent-framework"
|
||||
urls.source = "https://github.com/microsoft/agent-framework/tree/main/python"
|
||||
|
||||
@@ -31,7 +31,7 @@ from agent_framework.exceptions import AgentException
|
||||
try:
|
||||
from copilot import CopilotClient, CopilotSession, SubprocessConfig
|
||||
from copilot.generated.session_events import PermissionRequest, SessionEvent, SessionEventType
|
||||
from copilot.session import MCPServerConfig, PermissionRequestResult, SystemMessageConfig
|
||||
from copilot.session import MCPServerConfig, PermissionRequestResult, ProviderConfig, SystemMessageConfig
|
||||
from copilot.tools import Tool as CopilotTool
|
||||
from copilot.tools import ToolInvocation, ToolResult
|
||||
except ImportError as _copilot_import_error:
|
||||
@@ -120,6 +120,12 @@ class GitHubCopilotOptions(TypedDict, total=False):
|
||||
Supports both local (stdio) and remote (HTTP/SSE) servers.
|
||||
"""
|
||||
|
||||
provider: ProviderConfig
|
||||
"""Custom API provider configuration for BYOK (Bring Your Own Key) scenarios.
|
||||
Allows routing requests through your own OpenAI, Azure, or Anthropic endpoint
|
||||
instead of the default GitHub Copilot backend.
|
||||
"""
|
||||
|
||||
|
||||
OptionsT = TypeVar(
|
||||
"OptionsT",
|
||||
@@ -232,6 +238,7 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
log_level = opts.pop("log_level", None)
|
||||
on_permission_request: PermissionHandlerType | None = opts.pop("on_permission_request", None)
|
||||
mcp_servers: dict[str, MCPServerConfig] | None = opts.pop("mcp_servers", None)
|
||||
provider: ProviderConfig | None = opts.pop("provider", None)
|
||||
|
||||
self._settings = load_settings(
|
||||
GitHubCopilotSettings,
|
||||
@@ -247,6 +254,7 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
self._tools = normalize_tools(tools)
|
||||
self._permission_handler = on_permission_request
|
||||
self._mcp_servers = mcp_servers
|
||||
self._provider = provider
|
||||
self._default_options = opts
|
||||
self._started = False
|
||||
|
||||
@@ -730,6 +738,7 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
opts.get("on_permission_request") or self._permission_handler or _deny_all_permissions
|
||||
)
|
||||
mcp_servers = opts.get("mcp_servers") or self._mcp_servers or None
|
||||
provider = opts.get("provider") or self._provider or None
|
||||
tools = self._prepare_tools(self._tools) if self._tools else None
|
||||
|
||||
return await self._client.create_session(
|
||||
@@ -739,6 +748,7 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
system_message=system_message or None,
|
||||
tools=tools or None,
|
||||
mcp_servers=mcp_servers or None,
|
||||
provider=provider or None,
|
||||
)
|
||||
|
||||
async def _resume_session(self, session_id: str, streaming: bool) -> CopilotSession:
|
||||
@@ -755,4 +765,5 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
streaming=streaming,
|
||||
tools=tools or None,
|
||||
mcp_servers=self._mcp_servers or None,
|
||||
provider=self._provider or None,
|
||||
)
|
||||
|
||||
@@ -861,6 +861,7 @@ class TestGitHubCopilotAgentSessionManagement:
|
||||
streaming=unittest.mock.ANY,
|
||||
tools=unittest.mock.ANY,
|
||||
mcp_servers=unittest.mock.ANY,
|
||||
provider=unittest.mock.ANY,
|
||||
)
|
||||
|
||||
async def test_session_config_includes_model(
|
||||
@@ -1084,6 +1085,198 @@ class TestGitHubCopilotAgentMCPServers:
|
||||
assert config["mcp_servers"] is None
|
||||
|
||||
|
||||
class TestGitHubCopilotAgentProvider:
|
||||
"""Test cases for provider configuration (BYOK / Managed Identity)."""
|
||||
|
||||
async def test_provider_passed_to_create_session(
|
||||
self,
|
||||
mock_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test that provider config is passed through to create_session."""
|
||||
from copilot.session import ProviderConfig
|
||||
|
||||
provider: ProviderConfig = {
|
||||
"type": "azure",
|
||||
"base_url": "https://my-resource.openai.azure.com",
|
||||
"bearer_token": "test-token",
|
||||
}
|
||||
|
||||
agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent(
|
||||
client=mock_client,
|
||||
default_options={"provider": provider},
|
||||
)
|
||||
await agent.start()
|
||||
|
||||
await agent._get_or_create_session(AgentSession()) # type: ignore
|
||||
|
||||
call_args = mock_client.create_session.call_args
|
||||
config = call_args.kwargs
|
||||
assert config["provider"]["type"] == "azure"
|
||||
assert config["provider"]["base_url"] == "https://my-resource.openai.azure.com"
|
||||
assert config["provider"]["bearer_token"] == "test-token"
|
||||
|
||||
async def test_provider_passed_to_resume_session(
|
||||
self,
|
||||
mock_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test that provider config is passed through to resume_session."""
|
||||
from copilot.session import ProviderConfig
|
||||
|
||||
provider: ProviderConfig = {
|
||||
"type": "azure",
|
||||
"base_url": "https://my-resource.openai.azure.com",
|
||||
"bearer_token": "test-token",
|
||||
}
|
||||
|
||||
agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent(
|
||||
client=mock_client,
|
||||
default_options={"provider": provider},
|
||||
)
|
||||
await agent.start()
|
||||
|
||||
session = AgentSession()
|
||||
session.service_session_id = "existing-session-id"
|
||||
|
||||
await agent._get_or_create_session(session) # type: ignore
|
||||
|
||||
mock_client.resume_session.assert_called_once()
|
||||
call_args = mock_client.resume_session.call_args
|
||||
config = call_args.kwargs
|
||||
assert config["provider"]["type"] == "azure"
|
||||
|
||||
async def test_session_config_excludes_provider_when_not_set(
|
||||
self,
|
||||
mock_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test that provider is None in session config when not set."""
|
||||
agent = GitHubCopilotAgent(client=mock_client)
|
||||
await agent.start()
|
||||
|
||||
await agent._get_or_create_session(AgentSession()) # type: ignore
|
||||
|
||||
call_args = mock_client.create_session.call_args
|
||||
config = call_args.kwargs
|
||||
assert config["provider"] is None
|
||||
|
||||
async def test_resume_session_excludes_provider_when_not_set(
|
||||
self,
|
||||
mock_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test that provider is None in resume session config when not set."""
|
||||
agent = GitHubCopilotAgent(client=mock_client)
|
||||
await agent.start()
|
||||
|
||||
session = AgentSession()
|
||||
session.service_session_id = "existing-session-id"
|
||||
|
||||
await agent._get_or_create_session(session) # type: ignore
|
||||
|
||||
call_args = mock_client.resume_session.call_args
|
||||
config = call_args.kwargs
|
||||
assert config["provider"] is None
|
||||
|
||||
async def test_runtime_provider_takes_precedence(
|
||||
self,
|
||||
mock_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test that runtime provider options override default_options provider."""
|
||||
from copilot.session import ProviderConfig
|
||||
|
||||
default_provider: ProviderConfig = {
|
||||
"type": "azure",
|
||||
"base_url": "https://default.openai.azure.com",
|
||||
"bearer_token": "default-token",
|
||||
}
|
||||
runtime_provider: ProviderConfig = {
|
||||
"type": "openai",
|
||||
"base_url": "https://runtime.openai.com",
|
||||
"api_key": "runtime-key",
|
||||
}
|
||||
|
||||
agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent(
|
||||
client=mock_client,
|
||||
default_options={"provider": default_provider},
|
||||
)
|
||||
await agent.start()
|
||||
|
||||
await agent._get_or_create_session( # type: ignore
|
||||
AgentSession(),
|
||||
runtime_options={"provider": runtime_provider},
|
||||
)
|
||||
|
||||
call_args = mock_client.create_session.call_args
|
||||
config = call_args.kwargs
|
||||
assert config["provider"]["type"] == "openai"
|
||||
assert config["provider"]["base_url"] == "https://runtime.openai.com"
|
||||
|
||||
async def test_provider_not_leaked_into_default_options(
|
||||
self,
|
||||
mock_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test that provider is popped from opts and not left in _default_options."""
|
||||
from copilot.session import ProviderConfig
|
||||
|
||||
provider: ProviderConfig = {
|
||||
"type": "azure",
|
||||
"base_url": "https://my-resource.openai.azure.com",
|
||||
"bearer_token": "test-token",
|
||||
}
|
||||
|
||||
agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent(
|
||||
client=mock_client,
|
||||
default_options={"provider": provider, "model": "gpt-5"},
|
||||
)
|
||||
|
||||
assert "provider" not in agent._default_options
|
||||
assert agent._provider is not None
|
||||
assert agent._provider["type"] == "azure"
|
||||
|
||||
async def test_provider_coexists_with_other_options(
|
||||
self,
|
||||
mock_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test that provider works alongside model, tools, and mcp_servers."""
|
||||
from copilot.session import MCPServerConfig, ProviderConfig
|
||||
|
||||
provider: ProviderConfig = {
|
||||
"type": "azure",
|
||||
"base_url": "https://my-resource.openai.azure.com",
|
||||
"bearer_token": "test-token",
|
||||
}
|
||||
mcp_servers: dict[str, MCPServerConfig] = {
|
||||
"test-server": {
|
||||
"type": "stdio",
|
||||
"command": "echo",
|
||||
"args": ["hello"],
|
||||
"tools": ["*"],
|
||||
},
|
||||
}
|
||||
|
||||
def my_tool(arg: str) -> str:
|
||||
"""A test tool."""
|
||||
return arg
|
||||
|
||||
agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent(
|
||||
client=mock_client,
|
||||
tools=[my_tool],
|
||||
default_options={
|
||||
"model": "gpt-5",
|
||||
"provider": provider,
|
||||
"mcp_servers": mcp_servers,
|
||||
},
|
||||
)
|
||||
await agent.start()
|
||||
|
||||
await agent._get_or_create_session(AgentSession()) # type: ignore
|
||||
|
||||
call_args = mock_client.create_session.call_args
|
||||
config = call_args.kwargs
|
||||
assert config["provider"]["type"] == "azure"
|
||||
assert config["model"] == "gpt-5"
|
||||
assert config["mcp_servers"] is not None
|
||||
assert config["tools"] is not None
|
||||
|
||||
|
||||
class TestGitHubCopilotAgentToolConversion:
|
||||
"""Test cases for tool conversion."""
|
||||
|
||||
|
||||
@@ -1161,7 +1161,16 @@ class RawOpenAIChatClient( # type: ignore[misc]
|
||||
# First turn: prepend instructions as system message
|
||||
messages = prepend_instructions_to_messages(list(messages), instructions, role="system")
|
||||
# Continuation turn: instructions already exist in conversation context, skip prepending
|
||||
request_input = self._prepare_messages_for_openai(messages)
|
||||
request_uses_service_side_storage = False
|
||||
for key in ("conversation_id", "previous_response_id", "conversation"):
|
||||
value = options.get(key)
|
||||
if isinstance(value, str) and value:
|
||||
request_uses_service_side_storage = True
|
||||
break
|
||||
request_input = self._prepare_messages_for_openai(
|
||||
messages,
|
||||
request_uses_service_side_storage=request_uses_service_side_storage,
|
||||
)
|
||||
if not request_input:
|
||||
raise ChatClientInvalidRequestException("Messages are required for chat completions")
|
||||
conversation_id = options.get("conversation_id")
|
||||
@@ -1235,7 +1244,12 @@ class RawOpenAIChatClient( # type: ignore[misc]
|
||||
raise ValueError("model must be a non-empty string")
|
||||
options["model"] = self.model
|
||||
|
||||
def _prepare_messages_for_openai(self, chat_messages: Sequence[Message]) -> list[dict[str, Any]]:
|
||||
def _prepare_messages_for_openai(
|
||||
self,
|
||||
chat_messages: Sequence[Message],
|
||||
*,
|
||||
request_uses_service_side_storage: bool = True,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Prepare the chat messages for a request.
|
||||
|
||||
Allowing customization of the key names for role/author, and optionally overriding the role.
|
||||
@@ -1248,31 +1262,27 @@ class RawOpenAIChatClient( # type: ignore[misc]
|
||||
|
||||
Args:
|
||||
chat_messages: The chat history to prepare.
|
||||
request_uses_service_side_storage: Whether this request continues a service-managed
|
||||
response/conversation and can safely reference service-scoped response items.
|
||||
|
||||
Returns:
|
||||
The prepared chat messages for a request.
|
||||
"""
|
||||
list_of_list = [self._prepare_message_for_openai(message) for message in chat_messages]
|
||||
list_of_list = [
|
||||
self._prepare_message_for_openai(
|
||||
message,
|
||||
request_uses_service_side_storage=request_uses_service_side_storage,
|
||||
)
|
||||
for message in chat_messages
|
||||
]
|
||||
# Flatten the list of lists into a single list
|
||||
return list(chain.from_iterable(list_of_list))
|
||||
|
||||
@staticmethod
|
||||
def _message_replays_provider_context(message: Message) -> bool:
|
||||
"""Return whether the message came from provider-attributed replay context.
|
||||
|
||||
Responses ``fc_id`` values are response-scoped and only valid while replaying
|
||||
the same live tool loop. Once a message comes back through a context provider
|
||||
(for example, loaded session history), that message is historical input and
|
||||
must not reuse the original response-scoped ``fc_id``.
|
||||
"""
|
||||
additional_properties = getattr(message, "additional_properties", None)
|
||||
if not additional_properties:
|
||||
return False
|
||||
return "_attribution" in additional_properties
|
||||
|
||||
def _prepare_message_for_openai(
|
||||
self,
|
||||
message: Message,
|
||||
*,
|
||||
request_uses_service_side_storage: bool = True,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Prepare a chat message for the OpenAI Responses API format."""
|
||||
all_messages: list[dict[str, Any]] = []
|
||||
@@ -1280,34 +1290,63 @@ class RawOpenAIChatClient( # type: ignore[misc]
|
||||
"type": "message",
|
||||
"role": message.role,
|
||||
}
|
||||
additional_properties = message.additional_properties
|
||||
replays_local_storage = "_attribution" in additional_properties
|
||||
uses_service_side_storage = request_uses_service_side_storage and not replays_local_storage
|
||||
# Reasoning items are only valid in input when they directly preceded a function_call
|
||||
# in the same response. Including a reasoning item that preceded a text response
|
||||
# in the same response. Including a reasoning item that preceded a text response
|
||||
# (i.e. no function_call in the same message) causes an API error:
|
||||
# "reasoning was provided without its required following item."
|
||||
#
|
||||
# Local storage is stricter: response-scoped reasoning items (rs_*) cannot be replayed
|
||||
# back to the service unless that message is using service-side storage.
|
||||
# In that mode we omit reasoning items and rely on function call + tool output replay.
|
||||
has_function_call = any(c.type == "function_call" for c in message.contents)
|
||||
for content in message.contents:
|
||||
match content.type:
|
||||
case "text_reasoning":
|
||||
if not has_function_call:
|
||||
if not uses_service_side_storage or not has_function_call:
|
||||
continue # reasoning not followed by a function_call is invalid in input
|
||||
reasoning = self._prepare_content_for_openai(message.role, content, message=message)
|
||||
reasoning = self._prepare_content_for_openai(
|
||||
message.role,
|
||||
content,
|
||||
replays_local_storage=replays_local_storage,
|
||||
)
|
||||
if reasoning:
|
||||
all_messages.append(reasoning)
|
||||
case "function_result":
|
||||
new_args: dict[str, Any] = {}
|
||||
new_args.update(self._prepare_content_for_openai(message.role, content, message=message))
|
||||
new_args.update(
|
||||
self._prepare_content_for_openai(
|
||||
message.role,
|
||||
content,
|
||||
replays_local_storage=replays_local_storage,
|
||||
)
|
||||
)
|
||||
if new_args:
|
||||
all_messages.append(new_args)
|
||||
case "function_call":
|
||||
function_call = self._prepare_content_for_openai(message.role, content, message=message)
|
||||
function_call = self._prepare_content_for_openai(
|
||||
message.role,
|
||||
content,
|
||||
replays_local_storage=replays_local_storage,
|
||||
)
|
||||
if function_call:
|
||||
all_messages.append(function_call)
|
||||
case "function_approval_response" | "function_approval_request":
|
||||
prepared = self._prepare_content_for_openai(message.role, content, message=message)
|
||||
prepared = self._prepare_content_for_openai(
|
||||
message.role,
|
||||
content,
|
||||
replays_local_storage=replays_local_storage,
|
||||
)
|
||||
if prepared:
|
||||
all_messages.append(prepared)
|
||||
case _:
|
||||
prepared_content = self._prepare_content_for_openai(message.role, content, message=message)
|
||||
prepared_content = self._prepare_content_for_openai(
|
||||
message.role,
|
||||
content,
|
||||
replays_local_storage=replays_local_storage,
|
||||
)
|
||||
if prepared_content:
|
||||
if "content" not in args:
|
||||
args["content"] = []
|
||||
@@ -1321,7 +1360,7 @@ class RawOpenAIChatClient( # type: ignore[misc]
|
||||
role: Role | str,
|
||||
content: Content,
|
||||
*,
|
||||
message: Message | None = None,
|
||||
replays_local_storage: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Prepare content for the OpenAI Responses API format."""
|
||||
role = Role(role)
|
||||
@@ -1401,11 +1440,7 @@ class RawOpenAIChatClient( # type: ignore[misc]
|
||||
logger.warning(f"FunctionCallContent missing call_id for function '{content.name}'")
|
||||
return {}
|
||||
fc_id = content.call_id
|
||||
if (
|
||||
message is not None
|
||||
and not self._message_replays_provider_context(message)
|
||||
and content.additional_properties
|
||||
):
|
||||
if not replays_local_storage and content.additional_properties:
|
||||
live_fc_id = content.additional_properties.get("fc_id")
|
||||
if isinstance(live_fc_id, str) and live_fc_id:
|
||||
fc_id = live_fc_id
|
||||
|
||||
@@ -1015,6 +1015,84 @@ async def test_shell_call_is_invoked_as_local_shell_function_loop() -> None:
|
||||
assert len(local_shell_outputs) == 0
|
||||
|
||||
|
||||
async def test_tool_loop_store_false_omits_reasoning_items_from_second_request() -> None:
|
||||
"""Stateless tool-loop replay must omit response-scoped reasoning items."""
|
||||
client = OpenAIChatClient(model="test-model", api_key="test-key")
|
||||
|
||||
mock_response1 = MagicMock()
|
||||
mock_response1.output_parsed = None
|
||||
mock_response1.metadata = {}
|
||||
mock_response1.usage = None
|
||||
mock_response1.id = "resp-1"
|
||||
mock_response1.model = "test-model"
|
||||
mock_response1.created_at = 1000000000
|
||||
mock_response1.status = "completed"
|
||||
mock_response1.finish_reason = "tool_calls"
|
||||
mock_response1.incomplete = None
|
||||
mock_response1.conversation = None
|
||||
|
||||
mock_reasoning_item = MagicMock()
|
||||
mock_reasoning_item.type = "reasoning"
|
||||
mock_reasoning_item.id = "rs_local_only"
|
||||
mock_reasoning_item.content = []
|
||||
mock_reasoning_item.summary = []
|
||||
mock_reasoning_item.encrypted_content = None
|
||||
|
||||
mock_function_call_item = MagicMock()
|
||||
mock_function_call_item.type = "function_call"
|
||||
mock_function_call_item.id = "fc_tool123"
|
||||
mock_function_call_item.call_id = "call_123"
|
||||
mock_function_call_item.name = "get_weather"
|
||||
mock_function_call_item.arguments = '{"location":"Amsterdam"}'
|
||||
mock_function_call_item.status = "completed"
|
||||
|
||||
mock_response1.output = [mock_reasoning_item, mock_function_call_item]
|
||||
|
||||
mock_response2 = MagicMock()
|
||||
mock_response2.output_parsed = None
|
||||
mock_response2.metadata = {}
|
||||
mock_response2.usage = None
|
||||
mock_response2.id = "resp-2"
|
||||
mock_response2.model = "test-model"
|
||||
mock_response2.created_at = 1000000001
|
||||
mock_response2.status = "completed"
|
||||
mock_response2.finish_reason = "stop"
|
||||
mock_response2.incomplete = None
|
||||
mock_response2.conversation = None
|
||||
|
||||
mock_text_item = MagicMock()
|
||||
mock_text_item.type = "message"
|
||||
mock_text_content = MagicMock()
|
||||
mock_text_content.type = "output_text"
|
||||
mock_text_content.text = "The weather in Amsterdam is sunny."
|
||||
mock_text_item.content = [mock_text_content]
|
||||
mock_response2.output = [mock_text_item]
|
||||
|
||||
with patch.object(client.client.responses, "create", side_effect=[mock_response1, mock_response2]) as mock_create:
|
||||
response = await client.get_response(
|
||||
messages=[Message(role="user", contents=["What's the weather in Amsterdam?"])],
|
||||
options={
|
||||
"store": False,
|
||||
"tools": [get_weather],
|
||||
"tool_choice": {"mode": "required", "required_function_name": "get_weather"},
|
||||
},
|
||||
)
|
||||
|
||||
assert response.text == "The weather in Amsterdam is sunny."
|
||||
assert mock_create.call_count == 2
|
||||
|
||||
second_call_input = mock_create.call_args_list[1].kwargs["input"]
|
||||
assert not any(item.get("type") == "reasoning" for item in second_call_input)
|
||||
|
||||
function_calls = [item for item in second_call_input if item.get("type") == "function_call"]
|
||||
assert len(function_calls) == 1
|
||||
assert function_calls[0]["id"] == "fc_tool123"
|
||||
|
||||
function_outputs = [item for item in second_call_input if item.get("type") == "function_call_output"]
|
||||
assert len(function_outputs) == 1
|
||||
assert function_outputs[0]["call_id"] == "call_123"
|
||||
|
||||
|
||||
def test_response_content_creation_with_shell_call() -> None:
|
||||
"""Test _parse_response_from_openai with shell_call output."""
|
||||
client = OpenAIChatClient(model="test-model", api_key="test-key")
|
||||
@@ -3221,6 +3299,164 @@ async def test_prepare_options_store_parameter_handling() -> None:
|
||||
assert "previous_response_id" not in options
|
||||
|
||||
|
||||
async def test_prepare_options_store_false_omits_reasoning_items_for_stateless_replay() -> None:
|
||||
client = OpenAIChatClient(model="test-model", api_key="test-key")
|
||||
messages = [
|
||||
Message(role="user", contents=[Content.from_text(text="search for hotels")]),
|
||||
Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_text_reasoning(
|
||||
id="rs_test123",
|
||||
text="I need to search for hotels",
|
||||
additional_properties={"status": "completed"},
|
||||
),
|
||||
Content.from_function_call(
|
||||
call_id="call_1",
|
||||
name="search_hotels",
|
||||
arguments='{"city": "Paris"}',
|
||||
additional_properties={"fc_id": "fc_test456"},
|
||||
),
|
||||
],
|
||||
),
|
||||
Message(
|
||||
role="tool",
|
||||
contents=[
|
||||
Content.from_function_result(
|
||||
call_id="call_1",
|
||||
result="Found 3 hotels in Paris",
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
options = await client._prepare_options(messages, ChatOptions(store=False)) # type: ignore[arg-type]
|
||||
|
||||
assert not any(item.get("type") == "reasoning" for item in options["input"])
|
||||
assert any(item.get("type") == "function_call" for item in options["input"])
|
||||
assert any(item.get("type") == "function_call_output" for item in options["input"])
|
||||
|
||||
|
||||
async def test_prepare_options_with_conversation_id_keeps_reasoning_items() -> None:
|
||||
client = OpenAIChatClient(model="test-model", api_key="test-key")
|
||||
messages = [
|
||||
Message(role="user", contents=[Content.from_text(text="search for hotels")]),
|
||||
Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_text_reasoning(
|
||||
id="rs_test123",
|
||||
text="I need to search for hotels",
|
||||
additional_properties={"status": "completed"},
|
||||
),
|
||||
Content.from_function_call(
|
||||
call_id="call_1",
|
||||
name="search_hotels",
|
||||
arguments='{"city": "Paris"}',
|
||||
additional_properties={"fc_id": "fc_test456"},
|
||||
),
|
||||
],
|
||||
),
|
||||
Message(
|
||||
role="tool",
|
||||
contents=[
|
||||
Content.from_function_result(
|
||||
call_id="call_1",
|
||||
result="Found 3 hotels in Paris",
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
options = await client._prepare_options(
|
||||
messages,
|
||||
ChatOptions(store=False, conversation_id="resp_prev123"), # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
reasoning_items = [item for item in options["input"] if item.get("type") == "reasoning"]
|
||||
assert len(reasoning_items) == 1
|
||||
assert reasoning_items[0]["id"] == "rs_test123"
|
||||
assert options["previous_response_id"] == "resp_prev123"
|
||||
|
||||
|
||||
async def test_prepare_options_with_conversation_id_omits_reasoning_items_for_attributed_replay() -> None:
|
||||
client = OpenAIChatClient(model="test-model", api_key="test-key")
|
||||
messages = [
|
||||
Message(role="user", contents=[Content.from_text(text="search for hotels")]),
|
||||
Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_text_reasoning(
|
||||
id="rs_history123",
|
||||
text="I need to search history for hotels",
|
||||
additional_properties={"status": "completed"},
|
||||
),
|
||||
Content.from_function_call(
|
||||
call_id="call_history",
|
||||
name="search_hotels",
|
||||
arguments='{"city": "Paris"}',
|
||||
additional_properties={"fc_id": "fc_history456"},
|
||||
),
|
||||
],
|
||||
additional_properties={"_attribution": {"source_id": "history", "source_type": "InMemoryHistoryProvider"}},
|
||||
),
|
||||
Message(
|
||||
role="tool",
|
||||
contents=[
|
||||
Content.from_function_result(
|
||||
call_id="call_history",
|
||||
result="Found 3 hotels in Paris",
|
||||
),
|
||||
],
|
||||
),
|
||||
Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_text_reasoning(
|
||||
id="rs_live123",
|
||||
text="I should refine the search for a live follow-up",
|
||||
additional_properties={"status": "completed"},
|
||||
),
|
||||
Content.from_function_call(
|
||||
call_id="call_live",
|
||||
name="search_hotels",
|
||||
arguments='{"city": "London"}',
|
||||
additional_properties={"fc_id": "fc_live456"},
|
||||
),
|
||||
],
|
||||
),
|
||||
Message(
|
||||
role="tool",
|
||||
contents=[
|
||||
Content.from_function_result(
|
||||
call_id="call_live",
|
||||
result="Found 4 hotels in London",
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
options = await client._prepare_options(
|
||||
messages,
|
||||
ChatOptions(store=False, conversation_id="resp_prev123"), # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
reasoning_items = [item for item in options["input"] if item.get("type") == "reasoning"]
|
||||
assert [item["id"] for item in reasoning_items] == ["rs_live123"]
|
||||
assert any(
|
||||
item.get("type") == "function_call" and item.get("call_id") == "call_history" for item in options["input"]
|
||||
)
|
||||
assert any(item.get("type") == "function_call" and item.get("call_id") == "call_live" for item in options["input"])
|
||||
assert any(
|
||||
item.get("type") == "function_call_output" and item.get("call_id") == "call_history"
|
||||
for item in options["input"]
|
||||
)
|
||||
assert any(
|
||||
item.get("type") == "function_call_output" and item.get("call_id") == "call_live" for item in options["input"]
|
||||
)
|
||||
assert options["previous_response_id"] == "resp_prev123"
|
||||
|
||||
|
||||
def _create_mock_responses_text_response(*, response_id: str) -> MagicMock:
|
||||
mock_response = MagicMock()
|
||||
mock_response.id = response_id
|
||||
|
||||
Reference in New Issue
Block a user