Python: Add checkpoint save and restore hooks to executor (#2097)

* Add checkpoint hooks

* Deprecate get_executor_state and set_executor_state

* Fix tests and samples

* Add doc strings

* Add sample

* Fix import

* Address comments and fix tests

* Address comments

* conditional import
This commit is contained in:
Tao Chen
2025-11-17 10:19:01 -08:00
committed by GitHub
Unverified
parent 132597957a
commit c361ad8d33
22 changed files with 508 additions and 723 deletions
@@ -17,7 +17,7 @@ Workflow Steps:
import asyncio
import logging
from dataclasses import dataclass
from typing import Literal, Annotated
from typing import Literal
from agent_framework import (
Case,
@@ -31,9 +31,11 @@ from agent_framework import (
from pydantic import BaseModel, Field
from typing_extensions import Never
# Define response model with clear user guidance
class SpamDecision(BaseModel):
"""User's decision on whether the email is spam."""
decision: Literal["spam", "not spam"] = Field(
description="Enter 'spam' to mark as spam, or 'not spam' to mark as legitimate"
)
@@ -71,10 +73,11 @@ class SpamDetectorResponse:
class SpamApprovalRequest:
"""Human-in-the-loop approval request for spam classification."""
email_message: str = ""
detected_as_spam: bool = False
confidence: float = 0.0
reasons: str = ""
email_message: str
detected_as_spam: bool
confidence: float
reasons: list[str]
full_email_content: EmailContent
@dataclass
@@ -128,8 +131,6 @@ class EmailPreprocessor(Executor):
await ctx.send_message(result)
class SpamDetector(Executor):
"""Step 2: An executor that analyzes content and determines if a message is spam."""
@@ -139,7 +140,9 @@ class SpamDetector(Executor):
self._spam_keywords = spam_keywords
@handler
async def handle_email_content(self, email_content: EmailContent, ctx: WorkflowContext[SpamApprovalRequest]) -> None:
async def handle_email_content(
self, email_content: EmailContent, ctx: WorkflowContext[SpamApprovalRequest]
) -> None:
"""Analyze email content and determine if the message is spam, then request human approval."""
await asyncio.sleep(2.0) # Simulate analysis and detection time
@@ -186,25 +189,13 @@ class SpamDetector(Executor):
is_spam = spam_score >= 0.5
# Store detection result in executor state for later use
# Store minimal data needed (not complex objects that don't serialize well)
await ctx.set_executor_state({
"original_message": email_content.original_message,
"cleaned_message": email_content.cleaned_message,
"word_count": email_content.word_count,
"has_suspicious_patterns": email_content.has_suspicious_patterns,
"is_spam": is_spam,
"ai_original_classification": is_spam, # Store original AI decision
"confidence_score": spam_score,
"spam_reasons": spam_reasons
})
# Request human approval before proceeding using new API
approval_request = SpamApprovalRequest(
email_message=email_text[:200], # First 200 chars
detected_as_spam=is_spam,
confidence=spam_score,
reasons=", ".join(spam_reasons) if spam_reasons else "no specific reasons"
reasons=spam_reasons,
full_email_content=email_content,
)
await ctx.request_info(
@@ -214,20 +205,15 @@ class SpamDetector(Executor):
@response_handler
async def handle_human_response(
self,
original_request: SpamApprovalRequest,
response: SpamDecision,
ctx: WorkflowContext[SpamDetectorResponse]
self, original_request: SpamApprovalRequest, response: SpamDecision, ctx: WorkflowContext[SpamDetectorResponse]
) -> None:
"""Process human approval response and continue workflow."""
print(f"[SpamDetector] handle_human_response called with response: {response}")
# Get stored detection result
state = await ctx.get_executor_state() or {}
print(f"[SpamDetector] Retrieved state: {state}")
ai_original = state.get("ai_original_classification", False)
confidence_score = state.get("confidence_score", 0.0)
spam_reasons = state.get("spam_reasons", [])
ai_original = original_request.detected_as_spam
confidence_score = original_request.confidence
spam_reasons = original_request.reasons
# Parse human decision from the response model
human_decision = response.decision.strip().lower()
@@ -241,27 +227,21 @@ class SpamDetector(Executor):
# Default to AI decision if unclear
is_spam = ai_original
# Reconstruct EmailContent from stored primitives
email_content = EmailContent(
original_message=state.get("original_message", ""),
cleaned_message=state.get("cleaned_message", ""),
word_count=state.get("word_count", 0),
has_suspicious_patterns=state.get("has_suspicious_patterns", False)
)
result = SpamDetectorResponse(
email_content=email_content,
email_content=original_request.full_email_content,
is_spam=is_spam,
confidence_score=confidence_score,
spam_reasons=spam_reasons,
human_reviewed=True,
human_decision=response.decision,
ai_original_classification=ai_original
ai_original_classification=ai_original,
)
print(f"[SpamDetector] Sending SpamDetectorResponse: is_spam={is_spam}, confidence={confidence_score}, human_reviewed=True")
print(
f"[SpamDetector] Sending SpamDetectorResponse: is_spam={is_spam}, confidence={confidence_score}, human_reviewed=True"
)
await ctx.send_message(result)
print(f"[SpamDetector] Message sent successfully")
print("[SpamDetector] Message sent successfully")
class SpamHandler(Executor):
@@ -427,7 +407,9 @@ workflow = (
spam_detector,
[
Case(condition=lambda x: isinstance(x, SpamDetectorResponse) and x.is_spam, target=spam_handler),
Default(target=legitimate_message_handler), # Default handles non-spam and non-SpamDetectorResponse messages
Default(
target=legitimate_message_handler
), # Default handles non-spam and non-SpamDetectorResponse messages
],
)
.add_edge(spam_handler, final_processor)