mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Merge branch 'main' into local-branch-python-enable-observability-by-default
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
// Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics.CodeAnalysis;
|
||||
using System.Threading.Tasks;
|
||||
@@ -140,7 +141,15 @@ public class MagenticWorkflowBuilder(AIAgent managerAgent)
|
||||
}
|
||||
|
||||
/// <inheritdoc cref="WorkflowBuilder.Build"/>
|
||||
public Workflow Build() => this.ReduceToWorkflowBuilder().Build();
|
||||
public Workflow Build()
|
||||
{
|
||||
if (this._team.Count == 0)
|
||||
{
|
||||
throw new InvalidOperationException("At least one participant must be added via AddParticipants() before building the workflow.");
|
||||
}
|
||||
|
||||
return this.ReduceToWorkflowBuilder().Build();
|
||||
}
|
||||
|
||||
private TaskLimits Limits => new(
|
||||
MaxRoundCount: this._maxRounds,
|
||||
|
||||
+24
-9
@@ -101,6 +101,7 @@ internal class MagenticOrchestrator(AIAgent managerAgent, List<AIAgent> team, Ta
|
||||
return base.ConfigureProtocol(protocolBuilder)
|
||||
.SendsMessage<ChatMessage>()
|
||||
.SendsMessage<ResetChatSignal>()
|
||||
.YieldsOutput<List<ChatMessage>>()
|
||||
.ConfigureRoutes(ConfigureRoutes);
|
||||
|
||||
void ConfigureRoutes(RouteBuilder routeBuilder) => routeBuilder.AddPortHandler<MagenticPlanReviewRequest, MagenticPlanReviewResponse>(
|
||||
@@ -109,7 +110,7 @@ internal class MagenticOrchestrator(AIAgent managerAgent, List<AIAgent> team, Ta
|
||||
out this._planReviewPort);
|
||||
}
|
||||
|
||||
private ValueTask SubmitPlanReviewRequestAsync(MagenticTaskContext taskContext, IWorkflowContext workflowContext)
|
||||
private ValueTask SubmitPlanReviewRequestAsync(MagenticTaskContext taskContext, IWorkflowContext workflowContext, bool replanAfterStall = false)
|
||||
{
|
||||
MagenticProgressLedger? progressLedger = taskContext.ProgressLedger;
|
||||
if (progressLedger?.IsStarted is not true)
|
||||
@@ -117,7 +118,7 @@ internal class MagenticOrchestrator(AIAgent managerAgent, List<AIAgent> team, Ta
|
||||
progressLedger = null;
|
||||
}
|
||||
|
||||
MagenticPlanReviewRequest request = new(taskContext.TaskLedger!.CurrentPlan, progressLedger, taskContext.IsStalled);
|
||||
MagenticPlanReviewRequest request = new(taskContext.TaskLedger!.CurrentPlan, progressLedger, replanAfterStall);
|
||||
|
||||
return this._planReviewPort!.PostRequestAsync(request);
|
||||
}
|
||||
@@ -146,7 +147,7 @@ internal class MagenticOrchestrator(AIAgent managerAgent, List<AIAgent> team, Ta
|
||||
|
||||
if (this._taskContext.IsTerminated)
|
||||
{
|
||||
throw new InvalidOperationException("Magentic Orchestration has already been terminated and cannot process new messages. Please start a new session.");
|
||||
throw new InvalidOperationException("This Magentic orchestration has already terminated. To process new messages, create a new workflow instance.");
|
||||
}
|
||||
|
||||
if (response.IsApproved)
|
||||
@@ -161,7 +162,7 @@ internal class MagenticOrchestrator(AIAgent managerAgent, List<AIAgent> team, Ta
|
||||
}
|
||||
}
|
||||
|
||||
private async ValueTask UpdatePlanAndDelegateAsync(MagenticTaskContext taskContext, IWorkflowContext context, CancellationToken cancellationToken)
|
||||
private async ValueTask UpdatePlanAndDelegateAsync(MagenticTaskContext taskContext, IWorkflowContext context, CancellationToken cancellationToken, bool replanAfterStall = false)
|
||||
{
|
||||
bool isReplan = taskContext.TaskLedger != null;
|
||||
|
||||
@@ -177,7 +178,7 @@ internal class MagenticOrchestrator(AIAgent managerAgent, List<AIAgent> team, Ta
|
||||
|
||||
if (requirePlanSignoff)
|
||||
{
|
||||
await this.SubmitPlanReviewRequestAsync(taskContext, context).ConfigureAwait(false);
|
||||
await this.SubmitPlanReviewRequestAsync(taskContext, context, replanAfterStall).ConfigureAwait(false);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -187,9 +188,22 @@ internal class MagenticOrchestrator(AIAgent managerAgent, List<AIAgent> team, Ta
|
||||
|
||||
protected override async ValueTask TakeTurnAsync(List<ChatMessage> messages, IWorkflowContext context, bool? emitEvents, CancellationToken cancellationToken = default)
|
||||
{
|
||||
// First Turn: Initialize the task context and send the initial messages to the planner agent
|
||||
this._taskContext ??= new(messages, team, limits, emitEvents, []);
|
||||
await this.UpdatePlanAndDelegateAsync(this._taskContext, context, cancellationToken).ConfigureAwait(false);
|
||||
if (this._taskContext?.IsTerminated == true)
|
||||
{
|
||||
throw new InvalidOperationException("This Magentic orchestration has already terminated. To process new messages, create a new workflow instance.");
|
||||
}
|
||||
|
||||
if (this._taskContext == null)
|
||||
{
|
||||
// First Turn: Initialize the task context and create the initial plan
|
||||
this._taskContext = new(messages, team, limits, emitEvents, []);
|
||||
await this.UpdatePlanAndDelegateAsync(this._taskContext, context, cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Subsequent turns: agent returned control, go directly to coordination (progress ledger only, no replan)
|
||||
await this.RunCoordinationRoundAsync(this._taskContext, context, cancellationToken).ConfigureAwait(false);
|
||||
}
|
||||
}
|
||||
|
||||
private ChatMessage? _fullTaskLedgerMessage;
|
||||
@@ -288,10 +302,11 @@ internal class MagenticOrchestrator(AIAgent managerAgent, List<AIAgent> team, Ta
|
||||
|
||||
private async ValueTask ResetAndReplanAsync(MagenticTaskContext taskContext, IWorkflowContext context, CancellationToken cancellationToken)
|
||||
{
|
||||
bool wasStalled = taskContext.IsStalled;
|
||||
taskContext.Reset();
|
||||
await context.SendMessageAsync(new ResetChatSignal(), cancellationToken: cancellationToken).ConfigureAwait(false);
|
||||
|
||||
await this.UpdatePlanAndDelegateAsync(taskContext, context, cancellationToken).ConfigureAwait(false);
|
||||
await this.UpdatePlanAndDelegateAsync(taskContext, context, cancellationToken, replanAfterStall: wasStalled).ConfigureAwait(false);
|
||||
}
|
||||
|
||||
private async ValueTask PrepareFinalAnswerAsync(MagenticTaskContext taskContext, IWorkflowContext context, CancellationToken cancellationToken)
|
||||
|
||||
+1
-1
@@ -58,7 +58,7 @@ internal class MagenticTaskContext(List<ChatMessage> taskDefinition, List<AIAgen
|
||||
|
||||
public bool IsTerminated { get; internal set; }
|
||||
|
||||
public bool IsStalled => this.TaskCounters.StallCount >= this.TaskLimits.MaxStallCount;
|
||||
public bool IsStalled => this.TaskCounters.StallCount > this.TaskLimits.MaxStallCount;
|
||||
|
||||
public (bool HitRoundLimit, bool HitResetLimit) CheckLimits()
|
||||
{
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -133,31 +133,31 @@ public sealed class ObservabilityTests : IDisposable
|
||||
activityEvents.Should().Contain(e => e.Name == EventNames.WorkflowCompleted, "activity should have workflow completed event");
|
||||
}
|
||||
|
||||
[Fact(Skip = "Flaky test - temporarily disabled.")]
|
||||
[Fact]
|
||||
public async Task CreatesWorkflowEndToEndActivities_WithCorrectName_DefaultAsync()
|
||||
{
|
||||
await this.TestWorkflowEndToEndActivitiesAsync("Default");
|
||||
}
|
||||
|
||||
[Fact(Skip = "Flaky test - temporarily disabled.")]
|
||||
[Fact]
|
||||
public async Task CreatesWorkflowEndToEndActivities_WithCorrectName_OffThreadAsync()
|
||||
{
|
||||
await this.TestWorkflowEndToEndActivitiesAsync("OffThread");
|
||||
}
|
||||
|
||||
[Fact(Skip = "Flaky test - temporarily disabled.")]
|
||||
[Fact]
|
||||
public async Task CreatesWorkflowEndToEndActivities_WithCorrectName_ConcurrentAsync()
|
||||
{
|
||||
await this.TestWorkflowEndToEndActivitiesAsync("Concurrent");
|
||||
}
|
||||
|
||||
[Fact(Skip = "Flaky test - temporarily disabled.")]
|
||||
[Fact]
|
||||
public async Task CreatesWorkflowEndToEndActivities_WithCorrectName_LockstepAsync()
|
||||
{
|
||||
await this.TestWorkflowEndToEndActivitiesAsync("Lockstep");
|
||||
}
|
||||
|
||||
[Fact(Skip = "Flaky test - temporarily disabled.")]
|
||||
[Fact]
|
||||
public async Task CreatesWorkflowActivities_WithCorrectNameAsync()
|
||||
{
|
||||
// Arrange
|
||||
@@ -182,7 +182,7 @@ public sealed class ObservabilityTests : IDisposable
|
||||
tags.Should().ContainKey(Tags.WorkflowDefinition);
|
||||
}
|
||||
|
||||
[Fact(Skip = "Flaky test - temporarily disabled.")]
|
||||
[Fact]
|
||||
public async Task TelemetryDisabledByDefault_CreatesNoActivitiesAsync()
|
||||
{
|
||||
// Arrange
|
||||
@@ -200,7 +200,7 @@ public sealed class ObservabilityTests : IDisposable
|
||||
capturedActivities.Should().BeEmpty("No activities should be created when telemetry is disabled (default).");
|
||||
}
|
||||
|
||||
[Fact(Skip = "Flaky test - temporarily disabled.")]
|
||||
[Fact]
|
||||
public async Task WithOpenTelemetry_UsesProvidedActivitySourceAsync()
|
||||
{
|
||||
// Arrange
|
||||
@@ -235,7 +235,7 @@ public sealed class ObservabilityTests : IDisposable
|
||||
"All activities should come from the user-provided ActivitySource.");
|
||||
}
|
||||
|
||||
[Fact(Skip = "Flaky test - temporarily disabled.")]
|
||||
[Fact]
|
||||
public async Task DisableWorkflowBuild_PreventsWorkflowBuildActivityAsync()
|
||||
{
|
||||
// Arrange
|
||||
@@ -255,7 +255,7 @@ public sealed class ObservabilityTests : IDisposable
|
||||
"WorkflowBuild activity should be disabled.");
|
||||
}
|
||||
|
||||
[Fact(Skip = "Flaky test - temporarily disabled.")]
|
||||
[Fact]
|
||||
public async Task DisableWorkflowRun_PreventsWorkflowRunActivityAsync()
|
||||
{
|
||||
// Arrange
|
||||
@@ -285,7 +285,7 @@ public sealed class ObservabilityTests : IDisposable
|
||||
"Other activities should still be created.");
|
||||
}
|
||||
|
||||
[Fact(Skip = "Flaky test - temporarily disabled.")]
|
||||
[Fact]
|
||||
public async Task DisableExecutorProcess_PreventsExecutorProcessActivityAsync()
|
||||
{
|
||||
// Arrange
|
||||
@@ -312,7 +312,7 @@ public sealed class ObservabilityTests : IDisposable
|
||||
"Other activities should still be created.");
|
||||
}
|
||||
|
||||
[Fact(Skip = "Flaky test - temporarily disabled.")]
|
||||
[Fact]
|
||||
public async Task DisableEdgeGroupProcess_PreventsEdgeGroupProcessActivityAsync()
|
||||
{
|
||||
// Arrange
|
||||
@@ -333,7 +333,7 @@ public sealed class ObservabilityTests : IDisposable
|
||||
"Other activities should still be created.");
|
||||
}
|
||||
|
||||
[Fact(Skip = "Flaky test - temporarily disabled.")]
|
||||
[Fact]
|
||||
public async Task DisableMessageSend_PreventsMessageSendActivityAsync()
|
||||
{
|
||||
// Arrange
|
||||
@@ -382,7 +382,7 @@ public sealed class ObservabilityTests : IDisposable
|
||||
return builder.WithOpenTelemetry(configure: opts => opts.DisableMessageSend = true).Build();
|
||||
}
|
||||
|
||||
[Fact(Skip = "Flaky test - temporarily disabled.")]
|
||||
[Fact]
|
||||
public async Task EnableSensitiveData_LogsExecutorInputAndOutputAsync()
|
||||
{
|
||||
// Arrange
|
||||
@@ -413,7 +413,7 @@ public sealed class ObservabilityTests : IDisposable
|
||||
tags[Tags.ExecutorOutput].Should().Contain("HELLO", "Output should contain the transformed value.");
|
||||
}
|
||||
|
||||
[Fact(Skip = "Flaky test - temporarily disabled.")]
|
||||
[Fact]
|
||||
public async Task EnableSensitiveData_Disabled_DoesNotLogInputOutputAsync()
|
||||
{
|
||||
// Arrange
|
||||
@@ -442,7 +442,7 @@ public sealed class ObservabilityTests : IDisposable
|
||||
tags.Should().NotContainKey(Tags.ExecutorOutput, "Output should NOT be logged when EnableSensitiveData is false.");
|
||||
}
|
||||
|
||||
[Fact(Skip = "Flaky test - temporarily disabled.")]
|
||||
[Fact]
|
||||
public async Task EnableSensitiveData_LogsMessageSendContentAsync()
|
||||
{
|
||||
// Arrange
|
||||
@@ -474,7 +474,7 @@ public sealed class ObservabilityTests : IDisposable
|
||||
tags.Should().ContainKey(Tags.MessageSourceId, "Source ID should be logged.");
|
||||
}
|
||||
|
||||
[Fact(Skip = "Flaky test - temporarily disabled.")]
|
||||
[Fact]
|
||||
public async Task EnableSensitiveData_Disabled_DoesNotLogMessageContentAsync()
|
||||
{
|
||||
// Arrange
|
||||
|
||||
+6
-6
@@ -67,7 +67,7 @@ public sealed class WorkflowRunActivityStopTests : IDisposable
|
||||
/// Bug: The Activity created by LockstepRunEventStream.TakeEventStreamAsync is never
|
||||
/// disposed because yield break in async iterators does not trigger using disposal.
|
||||
/// </summary>
|
||||
[Fact(Skip = "Flaky test - temporarily disabled.")]
|
||||
[Fact]
|
||||
public async Task WorkflowRunActivity_IsStopped_LockstepAsync()
|
||||
{
|
||||
// Arrange
|
||||
@@ -111,7 +111,7 @@ public sealed class WorkflowRunActivityStopTests : IDisposable
|
||||
/// Verifies that the workflow_invoke Activity is stopped when using the OffThread (Default)
|
||||
/// execution environment (StreamingRunEventStream).
|
||||
/// </summary>
|
||||
[Fact(Skip = "Flaky test - temporarily disabled.")]
|
||||
[Fact]
|
||||
public async Task WorkflowRunActivity_IsStopped_OffThreadAsync()
|
||||
{
|
||||
// Arrange
|
||||
@@ -156,7 +156,7 @@ public sealed class WorkflowRunActivityStopTests : IDisposable
|
||||
/// (StreamingRun.WatchStreamAsync) with the OffThread execution environment.
|
||||
/// This matches the exact usage pattern described in the issue.
|
||||
/// </summary>
|
||||
[Fact(Skip = "Flaky test - temporarily disabled.")]
|
||||
[Fact]
|
||||
public async Task WorkflowRunActivity_IsStopped_Streaming_OffThreadAsync()
|
||||
{
|
||||
// Arrange
|
||||
@@ -203,7 +203,7 @@ public sealed class WorkflowRunActivityStopTests : IDisposable
|
||||
/// streaming invocation, even when using the same workflow in a multi-turn pattern,
|
||||
/// and that each session gets its own session activity.
|
||||
/// </summary>
|
||||
[Fact(Skip = "Flaky test - temporarily disabled.")]
|
||||
[Fact]
|
||||
public async Task WorkflowRunActivity_IsStopped_Streaming_OffThread_MultiTurnAsync()
|
||||
{
|
||||
// Arrange
|
||||
@@ -264,7 +264,7 @@ public sealed class WorkflowRunActivityStopTests : IDisposable
|
||||
/// Verifies that all started activities (not just workflow_invoke) are properly stopped.
|
||||
/// This ensures no spans are "leaked" without being exported.
|
||||
/// </summary>
|
||||
[Fact(Skip = "Flaky test - temporarily disabled.")]
|
||||
[Fact]
|
||||
public async Task AllActivities_AreStopped_AfterWorkflowCompletionAsync()
|
||||
{
|
||||
// Arrange
|
||||
@@ -305,7 +305,7 @@ public sealed class WorkflowRunActivityStopTests : IDisposable
|
||||
/// be parented under the workflow session span. The run activity should
|
||||
/// still nest correctly under the session.
|
||||
/// </summary>
|
||||
[Fact(Skip = "Flaky test - temporarily disabled.")]
|
||||
[Fact]
|
||||
public async Task Lockstep_SessionActivity_DoesNotLeak_IntoCaller_ActivityCurrentAsync()
|
||||
{
|
||||
// Arrange
|
||||
|
||||
@@ -42,7 +42,7 @@ request_handler = DefaultRequestHandler(
|
||||
app = Starlette(
|
||||
routes=[
|
||||
*create_agent_card_routes(my_agent_card),
|
||||
*create_jsonrpc_routes(request_handler),
|
||||
*create_jsonrpc_routes(request_handler, "/"),
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
@@ -78,7 +78,7 @@ class A2AExecutor(AgentExecutor):
|
||||
app = Starlette(
|
||||
routes=[
|
||||
*create_agent_card_routes(public_agent_card),
|
||||
*create_jsonrpc_routes(request_handler),
|
||||
*create_jsonrpc_routes(request_handler, "/"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@@ -365,6 +365,10 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
|
||||
all_updates: list[AgentResponseUpdate] = []
|
||||
streamed_artifact_ids_by_task: dict[str, set[str]] = {}
|
||||
# In non-streaming mode, accumulate intermediate status content so it
|
||||
# can be surfaced when the terminal event arrives (mirroring v0.3.x
|
||||
# behavior where the full Task history was available at completion).
|
||||
pending_updates_by_task: dict[str, list[AgentResponseUpdate]] = {}
|
||||
async for item in a2a_stream:
|
||||
payload_type = item.WhichOneof("payload")
|
||||
if payload_type == "message":
|
||||
@@ -391,27 +395,55 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
)
|
||||
if task.status.state in TERMINAL_TASK_STATES:
|
||||
streamed_artifact_ids_by_task.pop(task.id, None)
|
||||
# If the terminal Task has no content, flush accumulated updates
|
||||
if not updates or all(not u.contents for u in updates):
|
||||
pending = pending_updates_by_task.pop(task.id, [])
|
||||
for update in pending:
|
||||
all_updates.append(update)
|
||||
yield update
|
||||
else:
|
||||
pending_updates_by_task.pop(task.id, None)
|
||||
for update in updates:
|
||||
all_updates.append(update)
|
||||
yield update
|
||||
elif payload_type == "status_update":
|
||||
status_event = item.status_update
|
||||
updates = self._updates_from_task_update_event(status_event)
|
||||
is_terminal = status_event.status.state in TERMINAL_TASK_STATES
|
||||
if emit_intermediate:
|
||||
for update in updates:
|
||||
all_updates.append(update)
|
||||
yield update
|
||||
elif is_terminal:
|
||||
if updates:
|
||||
# Terminal event with content — discard accumulated intermediates
|
||||
pending_updates_by_task.pop(status_event.task_id, None)
|
||||
for update in updates:
|
||||
all_updates.append(update)
|
||||
yield update
|
||||
else:
|
||||
# Terminal event with NO content — flush accumulated updates
|
||||
pending = pending_updates_by_task.pop(status_event.task_id, [])
|
||||
for update in pending:
|
||||
all_updates.append(update)
|
||||
yield update
|
||||
else:
|
||||
# Non-streaming intermediate: accumulate for later
|
||||
if updates:
|
||||
pending_updates_by_task.setdefault(status_event.task_id, []).extend(updates)
|
||||
elif payload_type == "artifact_update":
|
||||
artifact_event = item.artifact_update
|
||||
updates = self._updates_from_task_update_event(artifact_event)
|
||||
# Always yield artifact updates — they carry actual response
|
||||
# content (files, data). Track IDs so that a subsequent
|
||||
# terminal Task doesn't duplicate the same artifacts.
|
||||
if updates:
|
||||
streamed_artifact_ids_by_task.setdefault(artifact_event.task_id, set()).add(
|
||||
artifact_event.artifact.artifact_id
|
||||
)
|
||||
if emit_intermediate:
|
||||
for update in updates:
|
||||
all_updates.append(update)
|
||||
yield update
|
||||
for update in updates:
|
||||
all_updates.append(update)
|
||||
yield update
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported StreamResponse payload: {payload_type}")
|
||||
|
||||
|
||||
@@ -1570,4 +1570,102 @@ async def test_none_metadata_leaves_additional_properties_empty(
|
||||
assert not response.additional_properties
|
||||
|
||||
|
||||
async def test_non_streaming_terminal_status_update_surfaces_content(
|
||||
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
|
||||
) -> None:
|
||||
"""Non-streaming run() should surface content from terminal status_update events."""
|
||||
completed_msg = A2AMessage(
|
||||
message_id="msg-complete",
|
||||
role=A2ARole.ROLE_AGENT,
|
||||
parts=[Part(text="Done! Here is your answer.")],
|
||||
)
|
||||
status = TaskStatus(state=TaskState.TASK_STATE_COMPLETED, message=completed_msg)
|
||||
event = TaskStatusUpdateEvent(task_id="task-ts", context_id="ctx-ts", status=status)
|
||||
mock_a2a_client.responses.append(StreamResponse(status_update=event))
|
||||
|
||||
response = await a2a_agent.run("Hello")
|
||||
|
||||
assert len(response.messages) == 1
|
||||
assert response.messages[0].text == "Done! Here is your answer."
|
||||
|
||||
|
||||
async def test_non_streaming_accumulates_working_content_for_empty_terminal(
|
||||
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
|
||||
) -> None:
|
||||
"""Non-streaming run() accumulates WORKING content and flushes on empty terminal event."""
|
||||
# Intermediate WORKING event with content
|
||||
working_msg = A2AMessage(
|
||||
message_id="msg-working",
|
||||
role=A2ARole.ROLE_AGENT,
|
||||
parts=[Part(text="Here is your answer from working state.")],
|
||||
)
|
||||
working_status = TaskStatus(state=TaskState.TASK_STATE_WORKING, message=working_msg)
|
||||
working_event = TaskStatusUpdateEvent(task_id="task-acc", context_id="ctx-acc", status=working_status)
|
||||
mock_a2a_client.responses.append(StreamResponse(status_update=working_event))
|
||||
|
||||
# Terminal COMPLETED event with NO content
|
||||
completed_status = TaskStatus(state=TaskState.TASK_STATE_COMPLETED)
|
||||
completed_event = TaskStatusUpdateEvent(task_id="task-acc", context_id="ctx-acc", status=completed_status)
|
||||
mock_a2a_client.responses.append(StreamResponse(status_update=completed_event))
|
||||
|
||||
response = await a2a_agent.run("Hello")
|
||||
|
||||
# The accumulated WORKING content is flushed when terminal arrives empty
|
||||
assert len(response.messages) == 1
|
||||
assert response.messages[0].text == "Here is your answer from working state."
|
||||
|
||||
|
||||
async def test_non_streaming_intermediate_discarded_when_terminal_has_content(
|
||||
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
|
||||
) -> None:
|
||||
"""Non-streaming: if terminal event has content, intermediate content is discarded."""
|
||||
# Intermediate WORKING event
|
||||
working_msg = A2AMessage(
|
||||
message_id="msg-working",
|
||||
role=A2ARole.ROLE_AGENT,
|
||||
parts=[Part(text="Still thinking...")],
|
||||
)
|
||||
working_status = TaskStatus(state=TaskState.TASK_STATE_WORKING, message=working_msg)
|
||||
working_event = TaskStatusUpdateEvent(task_id="task-wi", context_id="ctx-wi", status=working_status)
|
||||
mock_a2a_client.responses.append(StreamResponse(status_update=working_event))
|
||||
|
||||
# Terminal COMPLETED event WITH content
|
||||
completed_msg = A2AMessage(
|
||||
message_id="msg-final",
|
||||
role=A2ARole.ROLE_AGENT,
|
||||
parts=[Part(text="Final answer")],
|
||||
)
|
||||
completed_status = TaskStatus(state=TaskState.TASK_STATE_COMPLETED, message=completed_msg)
|
||||
completed_event = TaskStatusUpdateEvent(task_id="task-wi", context_id="ctx-wi", status=completed_status)
|
||||
mock_a2a_client.responses.append(StreamResponse(status_update=completed_event))
|
||||
|
||||
response = await a2a_agent.run("Hello")
|
||||
|
||||
# Terminal content supersedes accumulated intermediates
|
||||
assert len(response.messages) == 1
|
||||
assert response.messages[0].text == "Final answer"
|
||||
|
||||
|
||||
async def test_non_streaming_artifact_update_surfaces_content(
|
||||
a2a_agent: A2AAgent, mock_a2a_client: MockA2AClient
|
||||
) -> None:
|
||||
"""Non-streaming run() should surface content from artifact_update events."""
|
||||
artifact = Artifact(
|
||||
artifact_id="art-ns",
|
||||
parts=[Part(text="Artifact content")],
|
||||
)
|
||||
event = TaskArtifactUpdateEvent(task_id="task-anu", context_id="ctx-anu", artifact=artifact, append=False)
|
||||
mock_a2a_client.responses.append(StreamResponse(artifact_update=event))
|
||||
|
||||
# Terminal task with the same artifact ID — should be deduped
|
||||
mock_a2a_client.add_task_response("task-anu", [{"id": "art-ns", "content": "Artifact content"}])
|
||||
|
||||
response = await a2a_agent.run("Hello")
|
||||
|
||||
# Artifact update + terminal task with same artifact ID = content emitted once from
|
||||
# the artifact_update, then the duplicate from the task is filtered by streamed_artifact_ids
|
||||
assert len(response.messages) == 1
|
||||
assert response.messages[0].text == "Artifact content"
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -261,6 +261,7 @@ class MCPTool:
|
||||
self.request_timeout = request_timeout
|
||||
self.client = client
|
||||
self._functions: list[FunctionTool] = []
|
||||
self._tool_call_meta_by_name: dict[str, dict[str, Any]] = {}
|
||||
self.is_connected: bool = False
|
||||
self._tools_loaded: bool = False
|
||||
self._prompts_loaded: bool = False
|
||||
@@ -1026,6 +1027,7 @@ class MCPTool:
|
||||
|
||||
# Track existing function names to prevent duplicates
|
||||
existing_names = {func.name for func in self._functions}
|
||||
self._tool_call_meta_by_name.clear()
|
||||
|
||||
params: types.PaginatedRequestParams | None = None
|
||||
while True:
|
||||
@@ -1035,6 +1037,9 @@ class MCPTool:
|
||||
tool_list = await self.session.list_tools(params=params) # type: ignore[union-attr]
|
||||
|
||||
for tool in tool_list.tools:
|
||||
if tool.meta is not None:
|
||||
self._tool_call_meta_by_name[tool.name] = dict(tool.meta)
|
||||
|
||||
normalized_name = _normalize_mcp_name(tool.name)
|
||||
local_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix)
|
||||
|
||||
@@ -1185,14 +1190,15 @@ class MCPTool:
|
||||
}
|
||||
}
|
||||
|
||||
# Inject OpenTelemetry trace context into MCP _meta for distributed tracing.
|
||||
otel_meta = _inject_otel_into_mcp_meta()
|
||||
# Some MCP proxies require their tools/list metadata to be echoed on tools/call.
|
||||
tool_meta = self._tool_call_meta_by_name.get(tool_name)
|
||||
meta = _inject_otel_into_mcp_meta(dict(tool_meta) if tool_meta is not None else None)
|
||||
|
||||
parser = self.parse_tool_results or self._parse_tool_result_from_mcp
|
||||
# Try the operation, reconnecting once if the connection is closed
|
||||
for attempt in range(2):
|
||||
try:
|
||||
result = await self.session.call_tool(tool_name, arguments=filtered_kwargs, meta=otel_meta) # type: ignore
|
||||
result = await self.session.call_tool(tool_name, arguments=filtered_kwargs, meta=meta) # type: ignore
|
||||
if result.isError:
|
||||
parsed = parser(result)
|
||||
text = (
|
||||
|
||||
@@ -4194,6 +4194,57 @@ async def test_mcp_tool_call_tool_otel_meta(use_span, expect_traceparent, span_e
|
||||
assert meta is None
|
||||
|
||||
|
||||
async def test_mcp_tool_call_tool_forwards_tool_list_meta():
|
||||
"""call_tool echoes per-tool metadata returned by tools/list."""
|
||||
from opentelemetry import trace
|
||||
|
||||
tool_meta = {
|
||||
"tool_configuration": {
|
||||
"name": "WorkIQSharePoint.readSmallBinaryFile",
|
||||
"type": "foundry_toolbox",
|
||||
}
|
||||
}
|
||||
|
||||
class TestServer(MCPTool):
|
||||
async def connect(self):
|
||||
self.session = Mock(spec=ClientSession)
|
||||
self.session.list_tools = AsyncMock(
|
||||
return_value=types.ListToolsResult(
|
||||
tools=[
|
||||
types.Tool(
|
||||
name="WorkIQSharePoint.readSmallBinaryFile",
|
||||
description="Read a binary file",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {"fileId": {"type": "string"}},
|
||||
"required": ["fileId"],
|
||||
},
|
||||
_meta=tool_meta,
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
self.session.call_tool = AsyncMock(
|
||||
return_value=types.CallToolResult(content=[types.TextContent(type="text", text="result")])
|
||||
)
|
||||
self.session.list_prompts = AsyncMock(
|
||||
return_value=types.ListPromptsResult(prompts=[])
|
||||
)
|
||||
|
||||
def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
|
||||
return None
|
||||
|
||||
server = TestServer(name="test_server")
|
||||
async with server:
|
||||
await server.load_tools()
|
||||
await server.load_prompts()
|
||||
|
||||
with trace.use_span(trace.NonRecordingSpan(trace.INVALID_SPAN_CONTEXT)):
|
||||
await server.call_tool("WorkIQSharePoint.readSmallBinaryFile", fileId="file-1")
|
||||
|
||||
assert server.session.call_tool.call_args.kwargs["meta"] == tool_meta
|
||||
|
||||
|
||||
async def test_mcp_streamable_http_tool_hook_not_duplicated_on_repeated_get_mcp_client():
|
||||
"""Test that calling get_mcp_client multiple times does not accumulate duplicate hooks."""
|
||||
tool = MCPStreamableHTTPTool(
|
||||
|
||||
@@ -11,6 +11,7 @@ import tempfile
|
||||
import threading
|
||||
from collections.abc import AsyncIterable, AsyncIterator, Generator, Mapping, Sequence
|
||||
from contextlib import suppress
|
||||
from pathlib import Path
|
||||
from typing import Protocol, cast
|
||||
|
||||
from agent_framework import (
|
||||
@@ -205,6 +206,47 @@ class FileBasedFunctionApprovalStorage:
|
||||
return await asyncio.to_thread(self._load_sync, approval_request_id)
|
||||
|
||||
|
||||
def _checkpoint_storage_for_context(root: str, context_id: str) -> FileCheckpointStorage:
|
||||
"""Build a ``FileCheckpointStorage`` for ``context_id`` rooted under ``root``.
|
||||
|
||||
``context_id`` originates from caller-controlled fields such as
|
||||
``previous_response_id`` or from server-generated fields such as
|
||||
``conversation_id`` / ``response_id``. In every case it must be treated as
|
||||
an untrusted single path segment: path separators, drive letters, parent
|
||||
references and similar would otherwise let the resulting directory escape
|
||||
the configured checkpoint root (CWE-22). The check resolves the joined
|
||||
path and verifies it stays under the resolved root before any directory is
|
||||
created on disk.
|
||||
"""
|
||||
if not isinstance(context_id, str) or not context_id:
|
||||
raise RuntimeError("Invalid checkpoint context id: must be a non-empty string.")
|
||||
# Reject any segment that is not a single safe path component. This covers
|
||||
# POSIX/Windows separators, NUL bytes, drive letters, and all-dot segments
|
||||
# (``.``, ``..``, ``...``, ...). We deliberately do not URL-decode the id
|
||||
# here: the hosting layer never decodes context ids before joining them, so
|
||||
# forms such as ``%2e%2e`` are accepted as literal directory names. Do NOT
|
||||
# add decoding here without re-validating after the decode -- decode-then-
|
||||
# join is exactly the pattern that reintroduces traversal. We also do not
|
||||
# attempt to "sanitize" by stripping characters because that can introduce
|
||||
# collisions between distinct ids.
|
||||
if (
|
||||
"/" in context_id
|
||||
or "\\" in context_id
|
||||
or "\x00" in context_id
|
||||
# All-dot segments (``.``, ``..``, ``...``, ...) reduce to "" after stripping dots.
|
||||
or context_id.strip(".") == ""
|
||||
or os.path.isabs(context_id)
|
||||
or os.path.splitdrive(context_id)[0]
|
||||
):
|
||||
raise RuntimeError(f"Invalid checkpoint context id: {context_id!r}")
|
||||
|
||||
root_path = Path(root).resolve()
|
||||
storage_path = (root_path / context_id).resolve()
|
||||
if not storage_path.is_relative_to(root_path):
|
||||
raise RuntimeError(f"Invalid checkpoint context id: {context_id!r}")
|
||||
return FileCheckpointStorage(storage_path)
|
||||
|
||||
|
||||
class ResponsesHostServer(ResponsesAgentServerHost):
|
||||
"""A responses server host for an agent."""
|
||||
|
||||
@@ -400,7 +442,7 @@ class ResponsesHostServer(ResponsesAgentServerHost):
|
||||
latest_checkpoint_id: str | None = None
|
||||
restore_storage: FileCheckpointStorage | None = None
|
||||
if context_id is not None:
|
||||
restore_storage = FileCheckpointStorage(os.path.join(self._checkpoint_storage_path, context_id))
|
||||
restore_storage = _checkpoint_storage_for_context(self._checkpoint_storage_path, context_id)
|
||||
latest_checkpoint = await restore_storage.get_latest(workflow_name=self._agent.workflow.name)
|
||||
if latest_checkpoint is not None:
|
||||
latest_checkpoint_id = latest_checkpoint.checkpoint_id
|
||||
@@ -414,7 +456,7 @@ class ResponsesHostServer(ResponsesAgentServerHost):
|
||||
# supplied, restore_storage points at the *prior* response's
|
||||
# directory and write_storage points at the *current* response's.
|
||||
write_context_id = context.conversation_id or context.response_id
|
||||
write_storage = FileCheckpointStorage(os.path.join(self._checkpoint_storage_path, write_context_id))
|
||||
write_storage = _checkpoint_storage_for_context(self._checkpoint_storage_path, write_context_id)
|
||||
|
||||
# Multi-turn pattern: when we have a prior checkpoint, restore it
|
||||
# first (drive the workflow back to idle with prior state intact),
|
||||
|
||||
@@ -11,7 +11,7 @@ the registered _handle_create handler.
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from collections.abc import AsyncIterator, Callable
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import httpx
|
||||
@@ -20,6 +20,7 @@ from agent_framework import (
|
||||
AgentResponse,
|
||||
AgentResponseUpdate,
|
||||
Content,
|
||||
FileCheckpointStorage,
|
||||
HistoryProvider,
|
||||
Message,
|
||||
RawAgent,
|
||||
@@ -2652,3 +2653,241 @@ class TestFunctionApprovalRoundTrip:
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region Checkpoint context path validation
|
||||
|
||||
|
||||
class TestCheckpointContextPathValidation:
|
||||
"""Regression tests for the path-traversal hardening of checkpoint storage.
|
||||
|
||||
These tests guard against CWE-22 in the workflow hosting path. The hosting
|
||||
code joins caller-supplied identifiers (``previous_response_id``) and
|
||||
server-generated identifiers (``conversation_id`` / ``response_id``) under
|
||||
the configured checkpoint root. Without validation, traversal segments
|
||||
such as ``../../escape`` or absolute paths cause directory creation
|
||||
outside the intended root.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _helper() -> Callable[[str, str], FileCheckpointStorage]:
|
||||
from agent_framework_foundry_hosting._responses import ( # pyright: ignore[reportPrivateUsage]
|
||||
_checkpoint_storage_for_context,
|
||||
)
|
||||
|
||||
return _checkpoint_storage_for_context
|
||||
|
||||
def test_valid_segment_creates_storage_under_root(self, tmp_path: Any) -> None:
|
||||
helper = self._helper()
|
||||
root = tmp_path / "root"
|
||||
root.mkdir()
|
||||
storage = helper(str(root), "resp_abc123")
|
||||
assert storage.storage_path.is_dir()
|
||||
assert storage.storage_path.parent == root.resolve()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"bad_id",
|
||||
[
|
||||
# Original MSRC repro: traversal embedded inside an id-shaped value.
|
||||
# The 14 ``A``s pad the suffix to mimic the exact length of the
|
||||
# ``api-made-dir<14-char-suffix>`` segment from the original report.
|
||||
"caresp_x/../../service-data/api-made-dir" + "A" * 14,
|
||||
# Variant report repros.
|
||||
"../../escape",
|
||||
"..",
|
||||
".",
|
||||
"...",
|
||||
"/tmp/escape",
|
||||
"/absolute/path",
|
||||
"C:\\temp\\escape",
|
||||
"..\\..\\escape",
|
||||
"foo\\..\\bar",
|
||||
"foo/bar",
|
||||
"with\x00null",
|
||||
"",
|
||||
],
|
||||
)
|
||||
def test_traversal_and_separator_payloads_are_rejected(self, tmp_path: Any, bad_id: str) -> None:
|
||||
helper = self._helper()
|
||||
# Use a dedicated root *inside* tmp_path so we can assert that nothing
|
||||
# was created anywhere under tmp_path (root, siblings, or above).
|
||||
# Asserting against tmp_path.parent would be flaky under parallel test
|
||||
# execution because tmp_path.parent is shared across tests.
|
||||
root = tmp_path / "root"
|
||||
root.mkdir()
|
||||
before = sorted(p.name for p in tmp_path.iterdir())
|
||||
with pytest.raises(RuntimeError):
|
||||
helper(str(root), bad_id)
|
||||
# No sibling/escape directory should have been created next to the root.
|
||||
after = sorted(p.name for p in tmp_path.iterdir())
|
||||
assert before == after, f"Unexpected filesystem artifacts created for payload {bad_id!r}"
|
||||
# And nothing inside the root either.
|
||||
assert list(root.iterdir()) == []
|
||||
|
||||
def test_non_string_context_id_is_rejected(self, tmp_path: Any) -> None:
|
||||
helper = self._helper()
|
||||
with pytest.raises(RuntimeError):
|
||||
helper(str(tmp_path), None) # type: ignore[arg-type]
|
||||
|
||||
def test_url_encoded_traversal_is_treated_as_literal_segment(self, tmp_path: Any) -> None:
|
||||
"""URL-encoded traversal should not decode to traversal at the filesystem layer.
|
||||
|
||||
The hosting layer never URL-decodes ids before using them; the helper
|
||||
should accept ``%2e%2e`` as a single literal segment (no escape).
|
||||
"""
|
||||
helper = self._helper()
|
||||
root = tmp_path / "root"
|
||||
root.mkdir()
|
||||
storage = helper(str(root), "%2e%2e")
|
||||
assert storage.storage_path.parent == root.resolve()
|
||||
assert storage.storage_path.name == "%2e%2e"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"context_field,bad_id",
|
||||
[
|
||||
# Restore sink: caller-controlled previous_response_id.
|
||||
("previous_response_id", "../../escape"),
|
||||
("previous_response_id", "/tmp/escape-abs"),
|
||||
("previous_response_id", "caresp_x/../../service-data/api-made-dir" + "A" * 14),
|
||||
# Restore sink: server-issued conversation_id (defense in depth).
|
||||
("conversation_id", "../../escape"),
|
||||
# Write sink: malicious response_id (defense in depth).
|
||||
("response_id", "../../escape"),
|
||||
],
|
||||
)
|
||||
async def test_handle_inner_workflow_rejects_malicious_context_id(
|
||||
self, tmp_path: Any, context_field: str, bad_id: str
|
||||
) -> None:
|
||||
"""End-to-end: ``_handle_inner_workflow`` must reject malicious ids on
|
||||
both the restore sink (``previous_response_id`` / ``conversation_id``)
|
||||
and the write sink (``response_id``) without creating any directories.
|
||||
"""
|
||||
from unittest.mock import patch
|
||||
|
||||
from agent_framework import WorkflowAgent
|
||||
from azure.ai.agentserver.responses import ResponseContext
|
||||
from azure.ai.agentserver.responses.models import CreateResponse
|
||||
|
||||
# Build a mock that satisfies isinstance(agent, WorkflowAgent) and the
|
||||
# constructor's "no existing checkpointing" guard.
|
||||
agent = MagicMock(spec=WorkflowAgent)
|
||||
agent.id = "wf-agent"
|
||||
agent.name = "wf"
|
||||
agent.description = ""
|
||||
agent.context_providers = []
|
||||
agent.workflow = MagicMock()
|
||||
agent.workflow.name = "wf"
|
||||
agent.workflow._runner_context.has_checkpointing = MagicMock(return_value=False)
|
||||
|
||||
# Constructor inspects WorkflowAgent.workflow internals; bypass setup
|
||||
# by feeding a configured mock through a normal init.
|
||||
server = ResponsesHostServer(agent, store=InMemoryResponseProvider())
|
||||
# Re-root checkpoint storage at our isolated tmp_path so we can detect
|
||||
# any escape attempt on the filesystem.
|
||||
root = tmp_path / "root"
|
||||
root.mkdir()
|
||||
server._checkpoint_storage_path = str(root) # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# Build a ResponseContext with the malicious id targeting the chosen sink.
|
||||
kwargs: dict[str, Any] = {
|
||||
"response_id": "resp_" + "a" * 48,
|
||||
"mode_flags": MagicMock(),
|
||||
}
|
||||
if context_field == "previous_response_id":
|
||||
request = CreateResponse(model="m", input="hi", previous_response_id=bad_id)
|
||||
kwargs["previous_response_id"] = bad_id
|
||||
elif context_field == "conversation_id":
|
||||
request = CreateResponse(model="m", input="hi")
|
||||
kwargs["conversation_id"] = bad_id
|
||||
else: # response_id (write sink)
|
||||
request = CreateResponse(model="m", input="hi")
|
||||
kwargs["response_id"] = bad_id
|
||||
|
||||
# Avoid invoking the real input-resolution machinery, which would need
|
||||
# a configured provider; we never reach the workflow run on rejection.
|
||||
with patch.object(ResponseContext, "get_input_items", new=AsyncMock(return_value=[])):
|
||||
context = ResponseContext(**kwargs)
|
||||
before = sorted(p.name for p in tmp_path.iterdir())
|
||||
with pytest.raises(RuntimeError, match="Invalid checkpoint context id"):
|
||||
async for _ in server._handle_inner_workflow(request, context): # pyright: ignore[reportPrivateUsage]
|
||||
pass
|
||||
after = sorted(p.name for p in tmp_path.iterdir())
|
||||
|
||||
assert before == after, f"Unexpected filesystem artifacts created for {context_field}={bad_id!r}"
|
||||
assert list(root.iterdir()) == [], f"Checkpoint dir created inside root for {context_field}={bad_id!r}"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"context_field,bad_id",
|
||||
[
|
||||
# Restore sink: caller-controlled previous_response_id. These are
|
||||
# rejected by request validation (HTTP 400) before the checkpoint
|
||||
# code is reached.
|
||||
("previous_response_id", "../../escape"),
|
||||
("previous_response_id", "/tmp/escape-abs"),
|
||||
("previous_response_id", "caresp_x/../../service-data/api-made-dir" + "A" * 14),
|
||||
# Restore sink: server-issued conversation id (defense in depth).
|
||||
# Reaches the checkpoint code and is rejected there, surfacing as
|
||||
# an HTTP 5xx without creating any filesystem artifacts.
|
||||
("conversation", "../../escape"),
|
||||
("conversation", "/tmp/escape-abs"),
|
||||
],
|
||||
)
|
||||
async def test_malicious_context_id_rejected_e2e(self, tmp_path: Any, context_field: str, bad_id: str) -> None:
|
||||
"""End-to-end (ASGI-in-process): malicious context ids must be rejected
|
||||
through the full HTTP pipeline, and no checkpoint directory may be
|
||||
created on disk for either the validation-layer rejection
|
||||
(``previous_response_id``) or the deeper checkpoint-layer rejection
|
||||
(``conversation``).
|
||||
|
||||
The ``response_id`` write-sink is server-generated and not reachable
|
||||
via the public HTTP surface, so its defense-in-depth check is covered
|
||||
by the helper-level test above.
|
||||
"""
|
||||
from agent_framework import WorkflowAgent
|
||||
|
||||
# Build a mock that satisfies isinstance(agent, WorkflowAgent) and the
|
||||
# constructor's "no existing checkpointing" guard.
|
||||
agent = MagicMock(spec=WorkflowAgent)
|
||||
agent.id = "wf-agent"
|
||||
agent.name = "wf"
|
||||
agent.description = ""
|
||||
agent.context_providers = []
|
||||
agent.workflow = MagicMock()
|
||||
agent.workflow.name = "wf"
|
||||
agent.workflow._runner_context.has_checkpointing = MagicMock( # pyright: ignore[reportPrivateUsage]
|
||||
return_value=False
|
||||
)
|
||||
|
||||
server = ResponsesHostServer(agent, store=InMemoryResponseProvider())
|
||||
# Re-root checkpoint storage at our isolated tmp_path so we can detect
|
||||
# any escape attempt on the filesystem.
|
||||
root = tmp_path / "root"
|
||||
root.mkdir()
|
||||
server._checkpoint_storage_path = str(root) # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
payload: dict[str, Any] = {"model": "m", "input": "hi"}
|
||||
if context_field == "previous_response_id":
|
||||
payload["previous_response_id"] = bad_id
|
||||
else: # conversation
|
||||
payload["conversation"] = bad_id
|
||||
|
||||
before = sorted(p.name for p in tmp_path.iterdir())
|
||||
transport = httpx.ASGITransport(app=server)
|
||||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
resp = await client.post("/responses", json=payload)
|
||||
after = sorted(p.name for p in tmp_path.iterdir())
|
||||
|
||||
# The request must not succeed; either request validation rejects it
|
||||
# (4xx) or the checkpoint layer raises and the server returns 5xx.
|
||||
# Either way, no successful response may be produced.
|
||||
assert resp.status_code >= 400, (
|
||||
f"Expected non-2xx for {context_field}={bad_id!r}, got {resp.status_code}: {resp.text[:200]}"
|
||||
)
|
||||
assert before == after, (
|
||||
f"Unexpected filesystem artifacts under tmp_path for {context_field}={bad_id!r}: "
|
||||
f"before={before} after={after}"
|
||||
)
|
||||
assert list(root.iterdir()) == [], f"Checkpoint directory created inside root for {context_field}={bad_id!r}"
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -103,7 +103,7 @@ def main() -> None:
|
||||
app = Starlette(
|
||||
routes=[
|
||||
*create_agent_card_routes(agent_card),
|
||||
*create_jsonrpc_routes(request_handler),
|
||||
*create_jsonrpc_routes(request_handler, "/"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -8,16 +8,11 @@ AgentCards for the invoice, policy, and logistics agent types.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from a2a.types import AgentCapabilities, AgentCard, AgentInterface, AgentSkill
|
||||
from agent_framework import Agent
|
||||
from agent_framework.foundry import FoundryChatClient
|
||||
from invoice_data import query_by_invoice_id, query_by_transaction_id, query_invoices
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agent_framework import Agent
|
||||
from agent_framework.foundry import FoundryChatClient
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Agent instructions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -10,18 +10,12 @@ published back through the a2a-sdk event queue.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from a2a.helpers import new_task_from_user_message
|
||||
from a2a.server.agent_execution.agent_executor import AgentExecutor
|
||||
from a2a.types import (
|
||||
Message,
|
||||
Part,
|
||||
Role,
|
||||
TaskState,
|
||||
TaskStatus,
|
||||
TaskStatusUpdateEvent,
|
||||
)
|
||||
from a2a.server.tasks import TaskUpdater
|
||||
from a2a.types import Part, TaskState
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.server.agent_execution.context import RequestContext
|
||||
@@ -47,17 +41,17 @@ class AgentFrameworkExecutor(AgentExecutor):
|
||||
if not user_text:
|
||||
user_text = "Hello"
|
||||
|
||||
task_id = context.task_id or str(uuid.uuid4())
|
||||
context_id = context.context_id or str(uuid.uuid4())
|
||||
# v1.0 requires a Task object in the queue before any TaskStatusUpdateEvent
|
||||
task = context.current_task
|
||||
if not task and context.message:
|
||||
task = new_task_from_user_message(context.message)
|
||||
await event_queue.enqueue_event(task)
|
||||
|
||||
task_id = task.id if task else context.task_id
|
||||
updater = TaskUpdater(event_queue, task_id, context.context_id)
|
||||
|
||||
# Signal that the agent is working
|
||||
await event_queue.enqueue_event(
|
||||
TaskStatusUpdateEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
status=TaskStatus(state=TaskState.TASK_STATE_WORKING),
|
||||
)
|
||||
)
|
||||
await updater.start_work()
|
||||
|
||||
try:
|
||||
response = await self.agent.run(user_text)
|
||||
@@ -71,48 +65,19 @@ class AgentFrameworkExecutor(AgentExecutor):
|
||||
if not response_parts:
|
||||
response_parts.append(Part(text=str(response)))
|
||||
|
||||
# Publish the agent's response as a completed message
|
||||
await event_queue.enqueue_event(
|
||||
TaskStatusUpdateEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
status=TaskStatus(
|
||||
state=TaskState.TASK_STATE_COMPLETED,
|
||||
message=Message(
|
||||
message_id=str(uuid.uuid4()),
|
||||
role=Role.ROLE_AGENT,
|
||||
parts=response_parts,
|
||||
),
|
||||
),
|
||||
)
|
||||
# Publish the agent's response and mark as completed
|
||||
await updater.complete(
|
||||
message=updater.new_agent_message(response_parts),
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
await event_queue.enqueue_event(
|
||||
TaskStatusUpdateEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
status=TaskStatus(
|
||||
state=TaskState.TASK_STATE_FAILED,
|
||||
message=Message(
|
||||
message_id=str(uuid.uuid4()),
|
||||
role=Role.ROLE_AGENT,
|
||||
parts=[Part(text=f"Agent error: {e}")],
|
||||
),
|
||||
),
|
||||
)
|
||||
await updater.update_status(
|
||||
state=TaskState.TASK_STATE_FAILED,
|
||||
message=updater.new_agent_message([Part(text=f"Agent error: {e}")]),
|
||||
)
|
||||
|
||||
async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None:
|
||||
"""Handle cancellation by publishing a canceled status."""
|
||||
task_id = context.task_id or str(uuid.uuid4())
|
||||
context_id = context.context_id or str(uuid.uuid4())
|
||||
|
||||
await event_queue.enqueue_event(
|
||||
TaskStatusUpdateEvent(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
status=TaskStatus(state=TaskState.TASK_STATE_CANCELED),
|
||||
)
|
||||
)
|
||||
updater = TaskUpdater(event_queue, context.task_id, context.context_id)
|
||||
await updater.update_status(state=TaskState.TASK_STATE_CANCELED)
|
||||
|
||||
@@ -65,7 +65,7 @@ if __name__ == "__main__":
|
||||
server = Starlette(
|
||||
routes=[
|
||||
*create_agent_card_routes(public_agent_card),
|
||||
*create_jsonrpc_routes(request_handler),
|
||||
*create_jsonrpc_routes(request_handler, "/"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user