mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Address PR review feedback: clarify URL-decode comment, isolate test root, add e2e workflow rejection tests
Agent-Logs-Url: https://github.com/microsoft/agent-framework/sessions/832f45a6-c01e-4da9-bf85-1ba7b5f302e6 Co-authored-by: lokitoth <6936551+lokitoth@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
9e18c60745
commit
cafcc53483
@@ -221,11 +221,14 @@ def _checkpoint_storage_for_context(root: str, context_id: str) -> FileCheckpoin
|
||||
if not isinstance(context_id, str) or not context_id:
|
||||
raise RuntimeError("Invalid checkpoint context id: must be a non-empty string.")
|
||||
# Reject any segment that is not a single safe path component. This covers
|
||||
# POSIX/Windows separators, NUL bytes, drive letters, all-dot segments
|
||||
# (``.``, ``..``, ``...``, ...), and embedded URL-encoded forms once
|
||||
# decoded by the framework. We deliberately do not attempt to "sanitize"
|
||||
# by stripping characters because that can introduce collisions between
|
||||
# distinct ids.
|
||||
# POSIX/Windows separators, NUL bytes, drive letters, and all-dot segments
|
||||
# (``.``, ``..``, ``...``, ...). We deliberately do not URL-decode the id
|
||||
# here: the hosting layer never decodes context ids before joining them, so
|
||||
# forms such as ``%2e%2e`` are accepted as literal directory names. Do NOT
|
||||
# add decoding here without re-validating after the decode -- decode-then-
|
||||
# join is exactly the pattern that reintroduces traversal. We also do not
|
||||
# attempt to "sanitize" by stripping characters because that can introduce
|
||||
# collisions between distinct ids.
|
||||
if (
|
||||
"/" in context_id
|
||||
or "\\" in context_id
|
||||
|
||||
@@ -2679,9 +2679,11 @@ class TestCheckpointContextPathValidation:
|
||||
|
||||
def test_valid_segment_creates_storage_under_root(self, tmp_path: Any) -> None:
|
||||
helper = self._helper()
|
||||
storage = helper(str(tmp_path), "resp_abc123")
|
||||
root = tmp_path / "root"
|
||||
root.mkdir()
|
||||
storage = helper(str(root), "resp_abc123")
|
||||
assert storage.storage_path.is_dir()
|
||||
assert storage.storage_path.parent == tmp_path.resolve()
|
||||
assert storage.storage_path.parent == root.resolve()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"bad_id",
|
||||
@@ -2705,14 +2707,20 @@ class TestCheckpointContextPathValidation:
|
||||
)
|
||||
def test_traversal_and_separator_payloads_are_rejected(self, tmp_path: Any, bad_id: str) -> None:
|
||||
helper = self._helper()
|
||||
before = sorted(p.name for p in tmp_path.parent.iterdir())
|
||||
# Use a dedicated root *inside* tmp_path so we can assert that nothing
|
||||
# was created anywhere under tmp_path (root, siblings, or above).
|
||||
# Asserting against tmp_path.parent would be flaky under parallel test
|
||||
# execution because tmp_path.parent is shared across tests.
|
||||
root = tmp_path / "root"
|
||||
root.mkdir()
|
||||
before = sorted(p.name for p in tmp_path.iterdir())
|
||||
with pytest.raises(RuntimeError):
|
||||
helper(str(tmp_path), bad_id)
|
||||
# Ensure no escape directory was created adjacent to (or above) the root.
|
||||
after = sorted(p.name for p in tmp_path.parent.iterdir())
|
||||
helper(str(root), bad_id)
|
||||
# No sibling/escape directory should have been created next to the root.
|
||||
after = sorted(p.name for p in tmp_path.iterdir())
|
||||
assert before == after, f"Unexpected filesystem artifacts created for payload {bad_id!r}"
|
||||
# And nothing inside the root either.
|
||||
assert list(tmp_path.iterdir()) == []
|
||||
assert list(root.iterdir()) == []
|
||||
|
||||
def test_non_string_context_id_is_rejected(self, tmp_path: Any) -> None:
|
||||
helper = self._helper()
|
||||
@@ -2726,9 +2734,85 @@ class TestCheckpointContextPathValidation:
|
||||
should accept ``%2e%2e`` as a single literal segment (no escape).
|
||||
"""
|
||||
helper = self._helper()
|
||||
storage = helper(str(tmp_path), "%2e%2e")
|
||||
assert storage.storage_path.parent == tmp_path.resolve()
|
||||
root = tmp_path / "root"
|
||||
root.mkdir()
|
||||
storage = helper(str(root), "%2e%2e")
|
||||
assert storage.storage_path.parent == root.resolve()
|
||||
assert storage.storage_path.name == "%2e%2e"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"context_field,bad_id",
|
||||
[
|
||||
# Restore sink: caller-controlled previous_response_id.
|
||||
("previous_response_id", "../../escape"),
|
||||
("previous_response_id", "/tmp/escape-abs"),
|
||||
("previous_response_id", "caresp_x/../../service-data/api-made-dir" + "A" * 14),
|
||||
# Restore sink: server-issued conversation_id (defense in depth).
|
||||
("conversation_id", "../../escape"),
|
||||
# Write sink: malicious response_id (defense in depth).
|
||||
("response_id", "../../escape"),
|
||||
],
|
||||
)
|
||||
async def test_handle_inner_workflow_rejects_malicious_context_id(
|
||||
self, tmp_path: Any, context_field: str, bad_id: str
|
||||
) -> None:
|
||||
"""End-to-end: ``_handle_inner_workflow`` must reject malicious ids on
|
||||
both the restore sink (``previous_response_id`` / ``conversation_id``)
|
||||
and the write sink (``response_id``) without creating any directories.
|
||||
"""
|
||||
from unittest.mock import patch
|
||||
|
||||
from agent_framework import WorkflowAgent
|
||||
from azure.ai.agentserver.responses import ResponseContext
|
||||
from azure.ai.agentserver.responses.models import CreateResponse
|
||||
|
||||
# Build a mock that satisfies isinstance(agent, WorkflowAgent) and the
|
||||
# constructor's "no existing checkpointing" guard.
|
||||
agent = MagicMock(spec=WorkflowAgent)
|
||||
agent.id = "wf-agent"
|
||||
agent.name = "wf"
|
||||
agent.description = ""
|
||||
agent.context_providers = []
|
||||
agent.workflow = MagicMock()
|
||||
agent.workflow.name = "wf"
|
||||
agent.workflow._runner_context.has_checkpointing = MagicMock(return_value=False)
|
||||
|
||||
# Constructor inspects WorkflowAgent.workflow internals; bypass setup
|
||||
# by feeding a configured mock through a normal init.
|
||||
server = ResponsesHostServer(agent, store=InMemoryResponseProvider())
|
||||
# Re-root checkpoint storage at our isolated tmp_path so we can detect
|
||||
# any escape attempt on the filesystem.
|
||||
root = tmp_path / "root"
|
||||
root.mkdir()
|
||||
server._checkpoint_storage_path = str(root) # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# Build a ResponseContext with the malicious id targeting the chosen sink.
|
||||
kwargs: dict[str, Any] = {
|
||||
"response_id": "resp_" + "a" * 48,
|
||||
"mode_flags": MagicMock(),
|
||||
}
|
||||
if context_field == "previous_response_id":
|
||||
request = CreateResponse(model="m", input="hi", previous_response_id=bad_id)
|
||||
kwargs["previous_response_id"] = bad_id
|
||||
elif context_field == "conversation_id":
|
||||
request = CreateResponse(model="m", input="hi")
|
||||
kwargs["conversation_id"] = bad_id
|
||||
else: # response_id (write sink)
|
||||
request = CreateResponse(model="m", input="hi")
|
||||
kwargs["response_id"] = bad_id
|
||||
|
||||
# Avoid invoking the real input-resolution machinery, which would need
|
||||
# a configured provider; we never reach the workflow run on rejection.
|
||||
with patch.object(ResponseContext, "get_input_items", new=AsyncMock(return_value=[])):
|
||||
context = ResponseContext(**kwargs)
|
||||
before = sorted(p.name for p in tmp_path.iterdir())
|
||||
with pytest.raises(RuntimeError, match="Invalid checkpoint context id"):
|
||||
async for _ in server._handle_inner_workflow(request, context): # pyright: ignore[reportPrivateUsage]
|
||||
pass
|
||||
after = sorted(p.name for p in tmp_path.iterdir())
|
||||
|
||||
assert before == after, f"Unexpected filesystem artifacts created for {context_field}={bad_id!r}"
|
||||
assert list(root.iterdir()) == [], f"Checkpoint dir created inside root for {context_field}={bad_id!r}"
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
Reference in New Issue
Block a user