From c99c63810aed231023685fa950e17cede6c11524 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Thu, 16 Apr 2026 13:54:15 -0700 Subject: [PATCH] Address comments --- .../_responses.py | 105 ++++++++++++++---- 1 file changed, 85 insertions(+), 20 deletions(-) diff --git a/python/packages/foundry_hosting/agent_framework_foundry_hosting/_responses.py b/python/packages/foundry_hosting/agent_framework_foundry_hosting/_responses.py index 4826e1eed6..d41e70d94e 100644 --- a/python/packages/foundry_hosting/agent_framework_foundry_hosting/_responses.py +++ b/python/packages/foundry_hosting/agent_framework_foundry_hosting/_responses.py @@ -11,6 +11,7 @@ from collections.abc import AsyncIterable, AsyncIterator, Generator, Mapping from agent_framework import ( ChatOptions, Content, + ContextProvider, FileCheckpointStorage, HistoryProvider, Message, @@ -106,22 +107,27 @@ class ResponsesHostServer(ResponsesAgentServerHost): "There shouldn't be a history provider with `load_messages=True` already present. " "History is managed by the hosting infrastructure." ) + provider = cast(ContextProvider, provider) + logger.warning( + "Context provider %s is present. If it maintains context in memory, " + "the context may be lost between requests. Use with caution.", + provider.source_id, + ) self._is_workflow_agent = False - self._checkpoint_storage = None + self._checkpoint_storage_path = None if isinstance(agent, WorkflowAgent): if agent.workflow._runner_context.has_checkpointing(): # pyright: ignore[reportPrivateUsage] raise RuntimeError( "There should not be a checkpoint storage already present in the workflow agent. " "The hosting infrastructure will manage checkpoints instead." ) - checkpoint_storage_path = ( + self._checkpoint_storage_path = ( self.CHECKPOINT_STORAGE_PATH if self.config.is_hosted else os.path.join(os.getcwd(), self.CHECKPOINT_STORAGE_PATH.lstrip("/")) ) self._is_workflow_agent = True - self._checkpoint_storage = FileCheckpointStorage(checkpoint_storage_path) self._agent = agent self.response_handler(self._handler) # pyright: ignore[reportUnknownMemberType] @@ -129,6 +135,11 @@ class ResponsesHostServer(ResponsesAgentServerHost): # Append the user agent prefix for telemetry purposes append_to_user_agent(self.USER_AGENT_PREFIX) + @staticmethod + def _is_streaming_request(request: CreateResponse) -> bool: + """Check if the request is a streaming request.""" + return request.stream is not None and request.stream is True + async def _handler( self, request: CreateResponse, @@ -146,19 +157,22 @@ class ResponsesHostServer(ResponsesAgentServerHost): history = await context.get_history() messages = [*_to_messages(history), input_text] - chat_options = _to_chat_options(request) + chat_options, are_options_set = _to_chat_options(request) + is_streaming_request = self._is_streaming_request(request) response_event_stream = ResponseEventStream(response_id=context.response_id, model=request.model) yield response_event_stream.emit_created() yield response_event_stream.emit_in_progress() - if request.stream is None or request.stream is False: + if not is_streaming_request: # Run the agent in non-streaming mode if isinstance(self._agent, RawAgent): raw_agent = cast("RawAgent[Any]", self._agent) # pyright: ignore[reportUnknownMemberType] response = await raw_agent.run(messages, stream=False, options=chat_options) else: + if are_options_set: + logger.warning("Agent doesn't support runtime options. They will be ignored.") response = await self._agent.run(messages, stream=False) for message in response.messages: @@ -174,6 +188,8 @@ class ResponsesHostServer(ResponsesAgentServerHost): raw_agent = cast("RawAgent[Any]", self._agent) # pyright: ignore[reportUnknownMemberType] response_stream = raw_agent.run(messages, stream=True, options=chat_options) else: + if are_options_set: + logger.warning("Agent doesn't support runtime options. They will be ignored.") response_stream = self._agent.run(messages, stream=True) # Track the current active output item builder for streaming; @@ -208,44 +224,71 @@ class ResponsesHostServer(ResponsesAgentServerHost): by the hosting infrastructure or files will be preserved upon deactivation. """ input_text = await context.get_input_text() - stream = request.stream is not None and request.stream is True + is_streaming_request = self._is_streaming_request(request) + + _, are_options_set = _to_chat_options(request) + if are_options_set: + logger.warning("Workflow agent doesn't support runtime options. They will be ignored.") + + if request.previous_response_id is not None and context.conversation_id is not None: + raise RuntimeError("Previous response ID cannot be used in conjunction with conversation ID.") + context_id = request.previous_response_id or context.conversation_id # The following should never happen due to the checks above. # This is for type safety and defensive programming. - if self._checkpoint_storage is None: - raise RuntimeError("Checkpoint storage is not available for workflow agent.") + if self._checkpoint_storage_path is None: + raise RuntimeError("Checkpoint storage path is not configured for workflow agent.") if not isinstance(self._agent, WorkflowAgent): raise RuntimeError("Agent is not a workflow agent.") # Restore from the latest checkpoint if available, otherwise start with an empty history - latest_checkpoint = await self._checkpoint_storage.get_latest(workflow_name=self._agent.workflow.name) - if latest_checkpoint is not None: - _ = await self._agent.run( - stream=stream, - checkpoint_id=latest_checkpoint.checkpoint_id, - checkpoint_storage=self._checkpoint_storage, - ) + if context_id is not None: + checkpoint_storage = FileCheckpointStorage(os.path.join(self._checkpoint_storage_path, context_id)) + latest_checkpoint = await checkpoint_storage.get_latest(workflow_name=self._agent.workflow.name) + if latest_checkpoint is not None: + if not is_streaming_request: + _ = await self._agent.run( + stream=False, + checkpoint_id=latest_checkpoint.checkpoint_id, + checkpoint_storage=checkpoint_storage, + ) + else: + # Consume the streaming or the invocation will result in a no-op + async for _ in self._agent.run( + stream=True, + checkpoint_id=latest_checkpoint.checkpoint_id, + checkpoint_storage=checkpoint_storage, + ): + pass # Now run the agent with the latest input response_event_stream = ResponseEventStream(response_id=context.response_id, model=request.model) + # Create a new checkpoint storage for this response based on the following rules: + # - If no previous response ID or conversation ID is provided, create a new checkpoint storage for this response + # - If a previous response ID is provided, create a new checkpoint storage for this response + # - If a conversation ID is provided, reuse the existing checkpoint storage for the conversation + context_id = context.conversation_id or context.response_id + checkpoint_storage = FileCheckpointStorage(os.path.join(self._checkpoint_storage_path, context_id)) + yield response_event_stream.emit_created() yield response_event_stream.emit_in_progress() - if not stream: + if not is_streaming_request: # Run the agent in non-streaming mode - response = await self._agent.run(input_text, stream=False, checkpoint_storage=self._checkpoint_storage) + response = await self._agent.run(input_text, stream=False, checkpoint_storage=checkpoint_storage) for message in response.messages: for content in message.contents: async for item in _to_outputs(response_event_stream, content): yield item + await self._delete_not_latest_checkpoints(checkpoint_storage, self._agent.workflow.name) yield response_event_stream.emit_completed() return # Run the agent in streaming mode - response_stream = self._agent.run(input_text, stream=True, checkpoint_storage=self._checkpoint_storage) + response_stream = self._agent.run(input_text, stream=True, checkpoint_storage=checkpoint_storage) # Track the current active output item builder for streaming; # lazily created on matching content, closed when a different type arrives. @@ -264,7 +307,22 @@ class ResponsesHostServer(ResponsesAgentServerHost): for event in tracker.close(): yield event + await self._delete_not_latest_checkpoints(checkpoint_storage, self._agent.workflow.name) yield response_event_stream.emit_completed() + return + + @staticmethod + async def _delete_not_latest_checkpoints(checkpoint_storage: FileCheckpointStorage, workflow_name: str): + """Delete all checkpoints except the latest one. + + We only need the last checkpoint for each invocation. + """ + latest_checkpoint = await checkpoint_storage.get_latest(workflow_name=workflow_name) + if latest_checkpoint is not None: + all_checkpoints = await checkpoint_storage.list_checkpoints(workflow_name=workflow_name) + for checkpoint in all_checkpoints: + if checkpoint.checkpoint_id != latest_checkpoint.checkpoint_id: + await checkpoint_storage.delete(checkpoint.checkpoint_id) # region Active Builder State @@ -418,7 +476,7 @@ class _OutputItemTracker: # region Option Conversion -def _to_chat_options(request: CreateResponse) -> ChatOptions: +def _to_chat_options(request: CreateResponse) -> tuple[ChatOptions, bool]: """Converts a CreateResponse request to ChatOptions. Args: @@ -426,19 +484,26 @@ def _to_chat_options(request: CreateResponse) -> ChatOptions: Returns: ChatOptions: The converted ChatOptions. + bool: Whether any options were set. + """ chat_options = ChatOptions() + are_options_set = False if request.temperature is not None: chat_options["temperature"] = request.temperature + are_options_set = True if request.top_p is not None: chat_options["top_p"] = request.top_p + are_options_set = True if request.max_output_tokens is not None: chat_options["max_tokens"] = request.max_output_tokens + are_options_set = True if request.parallel_tool_calls is not None: chat_options["allow_multiple_tool_calls"] = request.parallel_tool_calls + are_options_set = True - return chat_options + return chat_options, are_options_set # endregion