mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Implement annotation-based context compaction (#4469)
* Implement annotation-based context compaction Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Handle missing compaction attributes in BaseChatClient Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Fix CI typing and bandit issues Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Optimize incremental compaction annotation pass Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * refinement * Python: add ToolResultCompactionStrategy and CompactionProvider Add ToolResultCompactionStrategy that collapses older tool-call groups into short summary messages (e.g. [Tool calls: get_weather]) while keeping the most recent groups verbatim. This mirrors the .NET ToolResultCompactionStrategy from PR #4533. Add CompactionProvider as a context-provider that auto-applies compaction before each agent turn and stores compacted history in session state after each turn. Includes tests and samples for both features. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * refinement and alignment with dotnet PR * updated tool result compaction * updated tool result compaction * Python: add ToolResultCompactionStrategy, CompactionProvider, and skip_excluded - ToolResultCompactionStrategy collapses older tool-call groups into [Tool results: func_name: result] summaries with bidirectional tracing (same pattern as SummarizationStrategy). - CompactionProvider as BaseContextProvider with separate before_strategy and after_strategy parameters. before_strategy compacts loaded context; after_strategy compacts stored history via history_source_id. - InMemoryHistoryProvider gains skip_excluded flag to filter out messages marked as excluded by compaction strategies. - Tests, samples, and exports updated. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fixed checks * fix mypy * Fix: ensure summary messages from both strategies get full compaction annotations SummarizationStrategy was not calling annotate_message_groups after inserting its summary message, so the summary lacked core group annotations (id, kind, index, has_reasoning, _excluded). Added the missing call. ToolResultCompactionStrategy already had it. Added tests verifying both strategies produce fully annotated summaries. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * updated propagation * fix mypy --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
565c0b1623
commit
3e03a305f6
@@ -29,6 +29,34 @@ from ._clients import (
|
||||
SupportsMCPTool,
|
||||
SupportsWebSearchTool,
|
||||
)
|
||||
from ._compaction import (
|
||||
COMPACTION_STATE_KEY,
|
||||
EXCLUDE_REASON_KEY,
|
||||
EXCLUDED_KEY,
|
||||
GROUP_ANNOTATION_KEY,
|
||||
GROUP_HAS_REASONING_KEY,
|
||||
GROUP_ID_KEY,
|
||||
GROUP_INDEX_KEY,
|
||||
GROUP_KIND_KEY,
|
||||
GROUP_TOKEN_COUNT_KEY,
|
||||
SUMMARIZED_BY_SUMMARY_ID_KEY,
|
||||
SUMMARY_OF_GROUP_IDS_KEY,
|
||||
SUMMARY_OF_MESSAGE_IDS_KEY,
|
||||
CharacterEstimatorTokenizer,
|
||||
CompactionProvider,
|
||||
CompactionStrategy,
|
||||
SelectiveToolCallCompactionStrategy,
|
||||
SlidingWindowStrategy,
|
||||
SummarizationStrategy,
|
||||
TokenBudgetComposedStrategy,
|
||||
TokenizerProtocol,
|
||||
ToolResultCompactionStrategy,
|
||||
TruncationStrategy,
|
||||
annotate_message_groups,
|
||||
apply_compaction,
|
||||
included_messages,
|
||||
included_token_count,
|
||||
)
|
||||
from ._mcp import MCPStdioTool, MCPStreamableHTTPTool, MCPWebsocketTool
|
||||
from ._middleware import (
|
||||
AgentContext,
|
||||
@@ -196,7 +224,19 @@ from .exceptions import (
|
||||
__all__ = [
|
||||
"AGENT_FRAMEWORK_USER_AGENT",
|
||||
"APP_INFO",
|
||||
"COMPACTION_STATE_KEY",
|
||||
"DEFAULT_MAX_ITERATIONS",
|
||||
"EXCLUDED_KEY",
|
||||
"EXCLUDE_REASON_KEY",
|
||||
"GROUP_ANNOTATION_KEY",
|
||||
"GROUP_HAS_REASONING_KEY",
|
||||
"GROUP_ID_KEY",
|
||||
"GROUP_INDEX_KEY",
|
||||
"GROUP_KIND_KEY",
|
||||
"GROUP_TOKEN_COUNT_KEY",
|
||||
"SUMMARIZED_BY_SUMMARY_ID_KEY",
|
||||
"SUMMARY_OF_GROUP_IDS_KEY",
|
||||
"SUMMARY_OF_MESSAGE_IDS_KEY",
|
||||
"USER_AGENT_KEY",
|
||||
"USER_AGENT_TELEMETRY_DISABLED_ENV_VAR",
|
||||
"Agent",
|
||||
@@ -218,6 +258,7 @@ __all__ = [
|
||||
"BaseEmbeddingClient",
|
||||
"BaseHistoryProvider",
|
||||
"Case",
|
||||
"CharacterEstimatorTokenizer",
|
||||
"ChatAndFunctionMiddlewareTypes",
|
||||
"ChatContext",
|
||||
"ChatMiddleware",
|
||||
@@ -227,6 +268,8 @@ __all__ = [
|
||||
"ChatResponse",
|
||||
"ChatResponseUpdate",
|
||||
"CheckpointStorage",
|
||||
"CompactionProvider",
|
||||
"CompactionStrategy",
|
||||
"Content",
|
||||
"ContinuationToken",
|
||||
"Default",
|
||||
@@ -273,6 +316,7 @@ __all__ = [
|
||||
"Runner",
|
||||
"RunnerContext",
|
||||
"SecretString",
|
||||
"SelectiveToolCallCompactionStrategy",
|
||||
"SessionContext",
|
||||
"SingleEdgeGroup",
|
||||
"Skill",
|
||||
@@ -280,8 +324,10 @@ __all__ = [
|
||||
"SkillScript",
|
||||
"SkillScriptRunner",
|
||||
"SkillsProvider",
|
||||
"SlidingWindowStrategy",
|
||||
"SubWorkflowRequestMessage",
|
||||
"SubWorkflowResponseMessage",
|
||||
"SummarizationStrategy",
|
||||
"SupportsAgentRun",
|
||||
"SupportsChatGetResponse",
|
||||
"SupportsCodeInterpreterTool",
|
||||
@@ -294,8 +340,12 @@ __all__ = [
|
||||
"SwitchCaseEdgeGroupCase",
|
||||
"SwitchCaseEdgeGroupDefault",
|
||||
"TextSpanRegion",
|
||||
"TokenBudgetComposedStrategy",
|
||||
"TokenizerProtocol",
|
||||
"ToolMode",
|
||||
"ToolResultCompactionStrategy",
|
||||
"ToolTypes",
|
||||
"TruncationStrategy",
|
||||
"TypeCompatibilityError",
|
||||
"UpdateT",
|
||||
"UsageDetails",
|
||||
@@ -322,12 +372,16 @@ __all__ = [
|
||||
"__version__",
|
||||
"add_usage_details",
|
||||
"agent_middleware",
|
||||
"annotate_message_groups",
|
||||
"apply_compaction",
|
||||
"chat_middleware",
|
||||
"create_edge_runner",
|
||||
"detect_media_type_from_base64",
|
||||
"executor",
|
||||
"function_middleware",
|
||||
"handler",
|
||||
"included_messages",
|
||||
"included_token_count",
|
||||
"load_settings",
|
||||
"map_chat_to_agent_update",
|
||||
"merge_chat_options",
|
||||
|
||||
@@ -74,6 +74,7 @@ else:
|
||||
from typing_extensions import Self, TypedDict # pragma: no cover
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._compaction import CompactionStrategy, TokenizerProtocol
|
||||
from ._types import ChatOptions
|
||||
|
||||
logger = logging.getLogger("agent_framework")
|
||||
@@ -177,6 +178,8 @@ class _RunContext(TypedDict):
|
||||
session_messages: Sequence[Message]
|
||||
agent_name: str
|
||||
chat_options: MutableMapping[str, Any]
|
||||
compaction_strategy: CompactionStrategy | None
|
||||
tokenizer: TokenizerProtocol | None
|
||||
filtered_kwargs: Mapping[str, Any]
|
||||
finalize_kwargs: Mapping[str, Any]
|
||||
|
||||
@@ -665,6 +668,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
|
||||
default_options: OptionsCoT | None = None,
|
||||
context_providers: Sequence[BaseContextProvider] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a Agent instance.
|
||||
@@ -688,6 +693,10 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
Note: response_format typing does not flow into run outputs when set via default_options.
|
||||
These can be overridden at runtime via the ``options`` parameter of ``run()``.
|
||||
tools: The tools to use for the request.
|
||||
compaction_strategy: Optional agent-level in-run compaction.
|
||||
If both this and a compaction_strategy on the underlying client are set, this one is used.
|
||||
tokenizer: Optional agent-level tokenizer.
|
||||
If both this and a tokenizer on the underlying client are set, this one is used.
|
||||
kwargs: Any additional keyword arguments. Will be stored as ``additional_properties``.
|
||||
"""
|
||||
opts = dict(default_options) if default_options else {}
|
||||
@@ -705,6 +714,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
**kwargs,
|
||||
)
|
||||
self.client = client
|
||||
self.compaction_strategy = compaction_strategy
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
# Get tools from options or named parameter (named param takes precedence)
|
||||
tools_ = tools if tools is not None else opts.pop("tools", None)
|
||||
@@ -799,6 +810,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
session: AgentSession | None = None,
|
||||
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
|
||||
options: ChatOptions[ResponseModelBoundT],
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[ResponseModelBoundT]]: ...
|
||||
|
||||
@@ -811,6 +824,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
session: AgentSession | None = None,
|
||||
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
|
||||
options: OptionsCoT | ChatOptions[None] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]]: ...
|
||||
|
||||
@@ -823,6 +838,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
session: AgentSession | None = None,
|
||||
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
|
||||
@@ -834,6 +851,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
session: AgentSession | None = None,
|
||||
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
|
||||
"""Run the agent with the given messages and options.
|
||||
@@ -857,8 +876,14 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
``Agent[OpenAIChatOptions]``, this enables IDE autocomplete for
|
||||
provider-specific options including temperature, max_tokens, model_id,
|
||||
tool_choice, and provider-specific options like reasoning_effort.
|
||||
kwargs: Additional keyword arguments for the agent.
|
||||
Will only be passed to functions that are called.
|
||||
compaction_strategy: Optional per-run compaction override passed to
|
||||
``client.get_response()``. When omitted, the agent-level override
|
||||
is used, falling back to the client default.
|
||||
tokenizer: Optional per-run tokenizer override passed to
|
||||
``client.get_response()``. When omitted, the agent-level override
|
||||
is used, falling back to the client default.
|
||||
kwargs: Additional keyword arguments for the agent. These are only
|
||||
passed to functions that are called.
|
||||
|
||||
Returns:
|
||||
When stream=False: An Awaitable[AgentResponse] containing the agent's response.
|
||||
@@ -873,6 +898,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
session=session,
|
||||
tools=tools,
|
||||
options=options,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
response = cast(
|
||||
@@ -881,6 +908,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
messages=ctx["session_messages"],
|
||||
stream=False,
|
||||
options=ctx["chat_options"], # type: ignore[reportArgumentType]
|
||||
compaction_strategy=ctx["compaction_strategy"],
|
||||
tokenizer=ctx["tokenizer"],
|
||||
**ctx["filtered_kwargs"],
|
||||
),
|
||||
)
|
||||
@@ -954,6 +983,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
session=session,
|
||||
tools=tools,
|
||||
options=options,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
ctx: _RunContext = ctx_holder["ctx"] # type: ignore[assignment] # Safe: we just assigned it
|
||||
@@ -961,6 +992,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
messages=ctx["session_messages"],
|
||||
stream=True,
|
||||
options=ctx["chat_options"], # type: ignore[reportArgumentType]
|
||||
compaction_strategy=ctx["compaction_strategy"],
|
||||
tokenizer=ctx["tokenizer"],
|
||||
**ctx["filtered_kwargs"],
|
||||
)
|
||||
|
||||
@@ -1047,6 +1080,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
session: AgentSession | None,
|
||||
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None,
|
||||
options: Mapping[str, Any] | None,
|
||||
compaction_strategy: CompactionStrategy | None,
|
||||
tokenizer: TokenizerProtocol | None,
|
||||
kwargs: dict[str, Any],
|
||||
) -> _RunContext:
|
||||
opts = dict(options) if options else {}
|
||||
@@ -1081,9 +1116,10 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
options=opts,
|
||||
)
|
||||
|
||||
agent_name = self._get_agent_name()
|
||||
|
||||
# Normalize tools
|
||||
normalized_tools = normalize_tools(tools_)
|
||||
agent_name = self._get_agent_name()
|
||||
|
||||
# Resolve final tool list (runtime provided tools + local MCP server tools)
|
||||
final_tools: list[FunctionTool | Callable[..., Any] | dict[str, Any] | Any] = []
|
||||
@@ -1153,6 +1189,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
"session_messages": session_messages,
|
||||
"agent_name": agent_name,
|
||||
"chat_options": co,
|
||||
"compaction_strategy": compaction_strategy or self.compaction_strategy,
|
||||
"tokenizer": tokenizer or self.tokenizer,
|
||||
"filtered_kwargs": filtered_kwargs,
|
||||
"finalize_kwargs": finalize_kwargs,
|
||||
}
|
||||
@@ -1408,6 +1446,8 @@ class Agent(
|
||||
default_options: OptionsCoT | None = None,
|
||||
context_providers: Sequence[BaseContextProvider] | None = None,
|
||||
middleware: Sequence[MiddlewareTypes] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a Agent instance."""
|
||||
@@ -1421,5 +1461,7 @@ class Agent(
|
||||
default_options=default_options,
|
||||
context_providers=context_providers,
|
||||
middleware=middleware,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -52,6 +52,7 @@ else:
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._agents import Agent
|
||||
from ._compaction import CompactionStrategy, TokenizerProtocol
|
||||
from ._middleware import (
|
||||
MiddlewareTypes,
|
||||
)
|
||||
@@ -134,6 +135,8 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: ChatOptions[ResponseModelBoundT],
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ...
|
||||
|
||||
@@ -144,6 +147,8 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: OptionsContraT | ChatOptions[None] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[Any]]: ...
|
||||
|
||||
@@ -154,6 +159,8 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
|
||||
*,
|
||||
stream: Literal[True],
|
||||
options: OptionsContraT | ChatOptions[Any] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
|
||||
|
||||
@@ -163,6 +170,8 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
|
||||
*,
|
||||
stream: bool = False,
|
||||
options: OptionsContraT | ChatOptions[Any] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
|
||||
"""Send input and return the response.
|
||||
@@ -171,6 +180,8 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
|
||||
messages: The sequence of input messages to send.
|
||||
stream: Whether to stream the response. Defaults to False.
|
||||
options: Chat options as a TypedDict.
|
||||
compaction_strategy: Optional per-call compaction override.
|
||||
tokenizer: Optional per-call tokenizer override.
|
||||
**kwargs: Additional chat options.
|
||||
|
||||
Returns:
|
||||
@@ -252,7 +263,13 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
"""
|
||||
|
||||
OTEL_PROVIDER_NAME: ClassVar[str] = "unknown"
|
||||
DEFAULT_EXCLUDE: ClassVar[set[str]] = {"additional_properties"}
|
||||
compaction_strategy: CompactionStrategy | None = None
|
||||
tokenizer: TokenizerProtocol | None = None
|
||||
DEFAULT_EXCLUDE: ClassVar[set[str]] = {
|
||||
"additional_properties",
|
||||
"compaction_strategy",
|
||||
"tokenizer",
|
||||
}
|
||||
STORES_BY_DEFAULT: ClassVar[bool] = False
|
||||
"""Whether this client stores conversation history server-side by default.
|
||||
|
||||
@@ -267,15 +284,21 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
self,
|
||||
*,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a BaseChatClient instance.
|
||||
|
||||
Keyword Args:
|
||||
additional_properties: Additional properties for the client.
|
||||
compaction_strategy: Optional compaction strategy to apply before model calls.
|
||||
tokenizer: Optional tokenizer used by token-aware compaction strategies.
|
||||
kwargs: Additional keyword arguments (merged into additional_properties).
|
||||
"""
|
||||
self.additional_properties = additional_properties or {}
|
||||
self.compaction_strategy = compaction_strategy
|
||||
self.tokenizer = tokenizer
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]:
|
||||
@@ -337,6 +360,46 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
finalizer=lambda updates: self._finalize_response_updates(updates, response_format=response_format),
|
||||
)
|
||||
|
||||
async def _prepare_messages_for_model_call(
|
||||
self,
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
) -> list[Message]:
|
||||
prepared_messages = list(messages)
|
||||
if compaction_strategy is None:
|
||||
if tokenizer is None:
|
||||
return prepared_messages
|
||||
from ._compaction import annotate_message_groups
|
||||
|
||||
annotate_message_groups(prepared_messages, tokenizer=tokenizer)
|
||||
return prepared_messages
|
||||
from ._compaction import apply_compaction
|
||||
|
||||
return await apply_compaction(
|
||||
prepared_messages,
|
||||
strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
def _resolve_compaction_overrides(
|
||||
self,
|
||||
*,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
) -> dict[str, Any]:
|
||||
current_compaction_strategy = getattr(self, "compaction_strategy", None)
|
||||
current_tokenizer = getattr(self, "tokenizer", None)
|
||||
ret: dict[str, Any] = {}
|
||||
if current_compaction_strategy is not None or compaction_strategy is not None:
|
||||
ret["compaction_strategy"] = (
|
||||
current_compaction_strategy if compaction_strategy is None else compaction_strategy
|
||||
)
|
||||
if current_tokenizer is not None or tokenizer is not None:
|
||||
ret["tokenizer"] = current_tokenizer if tokenizer is None else tokenizer
|
||||
return ret
|
||||
|
||||
# region Internal method to be implemented by derived classes
|
||||
|
||||
@abstractmethod
|
||||
@@ -374,6 +437,8 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: ChatOptions[ResponseModelBoundT],
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ...
|
||||
|
||||
@@ -384,6 +449,8 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: OptionsCoT | ChatOptions[None] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[Any]]: ...
|
||||
|
||||
@@ -394,6 +461,8 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
*,
|
||||
stream: Literal[True],
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
|
||||
|
||||
@@ -403,6 +472,8 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
*,
|
||||
stream: bool = False,
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
|
||||
"""Get a response from a chat client.
|
||||
@@ -411,17 +482,62 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
messages: The message or messages to send to the model.
|
||||
stream: Whether to stream the response. Defaults to False.
|
||||
options: Chat options as a TypedDict.
|
||||
compaction_strategy: Optional per-call override for in-run compaction.
|
||||
When omitted, the client-level default is used.
|
||||
tokenizer: Optional per-call tokenizer override. When omitted, the
|
||||
client-level default is used.
|
||||
**kwargs: Other keyword arguments, can be used to pass function specific parameters.
|
||||
|
||||
Returns:
|
||||
When streaming a response stream of ChatResponseUpdates, otherwise an Awaitable ChatResponse.
|
||||
"""
|
||||
return self._inner_get_response(
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
options=options or {}, # type: ignore[arg-type]
|
||||
**kwargs,
|
||||
compaction_overrides = self._resolve_compaction_overrides(
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
if not compaction_overrides:
|
||||
return self._inner_get_response(
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
options=options or {},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if stream:
|
||||
|
||||
async def _get_stream() -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
|
||||
prepared_messages = await self._prepare_messages_for_model_call(
|
||||
messages,
|
||||
**compaction_overrides,
|
||||
)
|
||||
stream_response = self._inner_get_response(
|
||||
messages=prepared_messages,
|
||||
stream=True,
|
||||
options=options or {},
|
||||
**kwargs,
|
||||
)
|
||||
if isinstance(stream_response, ResponseStream):
|
||||
return stream_response # type: ignore[reportUnknownVariableType]
|
||||
awaited_stream_response = await stream_response
|
||||
if isinstance(awaited_stream_response, ResponseStream):
|
||||
return awaited_stream_response
|
||||
raise ValueError("Streaming responses must return a ResponseStream.")
|
||||
|
||||
return ResponseStream.from_awaitable(_get_stream()) # type: ignore[reportUnknownVariableType]
|
||||
|
||||
async def _get_response() -> ChatResponse[Any]:
|
||||
prepared_messages = await self._prepare_messages_for_model_call(
|
||||
messages,
|
||||
**compaction_overrides,
|
||||
)
|
||||
return await self._inner_get_response(
|
||||
messages=prepared_messages,
|
||||
stream=False,
|
||||
options=options or {},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return _get_response()
|
||||
|
||||
def service_url(self) -> str:
|
||||
"""Get the URL of the service.
|
||||
@@ -446,6 +562,8 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
context_providers: Sequence[Any] | None = None,
|
||||
middleware: Sequence[MiddlewareTypes] | None = None,
|
||||
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Agent[OptionsCoT]:
|
||||
"""Create a Agent with this client.
|
||||
@@ -468,6 +586,10 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
context_providers: Context providers to include during agent invocation.
|
||||
middleware: List of middleware to intercept agent and function invocations.
|
||||
function_invocation_configuration: Optional function invocation configuration override.
|
||||
compaction_strategy: Optional agent-level compaction override. When omitted,
|
||||
client-level compaction defaults remain in effect for each call.
|
||||
tokenizer: Optional agent-level tokenizer override. When omitted,
|
||||
client-level tokenizer defaults remain in effect for each call.
|
||||
kwargs: Any additional keyword arguments. Will be stored as ``additional_properties``.
|
||||
|
||||
Returns:
|
||||
@@ -504,6 +626,8 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
context_providers=context_providers,
|
||||
middleware=middleware,
|
||||
function_invocation_configuration=function_invocation_configuration,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -37,6 +37,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from ._agents import SupportsAgentRun
|
||||
from ._clients import SupportsChatGetResponse
|
||||
from ._compaction import CompactionStrategy, TokenizerProtocol
|
||||
from ._sessions import AgentSession
|
||||
from ._tools import FunctionTool
|
||||
from ._types import ChatOptions, ChatResponse, ChatResponseUpdate
|
||||
@@ -101,6 +102,8 @@ class AgentContext:
|
||||
session: The agent session for this invocation, if any.
|
||||
options: The options for the agent invocation as a dict.
|
||||
stream: Whether this is a streaming invocation.
|
||||
compaction_strategy: Optional per-run compaction override.
|
||||
tokenizer: Optional per-run tokenizer override.
|
||||
metadata: Metadata dictionary for sharing data between agent middleware.
|
||||
result: Agent execution result. Can be observed after calling ``call_next()``
|
||||
to see the actual execution result or can be set to override the execution result.
|
||||
@@ -139,6 +142,8 @@ class AgentContext:
|
||||
session: AgentSession | None = None,
|
||||
options: Mapping[str, Any] | None = None,
|
||||
stream: bool = False,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
metadata: Mapping[str, Any] | None = None,
|
||||
result: AgentResponse | ResponseStream[AgentResponseUpdate, AgentResponse] | None = None,
|
||||
kwargs: Mapping[str, Any] | None = None,
|
||||
@@ -158,6 +163,8 @@ class AgentContext:
|
||||
session: The agent session for this invocation, if any.
|
||||
options: The options for the agent invocation as a dict.
|
||||
stream: Whether this is a streaming invocation.
|
||||
compaction_strategy: Optional per-run compaction override.
|
||||
tokenizer: Optional per-run tokenizer override.
|
||||
metadata: Metadata dictionary for sharing data between agent middleware.
|
||||
result: Agent execution result.
|
||||
kwargs: Additional keyword arguments passed to the agent run method.
|
||||
@@ -170,6 +177,8 @@ class AgentContext:
|
||||
self.session = session
|
||||
self.options = options
|
||||
self.stream = stream
|
||||
self.compaction_strategy = compaction_strategy
|
||||
self.tokenizer = tokenizer
|
||||
self.metadata: dict[str, Any] = dict(metadata) if metadata is not None else {}
|
||||
self.result = result
|
||||
self.kwargs: dict[str, Any] = dict(kwargs) if kwargs is not None else {}
|
||||
@@ -969,6 +978,8 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: ChatOptions[ResponseModelBoundT],
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ...
|
||||
|
||||
@@ -979,6 +990,8 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: OptionsCoT | ChatOptions[None] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[Any]]: ...
|
||||
|
||||
@@ -989,6 +1002,8 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
|
||||
*,
|
||||
stream: Literal[True],
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
|
||||
|
||||
@@ -998,11 +1013,18 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
|
||||
*,
|
||||
stream: bool = False,
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
|
||||
"""Execute the chat pipeline if middleware is configured."""
|
||||
super_get_response = super().get_response # type: ignore[misc]
|
||||
|
||||
if compaction_strategy is not None:
|
||||
kwargs["compaction_strategy"] = compaction_strategy
|
||||
if tokenizer is not None:
|
||||
kwargs["tokenizer"] = tokenizer
|
||||
|
||||
call_middleware = kwargs.pop("middleware", [])
|
||||
middleware = categorize_middleware(call_middleware)
|
||||
kwargs["function_middleware"] = middleware["function"]
|
||||
@@ -1091,6 +1113,8 @@ class AgentMiddlewareLayer:
|
||||
session: AgentSession | None = None,
|
||||
middleware: Sequence[MiddlewareTypes] | None = None,
|
||||
options: ChatOptions[ResponseModelBoundT],
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[ResponseModelBoundT]]: ...
|
||||
|
||||
@@ -1103,6 +1127,8 @@ class AgentMiddlewareLayer:
|
||||
session: AgentSession | None = None,
|
||||
middleware: Sequence[MiddlewareTypes] | None = None,
|
||||
options: ChatOptions[None] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]]: ...
|
||||
|
||||
@@ -1115,6 +1141,8 @@ class AgentMiddlewareLayer:
|
||||
session: AgentSession | None = None,
|
||||
middleware: Sequence[MiddlewareTypes] | None = None,
|
||||
options: ChatOptions[Any] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
|
||||
@@ -1126,6 +1154,8 @@ class AgentMiddlewareLayer:
|
||||
session: AgentSession | None = None,
|
||||
middleware: Sequence[MiddlewareTypes] | None = None,
|
||||
options: ChatOptions[Any] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
|
||||
"""MiddlewareTypes-enabled unified run method."""
|
||||
@@ -1150,7 +1180,15 @@ class AgentMiddlewareLayer:
|
||||
|
||||
# Execute with middleware if available
|
||||
if not pipeline.has_middlewares:
|
||||
return super().run(messages, stream=stream, session=session, options=options, **combined_kwargs) # type: ignore[misc, no-any-return]
|
||||
return super().run( # type: ignore[misc, no-any-return]
|
||||
messages,
|
||||
stream=stream,
|
||||
session=session,
|
||||
options=options,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
**combined_kwargs,
|
||||
)
|
||||
|
||||
context = AgentContext(
|
||||
agent=self, # type: ignore[arg-type]
|
||||
@@ -1158,6 +1196,8 @@ class AgentMiddlewareLayer:
|
||||
session=session,
|
||||
options=options,
|
||||
stream=stream,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
kwargs=combined_kwargs,
|
||||
)
|
||||
|
||||
@@ -1195,6 +1235,8 @@ class AgentMiddlewareLayer:
|
||||
stream=context.stream,
|
||||
session=context.session,
|
||||
options=context.options,
|
||||
compaction_strategy=context.compaction_strategy,
|
||||
tokenizer=context.tokenizer,
|
||||
**context.kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -547,6 +547,7 @@ class InMemoryHistoryProvider(BaseHistoryProvider):
|
||||
store_context_messages: bool = False,
|
||||
store_context_from: set[str] | None = None,
|
||||
store_outputs: bool = True,
|
||||
skip_excluded: bool = False,
|
||||
) -> None:
|
||||
"""Initialize the in-memory history provider.
|
||||
|
||||
@@ -558,6 +559,11 @@ class InMemoryHistoryProvider(BaseHistoryProvider):
|
||||
store_context_messages: Whether to store context from other providers.
|
||||
store_context_from: If set, only store context from these source_ids.
|
||||
store_outputs: Whether to store response messages.
|
||||
skip_excluded: When True, ``get_messages`` omits messages whose
|
||||
``additional_properties["_excluded"]`` is truthy. This is
|
||||
useful when a ``CompactionProvider`` marks messages as excluded
|
||||
in stored history and you want the loaded context to reflect
|
||||
those exclusions. Defaults to False (load all messages).
|
||||
"""
|
||||
super().__init__(
|
||||
source_id=source_id or self.DEFAULT_SOURCE_ID,
|
||||
@@ -567,6 +573,7 @@ class InMemoryHistoryProvider(BaseHistoryProvider):
|
||||
store_context_from=store_context_from,
|
||||
store_outputs=store_outputs,
|
||||
)
|
||||
self.skip_excluded = skip_excluded
|
||||
|
||||
async def get_messages(
|
||||
self, session_id: str | None, *, state: dict[str, Any] | None = None, **kwargs: Any
|
||||
@@ -574,7 +581,10 @@ class InMemoryHistoryProvider(BaseHistoryProvider):
|
||||
"""Retrieve messages from session state."""
|
||||
if state is None:
|
||||
return []
|
||||
return list(state.get("messages", []))
|
||||
messages = list(state.get("messages", []))
|
||||
if self.skip_excluded:
|
||||
messages = [m for m in messages if not m.additional_properties.get("_excluded", False)]
|
||||
return messages
|
||||
|
||||
async def save_messages(
|
||||
self,
|
||||
|
||||
@@ -196,9 +196,7 @@ class SkillScript:
|
||||
self._accepts_kwargs: bool = False
|
||||
if function is not None:
|
||||
sig = inspect.signature(function)
|
||||
self._accepts_kwargs = any(
|
||||
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
|
||||
)
|
||||
self._accepts_kwargs = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values())
|
||||
|
||||
@property
|
||||
def parameters_schema(self) -> dict[str, Any] | None:
|
||||
@@ -454,9 +452,7 @@ class SkillScriptRunner(Protocol):
|
||||
satisfies this protocol.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self, skill: Skill, script: SkillScript, args: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
def __call__(self, skill: Skill, script: SkillScript, args: dict[str, Any] | None = None) -> Any:
|
||||
"""Run a skill script.
|
||||
|
||||
The :class:`SkillsProvider` resolves skill and script names
|
||||
@@ -677,7 +673,7 @@ class SkillsProvider(BaseContextProvider):
|
||||
self._instructions = _create_instructions(
|
||||
prompt_template=instruction_template,
|
||||
skills=self._skills,
|
||||
include_script_runner_instructions=has_file_scripts or has_code_scripts
|
||||
include_script_runner_instructions=has_file_scripts or has_code_scripts,
|
||||
)
|
||||
|
||||
self._tools = self._create_tools(
|
||||
|
||||
@@ -59,6 +59,7 @@ else:
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._clients import SupportsChatGetResponse
|
||||
from ._compaction import CompactionStrategy, TokenizerProtocol
|
||||
from ._mcp import MCPTool
|
||||
from ._middleware import FunctionMiddlewarePipeline, FunctionMiddlewareTypes
|
||||
from ._types import (
|
||||
@@ -1811,6 +1812,8 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: ChatOptions[ResponseModelBoundT],
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ...
|
||||
|
||||
@@ -1821,6 +1824,8 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: OptionsCoT | ChatOptions[None] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[Any]]: ...
|
||||
|
||||
@@ -1831,6 +1836,8 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
*,
|
||||
stream: Literal[True],
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
|
||||
|
||||
@@ -1841,6 +1848,8 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
stream: bool = False,
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
function_middleware: Sequence[FunctionMiddlewareTypes] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
|
||||
from ._middleware import FunctionMiddlewarePipeline
|
||||
@@ -1869,6 +1878,10 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
middleware_pipeline=function_middleware_pipeline,
|
||||
)
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "session"}
|
||||
if compaction_strategy is not None:
|
||||
filtered_kwargs["compaction_strategy"] = compaction_strategy
|
||||
if tokenizer is not None:
|
||||
filtered_kwargs["tokenizer"] = tokenizer
|
||||
|
||||
# Make options mutable so we can update conversation_id during function invocation loop
|
||||
mutable_options: dict[str, Any] = dict(options) if options else {}
|
||||
|
||||
@@ -277,6 +277,17 @@ def _serialize_value(value: Any, exclude_none: bool) -> Any:
|
||||
return value
|
||||
|
||||
|
||||
def _restore_compaction_annotation_in_additional_properties(
|
||||
additional_properties: MutableMapping[str, Any] | None,
|
||||
*,
|
||||
allow_none: bool = False,
|
||||
) -> dict[str, Any] | None:
|
||||
if additional_properties is None:
|
||||
return None if allow_none else {}
|
||||
|
||||
return dict(additional_properties)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Constants and types
|
||||
@@ -509,7 +520,9 @@ class Content:
|
||||
"""
|
||||
self.type = type
|
||||
self.annotations = annotations
|
||||
self.additional_properties: dict[str, Any] = additional_properties or {} # type: ignore[assignment]
|
||||
self.additional_properties: dict[str, Any] = (
|
||||
_restore_compaction_annotation_in_additional_properties(additional_properties) or {}
|
||||
)
|
||||
self.raw_representation = raw_representation
|
||||
|
||||
# Set all content-specific attributes
|
||||
@@ -1638,7 +1651,9 @@ class Message(SerializationMixin):
|
||||
self.contents = parsed_contents
|
||||
self.author_name = author_name
|
||||
self.message_id = message_id
|
||||
self.additional_properties = additional_properties or {}
|
||||
self.additional_properties = (
|
||||
_restore_compaction_annotation_in_additional_properties(additional_properties) or {}
|
||||
)
|
||||
self.raw_representation = raw_representation
|
||||
|
||||
@property
|
||||
@@ -1989,7 +2004,9 @@ class ChatResponse(SerializationMixin, Generic[ResponseModelT]):
|
||||
self._value: ResponseModelT | None = value
|
||||
self._response_format: type[BaseModel] | None = response_format
|
||||
self._value_parsed: bool = value is not None
|
||||
self.additional_properties = additional_properties or {}
|
||||
self.additional_properties = (
|
||||
_restore_compaction_annotation_in_additional_properties(additional_properties) or {}
|
||||
)
|
||||
self.continuation_token = continuation_token
|
||||
self.raw_representation: Any | list[Any] | None = raw_representation
|
||||
|
||||
@@ -2239,7 +2256,10 @@ class ChatResponseUpdate(SerializationMixin):
|
||||
self.created_at = created_at
|
||||
self.finish_reason = finish_reason
|
||||
self.continuation_token = continuation_token
|
||||
self.additional_properties = additional_properties
|
||||
self.additional_properties = _restore_compaction_annotation_in_additional_properties(
|
||||
additional_properties,
|
||||
allow_none=True,
|
||||
)
|
||||
self.raw_representation = raw_representation
|
||||
|
||||
@property
|
||||
@@ -2352,7 +2372,9 @@ class AgentResponse(SerializationMixin, Generic[ResponseModelT]):
|
||||
self._value: ResponseModelT | None = value
|
||||
self._response_format: type[BaseModel] | None = response_format
|
||||
self._value_parsed: bool = value is not None
|
||||
self.additional_properties = additional_properties or {}
|
||||
self.additional_properties = (
|
||||
_restore_compaction_annotation_in_additional_properties(additional_properties) or {}
|
||||
)
|
||||
self.continuation_token = continuation_token
|
||||
self.raw_representation = raw_representation
|
||||
|
||||
@@ -2582,7 +2604,10 @@ class AgentResponseUpdate(SerializationMixin):
|
||||
self.message_id = message_id
|
||||
self.created_at = created_at
|
||||
self.continuation_token = continuation_token
|
||||
self.additional_properties = additional_properties
|
||||
self.additional_properties = _restore_compaction_annotation_in_additional_properties(
|
||||
additional_properties,
|
||||
allow_none=True,
|
||||
)
|
||||
self.raw_representation: Any | list[Any] | None = raw_representation
|
||||
|
||||
@property
|
||||
@@ -3381,7 +3406,9 @@ class Embedding(Generic[EmbeddingT]):
|
||||
self._dimensions = dimensions
|
||||
self.model_id = model_id
|
||||
self.created_at = created_at
|
||||
self.additional_properties = additional_properties or {}
|
||||
self.additional_properties = (
|
||||
_restore_compaction_annotation_in_additional_properties(additional_properties) or {}
|
||||
)
|
||||
|
||||
@property
|
||||
def dimensions(self) -> int | None:
|
||||
@@ -3439,7 +3466,9 @@ class GeneratedEmbeddings(list[Embedding[EmbeddingT]], Generic[EmbeddingT, Embed
|
||||
super().__init__(embeddings or [])
|
||||
self.options = options
|
||||
self.usage = usage
|
||||
self.additional_properties = additional_properties or {}
|
||||
self.additional_properties = (
|
||||
_restore_compaction_annotation_in_additional_properties(additional_properties) or {}
|
||||
)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -49,6 +49,7 @@ if TYPE_CHECKING: # pragma: no cover
|
||||
|
||||
from ._agents import SupportsAgentRun
|
||||
from ._clients import SupportsChatGetResponse
|
||||
from ._compaction import CompactionStrategy, TokenizerProtocol
|
||||
from ._sessions import AgentSession
|
||||
from ._tools import FunctionTool
|
||||
from ._types import (
|
||||
@@ -1122,6 +1123,8 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: ChatOptions[ResponseModelBoundT],
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ...
|
||||
|
||||
@@ -1132,6 +1135,8 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: OptionsCoT | ChatOptions[None] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[Any]]: ...
|
||||
|
||||
@@ -1142,6 +1147,8 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
|
||||
*,
|
||||
stream: Literal[True],
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
|
||||
|
||||
@@ -1151,6 +1158,8 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
|
||||
*,
|
||||
stream: bool = False,
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
|
||||
"""Trace chat responses with OpenTelemetry spans and metrics."""
|
||||
@@ -1160,7 +1169,14 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
|
||||
super_get_response = super().get_response # type: ignore[misc]
|
||||
|
||||
if not OBSERVABILITY_SETTINGS.ENABLED:
|
||||
return super_get_response(messages=messages, stream=stream, options=options, **kwargs) # type: ignore[no-any-return]
|
||||
return super_get_response( # type: ignore[no-any-return]
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
options=options,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
opts: dict[str, Any] = options or {} # type: ignore[assignment]
|
||||
provider_name = str(getattr(self, "otel_provider_name", "unknown"))
|
||||
@@ -1178,7 +1194,14 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
|
||||
if stream:
|
||||
result_stream = cast(
|
||||
ResponseStream[ChatResponseUpdate, ChatResponse[Any]],
|
||||
super_get_response(messages=messages, stream=True, options=opts, **kwargs),
|
||||
super_get_response(
|
||||
messages=messages,
|
||||
stream=True,
|
||||
options=opts,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
**kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
# Create span directly without trace.use_span() context attachment.
|
||||
@@ -1266,6 +1289,8 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
|
||||
messages=messages,
|
||||
stream=False,
|
||||
options=opts,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
**kwargs,
|
||||
),
|
||||
)
|
||||
@@ -1393,6 +1418,8 @@ class AgentTelemetryLayer:
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]]: ...
|
||||
|
||||
@@ -1403,6 +1430,8 @@ class AgentTelemetryLayer:
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
|
||||
@@ -1412,6 +1441,8 @@ class AgentTelemetryLayer:
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
|
||||
"""Trace agent runs with OpenTelemetry spans and metrics."""
|
||||
@@ -1430,6 +1461,8 @@ class AgentTelemetryLayer:
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
session=session,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -1452,6 +1485,8 @@ class AgentTelemetryLayer:
|
||||
messages=messages,
|
||||
stream=True,
|
||||
session=session,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
**kwargs,
|
||||
)
|
||||
if isinstance(run_result, ResponseStream):
|
||||
@@ -1541,6 +1576,8 @@ class AgentTelemetryLayer:
|
||||
messages=messages,
|
||||
stream=False,
|
||||
session=session,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as exception:
|
||||
|
||||
@@ -1164,7 +1164,6 @@ class RawOpenAIResponsesClient( # type: ignore[misc]
|
||||
"type": "function_call",
|
||||
"name": content.name,
|
||||
"arguments": content.arguments,
|
||||
"status": None,
|
||||
}
|
||||
case "function_result":
|
||||
shell_output_type = (
|
||||
|
||||
@@ -10,6 +10,8 @@ import pytest
|
||||
from pytest import raises
|
||||
|
||||
from agent_framework import (
|
||||
GROUP_ANNOTATION_KEY,
|
||||
GROUP_TOKEN_COUNT_KEY,
|
||||
Agent,
|
||||
AgentResponse,
|
||||
AgentResponseUpdate,
|
||||
@@ -21,14 +23,24 @@ from agent_framework import (
|
||||
Content,
|
||||
FunctionTool,
|
||||
Message,
|
||||
SlidingWindowStrategy,
|
||||
SupportsAgentRun,
|
||||
SupportsChatGetResponse,
|
||||
TruncationStrategy,
|
||||
tool,
|
||||
)
|
||||
from agent_framework._agents import _get_tool_name, _merge_options, _sanitize_agent_name
|
||||
from agent_framework._mcp import MCPTool
|
||||
|
||||
|
||||
class _FixedTokenizer:
|
||||
def __init__(self, token_count: int) -> None:
|
||||
self.token_count = token_count
|
||||
|
||||
def count_tokens(self, text: str) -> int:
|
||||
return self.token_count
|
||||
|
||||
|
||||
def test_agent_session_type(agent_session: AgentSession) -> None:
|
||||
assert isinstance(agent_session, AgentSession)
|
||||
|
||||
@@ -217,6 +229,30 @@ async def test_prepare_session_does_not_mutate_agent_chat_options(
|
||||
assert len(agent.default_options["tools"]) == 1
|
||||
|
||||
|
||||
async def test_prepare_run_context_keeps_compaction_overrides_out_of_kwargs(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
) -> None:
|
||||
strategy = SlidingWindowStrategy(keep_last_groups=2)
|
||||
tokenizer = _FixedTokenizer(13)
|
||||
agent = Agent(client=chat_client_base)
|
||||
|
||||
ctx = await agent._prepare_run_context( # type: ignore[reportPrivateUsage]
|
||||
messages=[Message(role="user", text="Hello")],
|
||||
session=None,
|
||||
tools=None,
|
||||
options=None,
|
||||
compaction_strategy=strategy,
|
||||
tokenizer=tokenizer,
|
||||
kwargs={"custom_flag": True},
|
||||
)
|
||||
|
||||
assert ctx["compaction_strategy"] is strategy
|
||||
assert ctx["tokenizer"] is tokenizer
|
||||
assert ctx["filtered_kwargs"].get("custom_flag") is True
|
||||
assert "compaction_strategy" not in ctx["filtered_kwargs"]
|
||||
assert "tokenizer" not in ctx["filtered_kwargs"]
|
||||
|
||||
|
||||
async def test_chat_client_agent_run_with_session(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
) -> None:
|
||||
@@ -1128,6 +1164,102 @@ async def test_chat_agent_tool_choice_none_at_run_preserves_agent_level(chat_cli
|
||||
assert captured_options[0]["tool_choice"] == "auto"
|
||||
|
||||
|
||||
async def test_chat_agent_compaction_overrides_client_defaults(chat_client_base: Any) -> None:
|
||||
captured_roles: list[list[str]] = []
|
||||
captured_token_counts: list[list[int | None]] = []
|
||||
original_inner = chat_client_base._inner_get_response
|
||||
|
||||
async def capturing_inner(
|
||||
*, messages: MutableSequence[Message], options: dict[str, Any], **kwargs: Any
|
||||
) -> ChatResponse:
|
||||
captured_roles.append([message.role for message in messages])
|
||||
captured_token_counts.append([
|
||||
group.get(GROUP_TOKEN_COUNT_KEY) if isinstance(group, dict) else None
|
||||
for group in (message.additional_properties.get(GROUP_ANNOTATION_KEY) for message in messages)
|
||||
])
|
||||
return await original_inner(messages=messages, options=options, **kwargs)
|
||||
|
||||
chat_client_base._inner_get_response = capturing_inner
|
||||
chat_client_base.function_invocation_configuration["enabled"] = False
|
||||
chat_client_base.compaction_strategy = TruncationStrategy(max_n=1, compact_to=1)
|
||||
chat_client_base.tokenizer = _FixedTokenizer(5)
|
||||
|
||||
agent = Agent(
|
||||
client=chat_client_base,
|
||||
compaction_strategy=SlidingWindowStrategy(keep_last_groups=2),
|
||||
tokenizer=_FixedTokenizer(9),
|
||||
)
|
||||
|
||||
await agent.run([
|
||||
Message(role="user", text="Hello"),
|
||||
Message(role="assistant", text="Previous response"),
|
||||
])
|
||||
|
||||
assert captured_roles == [["user", "assistant"]]
|
||||
assert captured_token_counts == [[9, 9]]
|
||||
|
||||
|
||||
async def test_chat_agent_uses_client_compaction_defaults_when_agent_unset(chat_client_base: Any) -> None:
|
||||
captured_roles: list[list[str]] = []
|
||||
original_inner = chat_client_base._inner_get_response
|
||||
|
||||
async def capturing_inner(
|
||||
*, messages: MutableSequence[Message], options: dict[str, Any], **kwargs: Any
|
||||
) -> ChatResponse:
|
||||
captured_roles.append([message.role for message in messages])
|
||||
return await original_inner(messages=messages, options=options, **kwargs)
|
||||
|
||||
chat_client_base._inner_get_response = capturing_inner
|
||||
chat_client_base.function_invocation_configuration["enabled"] = False
|
||||
chat_client_base.compaction_strategy = TruncationStrategy(max_n=1, compact_to=1)
|
||||
|
||||
agent = Agent(client=chat_client_base)
|
||||
|
||||
await agent.run([
|
||||
Message(role="user", text="Hello"),
|
||||
Message(role="assistant", text="Previous response"),
|
||||
])
|
||||
|
||||
assert captured_roles == [["assistant"]]
|
||||
|
||||
|
||||
async def test_chat_agent_run_level_compaction_and_tokenizer_override_agent_defaults(chat_client_base: Any) -> None:
|
||||
captured_roles: list[list[str]] = []
|
||||
captured_token_counts: list[list[int | None]] = []
|
||||
original_inner = chat_client_base._inner_get_response
|
||||
|
||||
async def capturing_inner(
|
||||
*, messages: MutableSequence[Message], options: dict[str, Any], **kwargs: Any
|
||||
) -> ChatResponse:
|
||||
captured_roles.append([message.role for message in messages])
|
||||
captured_token_counts.append([
|
||||
group.get(GROUP_TOKEN_COUNT_KEY) if isinstance(group, dict) else None
|
||||
for group in (message.additional_properties.get(GROUP_ANNOTATION_KEY) for message in messages)
|
||||
])
|
||||
return await original_inner(messages=messages, options=options, **kwargs)
|
||||
|
||||
chat_client_base._inner_get_response = capturing_inner
|
||||
chat_client_base.function_invocation_configuration["enabled"] = False
|
||||
|
||||
agent = Agent(
|
||||
client=chat_client_base,
|
||||
compaction_strategy=SlidingWindowStrategy(keep_last_groups=2),
|
||||
tokenizer=_FixedTokenizer(9),
|
||||
)
|
||||
|
||||
await agent.run(
|
||||
[
|
||||
Message(role="user", text="Hello"),
|
||||
Message(role="assistant", text="Previous response"),
|
||||
],
|
||||
compaction_strategy=TruncationStrategy(max_n=1, compact_to=1),
|
||||
tokenizer=_FixedTokenizer(23),
|
||||
)
|
||||
|
||||
assert captured_roles == [["assistant"]]
|
||||
assert captured_token_counts == [[23]]
|
||||
|
||||
|
||||
# region Test _merge_options
|
||||
|
||||
|
||||
|
||||
@@ -1,21 +1,34 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
from agent_framework import (
|
||||
GROUP_ANNOTATION_KEY,
|
||||
GROUP_TOKEN_COUNT_KEY,
|
||||
BaseChatClient,
|
||||
ChatResponse,
|
||||
Message,
|
||||
SlidingWindowStrategy,
|
||||
SupportsChatGetResponse,
|
||||
SupportsCodeInterpreterTool,
|
||||
SupportsFileSearchTool,
|
||||
SupportsImageGenerationTool,
|
||||
SupportsMCPTool,
|
||||
SupportsWebSearchTool,
|
||||
TruncationStrategy,
|
||||
)
|
||||
|
||||
|
||||
class _FixedTokenizer:
|
||||
def __init__(self, token_count: int) -> None:
|
||||
self.token_count = token_count
|
||||
|
||||
def count_tokens(self, text: str) -> int:
|
||||
return self.token_count
|
||||
|
||||
|
||||
def test_chat_client_type(client: SupportsChatGetResponse):
|
||||
assert isinstance(client, SupportsChatGetResponse)
|
||||
|
||||
@@ -48,6 +61,190 @@ async def test_base_client_get_response_streaming(chat_client_base: SupportsChat
|
||||
assert update.text == "update - Hello" or update.text == "another update"
|
||||
|
||||
|
||||
async def test_base_client_applies_compaction_before_non_streaming_inner_call(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
):
|
||||
chat_client_base.function_invocation_configuration["enabled"] = False # type: ignore[attr-defined]
|
||||
chat_client_base.compaction_strategy = TruncationStrategy(max_n=1, compact_to=1) # type: ignore[attr-defined]
|
||||
captured_roles: list[list[str]] = []
|
||||
original = chat_client_base._get_non_streaming_response # type: ignore[attr-defined]
|
||||
|
||||
async def _capture(
|
||||
*,
|
||||
messages: list[Message],
|
||||
options: dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> ChatResponse:
|
||||
captured_roles.append([message.role for message in messages])
|
||||
return await original(messages=messages, options=options, **kwargs)
|
||||
|
||||
chat_client_base._get_non_streaming_response = _capture # type: ignore[attr-defined,method-assign]
|
||||
await chat_client_base.get_response([
|
||||
Message(role="user", text="Hello"),
|
||||
Message(role="assistant", text="Previous response"),
|
||||
])
|
||||
assert captured_roles == [["assistant"]]
|
||||
|
||||
|
||||
async def test_base_client_applies_compaction_before_streaming_inner_call(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
):
|
||||
chat_client_base.function_invocation_configuration["enabled"] = False # type: ignore[attr-defined]
|
||||
chat_client_base.compaction_strategy = TruncationStrategy(max_n=1, compact_to=1) # type: ignore[attr-defined]
|
||||
captured_roles: list[list[str]] = []
|
||||
original = chat_client_base._get_streaming_response # type: ignore[attr-defined]
|
||||
|
||||
def _capture(
|
||||
*,
|
||||
messages: list[Message],
|
||||
options: dict[str, Any],
|
||||
**kwargs: Any,
|
||||
):
|
||||
captured_roles.append([message.role for message in messages])
|
||||
return original(messages=messages, options=options, **kwargs)
|
||||
|
||||
chat_client_base._get_streaming_response = _capture # type: ignore[attr-defined,method-assign]
|
||||
async for _ in chat_client_base.get_response(
|
||||
[
|
||||
Message(role="user", text="Hello"),
|
||||
Message(role="assistant", text="Previous response"),
|
||||
],
|
||||
stream=True,
|
||||
):
|
||||
pass
|
||||
assert captured_roles == [["assistant"]]
|
||||
|
||||
|
||||
async def test_base_client_per_call_compaction_override_applies_before_inner_call(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
) -> None:
|
||||
chat_client_base.function_invocation_configuration["enabled"] = False # type: ignore[attr-defined]
|
||||
captured_roles: list[list[str]] = []
|
||||
original = chat_client_base._get_non_streaming_response # type: ignore[attr-defined]
|
||||
|
||||
async def _capture(
|
||||
*,
|
||||
messages: list[Message],
|
||||
options: dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> ChatResponse:
|
||||
captured_roles.append([message.role for message in messages])
|
||||
return await original(messages=messages, options=options, **kwargs)
|
||||
|
||||
chat_client_base._get_non_streaming_response = _capture # type: ignore[attr-defined,method-assign]
|
||||
await chat_client_base.get_response(
|
||||
[
|
||||
Message(role="user", text="Hello"),
|
||||
Message(role="assistant", text="Previous response"),
|
||||
],
|
||||
compaction_strategy=TruncationStrategy(max_n=1, compact_to=1),
|
||||
)
|
||||
assert captured_roles == [["assistant"]]
|
||||
|
||||
|
||||
async def test_base_client_per_call_tokenizer_override_annotates_messages(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
) -> None:
|
||||
chat_client_base.function_invocation_configuration["enabled"] = False # type: ignore[attr-defined]
|
||||
captured_token_counts: list[list[int | None]] = []
|
||||
original = chat_client_base._get_non_streaming_response # type: ignore[attr-defined]
|
||||
|
||||
async def _capture(
|
||||
*,
|
||||
messages: list[Message],
|
||||
options: dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> ChatResponse:
|
||||
captured_token_counts.append([
|
||||
group.get(GROUP_TOKEN_COUNT_KEY) if isinstance(group, dict) else None
|
||||
for group in (message.additional_properties.get(GROUP_ANNOTATION_KEY) for message in messages)
|
||||
])
|
||||
return await original(messages=messages, options=options, **kwargs)
|
||||
|
||||
chat_client_base._get_non_streaming_response = _capture # type: ignore[attr-defined,method-assign]
|
||||
await chat_client_base.get_response(
|
||||
[
|
||||
Message(role="user", text="Hello"),
|
||||
Message(role="assistant", text="Previous response"),
|
||||
],
|
||||
compaction_strategy=SlidingWindowStrategy(keep_last_groups=2),
|
||||
tokenizer=_FixedTokenizer(17),
|
||||
)
|
||||
assert captured_token_counts == [[17, 17]]
|
||||
|
||||
|
||||
async def test_base_client_per_call_tokenizer_override_without_strategy_annotates_messages(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
) -> None:
|
||||
chat_client_base.function_invocation_configuration["enabled"] = False # type: ignore[attr-defined]
|
||||
captured_token_counts: list[list[int | None]] = []
|
||||
original = chat_client_base._get_non_streaming_response # type: ignore[attr-defined]
|
||||
|
||||
async def _capture(
|
||||
*,
|
||||
messages: list[Message],
|
||||
options: dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> ChatResponse:
|
||||
captured_token_counts.append([
|
||||
group.get(GROUP_TOKEN_COUNT_KEY) if isinstance(group, dict) else None
|
||||
for group in (message.additional_properties.get(GROUP_ANNOTATION_KEY) for message in messages)
|
||||
])
|
||||
return await original(messages=messages, options=options, **kwargs)
|
||||
|
||||
chat_client_base._get_non_streaming_response = _capture # type: ignore[attr-defined,method-assign]
|
||||
await chat_client_base.get_response(
|
||||
[
|
||||
Message(role="user", text="Hello"),
|
||||
Message(role="assistant", text="Previous response"),
|
||||
],
|
||||
tokenizer=_FixedTokenizer(17),
|
||||
)
|
||||
assert captured_token_counts == [[17, 17]]
|
||||
|
||||
|
||||
async def test_base_client_default_tokenizer_without_strategy_annotates_messages(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
) -> None:
|
||||
chat_client_base.function_invocation_configuration["enabled"] = False # type: ignore[attr-defined]
|
||||
chat_client_base.tokenizer = _FixedTokenizer(19) # type: ignore[attr-defined]
|
||||
captured_token_counts: list[list[int | None]] = []
|
||||
original = chat_client_base._get_non_streaming_response # type: ignore[attr-defined]
|
||||
|
||||
async def _capture(
|
||||
*,
|
||||
messages: list[Message],
|
||||
options: dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> ChatResponse:
|
||||
captured_token_counts.append([
|
||||
group.get(GROUP_TOKEN_COUNT_KEY) if isinstance(group, dict) else None
|
||||
for group in (message.additional_properties.get(GROUP_ANNOTATION_KEY) for message in messages)
|
||||
])
|
||||
return await original(messages=messages, options=options, **kwargs)
|
||||
|
||||
chat_client_base._get_non_streaming_response = _capture # type: ignore[attr-defined,method-assign]
|
||||
await chat_client_base.get_response([
|
||||
Message(role="user", text="Hello"),
|
||||
Message(role="assistant", text="Previous response"),
|
||||
])
|
||||
assert captured_token_counts == [[19, 19]]
|
||||
|
||||
|
||||
def test_base_client_as_agent_does_not_copy_client_compaction_defaults(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
) -> None:
|
||||
strategy = TruncationStrategy(max_n=1, compact_to=1)
|
||||
tokenizer = _FixedTokenizer(11)
|
||||
chat_client_base.compaction_strategy = strategy # type: ignore[attr-defined]
|
||||
chat_client_base.tokenizer = tokenizer # type: ignore[attr-defined]
|
||||
|
||||
agent = chat_client_base.as_agent(name="shared-client-agent")
|
||||
|
||||
assert agent.compaction_strategy is None # type: ignore[attr-defined]
|
||||
assert agent.tokenizer is None # type: ignore[attr-defined]
|
||||
|
||||
|
||||
async def test_chat_client_instructions_handling(chat_client_base: SupportsChatGetResponse):
|
||||
instructions = "You are a helpful assistant."
|
||||
|
||||
|
||||
@@ -0,0 +1,954 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import (
|
||||
EXCLUDED_KEY,
|
||||
GROUP_ANNOTATION_KEY,
|
||||
GROUP_HAS_REASONING_KEY,
|
||||
GROUP_ID_KEY,
|
||||
GROUP_KIND_KEY,
|
||||
GROUP_TOKEN_COUNT_KEY,
|
||||
SUMMARIZED_BY_SUMMARY_ID_KEY,
|
||||
SUMMARY_OF_GROUP_IDS_KEY,
|
||||
SUMMARY_OF_MESSAGE_IDS_KEY,
|
||||
CharacterEstimatorTokenizer,
|
||||
ChatResponse,
|
||||
CompactionProvider,
|
||||
Content,
|
||||
Message,
|
||||
SelectiveToolCallCompactionStrategy,
|
||||
SlidingWindowStrategy,
|
||||
SummarizationStrategy,
|
||||
TokenBudgetComposedStrategy,
|
||||
ToolResultCompactionStrategy,
|
||||
TruncationStrategy,
|
||||
annotate_message_groups,
|
||||
apply_compaction,
|
||||
included_messages,
|
||||
included_token_count,
|
||||
)
|
||||
from agent_framework._compaction import (
|
||||
append_compaction_message,
|
||||
extend_compaction_messages,
|
||||
)
|
||||
|
||||
|
||||
def _assistant_function_call(call_id: str) -> Message:
|
||||
return Message(
|
||||
role="assistant",
|
||||
contents=[Content.from_function_call(call_id=call_id, name="tool", arguments='{"value":"x"}')],
|
||||
)
|
||||
|
||||
|
||||
def _assistant_reasoning_and_function_calls(*call_ids: str) -> Message:
|
||||
contents: list[Content] = [Content.from_text_reasoning(text="thinking")]
|
||||
for call_id in call_ids:
|
||||
contents.append(
|
||||
Content.from_function_call(
|
||||
call_id=call_id,
|
||||
name="tool",
|
||||
arguments='{"value":"x"}',
|
||||
)
|
||||
)
|
||||
return Message(role="assistant", contents=contents)
|
||||
|
||||
|
||||
def _tool_result(call_id: str, result: str) -> Message:
|
||||
return Message(
|
||||
role="tool",
|
||||
contents=[Content.from_function_result(call_id=call_id, result=result)],
|
||||
)
|
||||
|
||||
|
||||
def _group_id(message: Message) -> str | None:
|
||||
annotation = message.additional_properties.get(GROUP_ANNOTATION_KEY)
|
||||
if not isinstance(annotation, dict):
|
||||
return None
|
||||
value = annotation.get(GROUP_ID_KEY)
|
||||
return value if isinstance(value, str) else None
|
||||
|
||||
|
||||
def _group_kind(message: Message) -> str | None:
|
||||
annotation = message.additional_properties.get(GROUP_ANNOTATION_KEY)
|
||||
if not isinstance(annotation, dict):
|
||||
return None
|
||||
value = annotation.get(GROUP_KIND_KEY)
|
||||
return value if isinstance(value, str) else None
|
||||
|
||||
|
||||
def _group_has_reasoning(message: Message) -> bool | None:
|
||||
annotation = message.additional_properties.get(GROUP_ANNOTATION_KEY)
|
||||
if not isinstance(annotation, dict):
|
||||
return None
|
||||
value = annotation.get(GROUP_HAS_REASONING_KEY)
|
||||
return value if isinstance(value, bool) else None
|
||||
|
||||
|
||||
def _token_count(message: Message) -> int | None:
|
||||
annotation = message.additional_properties.get(GROUP_ANNOTATION_KEY)
|
||||
if not isinstance(annotation, dict):
|
||||
return None
|
||||
value = annotation.get(GROUP_TOKEN_COUNT_KEY)
|
||||
return value if isinstance(value, int) else None
|
||||
|
||||
|
||||
def _group_unknown_value(message: Message, key: str) -> Any:
|
||||
annotation = message.additional_properties.get(GROUP_ANNOTATION_KEY)
|
||||
if not isinstance(annotation, dict):
|
||||
return None
|
||||
return annotation.get(key)
|
||||
|
||||
|
||||
def test_group_annotations_keep_tool_call_and_tool_result_atomic() -> None:
|
||||
messages = [
|
||||
Message(role="user", text="hello"),
|
||||
_assistant_function_call("c1"),
|
||||
_tool_result("c1", "ok"),
|
||||
Message(role="assistant", text="final"),
|
||||
]
|
||||
|
||||
annotate_message_groups(messages)
|
||||
|
||||
call_group = _group_id(messages[1])
|
||||
assert call_group is not None
|
||||
assert call_group == _group_id(messages[2])
|
||||
assert _group_id(messages[1]) != _group_id(messages[0])
|
||||
|
||||
|
||||
def test_group_annotations_include_reasoning_in_tool_call_group() -> None:
|
||||
messages = [
|
||||
_assistant_reasoning_and_function_calls("c2"),
|
||||
_tool_result("c2", "ok"),
|
||||
]
|
||||
|
||||
annotate_message_groups(messages)
|
||||
|
||||
first_group = _group_id(messages[0])
|
||||
assert first_group is not None
|
||||
assert _group_id(messages[1]) == first_group
|
||||
assert _group_has_reasoning(messages[0]) is True
|
||||
assert _group_kind(messages[0]) == "tool_call"
|
||||
|
||||
|
||||
def test_group_annotations_handle_same_message_reasoning_and_function_calls() -> None:
|
||||
messages = [
|
||||
Message(role="user", text="hello"),
|
||||
_assistant_reasoning_and_function_calls("c1", "c2"),
|
||||
_tool_result("c1", "ok1"),
|
||||
_tool_result("c2", "ok2"),
|
||||
Message(role="assistant", text="final"),
|
||||
]
|
||||
|
||||
annotate_message_groups(messages)
|
||||
|
||||
call_group = _group_id(messages[1])
|
||||
assert call_group is not None
|
||||
assert _group_id(messages[2]) == call_group
|
||||
assert _group_id(messages[3]) == call_group
|
||||
assert _group_kind(messages[1]) == "tool_call"
|
||||
assert _group_has_reasoning(messages[1]) is True
|
||||
|
||||
|
||||
def test_annotate_message_groups_with_tokenizer_adds_token_counts() -> None:
|
||||
messages = [
|
||||
Message(role="user", text="hello"),
|
||||
Message(role="assistant", text="world"),
|
||||
]
|
||||
|
||||
annotate_message_groups(
|
||||
messages,
|
||||
tokenizer=CharacterEstimatorTokenizer(),
|
||||
)
|
||||
|
||||
assert isinstance(_token_count(messages[0]), int)
|
||||
assert isinstance(_token_count(messages[1]), int)
|
||||
|
||||
|
||||
def test_extend_compaction_messages_preserves_existing_annotations_and_tokens() -> None:
|
||||
tokenizer = CharacterEstimatorTokenizer()
|
||||
messages = [_assistant_function_call("c3")]
|
||||
annotate_message_groups(messages)
|
||||
old_group_id = _group_id(messages[0])
|
||||
assert old_group_id is not None
|
||||
old_token_count = tokenizer.count_tokens("precomputed")
|
||||
annotation = messages[0].additional_properties.get(GROUP_ANNOTATION_KEY)
|
||||
if isinstance(annotation, dict):
|
||||
annotation[GROUP_TOKEN_COUNT_KEY] = old_token_count
|
||||
|
||||
extend_compaction_messages(messages, [_tool_result("c3", "ok")], tokenizer=tokenizer)
|
||||
|
||||
assert _group_id(messages[1]) == old_group_id
|
||||
assert _token_count(messages[0]) == old_token_count
|
||||
assert isinstance(_token_count(messages[1]), int)
|
||||
|
||||
|
||||
def test_append_compaction_message_annotates_new_message() -> None:
|
||||
messages = [Message(role="user", text="hello")]
|
||||
annotate_message_groups(messages)
|
||||
append_compaction_message(messages, Message(role="assistant", text="world"))
|
||||
|
||||
assert len(messages) == 2
|
||||
assert isinstance(_group_id(messages[1]), str)
|
||||
|
||||
|
||||
async def test_truncation_strategy_keeps_system_anchor() -> None:
|
||||
messages = [
|
||||
Message(role="system", text="you are helpful"),
|
||||
Message(role="user", text="u1"),
|
||||
Message(role="assistant", text="a1"),
|
||||
Message(role="user", text="u2"),
|
||||
Message(role="assistant", text="a2"),
|
||||
]
|
||||
strategy = TruncationStrategy(max_n=3, compact_to=3, preserve_system=True)
|
||||
annotate_message_groups(messages)
|
||||
|
||||
changed = await strategy(messages)
|
||||
|
||||
assert changed is True
|
||||
projected = included_messages(messages)
|
||||
assert projected[0].role == "system"
|
||||
assert len(projected) <= 3
|
||||
|
||||
|
||||
async def test_truncation_strategy_compacts_when_token_limit_exceeded() -> None:
|
||||
tokenizer = CharacterEstimatorTokenizer()
|
||||
messages = [
|
||||
Message(role="system", text="you are helpful"),
|
||||
Message(role="user", text="u1 " * 200),
|
||||
Message(role="assistant", text="a1 " * 200),
|
||||
]
|
||||
strategy = TruncationStrategy(
|
||||
max_n=80,
|
||||
compact_to=40,
|
||||
tokenizer=tokenizer,
|
||||
preserve_system=True,
|
||||
)
|
||||
annotate_message_groups(messages, tokenizer=tokenizer)
|
||||
|
||||
changed = await strategy(messages)
|
||||
|
||||
assert changed is True
|
||||
projected = included_messages(messages)
|
||||
assert projected[0].role == "system"
|
||||
assert included_token_count(messages) <= 40
|
||||
|
||||
|
||||
def test_truncation_strategy_validates_token_targets() -> None:
|
||||
try:
|
||||
TruncationStrategy(max_n=3, compact_to=4)
|
||||
except ValueError as exc:
|
||||
assert "compact_to must be less than or equal to max_n" in str(exc)
|
||||
else:
|
||||
raise AssertionError("Expected ValueError when compact_to is greater than max_n.")
|
||||
|
||||
|
||||
async def test_selective_tool_call_strategy_excludes_older_tool_groups() -> None:
|
||||
messages = [
|
||||
Message(role="user", text="u"),
|
||||
_assistant_function_call("call-1"),
|
||||
_tool_result("call-1", "r1"),
|
||||
_assistant_function_call("call-2"),
|
||||
_tool_result("call-2", "r2"),
|
||||
Message(role="assistant", text="done"),
|
||||
]
|
||||
strategy = SelectiveToolCallCompactionStrategy(keep_last_tool_call_groups=1)
|
||||
annotate_message_groups(messages)
|
||||
|
||||
changed = await strategy(messages)
|
||||
|
||||
assert changed is True
|
||||
assert messages[1].additional_properties.get(EXCLUDED_KEY) is True
|
||||
assert messages[2].additional_properties.get(EXCLUDED_KEY) is True
|
||||
assert messages[3].additional_properties.get(EXCLUDED_KEY) is not True
|
||||
assert messages[4].additional_properties.get(EXCLUDED_KEY) is not True
|
||||
|
||||
|
||||
async def test_selective_tool_call_strategy_with_zero_removes_assistant_tool_pair() -> None:
|
||||
messages = [
|
||||
Message(role="user", text="u"),
|
||||
_assistant_function_call("call-1"),
|
||||
_tool_result("call-1", "r1"),
|
||||
Message(role="assistant", text="done"),
|
||||
]
|
||||
strategy = SelectiveToolCallCompactionStrategy(keep_last_tool_call_groups=0)
|
||||
annotate_message_groups(messages)
|
||||
|
||||
changed = await strategy(messages)
|
||||
|
||||
assert changed is True
|
||||
assert messages[1].additional_properties.get(EXCLUDED_KEY) is True
|
||||
assert messages[2].additional_properties.get(EXCLUDED_KEY) is True
|
||||
assert messages[0].additional_properties.get(EXCLUDED_KEY) is not True
|
||||
assert messages[3].additional_properties.get(EXCLUDED_KEY) is not True
|
||||
|
||||
|
||||
def test_selective_tool_call_strategy_rejects_negative_keep_count() -> None:
|
||||
try:
|
||||
SelectiveToolCallCompactionStrategy(keep_last_tool_call_groups=-1)
|
||||
except ValueError as exc:
|
||||
assert "must be greater than or equal to 0" in str(exc)
|
||||
else:
|
||||
raise AssertionError("Expected ValueError for negative keep_last_tool_call_groups.")
|
||||
|
||||
|
||||
class _FakeSummarizer:
|
||||
async def get_response(
|
||||
self,
|
||||
messages: list[Message],
|
||||
*,
|
||||
stream: bool = False,
|
||||
options: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResponse:
|
||||
return ChatResponse(messages=[Message(role="assistant", text="summarized context")])
|
||||
|
||||
|
||||
class _FailingSummarizer:
|
||||
async def get_response(
|
||||
self,
|
||||
messages: list[Message],
|
||||
*,
|
||||
stream: bool = False,
|
||||
options: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResponse:
|
||||
raise RuntimeError("summary failed")
|
||||
|
||||
|
||||
class _EmptySummarizer:
|
||||
async def get_response(
|
||||
self,
|
||||
messages: list[Message],
|
||||
*,
|
||||
stream: bool = False,
|
||||
options: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResponse:
|
||||
return ChatResponse(messages=[Message(role="assistant", text=" ")])
|
||||
|
||||
|
||||
async def test_summarization_strategy_adds_bidirectional_trace_links() -> None:
|
||||
messages = [
|
||||
Message(role="user", text="u1"),
|
||||
Message(role="assistant", text="a1"),
|
||||
Message(role="user", text="u2"),
|
||||
Message(role="assistant", text="a2"),
|
||||
Message(role="user", text="u3"),
|
||||
Message(role="assistant", text="a3"),
|
||||
]
|
||||
strategy = SummarizationStrategy(client=_FakeSummarizer(), target_count=2, threshold=0)
|
||||
annotate_message_groups(messages)
|
||||
|
||||
changed = await strategy(messages)
|
||||
|
||||
assert changed is True
|
||||
summary_messages = [
|
||||
message for message in messages if _group_unknown_value(message, SUMMARY_OF_MESSAGE_IDS_KEY) is not None
|
||||
]
|
||||
assert len(summary_messages) == 1
|
||||
summary = summary_messages[0]
|
||||
summary_id = summary.message_id
|
||||
assert summary_id is not None
|
||||
assert _group_unknown_value(summary, SUMMARY_OF_GROUP_IDS_KEY)
|
||||
summarized_message_ids = _group_unknown_value(summary, SUMMARY_OF_MESSAGE_IDS_KEY)
|
||||
assert isinstance(summarized_message_ids, list)
|
||||
for message in messages:
|
||||
if message.message_id in summarized_message_ids:
|
||||
assert _group_unknown_value(message, SUMMARIZED_BY_SUMMARY_ID_KEY) == summary_id
|
||||
assert message.additional_properties.get(EXCLUDED_KEY) is True
|
||||
|
||||
|
||||
async def test_summarization_strategy_returns_false_when_summary_generation_fails(
|
||||
caplog: Any,
|
||||
) -> None:
|
||||
messages = [
|
||||
Message(role="user", text="u1"),
|
||||
Message(role="assistant", text="a1"),
|
||||
Message(role="user", text="u2"),
|
||||
Message(role="assistant", text="a2"),
|
||||
Message(role="user", text="u3"),
|
||||
Message(role="assistant", text="a3"),
|
||||
]
|
||||
strategy = SummarizationStrategy(client=_FailingSummarizer(), target_count=2, threshold=0)
|
||||
annotate_message_groups(messages)
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="agent_framework"):
|
||||
changed = await strategy(messages)
|
||||
|
||||
assert changed is False
|
||||
assert any("summary generation failed" in record.message for record in caplog.records)
|
||||
assert all(message.additional_properties.get(EXCLUDED_KEY) is not True for message in messages)
|
||||
|
||||
|
||||
async def test_summarization_strategy_returns_false_when_summary_is_empty(
|
||||
caplog: Any,
|
||||
) -> None:
|
||||
messages = [
|
||||
Message(role="user", text="u1"),
|
||||
Message(role="assistant", text="a1"),
|
||||
Message(role="user", text="u2"),
|
||||
Message(role="assistant", text="a2"),
|
||||
Message(role="user", text="u3"),
|
||||
Message(role="assistant", text="a3"),
|
||||
]
|
||||
strategy = SummarizationStrategy(client=_EmptySummarizer(), target_count=2, threshold=0)
|
||||
annotate_message_groups(messages)
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="agent_framework"):
|
||||
changed = await strategy(messages)
|
||||
|
||||
assert changed is False
|
||||
assert any("returned no text" in record.message for record in caplog.records)
|
||||
assert all(message.additional_properties.get(EXCLUDED_KEY) is not True for message in messages)
|
||||
|
||||
|
||||
async def test_token_budget_composed_strategy_meets_budget_or_falls_back() -> None:
|
||||
messages = [
|
||||
Message(role="system", text="system"),
|
||||
Message(role="user", text="user " * 200),
|
||||
Message(role="assistant", text="assistant " * 200),
|
||||
]
|
||||
strategy = TokenBudgetComposedStrategy(
|
||||
token_budget=20,
|
||||
tokenizer=CharacterEstimatorTokenizer(),
|
||||
strategies=[SlidingWindowStrategy(keep_last_groups=1)],
|
||||
)
|
||||
|
||||
changed = await strategy(messages)
|
||||
|
||||
assert changed is True
|
||||
assert included_token_count(messages) <= 20
|
||||
|
||||
|
||||
class _ExcludeOldestNonSystem:
|
||||
async def __call__(self, messages: list[Message]) -> bool:
|
||||
group_ids = annotate_message_groups(messages)
|
||||
kinds: dict[str, str] = {}
|
||||
for message in messages:
|
||||
group_id = _group_id(message)
|
||||
kind = _group_kind(message)
|
||||
if group_id is not None and kind is not None and group_id not in kinds:
|
||||
kinds[group_id] = kind
|
||||
for group_id in group_ids:
|
||||
if kinds.get(group_id) == "system":
|
||||
continue
|
||||
for message in messages:
|
||||
if _group_id(message) == group_id:
|
||||
message.additional_properties[EXCLUDED_KEY] = True
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def test_apply_compaction_projects_included_messages_only() -> None:
|
||||
messages = [
|
||||
Message(role="system", text="sys"),
|
||||
Message(role="user", text="hello"),
|
||||
Message(role="assistant", text="world"),
|
||||
]
|
||||
|
||||
projected = await apply_compaction(messages, strategy=_ExcludeOldestNonSystem())
|
||||
|
||||
assert len(projected) < len(messages)
|
||||
assert projected[0].role == "system"
|
||||
|
||||
|
||||
# --- ToolResultCompactionStrategy tests ---
|
||||
|
||||
|
||||
async def test_tool_result_compaction_collapses_old_groups_into_summary() -> None:
|
||||
"""Old tool-call groups are collapsed into summary messages, newest kept."""
|
||||
messages = [
|
||||
Message(role="user", text="u"),
|
||||
_assistant_function_call("call-1"),
|
||||
_tool_result("call-1", "r1"),
|
||||
_assistant_function_call("call-2"),
|
||||
_tool_result("call-2", "r2"),
|
||||
Message(role="assistant", text="done"),
|
||||
]
|
||||
strategy = ToolResultCompactionStrategy(keep_last_tool_call_groups=1)
|
||||
annotate_message_groups(messages)
|
||||
|
||||
changed = await strategy(messages)
|
||||
|
||||
assert changed is True
|
||||
projected = included_messages(messages)
|
||||
texts = [m.text or "" for m in projected]
|
||||
summary_msgs = [t for t in texts if t.startswith("[Tool results:")]
|
||||
assert len(summary_msgs) == 1
|
||||
assert "r1" in summary_msgs[0]
|
||||
assert any(m.role == "tool" for m in projected)
|
||||
|
||||
|
||||
async def test_tool_result_compaction_zero_collapses_all() -> None:
|
||||
"""With keep=0, all tool-call groups are collapsed into summaries."""
|
||||
messages = [
|
||||
Message(role="user", text="u"),
|
||||
_assistant_function_call("call-1"),
|
||||
_tool_result("call-1", "r1"),
|
||||
_assistant_function_call("call-2"),
|
||||
_tool_result("call-2", "r2"),
|
||||
Message(role="assistant", text="done"),
|
||||
]
|
||||
strategy = ToolResultCompactionStrategy(keep_last_tool_call_groups=0)
|
||||
annotate_message_groups(messages)
|
||||
|
||||
changed = await strategy(messages)
|
||||
|
||||
assert changed is True
|
||||
projected = included_messages(messages)
|
||||
summary_msgs = [m for m in projected if (m.text or "").startswith("[Tool results:")]
|
||||
assert len(summary_msgs) == 2
|
||||
assert not any(m.role == "tool" for m in projected)
|
||||
|
||||
|
||||
async def test_tool_result_compaction_no_change_when_within_limit() -> None:
|
||||
"""No compaction when tool groups count does not exceed keep limit."""
|
||||
messages = [
|
||||
Message(role="user", text="u"),
|
||||
_assistant_function_call("call-1"),
|
||||
_tool_result("call-1", "r1"),
|
||||
]
|
||||
strategy = ToolResultCompactionStrategy(keep_last_tool_call_groups=1)
|
||||
annotate_message_groups(messages)
|
||||
|
||||
changed = await strategy(messages)
|
||||
|
||||
assert changed is False
|
||||
|
||||
|
||||
def test_tool_result_compaction_rejects_negative() -> None:
|
||||
try:
|
||||
ToolResultCompactionStrategy(keep_last_tool_call_groups=-1)
|
||||
except ValueError as exc:
|
||||
assert "must be greater than or equal to 0" in str(exc)
|
||||
else:
|
||||
raise AssertionError("Expected ValueError for negative keep_last_tool_call_groups.")
|
||||
|
||||
|
||||
async def test_tool_result_compaction_preserves_tool_results_in_summary() -> None:
|
||||
"""Summary text should include the tool results from the collapsed group."""
|
||||
messages = [
|
||||
Message(role="user", text="u"),
|
||||
Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(call_id="c1", name="get_weather", arguments="{}"),
|
||||
Content.from_function_call(call_id="c2", name="search_docs", arguments="{}"),
|
||||
],
|
||||
),
|
||||
_tool_result("c1", "sunny"),
|
||||
_tool_result("c2", "found 3 docs"),
|
||||
Message(role="assistant", text="done"),
|
||||
]
|
||||
strategy = ToolResultCompactionStrategy(keep_last_tool_call_groups=0)
|
||||
annotate_message_groups(messages)
|
||||
|
||||
await strategy(messages)
|
||||
|
||||
projected = included_messages(messages)
|
||||
summary_msgs = [m for m in projected if (m.text or "").startswith("[Tool results:")]
|
||||
assert len(summary_msgs) == 1
|
||||
assert "sunny" in summary_msgs[0].text # type: ignore[operator]
|
||||
assert "found 3 docs" in summary_msgs[0].text # type: ignore[operator]
|
||||
|
||||
|
||||
async def test_tool_result_compaction_bidirectional_tracing() -> None:
|
||||
"""Summary and originals should link to each other like SummarizationStrategy does."""
|
||||
messages = [
|
||||
Message(role="user", text="u"),
|
||||
_assistant_function_call("call-1"),
|
||||
_tool_result("call-1", "r1"),
|
||||
Message(role="assistant", text="done"),
|
||||
]
|
||||
strategy = ToolResultCompactionStrategy(keep_last_tool_call_groups=0)
|
||||
annotate_message_groups(messages)
|
||||
|
||||
await strategy(messages)
|
||||
|
||||
# Find the summary message.
|
||||
summary_msgs = [m for m in messages if _group_unknown_value(m, SUMMARY_OF_MESSAGE_IDS_KEY) is not None]
|
||||
assert len(summary_msgs) == 1
|
||||
summary = summary_msgs[0]
|
||||
summary_id = summary.message_id
|
||||
assert summary_id is not None
|
||||
|
||||
# Forward link: summary knows which messages/groups it replaces.
|
||||
assert isinstance(_group_unknown_value(summary, SUMMARY_OF_MESSAGE_IDS_KEY), list)
|
||||
assert isinstance(_group_unknown_value(summary, SUMMARY_OF_GROUP_IDS_KEY), list)
|
||||
|
||||
# Back link: excluded originals know which summary replaced them.
|
||||
for m in messages:
|
||||
if m.additional_properties.get(EXCLUDED_KEY):
|
||||
assert _group_unknown_value(m, SUMMARIZED_BY_SUMMARY_ID_KEY) == summary_id
|
||||
|
||||
# Core compaction annotations must be present on the summary message.
|
||||
assert _group_id(summary) is not None
|
||||
assert _group_kind(summary) is not None
|
||||
assert summary.additional_properties.get(EXCLUDED_KEY) is False
|
||||
|
||||
|
||||
async def test_tool_result_compaction_summary_has_full_annotations() -> None:
|
||||
"""Summary messages inserted by ToolResultCompactionStrategy must have all compaction annotations."""
|
||||
messages = [
|
||||
Message(role="user", text="u"),
|
||||
_assistant_function_call("c1"),
|
||||
_tool_result("c1", "r1"),
|
||||
Message(role="assistant", text="done"),
|
||||
]
|
||||
strategy = ToolResultCompactionStrategy(keep_last_tool_call_groups=0)
|
||||
annotate_message_groups(messages)
|
||||
|
||||
await strategy(messages)
|
||||
|
||||
summary = next(m for m in messages if (m.text or "").startswith("[Tool results:"))
|
||||
annotation = summary.additional_properties.get(GROUP_ANNOTATION_KEY)
|
||||
assert isinstance(annotation, dict)
|
||||
assert GROUP_ID_KEY in annotation
|
||||
assert GROUP_KIND_KEY in annotation
|
||||
assert GROUP_HAS_REASONING_KEY in annotation
|
||||
assert SUMMARY_OF_MESSAGE_IDS_KEY in annotation
|
||||
assert summary.additional_properties.get(EXCLUDED_KEY) is False
|
||||
|
||||
|
||||
async def test_summarization_strategy_summary_has_full_annotations() -> None:
|
||||
"""Summary messages inserted by SummarizationStrategy must have all compaction annotations."""
|
||||
messages = [
|
||||
Message(role="user", text="u1"),
|
||||
Message(role="assistant", text="a1"),
|
||||
Message(role="user", text="u2"),
|
||||
Message(role="assistant", text="a2"),
|
||||
Message(role="user", text="u3"),
|
||||
Message(role="assistant", text="a3"),
|
||||
]
|
||||
strategy = SummarizationStrategy(client=_FakeSummarizer(), target_count=2, threshold=0)
|
||||
annotate_message_groups(messages)
|
||||
|
||||
changed = await strategy(messages)
|
||||
|
||||
assert changed is True
|
||||
summary = next(m for m in messages if _group_unknown_value(m, SUMMARY_OF_MESSAGE_IDS_KEY) is not None)
|
||||
annotation = summary.additional_properties.get(GROUP_ANNOTATION_KEY)
|
||||
assert isinstance(annotation, dict)
|
||||
assert GROUP_ID_KEY in annotation
|
||||
assert GROUP_KIND_KEY in annotation
|
||||
assert GROUP_HAS_REASONING_KEY in annotation
|
||||
assert SUMMARY_OF_MESSAGE_IDS_KEY in annotation
|
||||
assert summary.additional_properties.get(EXCLUDED_KEY) is False
|
||||
|
||||
|
||||
async def test_tool_result_compaction_multiple_groups_combined() -> None:
|
||||
"""Multiple tool-call groups collapsed independently, each with its own summary.
|
||||
|
||||
Scenario: 3 tool-call groups, keep_last=1 → groups 1 and 2 each get a
|
||||
separate summary, group 3 stays verbatim.
|
||||
"""
|
||||
messages = [
|
||||
Message(role="user", text="Compare weather in London, Paris, and Tokyo"),
|
||||
# Group 1: get_weather for London
|
||||
Message(
|
||||
role="assistant",
|
||||
contents=[Content.from_function_call(call_id="c1", name="get_weather", arguments='{"city":"London"}')],
|
||||
),
|
||||
_tool_result("c1", '{"temp":12,"condition":"cloudy","wind":"NW 15km/h"}'),
|
||||
Message(role="assistant", text="London is cloudy at 12°C."),
|
||||
# Group 2: get_weather for Paris + search_hotels
|
||||
Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(call_id="c2", name="get_weather", arguments='{"city":"Paris"}'),
|
||||
Content.from_function_call(call_id="c3", name="search_hotels", arguments='{"city":"Paris"}'),
|
||||
],
|
||||
),
|
||||
_tool_result("c2", '{"temp":18,"condition":"sunny"}'),
|
||||
_tool_result("c3", "Grand Hotel (€120), Le Petit (€85)"),
|
||||
Message(role="assistant", text="Paris is sunny at 18°C. Found 2 hotels."),
|
||||
# Group 3: get_weather for Tokyo (most recent — should be kept)
|
||||
Message(
|
||||
role="assistant",
|
||||
contents=[Content.from_function_call(call_id="c4", name="get_weather", arguments='{"city":"Tokyo"}')],
|
||||
),
|
||||
_tool_result("c4", '{"temp":22,"condition":"rainy"}'),
|
||||
Message(role="assistant", text="Tokyo is rainy at 22°C."),
|
||||
]
|
||||
strategy = ToolResultCompactionStrategy(keep_last_tool_call_groups=1)
|
||||
annotate_message_groups(messages)
|
||||
|
||||
changed = await strategy(messages)
|
||||
|
||||
assert changed is True
|
||||
projected = included_messages(messages)
|
||||
summary_msgs = [m for m in projected if (m.text or "").startswith("[Tool results:")]
|
||||
|
||||
# Two summaries: one for group 1, one for group 2.
|
||||
assert len(summary_msgs) == 2
|
||||
|
||||
# Group 1 summary: London weather result.
|
||||
g1_text = summary_msgs[0].text or ""
|
||||
assert "12" in g1_text
|
||||
assert "cloudy" in g1_text
|
||||
|
||||
# Group 2 summary: Paris weather + hotel results combined.
|
||||
g2_text = summary_msgs[1].text or ""
|
||||
assert "18" in g2_text
|
||||
assert "Grand Hotel" in g2_text
|
||||
|
||||
# Group 3 (Tokyo) stays verbatim — tool role messages still present.
|
||||
verbatim_tool_msgs = [m for m in projected if m.role == "tool"]
|
||||
assert len(verbatim_tool_msgs) == 1
|
||||
assert "rainy" in (verbatim_tool_msgs[0].contents[0].result or "")
|
||||
|
||||
# All text assistant messages should still be present.
|
||||
text_msgs = [m for m in projected if m.role == "assistant" and m.text and not m.text.startswith("[Tool results:")]
|
||||
texts = [m.text for m in text_msgs]
|
||||
assert "London is cloudy at 12°C." in texts
|
||||
assert "Paris is sunny at 18°C. Found 2 hotels." in texts
|
||||
assert "Tokyo is rainy at 22°C." in texts
|
||||
|
||||
# Final projected shape: 8 messages in order.
|
||||
assert len(projected) == 8
|
||||
assert projected[0].role == "user" # original user message
|
||||
assert projected[1].text == '[Tool results: get_weather: {"temp":12,"condition":"cloudy","wind":"NW 15km/h"}]'
|
||||
assert projected[2].text == "London is cloudy at 12°C."
|
||||
expected_g2 = (
|
||||
'[Tool results: get_weather: {"temp":18,"condition":"sunny"};'
|
||||
" search_hotels: Grand Hotel (€120), Le Petit (€85)]"
|
||||
)
|
||||
assert projected[3].text == expected_g2
|
||||
assert projected[4].text == "Paris is sunny at 18°C. Found 2 hotels." # group 2 assistant text
|
||||
assert projected[5].role == "assistant" # group 3 function_call (verbatim)
|
||||
assert projected[6].role == "tool" # group 3 tool result (verbatim)
|
||||
assert projected[7].text == "Tokyo is rainy at 22°C." # group 3 assistant text
|
||||
|
||||
|
||||
# --- CompactionProvider tests ---
|
||||
|
||||
|
||||
class _MockSessionContext:
|
||||
"""Minimal mock for SessionContext used in CompactionProvider tests."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.context_messages: dict[str, list[Message]] = {}
|
||||
self.input_messages: list[Message] = []
|
||||
self._response: Any = None
|
||||
|
||||
@property
|
||||
def response(self) -> Any:
|
||||
return self._response
|
||||
|
||||
def extend_messages(self, provider: Any, messages: list[Message]) -> None:
|
||||
source_id = getattr(provider, "source_id", "unknown")
|
||||
self.context_messages.setdefault(source_id, []).extend(messages)
|
||||
|
||||
def get_messages(self) -> list[Message]:
|
||||
result: list[Message] = []
|
||||
for msgs in self.context_messages.values():
|
||||
result.extend(msgs)
|
||||
return result
|
||||
|
||||
|
||||
async def test_compaction_provider_compacts_existing_context_messages() -> None:
|
||||
"""CompactionProvider.before_run compacts messages already in context from earlier providers."""
|
||||
provider = CompactionProvider(
|
||||
before_strategy=SlidingWindowStrategy(keep_last_groups=2, preserve_system=True),
|
||||
)
|
||||
|
||||
context = _MockSessionContext()
|
||||
context.context_messages["history"] = [
|
||||
Message(role="system", text="sys"),
|
||||
Message(role="user", text="u1"),
|
||||
Message(role="assistant", text="a1"),
|
||||
Message(role="user", text="u2"),
|
||||
Message(role="assistant", text="a2"),
|
||||
Message(role="user", text="u3"),
|
||||
Message(role="assistant", text="a3"),
|
||||
]
|
||||
|
||||
await provider.before_run(agent=None, session=None, context=context, state={})
|
||||
|
||||
remaining = context.context_messages["history"]
|
||||
assert len(remaining) == 3
|
||||
assert remaining[0].role == "system"
|
||||
assert remaining[1].text == "u3"
|
||||
assert remaining[2].text == "a3"
|
||||
|
||||
|
||||
async def test_compaction_provider_noop_when_no_context_messages() -> None:
|
||||
"""before_run with no context messages does nothing."""
|
||||
provider = CompactionProvider(
|
||||
before_strategy=SlidingWindowStrategy(keep_last_groups=2),
|
||||
)
|
||||
|
||||
context = _MockSessionContext()
|
||||
await provider.before_run(agent=None, session=None, context=context, state={})
|
||||
|
||||
assert context.context_messages == {}
|
||||
|
||||
|
||||
async def test_compaction_provider_preserves_messages_from_multiple_sources() -> None:
|
||||
"""CompactionProvider correctly filters across multiple provider sources."""
|
||||
provider = CompactionProvider(
|
||||
before_strategy=SlidingWindowStrategy(keep_last_groups=2, preserve_system=True),
|
||||
)
|
||||
|
||||
context = _MockSessionContext()
|
||||
context.context_messages["history"] = [
|
||||
Message(role="system", text="sys"),
|
||||
Message(role="user", text="old_user"),
|
||||
Message(role="assistant", text="old_assistant"),
|
||||
]
|
||||
context.context_messages["rag"] = [
|
||||
Message(role="user", text="recent_rag_context"),
|
||||
Message(role="assistant", text="recent_rag_answer"),
|
||||
]
|
||||
|
||||
await provider.before_run(agent=None, session=None, context=context, state={})
|
||||
|
||||
all_remaining = context.get_messages()
|
||||
assert any(m.role == "system" for m in all_remaining)
|
||||
assert len(all_remaining) < 5
|
||||
|
||||
|
||||
class _MockSession:
|
||||
"""Minimal mock for AgentSession used in CompactionProvider after_run tests."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.state: dict[str, Any] = {}
|
||||
|
||||
|
||||
async def test_compaction_provider_after_run_compacts_stored_history() -> None:
|
||||
"""after_run annotates exclusions on stored messages without removing them."""
|
||||
provider = CompactionProvider(
|
||||
after_strategy=SelectiveToolCallCompactionStrategy(keep_last_tool_call_groups=0),
|
||||
history_source_id="in_memory_history",
|
||||
)
|
||||
|
||||
session = _MockSession()
|
||||
session.state["in_memory_history"] = {
|
||||
"messages": [
|
||||
Message(role="user", text="old question"),
|
||||
Message(role="assistant", text="old answer"),
|
||||
_assistant_function_call("c1"),
|
||||
_tool_result("c1", "result"),
|
||||
Message(role="assistant", text="final answer"),
|
||||
]
|
||||
}
|
||||
|
||||
context = _MockSessionContext()
|
||||
await provider.after_run(agent=None, session=session, context=context, state={})
|
||||
|
||||
stored = session.state["in_memory_history"]["messages"]
|
||||
# All messages are kept; tool-call group is excluded via annotation.
|
||||
assert len(stored) == 5
|
||||
excluded = [m for m in stored if m.additional_properties.get("_excluded", False)]
|
||||
assert len(excluded) == 2 # assistant function_call + tool result
|
||||
assert any(m.text == "final answer" for m in stored if not m.additional_properties.get("_excluded", False))
|
||||
|
||||
|
||||
async def test_compaction_provider_after_run_noop_without_history() -> None:
|
||||
"""after_run does nothing when there is no history state."""
|
||||
provider = CompactionProvider(
|
||||
after_strategy=SlidingWindowStrategy(keep_last_groups=2),
|
||||
history_source_id="in_memory_history",
|
||||
)
|
||||
|
||||
session = _MockSession()
|
||||
context = _MockSessionContext()
|
||||
await provider.after_run(agent=None, session=session, context=context, state={})
|
||||
|
||||
assert "in_memory_history" not in session.state
|
||||
|
||||
|
||||
async def test_compaction_provider_both_strategies() -> None:
|
||||
"""Both before_strategy and after_strategy work independently."""
|
||||
provider = CompactionProvider(
|
||||
before_strategy=SlidingWindowStrategy(keep_last_groups=2, preserve_system=True),
|
||||
after_strategy=SelectiveToolCallCompactionStrategy(keep_last_tool_call_groups=0),
|
||||
history_source_id="history",
|
||||
)
|
||||
|
||||
# before_run: compact loaded context
|
||||
context = _MockSessionContext()
|
||||
context.context_messages["history"] = [
|
||||
Message(role="system", text="sys"),
|
||||
Message(role="user", text="u1"),
|
||||
Message(role="assistant", text="a1"),
|
||||
Message(role="user", text="u2"),
|
||||
Message(role="assistant", text="a2"),
|
||||
]
|
||||
await provider.before_run(agent=None, session=None, context=context, state={})
|
||||
assert len(context.get_messages()) == 3
|
||||
|
||||
# after_run: compact stored history
|
||||
session = _MockSession()
|
||||
session.state["history"] = {
|
||||
"messages": [
|
||||
Message(role="user", text="q"),
|
||||
_assistant_function_call("c1"),
|
||||
_tool_result("c1", "ok"),
|
||||
Message(role="assistant", text="done"),
|
||||
]
|
||||
}
|
||||
await provider.after_run(agent=None, session=session, context=_MockSessionContext(), state={})
|
||||
stored = session.state["history"]["messages"]
|
||||
excluded = [m for m in stored if m.additional_properties.get("_excluded", False)]
|
||||
assert len(excluded) == 2 # tool-call group excluded
|
||||
|
||||
|
||||
async def test_compaction_provider_none_strategies_are_noop() -> None:
|
||||
"""When both strategies are None, before_run and after_run are no-ops."""
|
||||
provider = CompactionProvider()
|
||||
|
||||
context = _MockSessionContext()
|
||||
context.context_messages["history"] = [
|
||||
Message(role="user", text="hello"),
|
||||
Message(role="assistant", text="hi"),
|
||||
]
|
||||
|
||||
await provider.before_run(agent=None, session=None, context=context, state={})
|
||||
assert len(context.get_messages()) == 2
|
||||
|
||||
session = _MockSession()
|
||||
await provider.after_run(agent=None, session=session, context=context, state={})
|
||||
assert "in_memory_history" not in session.state
|
||||
|
||||
|
||||
async def test_in_memory_history_provider_skip_excluded() -> None:
|
||||
"""InMemoryHistoryProvider with skip_excluded=True omits excluded messages."""
|
||||
from agent_framework._compaction import EXCLUDED_KEY
|
||||
from agent_framework._sessions import InMemoryHistoryProvider as _InMemoryHistoryProvider
|
||||
|
||||
provider = _InMemoryHistoryProvider(skip_excluded=True)
|
||||
state: dict[str, Any] = {
|
||||
"messages": [
|
||||
Message(role="user", text="u1"),
|
||||
Message(role="assistant", text="a1", additional_properties={EXCLUDED_KEY: True}),
|
||||
Message(role="user", text="u2"),
|
||||
Message(role="assistant", text="a2"),
|
||||
]
|
||||
}
|
||||
|
||||
loaded = await provider.get_messages(session_id="test", state=state)
|
||||
assert len(loaded) == 3
|
||||
assert all(m.text != "a1" for m in loaded)
|
||||
|
||||
|
||||
async def test_in_memory_history_provider_default_loads_all() -> None:
|
||||
"""InMemoryHistoryProvider with default settings loads all messages including excluded."""
|
||||
from agent_framework._compaction import EXCLUDED_KEY
|
||||
from agent_framework._sessions import InMemoryHistoryProvider as _InMemoryHistoryProvider
|
||||
|
||||
provider = _InMemoryHistoryProvider()
|
||||
state: dict[str, Any] = {
|
||||
"messages": [
|
||||
Message(role="user", text="u1"),
|
||||
Message(role="assistant", text="a1", additional_properties={EXCLUDED_KEY: True}),
|
||||
Message(role="user", text="u2"),
|
||||
]
|
||||
}
|
||||
|
||||
loaded = await provider.get_messages(session_id="test", state=state)
|
||||
assert len(loaded) == 3
|
||||
@@ -15,9 +15,27 @@ from agent_framework import (
|
||||
SupportsChatGetResponse,
|
||||
tool,
|
||||
)
|
||||
from agent_framework._compaction import (
|
||||
EXCLUDED_KEY,
|
||||
GROUP_ANNOTATION_KEY,
|
||||
GROUP_ID_KEY,
|
||||
CharacterEstimatorTokenizer,
|
||||
SlidingWindowStrategy,
|
||||
TokenBudgetComposedStrategy,
|
||||
annotate_message_groups,
|
||||
included_token_count,
|
||||
)
|
||||
from agent_framework._middleware import FunctionInvocationContext, FunctionMiddleware, MiddlewareTermination
|
||||
|
||||
|
||||
def _group_id(message: Message) -> str | None:
|
||||
annotation = message.additional_properties.get(GROUP_ANNOTATION_KEY)
|
||||
if not isinstance(annotation, dict):
|
||||
return None
|
||||
value = annotation.get(GROUP_ID_KEY)
|
||||
return value if isinstance(value, str) else None
|
||||
|
||||
|
||||
async def test_base_client_with_function_calling(chat_client_base: SupportsChatGetResponse):
|
||||
exec_counter = 0
|
||||
|
||||
@@ -131,6 +149,127 @@ async def test_base_client_with_function_calling_resets(chat_client_base: Suppor
|
||||
assert response.messages[3].contents[0].type == "function_result"
|
||||
|
||||
|
||||
async def test_function_loop_applies_compaction_projection_each_model_call(chat_client_base: SupportsChatGetResponse):
|
||||
@tool(name="test_function", approval_mode="never_require")
|
||||
def ai_func(arg1: str) -> str:
|
||||
return f"Processed {arg1}"
|
||||
|
||||
class _ExcludeOldestGroupAfterFirstTurn:
|
||||
async def __call__(self, messages: list[Message]) -> bool:
|
||||
groups = annotate_message_groups(messages)
|
||||
if len(groups) <= 1:
|
||||
return False
|
||||
oldest_group_id = groups[0]
|
||||
changed = False
|
||||
for message in messages:
|
||||
if _group_id(message) == oldest_group_id:
|
||||
if message.additional_properties.get(EXCLUDED_KEY) is not True:
|
||||
changed = True
|
||||
message.additional_properties[EXCLUDED_KEY] = True
|
||||
return changed
|
||||
|
||||
captured_roles: list[list[str]] = []
|
||||
original = chat_client_base._get_non_streaming_response # type: ignore[attr-defined]
|
||||
|
||||
async def _capture(
|
||||
*,
|
||||
messages: list[Message],
|
||||
options: dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> ChatResponse:
|
||||
captured_roles.append([message.role for message in messages])
|
||||
return await original(messages=messages, options=options, **kwargs)
|
||||
|
||||
chat_client_base._get_non_streaming_response = _capture # type: ignore[attr-defined,method-assign]
|
||||
chat_client_base.compaction_strategy = _ExcludeOldestGroupAfterFirstTurn() # type: ignore[attr-defined]
|
||||
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(call_id="1", name="test_function", arguments='{"arg1": "value1"}')
|
||||
],
|
||||
)
|
||||
),
|
||||
ChatResponse(messages=Message(role="assistant", text="done")),
|
||||
]
|
||||
|
||||
await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [ai_func]}
|
||||
)
|
||||
|
||||
assert len(captured_roles) >= 2
|
||||
assert "user" in captured_roles[0]
|
||||
assert "user" not in captured_roles[1]
|
||||
|
||||
|
||||
async def test_function_loop_token_budget_strategy_caps_tokens_each_iteration(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
):
|
||||
exec_counter = 0
|
||||
token_budget = 500
|
||||
tokenizer = CharacterEstimatorTokenizer()
|
||||
|
||||
@tool(name="test_function", approval_mode="never_require")
|
||||
def ai_func(arg1: str) -> str:
|
||||
nonlocal exec_counter
|
||||
exec_counter += 1
|
||||
return f"Processed {arg1}. " + ("result " * 120)
|
||||
|
||||
captured_token_counts: list[int] = []
|
||||
original = chat_client_base._get_non_streaming_response # type: ignore[attr-defined]
|
||||
|
||||
async def _capture(
|
||||
*,
|
||||
messages: list[Message],
|
||||
options: dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> ChatResponse:
|
||||
annotate_message_groups(messages, force_reannotate=True, tokenizer=tokenizer)
|
||||
captured_token_counts.append(included_token_count(messages))
|
||||
return await original(messages=messages, options=options, **kwargs)
|
||||
|
||||
chat_client_base._get_non_streaming_response = _capture # type: ignore[attr-defined,method-assign]
|
||||
chat_client_base.tokenizer = tokenizer # type: ignore[attr-defined]
|
||||
chat_client_base.function_invocation_configuration["max_iterations"] = 3 # type: ignore[attr-defined]
|
||||
chat_client_base.compaction_strategy = TokenBudgetComposedStrategy( # type: ignore[attr-defined]
|
||||
token_budget=token_budget,
|
||||
tokenizer=tokenizer,
|
||||
strategies=[SlidingWindowStrategy(keep_last_groups=2)],
|
||||
)
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(call_id="1", name="test_function", arguments='{"arg1": "value1"}')
|
||||
],
|
||||
)
|
||||
),
|
||||
ChatResponse(
|
||||
messages=Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(call_id="2", name="test_function", arguments='{"arg1": "value2"}')
|
||||
],
|
||||
)
|
||||
),
|
||||
ChatResponse(messages=Message(role="assistant", text="done")),
|
||||
]
|
||||
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello " * 160)],
|
||||
options={"tool_choice": "auto", "tools": [ai_func]},
|
||||
)
|
||||
|
||||
assert response.messages[-1].text == "done"
|
||||
assert exec_counter == 2
|
||||
assert len(captured_token_counts) >= 3
|
||||
assert all(token_count > 0 for token_count in captured_token_counts)
|
||||
assert all(token_count <= token_budget for token_count in captured_token_counts)
|
||||
|
||||
|
||||
async def test_base_client_with_streaming_function_calling(chat_client_base: SupportsChatGetResponse):
|
||||
exec_counter = 0
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ from agent_framework._skills import (
|
||||
|
||||
async def _noop_script_runner(skill: Any, script: Any, args: Any = None) -> None:
|
||||
"""No-op script runner for tests that need a SkillScriptRunner."""
|
||||
return None
|
||||
return
|
||||
|
||||
|
||||
def _symlinks_supported(tmp: Path) -> bool:
|
||||
@@ -1994,7 +1994,7 @@ class TestSkillScriptRunnerProtocol:
|
||||
"""Tests for the SkillScriptRunner protocol."""
|
||||
|
||||
async def test_async_callable_satisfies_protocol(self) -> None:
|
||||
from agent_framework import SkillScriptRunner, SkillScript
|
||||
from agent_framework import SkillScript, SkillScriptRunner
|
||||
|
||||
results: list[tuple] = []
|
||||
|
||||
@@ -2015,7 +2015,7 @@ class TestSkillScriptRunnerProtocol:
|
||||
assert results[0] == ("test-skill", "my-script", {"key": "val"})
|
||||
|
||||
async def test_callable_class_satisfies_protocol(self) -> None:
|
||||
from agent_framework import SkillScriptRunner, SkillScript
|
||||
from agent_framework import SkillScript, SkillScriptRunner
|
||||
|
||||
class _CustomRunner:
|
||||
async def __call__(self, skill, script, args=None):
|
||||
@@ -2056,7 +2056,7 @@ class TestSkillScriptRunnerProtocol:
|
||||
assert result == {"exit_code": 0, "output": "ok"}
|
||||
|
||||
def test_sync_callable_satisfies_protocol(self) -> None:
|
||||
from agent_framework import SkillScriptRunner, SkillScript
|
||||
from agent_framework import SkillScript, SkillScriptRunner
|
||||
|
||||
results: list[tuple] = []
|
||||
|
||||
@@ -2077,7 +2077,7 @@ class TestSkillScriptRunnerProtocol:
|
||||
assert results[0] == ("test-skill", "my-script", {"key": "val"})
|
||||
|
||||
def test_sync_callable_class_satisfies_protocol(self) -> None:
|
||||
from agent_framework import SkillScriptRunner, SkillScript
|
||||
from agent_framework import SkillScript, SkillScriptRunner
|
||||
|
||||
class _SyncRunner:
|
||||
def __call__(self, skill, script, args=None):
|
||||
@@ -2117,6 +2117,7 @@ class TestSkillScriptRunnerProtocol:
|
||||
result = dict_runner(skill, script)
|
||||
assert result == {"exit_code": 0, "output": "ok"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SkillsProvider static factory tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -28,6 +28,12 @@ from agent_framework import (
|
||||
merge_chat_options,
|
||||
tool,
|
||||
)
|
||||
from agent_framework._compaction import (
|
||||
GROUP_ANNOTATION_KEY,
|
||||
GROUP_HAS_REASONING_KEY,
|
||||
GROUP_ID_KEY,
|
||||
GROUP_TOKEN_COUNT_KEY,
|
||||
)
|
||||
from agent_framework._types import (
|
||||
_get_data_bytes,
|
||||
_get_data_bytes_as_str,
|
||||
@@ -1654,6 +1660,78 @@ def test_chat_message_complex_content_serialization():
|
||||
assert reconstructed.contents[2].type == "function_result"
|
||||
|
||||
|
||||
def test_message_roundtrip_preserves_compaction_annotation_dict() -> None:
|
||||
message = Message(
|
||||
role="assistant",
|
||||
contents=[Content.from_text("Hello")],
|
||||
additional_properties={
|
||||
GROUP_ANNOTATION_KEY: {
|
||||
"id": "group_1",
|
||||
"kind": "assistant_text",
|
||||
"index": 1,
|
||||
"has_reasoning": False,
|
||||
"token_count": 42,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
restored = Message.from_dict(message.to_dict())
|
||||
annotation = restored.additional_properties.get(GROUP_ANNOTATION_KEY)
|
||||
|
||||
assert isinstance(annotation, dict)
|
||||
assert annotation[GROUP_ID_KEY] == "group_1"
|
||||
assert annotation[GROUP_TOKEN_COUNT_KEY] == 42
|
||||
|
||||
|
||||
def test_content_roundtrip_preserves_compaction_annotation_dict() -> None:
|
||||
content = Content.from_text(
|
||||
text="Hello",
|
||||
additional_properties={
|
||||
GROUP_ANNOTATION_KEY: {
|
||||
"id": "group_2",
|
||||
"kind": "assistant_text",
|
||||
"index": 2,
|
||||
"has_reasoning": False,
|
||||
"token_count": None,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
restored = Content.from_dict(content.to_dict())
|
||||
annotation = restored.additional_properties.get(GROUP_ANNOTATION_KEY)
|
||||
|
||||
assert isinstance(annotation, dict)
|
||||
assert annotation[GROUP_ID_KEY] == "group_2"
|
||||
assert annotation[GROUP_TOKEN_COUNT_KEY] is None
|
||||
|
||||
|
||||
def test_chat_response_roundtrip_preserves_compaction_annotation_dict() -> None:
|
||||
response = ChatResponse(
|
||||
messages=[
|
||||
Message(
|
||||
role="assistant",
|
||||
contents=[Content.from_text("Hello")],
|
||||
additional_properties={
|
||||
GROUP_ANNOTATION_KEY: {
|
||||
"id": "group_3",
|
||||
"kind": "assistant_text",
|
||||
"index": 3,
|
||||
"has_reasoning": True,
|
||||
"token_count": 15,
|
||||
}
|
||||
},
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
restored = ChatResponse.from_dict(response.to_dict())
|
||||
annotation = restored.messages[0].additional_properties.get(GROUP_ANNOTATION_KEY)
|
||||
|
||||
assert isinstance(annotation, dict)
|
||||
assert annotation[GROUP_ID_KEY] == "group_3"
|
||||
assert annotation[GROUP_HAS_REASONING_KEY] is True
|
||||
|
||||
|
||||
def test_usage_content_serialization_with_details():
|
||||
"""Test UsageContent from_dict and to_dict with UsageDetails conversion."""
|
||||
|
||||
|
||||
@@ -524,6 +524,58 @@ def test_response_content_creation_with_reasoning() -> None:
|
||||
assert response.messages[0].contents[0].text == "Reasoning step"
|
||||
|
||||
|
||||
def test_response_content_keeps_reasoning_and_function_calls_in_one_message() -> None:
|
||||
"""Reasoning + function calls should parse into one assistant message."""
|
||||
client = OpenAIResponsesClient(model_id="test-model", api_key="test-key")
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.output_parsed = None
|
||||
mock_response.metadata = {}
|
||||
mock_response.usage = None
|
||||
mock_response.id = "test-id"
|
||||
mock_response.model = "test-model"
|
||||
mock_response.created_at = 1000000000
|
||||
|
||||
mock_reasoning_content = MagicMock()
|
||||
mock_reasoning_content.text = "Reasoning step"
|
||||
|
||||
mock_reasoning_item = MagicMock()
|
||||
mock_reasoning_item.type = "reasoning"
|
||||
mock_reasoning_item.id = "rs_123"
|
||||
mock_reasoning_item.content = [mock_reasoning_content]
|
||||
mock_reasoning_item.summary = []
|
||||
|
||||
mock_function_call_item_1 = MagicMock()
|
||||
mock_function_call_item_1.type = "function_call"
|
||||
mock_function_call_item_1.id = "fc_1"
|
||||
mock_function_call_item_1.call_id = "call_1"
|
||||
mock_function_call_item_1.name = "tool_1"
|
||||
mock_function_call_item_1.arguments = '{"x": 1}'
|
||||
|
||||
mock_function_call_item_2 = MagicMock()
|
||||
mock_function_call_item_2.type = "function_call"
|
||||
mock_function_call_item_2.id = "fc_2"
|
||||
mock_function_call_item_2.call_id = "call_2"
|
||||
mock_function_call_item_2.name = "tool_2"
|
||||
mock_function_call_item_2.arguments = '{"y": 2}'
|
||||
|
||||
mock_response.output = [
|
||||
mock_reasoning_item,
|
||||
mock_function_call_item_1,
|
||||
mock_function_call_item_2,
|
||||
]
|
||||
|
||||
response = client._parse_response_from_openai(mock_response, options={}) # type: ignore
|
||||
|
||||
assert len(response.messages) == 1
|
||||
assert response.messages[0].role == "assistant"
|
||||
assert [content.type for content in response.messages[0].contents] == [
|
||||
"text_reasoning",
|
||||
"function_call",
|
||||
"function_call",
|
||||
]
|
||||
|
||||
|
||||
def test_response_content_creation_with_code_interpreter() -> None:
|
||||
"""Test _parse_response_from_openai with code interpreter outputs."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user