Python: Implement annotation-based context compaction (#4469)

* Implement annotation-based context compaction

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Handle missing compaction attributes in BaseChatClient

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Fix CI typing and bandit issues

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Optimize incremental compaction annotation pass

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* refinement

* Python: add ToolResultCompactionStrategy and CompactionProvider

Add ToolResultCompactionStrategy that collapses older tool-call groups
into short summary messages (e.g. [Tool calls: get_weather]) while
keeping the most recent groups verbatim. This mirrors the .NET
ToolResultCompactionStrategy from PR #4533.

Add CompactionProvider as a context-provider that auto-applies compaction
before each agent turn and stores compacted history in session state
after each turn.

Includes tests and samples for both features.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* refinement and alignment with dotnet PR

* updated tool result compaction

* updated tool result compaction

* Python: add ToolResultCompactionStrategy, CompactionProvider, and skip_excluded

- ToolResultCompactionStrategy collapses older tool-call groups into
  [Tool results: func_name: result] summaries with bidirectional tracing
  (same pattern as SummarizationStrategy).
- CompactionProvider as BaseContextProvider with separate before_strategy
  and after_strategy parameters. before_strategy compacts loaded context;
  after_strategy compacts stored history via history_source_id.
- InMemoryHistoryProvider gains skip_excluded flag to filter out messages
  marked as excluded by compaction strategies.
- Tests, samples, and exports updated.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* fixed checks

* fix mypy

* Fix: ensure summary messages from both strategies get full compaction annotations

SummarizationStrategy was not calling annotate_message_groups after
inserting its summary message, so the summary lacked core group
annotations (id, kind, index, has_reasoning, _excluded). Added the
missing call. ToolResultCompactionStrategy already had it.

Added tests verifying both strategies produce fully annotated summaries.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* updated propagation

* fix mypy

---------

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Eduard van Valkenburg
2026-03-11 20:23:00 +01:00
committed by GitHub
Unverified
parent 565c0b1623
commit 3e03a305f6
29 changed files with 4397 additions and 205 deletions
@@ -29,6 +29,34 @@ from ._clients import (
SupportsMCPTool,
SupportsWebSearchTool,
)
from ._compaction import (
COMPACTION_STATE_KEY,
EXCLUDE_REASON_KEY,
EXCLUDED_KEY,
GROUP_ANNOTATION_KEY,
GROUP_HAS_REASONING_KEY,
GROUP_ID_KEY,
GROUP_INDEX_KEY,
GROUP_KIND_KEY,
GROUP_TOKEN_COUNT_KEY,
SUMMARIZED_BY_SUMMARY_ID_KEY,
SUMMARY_OF_GROUP_IDS_KEY,
SUMMARY_OF_MESSAGE_IDS_KEY,
CharacterEstimatorTokenizer,
CompactionProvider,
CompactionStrategy,
SelectiveToolCallCompactionStrategy,
SlidingWindowStrategy,
SummarizationStrategy,
TokenBudgetComposedStrategy,
TokenizerProtocol,
ToolResultCompactionStrategy,
TruncationStrategy,
annotate_message_groups,
apply_compaction,
included_messages,
included_token_count,
)
from ._mcp import MCPStdioTool, MCPStreamableHTTPTool, MCPWebsocketTool
from ._middleware import (
AgentContext,
@@ -196,7 +224,19 @@ from .exceptions import (
__all__ = [
"AGENT_FRAMEWORK_USER_AGENT",
"APP_INFO",
"COMPACTION_STATE_KEY",
"DEFAULT_MAX_ITERATIONS",
"EXCLUDED_KEY",
"EXCLUDE_REASON_KEY",
"GROUP_ANNOTATION_KEY",
"GROUP_HAS_REASONING_KEY",
"GROUP_ID_KEY",
"GROUP_INDEX_KEY",
"GROUP_KIND_KEY",
"GROUP_TOKEN_COUNT_KEY",
"SUMMARIZED_BY_SUMMARY_ID_KEY",
"SUMMARY_OF_GROUP_IDS_KEY",
"SUMMARY_OF_MESSAGE_IDS_KEY",
"USER_AGENT_KEY",
"USER_AGENT_TELEMETRY_DISABLED_ENV_VAR",
"Agent",
@@ -218,6 +258,7 @@ __all__ = [
"BaseEmbeddingClient",
"BaseHistoryProvider",
"Case",
"CharacterEstimatorTokenizer",
"ChatAndFunctionMiddlewareTypes",
"ChatContext",
"ChatMiddleware",
@@ -227,6 +268,8 @@ __all__ = [
"ChatResponse",
"ChatResponseUpdate",
"CheckpointStorage",
"CompactionProvider",
"CompactionStrategy",
"Content",
"ContinuationToken",
"Default",
@@ -273,6 +316,7 @@ __all__ = [
"Runner",
"RunnerContext",
"SecretString",
"SelectiveToolCallCompactionStrategy",
"SessionContext",
"SingleEdgeGroup",
"Skill",
@@ -280,8 +324,10 @@ __all__ = [
"SkillScript",
"SkillScriptRunner",
"SkillsProvider",
"SlidingWindowStrategy",
"SubWorkflowRequestMessage",
"SubWorkflowResponseMessage",
"SummarizationStrategy",
"SupportsAgentRun",
"SupportsChatGetResponse",
"SupportsCodeInterpreterTool",
@@ -294,8 +340,12 @@ __all__ = [
"SwitchCaseEdgeGroupCase",
"SwitchCaseEdgeGroupDefault",
"TextSpanRegion",
"TokenBudgetComposedStrategy",
"TokenizerProtocol",
"ToolMode",
"ToolResultCompactionStrategy",
"ToolTypes",
"TruncationStrategy",
"TypeCompatibilityError",
"UpdateT",
"UsageDetails",
@@ -322,12 +372,16 @@ __all__ = [
"__version__",
"add_usage_details",
"agent_middleware",
"annotate_message_groups",
"apply_compaction",
"chat_middleware",
"create_edge_runner",
"detect_media_type_from_base64",
"executor",
"function_middleware",
"handler",
"included_messages",
"included_token_count",
"load_settings",
"map_chat_to_agent_update",
"merge_chat_options",
@@ -74,6 +74,7 @@ else:
from typing_extensions import Self, TypedDict # pragma: no cover
if TYPE_CHECKING:
from ._compaction import CompactionStrategy, TokenizerProtocol
from ._types import ChatOptions
logger = logging.getLogger("agent_framework")
@@ -177,6 +178,8 @@ class _RunContext(TypedDict):
session_messages: Sequence[Message]
agent_name: str
chat_options: MutableMapping[str, Any]
compaction_strategy: CompactionStrategy | None
tokenizer: TokenizerProtocol | None
filtered_kwargs: Mapping[str, Any]
finalize_kwargs: Mapping[str, Any]
@@ -665,6 +668,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
default_options: OptionsCoT | None = None,
context_providers: Sequence[BaseContextProvider] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
) -> None:
"""Initialize a Agent instance.
@@ -688,6 +693,10 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
Note: response_format typing does not flow into run outputs when set via default_options.
These can be overridden at runtime via the ``options`` parameter of ``run()``.
tools: The tools to use for the request.
compaction_strategy: Optional agent-level in-run compaction.
If both this and a compaction_strategy on the underlying client are set, this one is used.
tokenizer: Optional agent-level tokenizer.
If both this and a tokenizer on the underlying client are set, this one is used.
kwargs: Any additional keyword arguments. Will be stored as ``additional_properties``.
"""
opts = dict(default_options) if default_options else {}
@@ -705,6 +714,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
**kwargs,
)
self.client = client
self.compaction_strategy = compaction_strategy
self.tokenizer = tokenizer
# Get tools from options or named parameter (named param takes precedence)
tools_ = tools if tools is not None else opts.pop("tools", None)
@@ -799,6 +810,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
session: AgentSession | None = None,
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
options: ChatOptions[ResponseModelBoundT],
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[ResponseModelBoundT]]: ...
@@ -811,6 +824,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
session: AgentSession | None = None,
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
options: OptionsCoT | ChatOptions[None] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]: ...
@@ -823,6 +838,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
session: AgentSession | None = None,
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
options: OptionsCoT | ChatOptions[Any] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
@@ -834,6 +851,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
session: AgentSession | None = None,
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
options: OptionsCoT | ChatOptions[Any] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
"""Run the agent with the given messages and options.
@@ -857,8 +876,14 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
``Agent[OpenAIChatOptions]``, this enables IDE autocomplete for
provider-specific options including temperature, max_tokens, model_id,
tool_choice, and provider-specific options like reasoning_effort.
kwargs: Additional keyword arguments for the agent.
Will only be passed to functions that are called.
compaction_strategy: Optional per-run compaction override passed to
``client.get_response()``. When omitted, the agent-level override
is used, falling back to the client default.
tokenizer: Optional per-run tokenizer override passed to
``client.get_response()``. When omitted, the agent-level override
is used, falling back to the client default.
kwargs: Additional keyword arguments for the agent. These are only
passed to functions that are called.
Returns:
When stream=False: An Awaitable[AgentResponse] containing the agent's response.
@@ -873,6 +898,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
session=session,
tools=tools,
options=options,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
kwargs=kwargs,
)
response = cast(
@@ -881,6 +908,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
messages=ctx["session_messages"],
stream=False,
options=ctx["chat_options"], # type: ignore[reportArgumentType]
compaction_strategy=ctx["compaction_strategy"],
tokenizer=ctx["tokenizer"],
**ctx["filtered_kwargs"],
),
)
@@ -954,6 +983,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
session=session,
tools=tools,
options=options,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
kwargs=kwargs,
)
ctx: _RunContext = ctx_holder["ctx"] # type: ignore[assignment] # Safe: we just assigned it
@@ -961,6 +992,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
messages=ctx["session_messages"],
stream=True,
options=ctx["chat_options"], # type: ignore[reportArgumentType]
compaction_strategy=ctx["compaction_strategy"],
tokenizer=ctx["tokenizer"],
**ctx["filtered_kwargs"],
)
@@ -1047,6 +1080,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
session: AgentSession | None,
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None,
options: Mapping[str, Any] | None,
compaction_strategy: CompactionStrategy | None,
tokenizer: TokenizerProtocol | None,
kwargs: dict[str, Any],
) -> _RunContext:
opts = dict(options) if options else {}
@@ -1081,9 +1116,10 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
options=opts,
)
agent_name = self._get_agent_name()
# Normalize tools
normalized_tools = normalize_tools(tools_)
agent_name = self._get_agent_name()
# Resolve final tool list (runtime provided tools + local MCP server tools)
final_tools: list[FunctionTool | Callable[..., Any] | dict[str, Any] | Any] = []
@@ -1153,6 +1189,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
"session_messages": session_messages,
"agent_name": agent_name,
"chat_options": co,
"compaction_strategy": compaction_strategy or self.compaction_strategy,
"tokenizer": tokenizer or self.tokenizer,
"filtered_kwargs": filtered_kwargs,
"finalize_kwargs": finalize_kwargs,
}
@@ -1408,6 +1446,8 @@ class Agent(
default_options: OptionsCoT | None = None,
context_providers: Sequence[BaseContextProvider] | None = None,
middleware: Sequence[MiddlewareTypes] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
) -> None:
"""Initialize a Agent instance."""
@@ -1421,5 +1461,7 @@ class Agent(
default_options=default_options,
context_providers=context_providers,
middleware=middleware,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
**kwargs,
)
@@ -52,6 +52,7 @@ else:
if TYPE_CHECKING:
from ._agents import Agent
from ._compaction import CompactionStrategy, TokenizerProtocol
from ._middleware import (
MiddlewareTypes,
)
@@ -134,6 +135,8 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
*,
stream: Literal[False] = ...,
options: ChatOptions[ResponseModelBoundT],
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ...
@@ -144,6 +147,8 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
*,
stream: Literal[False] = ...,
options: OptionsContraT | ChatOptions[None] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[Any]]: ...
@@ -154,6 +159,8 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
*,
stream: Literal[True],
options: OptionsContraT | ChatOptions[Any] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
@@ -163,6 +170,8 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
*,
stream: bool = False,
options: OptionsContraT | ChatOptions[Any] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
"""Send input and return the response.
@@ -171,6 +180,8 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
messages: The sequence of input messages to send.
stream: Whether to stream the response. Defaults to False.
options: Chat options as a TypedDict.
compaction_strategy: Optional per-call compaction override.
tokenizer: Optional per-call tokenizer override.
**kwargs: Additional chat options.
Returns:
@@ -252,7 +263,13 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
"""
OTEL_PROVIDER_NAME: ClassVar[str] = "unknown"
DEFAULT_EXCLUDE: ClassVar[set[str]] = {"additional_properties"}
compaction_strategy: CompactionStrategy | None = None
tokenizer: TokenizerProtocol | None = None
DEFAULT_EXCLUDE: ClassVar[set[str]] = {
"additional_properties",
"compaction_strategy",
"tokenizer",
}
STORES_BY_DEFAULT: ClassVar[bool] = False
"""Whether this client stores conversation history server-side by default.
@@ -267,15 +284,21 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
self,
*,
additional_properties: dict[str, Any] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
) -> None:
"""Initialize a BaseChatClient instance.
Keyword Args:
additional_properties: Additional properties for the client.
compaction_strategy: Optional compaction strategy to apply before model calls.
tokenizer: Optional tokenizer used by token-aware compaction strategies.
kwargs: Additional keyword arguments (merged into additional_properties).
"""
self.additional_properties = additional_properties or {}
self.compaction_strategy = compaction_strategy
self.tokenizer = tokenizer
super().__init__(**kwargs)
def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]:
@@ -337,6 +360,46 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
finalizer=lambda updates: self._finalize_response_updates(updates, response_format=response_format),
)
async def _prepare_messages_for_model_call(
self,
messages: Sequence[Message],
*,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
) -> list[Message]:
prepared_messages = list(messages)
if compaction_strategy is None:
if tokenizer is None:
return prepared_messages
from ._compaction import annotate_message_groups
annotate_message_groups(prepared_messages, tokenizer=tokenizer)
return prepared_messages
from ._compaction import apply_compaction
return await apply_compaction(
prepared_messages,
strategy=compaction_strategy,
tokenizer=tokenizer,
)
def _resolve_compaction_overrides(
self,
*,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
) -> dict[str, Any]:
current_compaction_strategy = getattr(self, "compaction_strategy", None)
current_tokenizer = getattr(self, "tokenizer", None)
ret: dict[str, Any] = {}
if current_compaction_strategy is not None or compaction_strategy is not None:
ret["compaction_strategy"] = (
current_compaction_strategy if compaction_strategy is None else compaction_strategy
)
if current_tokenizer is not None or tokenizer is not None:
ret["tokenizer"] = current_tokenizer if tokenizer is None else tokenizer
return ret
# region Internal method to be implemented by derived classes
@abstractmethod
@@ -374,6 +437,8 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
*,
stream: Literal[False] = ...,
options: ChatOptions[ResponseModelBoundT],
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ...
@@ -384,6 +449,8 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
*,
stream: Literal[False] = ...,
options: OptionsCoT | ChatOptions[None] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[Any]]: ...
@@ -394,6 +461,8 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
*,
stream: Literal[True],
options: OptionsCoT | ChatOptions[Any] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
@@ -403,6 +472,8 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
*,
stream: bool = False,
options: OptionsCoT | ChatOptions[Any] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
"""Get a response from a chat client.
@@ -411,17 +482,62 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
messages: The message or messages to send to the model.
stream: Whether to stream the response. Defaults to False.
options: Chat options as a TypedDict.
compaction_strategy: Optional per-call override for in-run compaction.
When omitted, the client-level default is used.
tokenizer: Optional per-call tokenizer override. When omitted, the
client-level default is used.
**kwargs: Other keyword arguments, can be used to pass function specific parameters.
Returns:
When streaming a response stream of ChatResponseUpdates, otherwise an Awaitable ChatResponse.
"""
return self._inner_get_response(
messages=messages,
stream=stream,
options=options or {}, # type: ignore[arg-type]
**kwargs,
compaction_overrides = self._resolve_compaction_overrides(
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
)
if not compaction_overrides:
return self._inner_get_response(
messages=messages,
stream=stream,
options=options or {},
**kwargs,
)
if stream:
async def _get_stream() -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
prepared_messages = await self._prepare_messages_for_model_call(
messages,
**compaction_overrides,
)
stream_response = self._inner_get_response(
messages=prepared_messages,
stream=True,
options=options or {},
**kwargs,
)
if isinstance(stream_response, ResponseStream):
return stream_response # type: ignore[reportUnknownVariableType]
awaited_stream_response = await stream_response
if isinstance(awaited_stream_response, ResponseStream):
return awaited_stream_response
raise ValueError("Streaming responses must return a ResponseStream.")
return ResponseStream.from_awaitable(_get_stream()) # type: ignore[reportUnknownVariableType]
async def _get_response() -> ChatResponse[Any]:
prepared_messages = await self._prepare_messages_for_model_call(
messages,
**compaction_overrides,
)
return await self._inner_get_response(
messages=prepared_messages,
stream=False,
options=options or {},
**kwargs,
)
return _get_response()
def service_url(self) -> str:
"""Get the URL of the service.
@@ -446,6 +562,8 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
context_providers: Sequence[Any] | None = None,
middleware: Sequence[MiddlewareTypes] | None = None,
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
) -> Agent[OptionsCoT]:
"""Create a Agent with this client.
@@ -468,6 +586,10 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
context_providers: Context providers to include during agent invocation.
middleware: List of middleware to intercept agent and function invocations.
function_invocation_configuration: Optional function invocation configuration override.
compaction_strategy: Optional agent-level compaction override. When omitted,
client-level compaction defaults remain in effect for each call.
tokenizer: Optional agent-level tokenizer override. When omitted,
client-level tokenizer defaults remain in effect for each call.
kwargs: Any additional keyword arguments. Will be stored as ``additional_properties``.
Returns:
@@ -504,6 +626,8 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
context_providers=context_providers,
middleware=middleware,
function_invocation_configuration=function_invocation_configuration,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
**kwargs,
)
File diff suppressed because it is too large Load Diff
@@ -37,6 +37,7 @@ if TYPE_CHECKING:
from ._agents import SupportsAgentRun
from ._clients import SupportsChatGetResponse
from ._compaction import CompactionStrategy, TokenizerProtocol
from ._sessions import AgentSession
from ._tools import FunctionTool
from ._types import ChatOptions, ChatResponse, ChatResponseUpdate
@@ -101,6 +102,8 @@ class AgentContext:
session: The agent session for this invocation, if any.
options: The options for the agent invocation as a dict.
stream: Whether this is a streaming invocation.
compaction_strategy: Optional per-run compaction override.
tokenizer: Optional per-run tokenizer override.
metadata: Metadata dictionary for sharing data between agent middleware.
result: Agent execution result. Can be observed after calling ``call_next()``
to see the actual execution result or can be set to override the execution result.
@@ -139,6 +142,8 @@ class AgentContext:
session: AgentSession | None = None,
options: Mapping[str, Any] | None = None,
stream: bool = False,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
metadata: Mapping[str, Any] | None = None,
result: AgentResponse | ResponseStream[AgentResponseUpdate, AgentResponse] | None = None,
kwargs: Mapping[str, Any] | None = None,
@@ -158,6 +163,8 @@ class AgentContext:
session: The agent session for this invocation, if any.
options: The options for the agent invocation as a dict.
stream: Whether this is a streaming invocation.
compaction_strategy: Optional per-run compaction override.
tokenizer: Optional per-run tokenizer override.
metadata: Metadata dictionary for sharing data between agent middleware.
result: Agent execution result.
kwargs: Additional keyword arguments passed to the agent run method.
@@ -170,6 +177,8 @@ class AgentContext:
self.session = session
self.options = options
self.stream = stream
self.compaction_strategy = compaction_strategy
self.tokenizer = tokenizer
self.metadata: dict[str, Any] = dict(metadata) if metadata is not None else {}
self.result = result
self.kwargs: dict[str, Any] = dict(kwargs) if kwargs is not None else {}
@@ -969,6 +978,8 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
*,
stream: Literal[False] = ...,
options: ChatOptions[ResponseModelBoundT],
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ...
@@ -979,6 +990,8 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
*,
stream: Literal[False] = ...,
options: OptionsCoT | ChatOptions[None] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[Any]]: ...
@@ -989,6 +1002,8 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
*,
stream: Literal[True],
options: OptionsCoT | ChatOptions[Any] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
@@ -998,11 +1013,18 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
*,
stream: bool = False,
options: OptionsCoT | ChatOptions[Any] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
"""Execute the chat pipeline if middleware is configured."""
super_get_response = super().get_response # type: ignore[misc]
if compaction_strategy is not None:
kwargs["compaction_strategy"] = compaction_strategy
if tokenizer is not None:
kwargs["tokenizer"] = tokenizer
call_middleware = kwargs.pop("middleware", [])
middleware = categorize_middleware(call_middleware)
kwargs["function_middleware"] = middleware["function"]
@@ -1091,6 +1113,8 @@ class AgentMiddlewareLayer:
session: AgentSession | None = None,
middleware: Sequence[MiddlewareTypes] | None = None,
options: ChatOptions[ResponseModelBoundT],
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[ResponseModelBoundT]]: ...
@@ -1103,6 +1127,8 @@ class AgentMiddlewareLayer:
session: AgentSession | None = None,
middleware: Sequence[MiddlewareTypes] | None = None,
options: ChatOptions[None] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]: ...
@@ -1115,6 +1141,8 @@ class AgentMiddlewareLayer:
session: AgentSession | None = None,
middleware: Sequence[MiddlewareTypes] | None = None,
options: ChatOptions[Any] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
@@ -1126,6 +1154,8 @@ class AgentMiddlewareLayer:
session: AgentSession | None = None,
middleware: Sequence[MiddlewareTypes] | None = None,
options: ChatOptions[Any] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
"""MiddlewareTypes-enabled unified run method."""
@@ -1150,7 +1180,15 @@ class AgentMiddlewareLayer:
# Execute with middleware if available
if not pipeline.has_middlewares:
return super().run(messages, stream=stream, session=session, options=options, **combined_kwargs) # type: ignore[misc, no-any-return]
return super().run( # type: ignore[misc, no-any-return]
messages,
stream=stream,
session=session,
options=options,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
**combined_kwargs,
)
context = AgentContext(
agent=self, # type: ignore[arg-type]
@@ -1158,6 +1196,8 @@ class AgentMiddlewareLayer:
session=session,
options=options,
stream=stream,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
kwargs=combined_kwargs,
)
@@ -1195,6 +1235,8 @@ class AgentMiddlewareLayer:
stream=context.stream,
session=context.session,
options=context.options,
compaction_strategy=context.compaction_strategy,
tokenizer=context.tokenizer,
**context.kwargs,
)
@@ -547,6 +547,7 @@ class InMemoryHistoryProvider(BaseHistoryProvider):
store_context_messages: bool = False,
store_context_from: set[str] | None = None,
store_outputs: bool = True,
skip_excluded: bool = False,
) -> None:
"""Initialize the in-memory history provider.
@@ -558,6 +559,11 @@ class InMemoryHistoryProvider(BaseHistoryProvider):
store_context_messages: Whether to store context from other providers.
store_context_from: If set, only store context from these source_ids.
store_outputs: Whether to store response messages.
skip_excluded: When True, ``get_messages`` omits messages whose
``additional_properties["_excluded"]`` is truthy. This is
useful when a ``CompactionProvider`` marks messages as excluded
in stored history and you want the loaded context to reflect
those exclusions. Defaults to False (load all messages).
"""
super().__init__(
source_id=source_id or self.DEFAULT_SOURCE_ID,
@@ -567,6 +573,7 @@ class InMemoryHistoryProvider(BaseHistoryProvider):
store_context_from=store_context_from,
store_outputs=store_outputs,
)
self.skip_excluded = skip_excluded
async def get_messages(
self, session_id: str | None, *, state: dict[str, Any] | None = None, **kwargs: Any
@@ -574,7 +581,10 @@ class InMemoryHistoryProvider(BaseHistoryProvider):
"""Retrieve messages from session state."""
if state is None:
return []
return list(state.get("messages", []))
messages = list(state.get("messages", []))
if self.skip_excluded:
messages = [m for m in messages if not m.additional_properties.get("_excluded", False)]
return messages
async def save_messages(
self,
@@ -196,9 +196,7 @@ class SkillScript:
self._accepts_kwargs: bool = False
if function is not None:
sig = inspect.signature(function)
self._accepts_kwargs = any(
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
)
self._accepts_kwargs = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values())
@property
def parameters_schema(self) -> dict[str, Any] | None:
@@ -454,9 +452,7 @@ class SkillScriptRunner(Protocol):
satisfies this protocol.
"""
def __call__(
self, skill: Skill, script: SkillScript, args: dict[str, Any] | None = None
) -> Any:
def __call__(self, skill: Skill, script: SkillScript, args: dict[str, Any] | None = None) -> Any:
"""Run a skill script.
The :class:`SkillsProvider` resolves skill and script names
@@ -677,7 +673,7 @@ class SkillsProvider(BaseContextProvider):
self._instructions = _create_instructions(
prompt_template=instruction_template,
skills=self._skills,
include_script_runner_instructions=has_file_scripts or has_code_scripts
include_script_runner_instructions=has_file_scripts or has_code_scripts,
)
self._tools = self._create_tools(
@@ -59,6 +59,7 @@ else:
if TYPE_CHECKING:
from ._clients import SupportsChatGetResponse
from ._compaction import CompactionStrategy, TokenizerProtocol
from ._mcp import MCPTool
from ._middleware import FunctionMiddlewarePipeline, FunctionMiddlewareTypes
from ._types import (
@@ -1811,6 +1812,8 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
*,
stream: Literal[False] = ...,
options: ChatOptions[ResponseModelBoundT],
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ...
@@ -1821,6 +1824,8 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
*,
stream: Literal[False] = ...,
options: OptionsCoT | ChatOptions[None] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[Any]]: ...
@@ -1831,6 +1836,8 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
*,
stream: Literal[True],
options: OptionsCoT | ChatOptions[Any] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
@@ -1841,6 +1848,8 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
stream: bool = False,
options: OptionsCoT | ChatOptions[Any] | None = None,
function_middleware: Sequence[FunctionMiddlewareTypes] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
from ._middleware import FunctionMiddlewarePipeline
@@ -1869,6 +1878,10 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
middleware_pipeline=function_middleware_pipeline,
)
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "session"}
if compaction_strategy is not None:
filtered_kwargs["compaction_strategy"] = compaction_strategy
if tokenizer is not None:
filtered_kwargs["tokenizer"] = tokenizer
# Make options mutable so we can update conversation_id during function invocation loop
mutable_options: dict[str, Any] = dict(options) if options else {}
+37 -8
View File
@@ -277,6 +277,17 @@ def _serialize_value(value: Any, exclude_none: bool) -> Any:
return value
def _restore_compaction_annotation_in_additional_properties(
additional_properties: MutableMapping[str, Any] | None,
*,
allow_none: bool = False,
) -> dict[str, Any] | None:
if additional_properties is None:
return None if allow_none else {}
return dict(additional_properties)
# endregion
# region Constants and types
@@ -509,7 +520,9 @@ class Content:
"""
self.type = type
self.annotations = annotations
self.additional_properties: dict[str, Any] = additional_properties or {} # type: ignore[assignment]
self.additional_properties: dict[str, Any] = (
_restore_compaction_annotation_in_additional_properties(additional_properties) or {}
)
self.raw_representation = raw_representation
# Set all content-specific attributes
@@ -1638,7 +1651,9 @@ class Message(SerializationMixin):
self.contents = parsed_contents
self.author_name = author_name
self.message_id = message_id
self.additional_properties = additional_properties or {}
self.additional_properties = (
_restore_compaction_annotation_in_additional_properties(additional_properties) or {}
)
self.raw_representation = raw_representation
@property
@@ -1989,7 +2004,9 @@ class ChatResponse(SerializationMixin, Generic[ResponseModelT]):
self._value: ResponseModelT | None = value
self._response_format: type[BaseModel] | None = response_format
self._value_parsed: bool = value is not None
self.additional_properties = additional_properties or {}
self.additional_properties = (
_restore_compaction_annotation_in_additional_properties(additional_properties) or {}
)
self.continuation_token = continuation_token
self.raw_representation: Any | list[Any] | None = raw_representation
@@ -2239,7 +2256,10 @@ class ChatResponseUpdate(SerializationMixin):
self.created_at = created_at
self.finish_reason = finish_reason
self.continuation_token = continuation_token
self.additional_properties = additional_properties
self.additional_properties = _restore_compaction_annotation_in_additional_properties(
additional_properties,
allow_none=True,
)
self.raw_representation = raw_representation
@property
@@ -2352,7 +2372,9 @@ class AgentResponse(SerializationMixin, Generic[ResponseModelT]):
self._value: ResponseModelT | None = value
self._response_format: type[BaseModel] | None = response_format
self._value_parsed: bool = value is not None
self.additional_properties = additional_properties or {}
self.additional_properties = (
_restore_compaction_annotation_in_additional_properties(additional_properties) or {}
)
self.continuation_token = continuation_token
self.raw_representation = raw_representation
@@ -2582,7 +2604,10 @@ class AgentResponseUpdate(SerializationMixin):
self.message_id = message_id
self.created_at = created_at
self.continuation_token = continuation_token
self.additional_properties = additional_properties
self.additional_properties = _restore_compaction_annotation_in_additional_properties(
additional_properties,
allow_none=True,
)
self.raw_representation: Any | list[Any] | None = raw_representation
@property
@@ -3381,7 +3406,9 @@ class Embedding(Generic[EmbeddingT]):
self._dimensions = dimensions
self.model_id = model_id
self.created_at = created_at
self.additional_properties = additional_properties or {}
self.additional_properties = (
_restore_compaction_annotation_in_additional_properties(additional_properties) or {}
)
@property
def dimensions(self) -> int | None:
@@ -3439,7 +3466,9 @@ class GeneratedEmbeddings(list[Embedding[EmbeddingT]], Generic[EmbeddingT, Embed
super().__init__(embeddings or [])
self.options = options
self.usage = usage
self.additional_properties = additional_properties or {}
self.additional_properties = (
_restore_compaction_annotation_in_additional_properties(additional_properties) or {}
)
# endregion
@@ -49,6 +49,7 @@ if TYPE_CHECKING: # pragma: no cover
from ._agents import SupportsAgentRun
from ._clients import SupportsChatGetResponse
from ._compaction import CompactionStrategy, TokenizerProtocol
from ._sessions import AgentSession
from ._tools import FunctionTool
from ._types import (
@@ -1122,6 +1123,8 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
*,
stream: Literal[False] = ...,
options: ChatOptions[ResponseModelBoundT],
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ...
@@ -1132,6 +1135,8 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
*,
stream: Literal[False] = ...,
options: OptionsCoT | ChatOptions[None] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[Any]]: ...
@@ -1142,6 +1147,8 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
*,
stream: Literal[True],
options: OptionsCoT | ChatOptions[Any] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
@@ -1151,6 +1158,8 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
*,
stream: bool = False,
options: OptionsCoT | ChatOptions[Any] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
"""Trace chat responses with OpenTelemetry spans and metrics."""
@@ -1160,7 +1169,14 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
super_get_response = super().get_response # type: ignore[misc]
if not OBSERVABILITY_SETTINGS.ENABLED:
return super_get_response(messages=messages, stream=stream, options=options, **kwargs) # type: ignore[no-any-return]
return super_get_response( # type: ignore[no-any-return]
messages=messages,
stream=stream,
options=options,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
**kwargs,
)
opts: dict[str, Any] = options or {} # type: ignore[assignment]
provider_name = str(getattr(self, "otel_provider_name", "unknown"))
@@ -1178,7 +1194,14 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
if stream:
result_stream = cast(
ResponseStream[ChatResponseUpdate, ChatResponse[Any]],
super_get_response(messages=messages, stream=True, options=opts, **kwargs),
super_get_response(
messages=messages,
stream=True,
options=opts,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
**kwargs,
),
)
# Create span directly without trace.use_span() context attachment.
@@ -1266,6 +1289,8 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
messages=messages,
stream=False,
options=opts,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
**kwargs,
),
)
@@ -1393,6 +1418,8 @@ class AgentTelemetryLayer:
*,
stream: Literal[False] = ...,
session: AgentSession | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]: ...
@@ -1403,6 +1430,8 @@ class AgentTelemetryLayer:
*,
stream: Literal[True],
session: AgentSession | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
@@ -1412,6 +1441,8 @@ class AgentTelemetryLayer:
*,
stream: bool = False,
session: AgentSession | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
"""Trace agent runs with OpenTelemetry spans and metrics."""
@@ -1430,6 +1461,8 @@ class AgentTelemetryLayer:
messages=messages,
stream=stream,
session=session,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
**kwargs,
)
@@ -1452,6 +1485,8 @@ class AgentTelemetryLayer:
messages=messages,
stream=True,
session=session,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
**kwargs,
)
if isinstance(run_result, ResponseStream):
@@ -1541,6 +1576,8 @@ class AgentTelemetryLayer:
messages=messages,
stream=False,
session=session,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
**kwargs,
)
except Exception as exception:
@@ -1164,7 +1164,6 @@ class RawOpenAIResponsesClient( # type: ignore[misc]
"type": "function_call",
"name": content.name,
"arguments": content.arguments,
"status": None,
}
case "function_result":
shell_output_type = (
@@ -10,6 +10,8 @@ import pytest
from pytest import raises
from agent_framework import (
GROUP_ANNOTATION_KEY,
GROUP_TOKEN_COUNT_KEY,
Agent,
AgentResponse,
AgentResponseUpdate,
@@ -21,14 +23,24 @@ from agent_framework import (
Content,
FunctionTool,
Message,
SlidingWindowStrategy,
SupportsAgentRun,
SupportsChatGetResponse,
TruncationStrategy,
tool,
)
from agent_framework._agents import _get_tool_name, _merge_options, _sanitize_agent_name
from agent_framework._mcp import MCPTool
class _FixedTokenizer:
def __init__(self, token_count: int) -> None:
self.token_count = token_count
def count_tokens(self, text: str) -> int:
return self.token_count
def test_agent_session_type(agent_session: AgentSession) -> None:
assert isinstance(agent_session, AgentSession)
@@ -217,6 +229,30 @@ async def test_prepare_session_does_not_mutate_agent_chat_options(
assert len(agent.default_options["tools"]) == 1
async def test_prepare_run_context_keeps_compaction_overrides_out_of_kwargs(
chat_client_base: SupportsChatGetResponse,
) -> None:
strategy = SlidingWindowStrategy(keep_last_groups=2)
tokenizer = _FixedTokenizer(13)
agent = Agent(client=chat_client_base)
ctx = await agent._prepare_run_context( # type: ignore[reportPrivateUsage]
messages=[Message(role="user", text="Hello")],
session=None,
tools=None,
options=None,
compaction_strategy=strategy,
tokenizer=tokenizer,
kwargs={"custom_flag": True},
)
assert ctx["compaction_strategy"] is strategy
assert ctx["tokenizer"] is tokenizer
assert ctx["filtered_kwargs"].get("custom_flag") is True
assert "compaction_strategy" not in ctx["filtered_kwargs"]
assert "tokenizer" not in ctx["filtered_kwargs"]
async def test_chat_client_agent_run_with_session(
chat_client_base: SupportsChatGetResponse,
) -> None:
@@ -1128,6 +1164,102 @@ async def test_chat_agent_tool_choice_none_at_run_preserves_agent_level(chat_cli
assert captured_options[0]["tool_choice"] == "auto"
async def test_chat_agent_compaction_overrides_client_defaults(chat_client_base: Any) -> None:
captured_roles: list[list[str]] = []
captured_token_counts: list[list[int | None]] = []
original_inner = chat_client_base._inner_get_response
async def capturing_inner(
*, messages: MutableSequence[Message], options: dict[str, Any], **kwargs: Any
) -> ChatResponse:
captured_roles.append([message.role for message in messages])
captured_token_counts.append([
group.get(GROUP_TOKEN_COUNT_KEY) if isinstance(group, dict) else None
for group in (message.additional_properties.get(GROUP_ANNOTATION_KEY) for message in messages)
])
return await original_inner(messages=messages, options=options, **kwargs)
chat_client_base._inner_get_response = capturing_inner
chat_client_base.function_invocation_configuration["enabled"] = False
chat_client_base.compaction_strategy = TruncationStrategy(max_n=1, compact_to=1)
chat_client_base.tokenizer = _FixedTokenizer(5)
agent = Agent(
client=chat_client_base,
compaction_strategy=SlidingWindowStrategy(keep_last_groups=2),
tokenizer=_FixedTokenizer(9),
)
await agent.run([
Message(role="user", text="Hello"),
Message(role="assistant", text="Previous response"),
])
assert captured_roles == [["user", "assistant"]]
assert captured_token_counts == [[9, 9]]
async def test_chat_agent_uses_client_compaction_defaults_when_agent_unset(chat_client_base: Any) -> None:
captured_roles: list[list[str]] = []
original_inner = chat_client_base._inner_get_response
async def capturing_inner(
*, messages: MutableSequence[Message], options: dict[str, Any], **kwargs: Any
) -> ChatResponse:
captured_roles.append([message.role for message in messages])
return await original_inner(messages=messages, options=options, **kwargs)
chat_client_base._inner_get_response = capturing_inner
chat_client_base.function_invocation_configuration["enabled"] = False
chat_client_base.compaction_strategy = TruncationStrategy(max_n=1, compact_to=1)
agent = Agent(client=chat_client_base)
await agent.run([
Message(role="user", text="Hello"),
Message(role="assistant", text="Previous response"),
])
assert captured_roles == [["assistant"]]
async def test_chat_agent_run_level_compaction_and_tokenizer_override_agent_defaults(chat_client_base: Any) -> None:
captured_roles: list[list[str]] = []
captured_token_counts: list[list[int | None]] = []
original_inner = chat_client_base._inner_get_response
async def capturing_inner(
*, messages: MutableSequence[Message], options: dict[str, Any], **kwargs: Any
) -> ChatResponse:
captured_roles.append([message.role for message in messages])
captured_token_counts.append([
group.get(GROUP_TOKEN_COUNT_KEY) if isinstance(group, dict) else None
for group in (message.additional_properties.get(GROUP_ANNOTATION_KEY) for message in messages)
])
return await original_inner(messages=messages, options=options, **kwargs)
chat_client_base._inner_get_response = capturing_inner
chat_client_base.function_invocation_configuration["enabled"] = False
agent = Agent(
client=chat_client_base,
compaction_strategy=SlidingWindowStrategy(keep_last_groups=2),
tokenizer=_FixedTokenizer(9),
)
await agent.run(
[
Message(role="user", text="Hello"),
Message(role="assistant", text="Previous response"),
],
compaction_strategy=TruncationStrategy(max_n=1, compact_to=1),
tokenizer=_FixedTokenizer(23),
)
assert captured_roles == [["assistant"]]
assert captured_token_counts == [[23]]
# region Test _merge_options
@@ -1,21 +1,34 @@
# Copyright (c) Microsoft. All rights reserved.
from typing import Any
from unittest.mock import patch
from agent_framework import (
GROUP_ANNOTATION_KEY,
GROUP_TOKEN_COUNT_KEY,
BaseChatClient,
ChatResponse,
Message,
SlidingWindowStrategy,
SupportsChatGetResponse,
SupportsCodeInterpreterTool,
SupportsFileSearchTool,
SupportsImageGenerationTool,
SupportsMCPTool,
SupportsWebSearchTool,
TruncationStrategy,
)
class _FixedTokenizer:
def __init__(self, token_count: int) -> None:
self.token_count = token_count
def count_tokens(self, text: str) -> int:
return self.token_count
def test_chat_client_type(client: SupportsChatGetResponse):
assert isinstance(client, SupportsChatGetResponse)
@@ -48,6 +61,190 @@ async def test_base_client_get_response_streaming(chat_client_base: SupportsChat
assert update.text == "update - Hello" or update.text == "another update"
async def test_base_client_applies_compaction_before_non_streaming_inner_call(
chat_client_base: SupportsChatGetResponse,
):
chat_client_base.function_invocation_configuration["enabled"] = False # type: ignore[attr-defined]
chat_client_base.compaction_strategy = TruncationStrategy(max_n=1, compact_to=1) # type: ignore[attr-defined]
captured_roles: list[list[str]] = []
original = chat_client_base._get_non_streaming_response # type: ignore[attr-defined]
async def _capture(
*,
messages: list[Message],
options: dict[str, Any],
**kwargs: Any,
) -> ChatResponse:
captured_roles.append([message.role for message in messages])
return await original(messages=messages, options=options, **kwargs)
chat_client_base._get_non_streaming_response = _capture # type: ignore[attr-defined,method-assign]
await chat_client_base.get_response([
Message(role="user", text="Hello"),
Message(role="assistant", text="Previous response"),
])
assert captured_roles == [["assistant"]]
async def test_base_client_applies_compaction_before_streaming_inner_call(
chat_client_base: SupportsChatGetResponse,
):
chat_client_base.function_invocation_configuration["enabled"] = False # type: ignore[attr-defined]
chat_client_base.compaction_strategy = TruncationStrategy(max_n=1, compact_to=1) # type: ignore[attr-defined]
captured_roles: list[list[str]] = []
original = chat_client_base._get_streaming_response # type: ignore[attr-defined]
def _capture(
*,
messages: list[Message],
options: dict[str, Any],
**kwargs: Any,
):
captured_roles.append([message.role for message in messages])
return original(messages=messages, options=options, **kwargs)
chat_client_base._get_streaming_response = _capture # type: ignore[attr-defined,method-assign]
async for _ in chat_client_base.get_response(
[
Message(role="user", text="Hello"),
Message(role="assistant", text="Previous response"),
],
stream=True,
):
pass
assert captured_roles == [["assistant"]]
async def test_base_client_per_call_compaction_override_applies_before_inner_call(
chat_client_base: SupportsChatGetResponse,
) -> None:
chat_client_base.function_invocation_configuration["enabled"] = False # type: ignore[attr-defined]
captured_roles: list[list[str]] = []
original = chat_client_base._get_non_streaming_response # type: ignore[attr-defined]
async def _capture(
*,
messages: list[Message],
options: dict[str, Any],
**kwargs: Any,
) -> ChatResponse:
captured_roles.append([message.role for message in messages])
return await original(messages=messages, options=options, **kwargs)
chat_client_base._get_non_streaming_response = _capture # type: ignore[attr-defined,method-assign]
await chat_client_base.get_response(
[
Message(role="user", text="Hello"),
Message(role="assistant", text="Previous response"),
],
compaction_strategy=TruncationStrategy(max_n=1, compact_to=1),
)
assert captured_roles == [["assistant"]]
async def test_base_client_per_call_tokenizer_override_annotates_messages(
chat_client_base: SupportsChatGetResponse,
) -> None:
chat_client_base.function_invocation_configuration["enabled"] = False # type: ignore[attr-defined]
captured_token_counts: list[list[int | None]] = []
original = chat_client_base._get_non_streaming_response # type: ignore[attr-defined]
async def _capture(
*,
messages: list[Message],
options: dict[str, Any],
**kwargs: Any,
) -> ChatResponse:
captured_token_counts.append([
group.get(GROUP_TOKEN_COUNT_KEY) if isinstance(group, dict) else None
for group in (message.additional_properties.get(GROUP_ANNOTATION_KEY) for message in messages)
])
return await original(messages=messages, options=options, **kwargs)
chat_client_base._get_non_streaming_response = _capture # type: ignore[attr-defined,method-assign]
await chat_client_base.get_response(
[
Message(role="user", text="Hello"),
Message(role="assistant", text="Previous response"),
],
compaction_strategy=SlidingWindowStrategy(keep_last_groups=2),
tokenizer=_FixedTokenizer(17),
)
assert captured_token_counts == [[17, 17]]
async def test_base_client_per_call_tokenizer_override_without_strategy_annotates_messages(
chat_client_base: SupportsChatGetResponse,
) -> None:
chat_client_base.function_invocation_configuration["enabled"] = False # type: ignore[attr-defined]
captured_token_counts: list[list[int | None]] = []
original = chat_client_base._get_non_streaming_response # type: ignore[attr-defined]
async def _capture(
*,
messages: list[Message],
options: dict[str, Any],
**kwargs: Any,
) -> ChatResponse:
captured_token_counts.append([
group.get(GROUP_TOKEN_COUNT_KEY) if isinstance(group, dict) else None
for group in (message.additional_properties.get(GROUP_ANNOTATION_KEY) for message in messages)
])
return await original(messages=messages, options=options, **kwargs)
chat_client_base._get_non_streaming_response = _capture # type: ignore[attr-defined,method-assign]
await chat_client_base.get_response(
[
Message(role="user", text="Hello"),
Message(role="assistant", text="Previous response"),
],
tokenizer=_FixedTokenizer(17),
)
assert captured_token_counts == [[17, 17]]
async def test_base_client_default_tokenizer_without_strategy_annotates_messages(
chat_client_base: SupportsChatGetResponse,
) -> None:
chat_client_base.function_invocation_configuration["enabled"] = False # type: ignore[attr-defined]
chat_client_base.tokenizer = _FixedTokenizer(19) # type: ignore[attr-defined]
captured_token_counts: list[list[int | None]] = []
original = chat_client_base._get_non_streaming_response # type: ignore[attr-defined]
async def _capture(
*,
messages: list[Message],
options: dict[str, Any],
**kwargs: Any,
) -> ChatResponse:
captured_token_counts.append([
group.get(GROUP_TOKEN_COUNT_KEY) if isinstance(group, dict) else None
for group in (message.additional_properties.get(GROUP_ANNOTATION_KEY) for message in messages)
])
return await original(messages=messages, options=options, **kwargs)
chat_client_base._get_non_streaming_response = _capture # type: ignore[attr-defined,method-assign]
await chat_client_base.get_response([
Message(role="user", text="Hello"),
Message(role="assistant", text="Previous response"),
])
assert captured_token_counts == [[19, 19]]
def test_base_client_as_agent_does_not_copy_client_compaction_defaults(
chat_client_base: SupportsChatGetResponse,
) -> None:
strategy = TruncationStrategy(max_n=1, compact_to=1)
tokenizer = _FixedTokenizer(11)
chat_client_base.compaction_strategy = strategy # type: ignore[attr-defined]
chat_client_base.tokenizer = tokenizer # type: ignore[attr-defined]
agent = chat_client_base.as_agent(name="shared-client-agent")
assert agent.compaction_strategy is None # type: ignore[attr-defined]
assert agent.tokenizer is None # type: ignore[attr-defined]
async def test_chat_client_instructions_handling(chat_client_base: SupportsChatGetResponse):
instructions = "You are a helpful assistant."
@@ -0,0 +1,954 @@
# Copyright (c) Microsoft. All rights reserved.
from __future__ import annotations
import logging
from typing import Any
from agent_framework import (
EXCLUDED_KEY,
GROUP_ANNOTATION_KEY,
GROUP_HAS_REASONING_KEY,
GROUP_ID_KEY,
GROUP_KIND_KEY,
GROUP_TOKEN_COUNT_KEY,
SUMMARIZED_BY_SUMMARY_ID_KEY,
SUMMARY_OF_GROUP_IDS_KEY,
SUMMARY_OF_MESSAGE_IDS_KEY,
CharacterEstimatorTokenizer,
ChatResponse,
CompactionProvider,
Content,
Message,
SelectiveToolCallCompactionStrategy,
SlidingWindowStrategy,
SummarizationStrategy,
TokenBudgetComposedStrategy,
ToolResultCompactionStrategy,
TruncationStrategy,
annotate_message_groups,
apply_compaction,
included_messages,
included_token_count,
)
from agent_framework._compaction import (
append_compaction_message,
extend_compaction_messages,
)
def _assistant_function_call(call_id: str) -> Message:
return Message(
role="assistant",
contents=[Content.from_function_call(call_id=call_id, name="tool", arguments='{"value":"x"}')],
)
def _assistant_reasoning_and_function_calls(*call_ids: str) -> Message:
contents: list[Content] = [Content.from_text_reasoning(text="thinking")]
for call_id in call_ids:
contents.append(
Content.from_function_call(
call_id=call_id,
name="tool",
arguments='{"value":"x"}',
)
)
return Message(role="assistant", contents=contents)
def _tool_result(call_id: str, result: str) -> Message:
return Message(
role="tool",
contents=[Content.from_function_result(call_id=call_id, result=result)],
)
def _group_id(message: Message) -> str | None:
annotation = message.additional_properties.get(GROUP_ANNOTATION_KEY)
if not isinstance(annotation, dict):
return None
value = annotation.get(GROUP_ID_KEY)
return value if isinstance(value, str) else None
def _group_kind(message: Message) -> str | None:
annotation = message.additional_properties.get(GROUP_ANNOTATION_KEY)
if not isinstance(annotation, dict):
return None
value = annotation.get(GROUP_KIND_KEY)
return value if isinstance(value, str) else None
def _group_has_reasoning(message: Message) -> bool | None:
annotation = message.additional_properties.get(GROUP_ANNOTATION_KEY)
if not isinstance(annotation, dict):
return None
value = annotation.get(GROUP_HAS_REASONING_KEY)
return value if isinstance(value, bool) else None
def _token_count(message: Message) -> int | None:
annotation = message.additional_properties.get(GROUP_ANNOTATION_KEY)
if not isinstance(annotation, dict):
return None
value = annotation.get(GROUP_TOKEN_COUNT_KEY)
return value if isinstance(value, int) else None
def _group_unknown_value(message: Message, key: str) -> Any:
annotation = message.additional_properties.get(GROUP_ANNOTATION_KEY)
if not isinstance(annotation, dict):
return None
return annotation.get(key)
def test_group_annotations_keep_tool_call_and_tool_result_atomic() -> None:
messages = [
Message(role="user", text="hello"),
_assistant_function_call("c1"),
_tool_result("c1", "ok"),
Message(role="assistant", text="final"),
]
annotate_message_groups(messages)
call_group = _group_id(messages[1])
assert call_group is not None
assert call_group == _group_id(messages[2])
assert _group_id(messages[1]) != _group_id(messages[0])
def test_group_annotations_include_reasoning_in_tool_call_group() -> None:
messages = [
_assistant_reasoning_and_function_calls("c2"),
_tool_result("c2", "ok"),
]
annotate_message_groups(messages)
first_group = _group_id(messages[0])
assert first_group is not None
assert _group_id(messages[1]) == first_group
assert _group_has_reasoning(messages[0]) is True
assert _group_kind(messages[0]) == "tool_call"
def test_group_annotations_handle_same_message_reasoning_and_function_calls() -> None:
messages = [
Message(role="user", text="hello"),
_assistant_reasoning_and_function_calls("c1", "c2"),
_tool_result("c1", "ok1"),
_tool_result("c2", "ok2"),
Message(role="assistant", text="final"),
]
annotate_message_groups(messages)
call_group = _group_id(messages[1])
assert call_group is not None
assert _group_id(messages[2]) == call_group
assert _group_id(messages[3]) == call_group
assert _group_kind(messages[1]) == "tool_call"
assert _group_has_reasoning(messages[1]) is True
def test_annotate_message_groups_with_tokenizer_adds_token_counts() -> None:
messages = [
Message(role="user", text="hello"),
Message(role="assistant", text="world"),
]
annotate_message_groups(
messages,
tokenizer=CharacterEstimatorTokenizer(),
)
assert isinstance(_token_count(messages[0]), int)
assert isinstance(_token_count(messages[1]), int)
def test_extend_compaction_messages_preserves_existing_annotations_and_tokens() -> None:
tokenizer = CharacterEstimatorTokenizer()
messages = [_assistant_function_call("c3")]
annotate_message_groups(messages)
old_group_id = _group_id(messages[0])
assert old_group_id is not None
old_token_count = tokenizer.count_tokens("precomputed")
annotation = messages[0].additional_properties.get(GROUP_ANNOTATION_KEY)
if isinstance(annotation, dict):
annotation[GROUP_TOKEN_COUNT_KEY] = old_token_count
extend_compaction_messages(messages, [_tool_result("c3", "ok")], tokenizer=tokenizer)
assert _group_id(messages[1]) == old_group_id
assert _token_count(messages[0]) == old_token_count
assert isinstance(_token_count(messages[1]), int)
def test_append_compaction_message_annotates_new_message() -> None:
messages = [Message(role="user", text="hello")]
annotate_message_groups(messages)
append_compaction_message(messages, Message(role="assistant", text="world"))
assert len(messages) == 2
assert isinstance(_group_id(messages[1]), str)
async def test_truncation_strategy_keeps_system_anchor() -> None:
messages = [
Message(role="system", text="you are helpful"),
Message(role="user", text="u1"),
Message(role="assistant", text="a1"),
Message(role="user", text="u2"),
Message(role="assistant", text="a2"),
]
strategy = TruncationStrategy(max_n=3, compact_to=3, preserve_system=True)
annotate_message_groups(messages)
changed = await strategy(messages)
assert changed is True
projected = included_messages(messages)
assert projected[0].role == "system"
assert len(projected) <= 3
async def test_truncation_strategy_compacts_when_token_limit_exceeded() -> None:
tokenizer = CharacterEstimatorTokenizer()
messages = [
Message(role="system", text="you are helpful"),
Message(role="user", text="u1 " * 200),
Message(role="assistant", text="a1 " * 200),
]
strategy = TruncationStrategy(
max_n=80,
compact_to=40,
tokenizer=tokenizer,
preserve_system=True,
)
annotate_message_groups(messages, tokenizer=tokenizer)
changed = await strategy(messages)
assert changed is True
projected = included_messages(messages)
assert projected[0].role == "system"
assert included_token_count(messages) <= 40
def test_truncation_strategy_validates_token_targets() -> None:
try:
TruncationStrategy(max_n=3, compact_to=4)
except ValueError as exc:
assert "compact_to must be less than or equal to max_n" in str(exc)
else:
raise AssertionError("Expected ValueError when compact_to is greater than max_n.")
async def test_selective_tool_call_strategy_excludes_older_tool_groups() -> None:
messages = [
Message(role="user", text="u"),
_assistant_function_call("call-1"),
_tool_result("call-1", "r1"),
_assistant_function_call("call-2"),
_tool_result("call-2", "r2"),
Message(role="assistant", text="done"),
]
strategy = SelectiveToolCallCompactionStrategy(keep_last_tool_call_groups=1)
annotate_message_groups(messages)
changed = await strategy(messages)
assert changed is True
assert messages[1].additional_properties.get(EXCLUDED_KEY) is True
assert messages[2].additional_properties.get(EXCLUDED_KEY) is True
assert messages[3].additional_properties.get(EXCLUDED_KEY) is not True
assert messages[4].additional_properties.get(EXCLUDED_KEY) is not True
async def test_selective_tool_call_strategy_with_zero_removes_assistant_tool_pair() -> None:
messages = [
Message(role="user", text="u"),
_assistant_function_call("call-1"),
_tool_result("call-1", "r1"),
Message(role="assistant", text="done"),
]
strategy = SelectiveToolCallCompactionStrategy(keep_last_tool_call_groups=0)
annotate_message_groups(messages)
changed = await strategy(messages)
assert changed is True
assert messages[1].additional_properties.get(EXCLUDED_KEY) is True
assert messages[2].additional_properties.get(EXCLUDED_KEY) is True
assert messages[0].additional_properties.get(EXCLUDED_KEY) is not True
assert messages[3].additional_properties.get(EXCLUDED_KEY) is not True
def test_selective_tool_call_strategy_rejects_negative_keep_count() -> None:
try:
SelectiveToolCallCompactionStrategy(keep_last_tool_call_groups=-1)
except ValueError as exc:
assert "must be greater than or equal to 0" in str(exc)
else:
raise AssertionError("Expected ValueError for negative keep_last_tool_call_groups.")
class _FakeSummarizer:
async def get_response(
self,
messages: list[Message],
*,
stream: bool = False,
options: dict[str, Any] | None = None,
**kwargs: Any,
) -> ChatResponse:
return ChatResponse(messages=[Message(role="assistant", text="summarized context")])
class _FailingSummarizer:
async def get_response(
self,
messages: list[Message],
*,
stream: bool = False,
options: dict[str, Any] | None = None,
**kwargs: Any,
) -> ChatResponse:
raise RuntimeError("summary failed")
class _EmptySummarizer:
async def get_response(
self,
messages: list[Message],
*,
stream: bool = False,
options: dict[str, Any] | None = None,
**kwargs: Any,
) -> ChatResponse:
return ChatResponse(messages=[Message(role="assistant", text=" ")])
async def test_summarization_strategy_adds_bidirectional_trace_links() -> None:
messages = [
Message(role="user", text="u1"),
Message(role="assistant", text="a1"),
Message(role="user", text="u2"),
Message(role="assistant", text="a2"),
Message(role="user", text="u3"),
Message(role="assistant", text="a3"),
]
strategy = SummarizationStrategy(client=_FakeSummarizer(), target_count=2, threshold=0)
annotate_message_groups(messages)
changed = await strategy(messages)
assert changed is True
summary_messages = [
message for message in messages if _group_unknown_value(message, SUMMARY_OF_MESSAGE_IDS_KEY) is not None
]
assert len(summary_messages) == 1
summary = summary_messages[0]
summary_id = summary.message_id
assert summary_id is not None
assert _group_unknown_value(summary, SUMMARY_OF_GROUP_IDS_KEY)
summarized_message_ids = _group_unknown_value(summary, SUMMARY_OF_MESSAGE_IDS_KEY)
assert isinstance(summarized_message_ids, list)
for message in messages:
if message.message_id in summarized_message_ids:
assert _group_unknown_value(message, SUMMARIZED_BY_SUMMARY_ID_KEY) == summary_id
assert message.additional_properties.get(EXCLUDED_KEY) is True
async def test_summarization_strategy_returns_false_when_summary_generation_fails(
caplog: Any,
) -> None:
messages = [
Message(role="user", text="u1"),
Message(role="assistant", text="a1"),
Message(role="user", text="u2"),
Message(role="assistant", text="a2"),
Message(role="user", text="u3"),
Message(role="assistant", text="a3"),
]
strategy = SummarizationStrategy(client=_FailingSummarizer(), target_count=2, threshold=0)
annotate_message_groups(messages)
with caplog.at_level(logging.WARNING, logger="agent_framework"):
changed = await strategy(messages)
assert changed is False
assert any("summary generation failed" in record.message for record in caplog.records)
assert all(message.additional_properties.get(EXCLUDED_KEY) is not True for message in messages)
async def test_summarization_strategy_returns_false_when_summary_is_empty(
caplog: Any,
) -> None:
messages = [
Message(role="user", text="u1"),
Message(role="assistant", text="a1"),
Message(role="user", text="u2"),
Message(role="assistant", text="a2"),
Message(role="user", text="u3"),
Message(role="assistant", text="a3"),
]
strategy = SummarizationStrategy(client=_EmptySummarizer(), target_count=2, threshold=0)
annotate_message_groups(messages)
with caplog.at_level(logging.WARNING, logger="agent_framework"):
changed = await strategy(messages)
assert changed is False
assert any("returned no text" in record.message for record in caplog.records)
assert all(message.additional_properties.get(EXCLUDED_KEY) is not True for message in messages)
async def test_token_budget_composed_strategy_meets_budget_or_falls_back() -> None:
messages = [
Message(role="system", text="system"),
Message(role="user", text="user " * 200),
Message(role="assistant", text="assistant " * 200),
]
strategy = TokenBudgetComposedStrategy(
token_budget=20,
tokenizer=CharacterEstimatorTokenizer(),
strategies=[SlidingWindowStrategy(keep_last_groups=1)],
)
changed = await strategy(messages)
assert changed is True
assert included_token_count(messages) <= 20
class _ExcludeOldestNonSystem:
async def __call__(self, messages: list[Message]) -> bool:
group_ids = annotate_message_groups(messages)
kinds: dict[str, str] = {}
for message in messages:
group_id = _group_id(message)
kind = _group_kind(message)
if group_id is not None and kind is not None and group_id not in kinds:
kinds[group_id] = kind
for group_id in group_ids:
if kinds.get(group_id) == "system":
continue
for message in messages:
if _group_id(message) == group_id:
message.additional_properties[EXCLUDED_KEY] = True
return True
return False
async def test_apply_compaction_projects_included_messages_only() -> None:
messages = [
Message(role="system", text="sys"),
Message(role="user", text="hello"),
Message(role="assistant", text="world"),
]
projected = await apply_compaction(messages, strategy=_ExcludeOldestNonSystem())
assert len(projected) < len(messages)
assert projected[0].role == "system"
# --- ToolResultCompactionStrategy tests ---
async def test_tool_result_compaction_collapses_old_groups_into_summary() -> None:
"""Old tool-call groups are collapsed into summary messages, newest kept."""
messages = [
Message(role="user", text="u"),
_assistant_function_call("call-1"),
_tool_result("call-1", "r1"),
_assistant_function_call("call-2"),
_tool_result("call-2", "r2"),
Message(role="assistant", text="done"),
]
strategy = ToolResultCompactionStrategy(keep_last_tool_call_groups=1)
annotate_message_groups(messages)
changed = await strategy(messages)
assert changed is True
projected = included_messages(messages)
texts = [m.text or "" for m in projected]
summary_msgs = [t for t in texts if t.startswith("[Tool results:")]
assert len(summary_msgs) == 1
assert "r1" in summary_msgs[0]
assert any(m.role == "tool" for m in projected)
async def test_tool_result_compaction_zero_collapses_all() -> None:
"""With keep=0, all tool-call groups are collapsed into summaries."""
messages = [
Message(role="user", text="u"),
_assistant_function_call("call-1"),
_tool_result("call-1", "r1"),
_assistant_function_call("call-2"),
_tool_result("call-2", "r2"),
Message(role="assistant", text="done"),
]
strategy = ToolResultCompactionStrategy(keep_last_tool_call_groups=0)
annotate_message_groups(messages)
changed = await strategy(messages)
assert changed is True
projected = included_messages(messages)
summary_msgs = [m for m in projected if (m.text or "").startswith("[Tool results:")]
assert len(summary_msgs) == 2
assert not any(m.role == "tool" for m in projected)
async def test_tool_result_compaction_no_change_when_within_limit() -> None:
"""No compaction when tool groups count does not exceed keep limit."""
messages = [
Message(role="user", text="u"),
_assistant_function_call("call-1"),
_tool_result("call-1", "r1"),
]
strategy = ToolResultCompactionStrategy(keep_last_tool_call_groups=1)
annotate_message_groups(messages)
changed = await strategy(messages)
assert changed is False
def test_tool_result_compaction_rejects_negative() -> None:
try:
ToolResultCompactionStrategy(keep_last_tool_call_groups=-1)
except ValueError as exc:
assert "must be greater than or equal to 0" in str(exc)
else:
raise AssertionError("Expected ValueError for negative keep_last_tool_call_groups.")
async def test_tool_result_compaction_preserves_tool_results_in_summary() -> None:
"""Summary text should include the tool results from the collapsed group."""
messages = [
Message(role="user", text="u"),
Message(
role="assistant",
contents=[
Content.from_function_call(call_id="c1", name="get_weather", arguments="{}"),
Content.from_function_call(call_id="c2", name="search_docs", arguments="{}"),
],
),
_tool_result("c1", "sunny"),
_tool_result("c2", "found 3 docs"),
Message(role="assistant", text="done"),
]
strategy = ToolResultCompactionStrategy(keep_last_tool_call_groups=0)
annotate_message_groups(messages)
await strategy(messages)
projected = included_messages(messages)
summary_msgs = [m for m in projected if (m.text or "").startswith("[Tool results:")]
assert len(summary_msgs) == 1
assert "sunny" in summary_msgs[0].text # type: ignore[operator]
assert "found 3 docs" in summary_msgs[0].text # type: ignore[operator]
async def test_tool_result_compaction_bidirectional_tracing() -> None:
"""Summary and originals should link to each other like SummarizationStrategy does."""
messages = [
Message(role="user", text="u"),
_assistant_function_call("call-1"),
_tool_result("call-1", "r1"),
Message(role="assistant", text="done"),
]
strategy = ToolResultCompactionStrategy(keep_last_tool_call_groups=0)
annotate_message_groups(messages)
await strategy(messages)
# Find the summary message.
summary_msgs = [m for m in messages if _group_unknown_value(m, SUMMARY_OF_MESSAGE_IDS_KEY) is not None]
assert len(summary_msgs) == 1
summary = summary_msgs[0]
summary_id = summary.message_id
assert summary_id is not None
# Forward link: summary knows which messages/groups it replaces.
assert isinstance(_group_unknown_value(summary, SUMMARY_OF_MESSAGE_IDS_KEY), list)
assert isinstance(_group_unknown_value(summary, SUMMARY_OF_GROUP_IDS_KEY), list)
# Back link: excluded originals know which summary replaced them.
for m in messages:
if m.additional_properties.get(EXCLUDED_KEY):
assert _group_unknown_value(m, SUMMARIZED_BY_SUMMARY_ID_KEY) == summary_id
# Core compaction annotations must be present on the summary message.
assert _group_id(summary) is not None
assert _group_kind(summary) is not None
assert summary.additional_properties.get(EXCLUDED_KEY) is False
async def test_tool_result_compaction_summary_has_full_annotations() -> None:
"""Summary messages inserted by ToolResultCompactionStrategy must have all compaction annotations."""
messages = [
Message(role="user", text="u"),
_assistant_function_call("c1"),
_tool_result("c1", "r1"),
Message(role="assistant", text="done"),
]
strategy = ToolResultCompactionStrategy(keep_last_tool_call_groups=0)
annotate_message_groups(messages)
await strategy(messages)
summary = next(m for m in messages if (m.text or "").startswith("[Tool results:"))
annotation = summary.additional_properties.get(GROUP_ANNOTATION_KEY)
assert isinstance(annotation, dict)
assert GROUP_ID_KEY in annotation
assert GROUP_KIND_KEY in annotation
assert GROUP_HAS_REASONING_KEY in annotation
assert SUMMARY_OF_MESSAGE_IDS_KEY in annotation
assert summary.additional_properties.get(EXCLUDED_KEY) is False
async def test_summarization_strategy_summary_has_full_annotations() -> None:
"""Summary messages inserted by SummarizationStrategy must have all compaction annotations."""
messages = [
Message(role="user", text="u1"),
Message(role="assistant", text="a1"),
Message(role="user", text="u2"),
Message(role="assistant", text="a2"),
Message(role="user", text="u3"),
Message(role="assistant", text="a3"),
]
strategy = SummarizationStrategy(client=_FakeSummarizer(), target_count=2, threshold=0)
annotate_message_groups(messages)
changed = await strategy(messages)
assert changed is True
summary = next(m for m in messages if _group_unknown_value(m, SUMMARY_OF_MESSAGE_IDS_KEY) is not None)
annotation = summary.additional_properties.get(GROUP_ANNOTATION_KEY)
assert isinstance(annotation, dict)
assert GROUP_ID_KEY in annotation
assert GROUP_KIND_KEY in annotation
assert GROUP_HAS_REASONING_KEY in annotation
assert SUMMARY_OF_MESSAGE_IDS_KEY in annotation
assert summary.additional_properties.get(EXCLUDED_KEY) is False
async def test_tool_result_compaction_multiple_groups_combined() -> None:
"""Multiple tool-call groups collapsed independently, each with its own summary.
Scenario: 3 tool-call groups, keep_last=1 groups 1 and 2 each get a
separate summary, group 3 stays verbatim.
"""
messages = [
Message(role="user", text="Compare weather in London, Paris, and Tokyo"),
# Group 1: get_weather for London
Message(
role="assistant",
contents=[Content.from_function_call(call_id="c1", name="get_weather", arguments='{"city":"London"}')],
),
_tool_result("c1", '{"temp":12,"condition":"cloudy","wind":"NW 15km/h"}'),
Message(role="assistant", text="London is cloudy at 12°C."),
# Group 2: get_weather for Paris + search_hotels
Message(
role="assistant",
contents=[
Content.from_function_call(call_id="c2", name="get_weather", arguments='{"city":"Paris"}'),
Content.from_function_call(call_id="c3", name="search_hotels", arguments='{"city":"Paris"}'),
],
),
_tool_result("c2", '{"temp":18,"condition":"sunny"}'),
_tool_result("c3", "Grand Hotel (€120), Le Petit (€85)"),
Message(role="assistant", text="Paris is sunny at 18°C. Found 2 hotels."),
# Group 3: get_weather for Tokyo (most recent — should be kept)
Message(
role="assistant",
contents=[Content.from_function_call(call_id="c4", name="get_weather", arguments='{"city":"Tokyo"}')],
),
_tool_result("c4", '{"temp":22,"condition":"rainy"}'),
Message(role="assistant", text="Tokyo is rainy at 22°C."),
]
strategy = ToolResultCompactionStrategy(keep_last_tool_call_groups=1)
annotate_message_groups(messages)
changed = await strategy(messages)
assert changed is True
projected = included_messages(messages)
summary_msgs = [m for m in projected if (m.text or "").startswith("[Tool results:")]
# Two summaries: one for group 1, one for group 2.
assert len(summary_msgs) == 2
# Group 1 summary: London weather result.
g1_text = summary_msgs[0].text or ""
assert "12" in g1_text
assert "cloudy" in g1_text
# Group 2 summary: Paris weather + hotel results combined.
g2_text = summary_msgs[1].text or ""
assert "18" in g2_text
assert "Grand Hotel" in g2_text
# Group 3 (Tokyo) stays verbatim — tool role messages still present.
verbatim_tool_msgs = [m for m in projected if m.role == "tool"]
assert len(verbatim_tool_msgs) == 1
assert "rainy" in (verbatim_tool_msgs[0].contents[0].result or "")
# All text assistant messages should still be present.
text_msgs = [m for m in projected if m.role == "assistant" and m.text and not m.text.startswith("[Tool results:")]
texts = [m.text for m in text_msgs]
assert "London is cloudy at 12°C." in texts
assert "Paris is sunny at 18°C. Found 2 hotels." in texts
assert "Tokyo is rainy at 22°C." in texts
# Final projected shape: 8 messages in order.
assert len(projected) == 8
assert projected[0].role == "user" # original user message
assert projected[1].text == '[Tool results: get_weather: {"temp":12,"condition":"cloudy","wind":"NW 15km/h"}]'
assert projected[2].text == "London is cloudy at 12°C."
expected_g2 = (
'[Tool results: get_weather: {"temp":18,"condition":"sunny"};'
" search_hotels: Grand Hotel (€120), Le Petit (€85)]"
)
assert projected[3].text == expected_g2
assert projected[4].text == "Paris is sunny at 18°C. Found 2 hotels." # group 2 assistant text
assert projected[5].role == "assistant" # group 3 function_call (verbatim)
assert projected[6].role == "tool" # group 3 tool result (verbatim)
assert projected[7].text == "Tokyo is rainy at 22°C." # group 3 assistant text
# --- CompactionProvider tests ---
class _MockSessionContext:
"""Minimal mock for SessionContext used in CompactionProvider tests."""
def __init__(self) -> None:
self.context_messages: dict[str, list[Message]] = {}
self.input_messages: list[Message] = []
self._response: Any = None
@property
def response(self) -> Any:
return self._response
def extend_messages(self, provider: Any, messages: list[Message]) -> None:
source_id = getattr(provider, "source_id", "unknown")
self.context_messages.setdefault(source_id, []).extend(messages)
def get_messages(self) -> list[Message]:
result: list[Message] = []
for msgs in self.context_messages.values():
result.extend(msgs)
return result
async def test_compaction_provider_compacts_existing_context_messages() -> None:
"""CompactionProvider.before_run compacts messages already in context from earlier providers."""
provider = CompactionProvider(
before_strategy=SlidingWindowStrategy(keep_last_groups=2, preserve_system=True),
)
context = _MockSessionContext()
context.context_messages["history"] = [
Message(role="system", text="sys"),
Message(role="user", text="u1"),
Message(role="assistant", text="a1"),
Message(role="user", text="u2"),
Message(role="assistant", text="a2"),
Message(role="user", text="u3"),
Message(role="assistant", text="a3"),
]
await provider.before_run(agent=None, session=None, context=context, state={})
remaining = context.context_messages["history"]
assert len(remaining) == 3
assert remaining[0].role == "system"
assert remaining[1].text == "u3"
assert remaining[2].text == "a3"
async def test_compaction_provider_noop_when_no_context_messages() -> None:
"""before_run with no context messages does nothing."""
provider = CompactionProvider(
before_strategy=SlidingWindowStrategy(keep_last_groups=2),
)
context = _MockSessionContext()
await provider.before_run(agent=None, session=None, context=context, state={})
assert context.context_messages == {}
async def test_compaction_provider_preserves_messages_from_multiple_sources() -> None:
"""CompactionProvider correctly filters across multiple provider sources."""
provider = CompactionProvider(
before_strategy=SlidingWindowStrategy(keep_last_groups=2, preserve_system=True),
)
context = _MockSessionContext()
context.context_messages["history"] = [
Message(role="system", text="sys"),
Message(role="user", text="old_user"),
Message(role="assistant", text="old_assistant"),
]
context.context_messages["rag"] = [
Message(role="user", text="recent_rag_context"),
Message(role="assistant", text="recent_rag_answer"),
]
await provider.before_run(agent=None, session=None, context=context, state={})
all_remaining = context.get_messages()
assert any(m.role == "system" for m in all_remaining)
assert len(all_remaining) < 5
class _MockSession:
"""Minimal mock for AgentSession used in CompactionProvider after_run tests."""
def __init__(self) -> None:
self.state: dict[str, Any] = {}
async def test_compaction_provider_after_run_compacts_stored_history() -> None:
"""after_run annotates exclusions on stored messages without removing them."""
provider = CompactionProvider(
after_strategy=SelectiveToolCallCompactionStrategy(keep_last_tool_call_groups=0),
history_source_id="in_memory_history",
)
session = _MockSession()
session.state["in_memory_history"] = {
"messages": [
Message(role="user", text="old question"),
Message(role="assistant", text="old answer"),
_assistant_function_call("c1"),
_tool_result("c1", "result"),
Message(role="assistant", text="final answer"),
]
}
context = _MockSessionContext()
await provider.after_run(agent=None, session=session, context=context, state={})
stored = session.state["in_memory_history"]["messages"]
# All messages are kept; tool-call group is excluded via annotation.
assert len(stored) == 5
excluded = [m for m in stored if m.additional_properties.get("_excluded", False)]
assert len(excluded) == 2 # assistant function_call + tool result
assert any(m.text == "final answer" for m in stored if not m.additional_properties.get("_excluded", False))
async def test_compaction_provider_after_run_noop_without_history() -> None:
"""after_run does nothing when there is no history state."""
provider = CompactionProvider(
after_strategy=SlidingWindowStrategy(keep_last_groups=2),
history_source_id="in_memory_history",
)
session = _MockSession()
context = _MockSessionContext()
await provider.after_run(agent=None, session=session, context=context, state={})
assert "in_memory_history" not in session.state
async def test_compaction_provider_both_strategies() -> None:
"""Both before_strategy and after_strategy work independently."""
provider = CompactionProvider(
before_strategy=SlidingWindowStrategy(keep_last_groups=2, preserve_system=True),
after_strategy=SelectiveToolCallCompactionStrategy(keep_last_tool_call_groups=0),
history_source_id="history",
)
# before_run: compact loaded context
context = _MockSessionContext()
context.context_messages["history"] = [
Message(role="system", text="sys"),
Message(role="user", text="u1"),
Message(role="assistant", text="a1"),
Message(role="user", text="u2"),
Message(role="assistant", text="a2"),
]
await provider.before_run(agent=None, session=None, context=context, state={})
assert len(context.get_messages()) == 3
# after_run: compact stored history
session = _MockSession()
session.state["history"] = {
"messages": [
Message(role="user", text="q"),
_assistant_function_call("c1"),
_tool_result("c1", "ok"),
Message(role="assistant", text="done"),
]
}
await provider.after_run(agent=None, session=session, context=_MockSessionContext(), state={})
stored = session.state["history"]["messages"]
excluded = [m for m in stored if m.additional_properties.get("_excluded", False)]
assert len(excluded) == 2 # tool-call group excluded
async def test_compaction_provider_none_strategies_are_noop() -> None:
"""When both strategies are None, before_run and after_run are no-ops."""
provider = CompactionProvider()
context = _MockSessionContext()
context.context_messages["history"] = [
Message(role="user", text="hello"),
Message(role="assistant", text="hi"),
]
await provider.before_run(agent=None, session=None, context=context, state={})
assert len(context.get_messages()) == 2
session = _MockSession()
await provider.after_run(agent=None, session=session, context=context, state={})
assert "in_memory_history" not in session.state
async def test_in_memory_history_provider_skip_excluded() -> None:
"""InMemoryHistoryProvider with skip_excluded=True omits excluded messages."""
from agent_framework._compaction import EXCLUDED_KEY
from agent_framework._sessions import InMemoryHistoryProvider as _InMemoryHistoryProvider
provider = _InMemoryHistoryProvider(skip_excluded=True)
state: dict[str, Any] = {
"messages": [
Message(role="user", text="u1"),
Message(role="assistant", text="a1", additional_properties={EXCLUDED_KEY: True}),
Message(role="user", text="u2"),
Message(role="assistant", text="a2"),
]
}
loaded = await provider.get_messages(session_id="test", state=state)
assert len(loaded) == 3
assert all(m.text != "a1" for m in loaded)
async def test_in_memory_history_provider_default_loads_all() -> None:
"""InMemoryHistoryProvider with default settings loads all messages including excluded."""
from agent_framework._compaction import EXCLUDED_KEY
from agent_framework._sessions import InMemoryHistoryProvider as _InMemoryHistoryProvider
provider = _InMemoryHistoryProvider()
state: dict[str, Any] = {
"messages": [
Message(role="user", text="u1"),
Message(role="assistant", text="a1", additional_properties={EXCLUDED_KEY: True}),
Message(role="user", text="u2"),
]
}
loaded = await provider.get_messages(session_id="test", state=state)
assert len(loaded) == 3
@@ -15,9 +15,27 @@ from agent_framework import (
SupportsChatGetResponse,
tool,
)
from agent_framework._compaction import (
EXCLUDED_KEY,
GROUP_ANNOTATION_KEY,
GROUP_ID_KEY,
CharacterEstimatorTokenizer,
SlidingWindowStrategy,
TokenBudgetComposedStrategy,
annotate_message_groups,
included_token_count,
)
from agent_framework._middleware import FunctionInvocationContext, FunctionMiddleware, MiddlewareTermination
def _group_id(message: Message) -> str | None:
annotation = message.additional_properties.get(GROUP_ANNOTATION_KEY)
if not isinstance(annotation, dict):
return None
value = annotation.get(GROUP_ID_KEY)
return value if isinstance(value, str) else None
async def test_base_client_with_function_calling(chat_client_base: SupportsChatGetResponse):
exec_counter = 0
@@ -131,6 +149,127 @@ async def test_base_client_with_function_calling_resets(chat_client_base: Suppor
assert response.messages[3].contents[0].type == "function_result"
async def test_function_loop_applies_compaction_projection_each_model_call(chat_client_base: SupportsChatGetResponse):
@tool(name="test_function", approval_mode="never_require")
def ai_func(arg1: str) -> str:
return f"Processed {arg1}"
class _ExcludeOldestGroupAfterFirstTurn:
async def __call__(self, messages: list[Message]) -> bool:
groups = annotate_message_groups(messages)
if len(groups) <= 1:
return False
oldest_group_id = groups[0]
changed = False
for message in messages:
if _group_id(message) == oldest_group_id:
if message.additional_properties.get(EXCLUDED_KEY) is not True:
changed = True
message.additional_properties[EXCLUDED_KEY] = True
return changed
captured_roles: list[list[str]] = []
original = chat_client_base._get_non_streaming_response # type: ignore[attr-defined]
async def _capture(
*,
messages: list[Message],
options: dict[str, Any],
**kwargs: Any,
) -> ChatResponse:
captured_roles.append([message.role for message in messages])
return await original(messages=messages, options=options, **kwargs)
chat_client_base._get_non_streaming_response = _capture # type: ignore[attr-defined,method-assign]
chat_client_base.compaction_strategy = _ExcludeOldestGroupAfterFirstTurn() # type: ignore[attr-defined]
chat_client_base.run_responses = [
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(call_id="1", name="test_function", arguments='{"arg1": "value1"}')
],
)
),
ChatResponse(messages=Message(role="assistant", text="done")),
]
await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [ai_func]}
)
assert len(captured_roles) >= 2
assert "user" in captured_roles[0]
assert "user" not in captured_roles[1]
async def test_function_loop_token_budget_strategy_caps_tokens_each_iteration(
chat_client_base: SupportsChatGetResponse,
):
exec_counter = 0
token_budget = 500
tokenizer = CharacterEstimatorTokenizer()
@tool(name="test_function", approval_mode="never_require")
def ai_func(arg1: str) -> str:
nonlocal exec_counter
exec_counter += 1
return f"Processed {arg1}. " + ("result " * 120)
captured_token_counts: list[int] = []
original = chat_client_base._get_non_streaming_response # type: ignore[attr-defined]
async def _capture(
*,
messages: list[Message],
options: dict[str, Any],
**kwargs: Any,
) -> ChatResponse:
annotate_message_groups(messages, force_reannotate=True, tokenizer=tokenizer)
captured_token_counts.append(included_token_count(messages))
return await original(messages=messages, options=options, **kwargs)
chat_client_base._get_non_streaming_response = _capture # type: ignore[attr-defined,method-assign]
chat_client_base.tokenizer = tokenizer # type: ignore[attr-defined]
chat_client_base.function_invocation_configuration["max_iterations"] = 3 # type: ignore[attr-defined]
chat_client_base.compaction_strategy = TokenBudgetComposedStrategy( # type: ignore[attr-defined]
token_budget=token_budget,
tokenizer=tokenizer,
strategies=[SlidingWindowStrategy(keep_last_groups=2)],
)
chat_client_base.run_responses = [
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(call_id="1", name="test_function", arguments='{"arg1": "value1"}')
],
)
),
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(call_id="2", name="test_function", arguments='{"arg1": "value2"}')
],
)
),
ChatResponse(messages=Message(role="assistant", text="done")),
]
response = await chat_client_base.get_response(
[Message(role="user", text="hello " * 160)],
options={"tool_choice": "auto", "tools": [ai_func]},
)
assert response.messages[-1].text == "done"
assert exec_counter == 2
assert len(captured_token_counts) >= 3
assert all(token_count > 0 for token_count in captured_token_counts)
assert all(token_count <= token_budget for token_count in captured_token_counts)
async def test_base_client_with_streaming_function_calling(chat_client_base: SupportsChatGetResponse):
exec_counter = 0
@@ -35,7 +35,7 @@ from agent_framework._skills import (
async def _noop_script_runner(skill: Any, script: Any, args: Any = None) -> None:
"""No-op script runner for tests that need a SkillScriptRunner."""
return None
return
def _symlinks_supported(tmp: Path) -> bool:
@@ -1994,7 +1994,7 @@ class TestSkillScriptRunnerProtocol:
"""Tests for the SkillScriptRunner protocol."""
async def test_async_callable_satisfies_protocol(self) -> None:
from agent_framework import SkillScriptRunner, SkillScript
from agent_framework import SkillScript, SkillScriptRunner
results: list[tuple] = []
@@ -2015,7 +2015,7 @@ class TestSkillScriptRunnerProtocol:
assert results[0] == ("test-skill", "my-script", {"key": "val"})
async def test_callable_class_satisfies_protocol(self) -> None:
from agent_framework import SkillScriptRunner, SkillScript
from agent_framework import SkillScript, SkillScriptRunner
class _CustomRunner:
async def __call__(self, skill, script, args=None):
@@ -2056,7 +2056,7 @@ class TestSkillScriptRunnerProtocol:
assert result == {"exit_code": 0, "output": "ok"}
def test_sync_callable_satisfies_protocol(self) -> None:
from agent_framework import SkillScriptRunner, SkillScript
from agent_framework import SkillScript, SkillScriptRunner
results: list[tuple] = []
@@ -2077,7 +2077,7 @@ class TestSkillScriptRunnerProtocol:
assert results[0] == ("test-skill", "my-script", {"key": "val"})
def test_sync_callable_class_satisfies_protocol(self) -> None:
from agent_framework import SkillScriptRunner, SkillScript
from agent_framework import SkillScript, SkillScriptRunner
class _SyncRunner:
def __call__(self, skill, script, args=None):
@@ -2117,6 +2117,7 @@ class TestSkillScriptRunnerProtocol:
result = dict_runner(skill, script)
assert result == {"exit_code": 0, "output": "ok"}
# ---------------------------------------------------------------------------
# SkillsProvider static factory tests
# ---------------------------------------------------------------------------
@@ -28,6 +28,12 @@ from agent_framework import (
merge_chat_options,
tool,
)
from agent_framework._compaction import (
GROUP_ANNOTATION_KEY,
GROUP_HAS_REASONING_KEY,
GROUP_ID_KEY,
GROUP_TOKEN_COUNT_KEY,
)
from agent_framework._types import (
_get_data_bytes,
_get_data_bytes_as_str,
@@ -1654,6 +1660,78 @@ def test_chat_message_complex_content_serialization():
assert reconstructed.contents[2].type == "function_result"
def test_message_roundtrip_preserves_compaction_annotation_dict() -> None:
message = Message(
role="assistant",
contents=[Content.from_text("Hello")],
additional_properties={
GROUP_ANNOTATION_KEY: {
"id": "group_1",
"kind": "assistant_text",
"index": 1,
"has_reasoning": False,
"token_count": 42,
}
},
)
restored = Message.from_dict(message.to_dict())
annotation = restored.additional_properties.get(GROUP_ANNOTATION_KEY)
assert isinstance(annotation, dict)
assert annotation[GROUP_ID_KEY] == "group_1"
assert annotation[GROUP_TOKEN_COUNT_KEY] == 42
def test_content_roundtrip_preserves_compaction_annotation_dict() -> None:
content = Content.from_text(
text="Hello",
additional_properties={
GROUP_ANNOTATION_KEY: {
"id": "group_2",
"kind": "assistant_text",
"index": 2,
"has_reasoning": False,
"token_count": None,
}
},
)
restored = Content.from_dict(content.to_dict())
annotation = restored.additional_properties.get(GROUP_ANNOTATION_KEY)
assert isinstance(annotation, dict)
assert annotation[GROUP_ID_KEY] == "group_2"
assert annotation[GROUP_TOKEN_COUNT_KEY] is None
def test_chat_response_roundtrip_preserves_compaction_annotation_dict() -> None:
response = ChatResponse(
messages=[
Message(
role="assistant",
contents=[Content.from_text("Hello")],
additional_properties={
GROUP_ANNOTATION_KEY: {
"id": "group_3",
"kind": "assistant_text",
"index": 3,
"has_reasoning": True,
"token_count": 15,
}
},
)
]
)
restored = ChatResponse.from_dict(response.to_dict())
annotation = restored.messages[0].additional_properties.get(GROUP_ANNOTATION_KEY)
assert isinstance(annotation, dict)
assert annotation[GROUP_ID_KEY] == "group_3"
assert annotation[GROUP_HAS_REASONING_KEY] is True
def test_usage_content_serialization_with_details():
"""Test UsageContent from_dict and to_dict with UsageDetails conversion."""
@@ -524,6 +524,58 @@ def test_response_content_creation_with_reasoning() -> None:
assert response.messages[0].contents[0].text == "Reasoning step"
def test_response_content_keeps_reasoning_and_function_calls_in_one_message() -> None:
"""Reasoning + function calls should parse into one assistant message."""
client = OpenAIResponsesClient(model_id="test-model", api_key="test-key")
mock_response = MagicMock()
mock_response.output_parsed = None
mock_response.metadata = {}
mock_response.usage = None
mock_response.id = "test-id"
mock_response.model = "test-model"
mock_response.created_at = 1000000000
mock_reasoning_content = MagicMock()
mock_reasoning_content.text = "Reasoning step"
mock_reasoning_item = MagicMock()
mock_reasoning_item.type = "reasoning"
mock_reasoning_item.id = "rs_123"
mock_reasoning_item.content = [mock_reasoning_content]
mock_reasoning_item.summary = []
mock_function_call_item_1 = MagicMock()
mock_function_call_item_1.type = "function_call"
mock_function_call_item_1.id = "fc_1"
mock_function_call_item_1.call_id = "call_1"
mock_function_call_item_1.name = "tool_1"
mock_function_call_item_1.arguments = '{"x": 1}'
mock_function_call_item_2 = MagicMock()
mock_function_call_item_2.type = "function_call"
mock_function_call_item_2.id = "fc_2"
mock_function_call_item_2.call_id = "call_2"
mock_function_call_item_2.name = "tool_2"
mock_function_call_item_2.arguments = '{"y": 2}'
mock_response.output = [
mock_reasoning_item,
mock_function_call_item_1,
mock_function_call_item_2,
]
response = client._parse_response_from_openai(mock_response, options={}) # type: ignore
assert len(response.messages) == 1
assert response.messages[0].role == "assistant"
assert [content.type for content in response.messages[0].contents] == [
"text_reasoning",
"function_call",
"function_call",
]
def test_response_content_creation_with_code_interpreter() -> None:
"""Test _parse_response_from_openai with code interpreter outputs."""