mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Add checkpoint save and restore hooks to executor (#2097)
* Add checkpoint hooks * Deprecate get_executor_state and set_executor_state * Fix tests and samples * Add doc strings * Add sample * Fix import * Address comments and fix tests * Address comments * conditional import
This commit is contained in:
committed by
GitHub
Unverified
parent
132597957a
commit
c361ad8d33
@@ -17,7 +17,7 @@ Workflow Steps:
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Annotated
|
||||
from typing import Literal
|
||||
|
||||
from agent_framework import (
|
||||
Case,
|
||||
@@ -31,9 +31,11 @@ from agent_framework import (
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Never
|
||||
|
||||
|
||||
# Define response model with clear user guidance
|
||||
class SpamDecision(BaseModel):
|
||||
"""User's decision on whether the email is spam."""
|
||||
|
||||
decision: Literal["spam", "not spam"] = Field(
|
||||
description="Enter 'spam' to mark as spam, or 'not spam' to mark as legitimate"
|
||||
)
|
||||
@@ -71,10 +73,11 @@ class SpamDetectorResponse:
|
||||
class SpamApprovalRequest:
|
||||
"""Human-in-the-loop approval request for spam classification."""
|
||||
|
||||
email_message: str = ""
|
||||
detected_as_spam: bool = False
|
||||
confidence: float = 0.0
|
||||
reasons: str = ""
|
||||
email_message: str
|
||||
detected_as_spam: bool
|
||||
confidence: float
|
||||
reasons: list[str]
|
||||
full_email_content: EmailContent
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -128,8 +131,6 @@ class EmailPreprocessor(Executor):
|
||||
await ctx.send_message(result)
|
||||
|
||||
|
||||
|
||||
|
||||
class SpamDetector(Executor):
|
||||
"""Step 2: An executor that analyzes content and determines if a message is spam."""
|
||||
|
||||
@@ -139,7 +140,9 @@ class SpamDetector(Executor):
|
||||
self._spam_keywords = spam_keywords
|
||||
|
||||
@handler
|
||||
async def handle_email_content(self, email_content: EmailContent, ctx: WorkflowContext[SpamApprovalRequest]) -> None:
|
||||
async def handle_email_content(
|
||||
self, email_content: EmailContent, ctx: WorkflowContext[SpamApprovalRequest]
|
||||
) -> None:
|
||||
"""Analyze email content and determine if the message is spam, then request human approval."""
|
||||
await asyncio.sleep(2.0) # Simulate analysis and detection time
|
||||
|
||||
@@ -186,25 +189,13 @@ class SpamDetector(Executor):
|
||||
|
||||
is_spam = spam_score >= 0.5
|
||||
|
||||
# Store detection result in executor state for later use
|
||||
# Store minimal data needed (not complex objects that don't serialize well)
|
||||
await ctx.set_executor_state({
|
||||
"original_message": email_content.original_message,
|
||||
"cleaned_message": email_content.cleaned_message,
|
||||
"word_count": email_content.word_count,
|
||||
"has_suspicious_patterns": email_content.has_suspicious_patterns,
|
||||
"is_spam": is_spam,
|
||||
"ai_original_classification": is_spam, # Store original AI decision
|
||||
"confidence_score": spam_score,
|
||||
"spam_reasons": spam_reasons
|
||||
})
|
||||
|
||||
# Request human approval before proceeding using new API
|
||||
approval_request = SpamApprovalRequest(
|
||||
email_message=email_text[:200], # First 200 chars
|
||||
detected_as_spam=is_spam,
|
||||
confidence=spam_score,
|
||||
reasons=", ".join(spam_reasons) if spam_reasons else "no specific reasons"
|
||||
reasons=spam_reasons,
|
||||
full_email_content=email_content,
|
||||
)
|
||||
|
||||
await ctx.request_info(
|
||||
@@ -214,20 +205,15 @@ class SpamDetector(Executor):
|
||||
|
||||
@response_handler
|
||||
async def handle_human_response(
|
||||
self,
|
||||
original_request: SpamApprovalRequest,
|
||||
response: SpamDecision,
|
||||
ctx: WorkflowContext[SpamDetectorResponse]
|
||||
self, original_request: SpamApprovalRequest, response: SpamDecision, ctx: WorkflowContext[SpamDetectorResponse]
|
||||
) -> None:
|
||||
"""Process human approval response and continue workflow."""
|
||||
print(f"[SpamDetector] handle_human_response called with response: {response}")
|
||||
|
||||
# Get stored detection result
|
||||
state = await ctx.get_executor_state() or {}
|
||||
print(f"[SpamDetector] Retrieved state: {state}")
|
||||
ai_original = state.get("ai_original_classification", False)
|
||||
confidence_score = state.get("confidence_score", 0.0)
|
||||
spam_reasons = state.get("spam_reasons", [])
|
||||
ai_original = original_request.detected_as_spam
|
||||
confidence_score = original_request.confidence
|
||||
spam_reasons = original_request.reasons
|
||||
|
||||
# Parse human decision from the response model
|
||||
human_decision = response.decision.strip().lower()
|
||||
@@ -241,27 +227,21 @@ class SpamDetector(Executor):
|
||||
# Default to AI decision if unclear
|
||||
is_spam = ai_original
|
||||
|
||||
# Reconstruct EmailContent from stored primitives
|
||||
email_content = EmailContent(
|
||||
original_message=state.get("original_message", ""),
|
||||
cleaned_message=state.get("cleaned_message", ""),
|
||||
word_count=state.get("word_count", 0),
|
||||
has_suspicious_patterns=state.get("has_suspicious_patterns", False)
|
||||
)
|
||||
|
||||
result = SpamDetectorResponse(
|
||||
email_content=email_content,
|
||||
email_content=original_request.full_email_content,
|
||||
is_spam=is_spam,
|
||||
confidence_score=confidence_score,
|
||||
spam_reasons=spam_reasons,
|
||||
human_reviewed=True,
|
||||
human_decision=response.decision,
|
||||
ai_original_classification=ai_original
|
||||
ai_original_classification=ai_original,
|
||||
)
|
||||
|
||||
print(f"[SpamDetector] Sending SpamDetectorResponse: is_spam={is_spam}, confidence={confidence_score}, human_reviewed=True")
|
||||
print(
|
||||
f"[SpamDetector] Sending SpamDetectorResponse: is_spam={is_spam}, confidence={confidence_score}, human_reviewed=True"
|
||||
)
|
||||
await ctx.send_message(result)
|
||||
print(f"[SpamDetector] Message sent successfully")
|
||||
print("[SpamDetector] Message sent successfully")
|
||||
|
||||
|
||||
class SpamHandler(Executor):
|
||||
@@ -427,7 +407,9 @@ workflow = (
|
||||
spam_detector,
|
||||
[
|
||||
Case(condition=lambda x: isinstance(x, SpamDetectorResponse) and x.is_spam, target=spam_handler),
|
||||
Default(target=legitimate_message_handler), # Default handles non-spam and non-SpamDetectorResponse messages
|
||||
Default(
|
||||
target=legitimate_message_handler
|
||||
), # Default handles non-spam and non-SpamDetectorResponse messages
|
||||
],
|
||||
)
|
||||
.add_edge(spam_handler, final_processor)
|
||||
|
||||
Reference in New Issue
Block a user