mirror of
https://github.com/LifeArchiveProject/WeChatDataAnalysis.git
synced 2026-02-19 22:30:49 +08:00
feat(decrypt): 解密支持 SSE 实时进度
- 新增 /api/decrypt_stream(GET + SSE):扫描 db_storage,逐库解密并推送 start/progress/complete/error - 前端解密页优先使用 SSE 展示实时进度,不支持时回退到原 POST(无进度) - 增加流式接口单测:验证事件序列与输出落盘
This commit is contained in:
@@ -125,6 +125,40 @@
|
|||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- 解密进度 -->
|
||||||
|
<div v-if="loading || dbDecryptProgress.total > 0" class="mt-6">
|
||||||
|
<div class="flex items-center justify-between mb-2">
|
||||||
|
<div class="text-sm text-[#7F7F7F]">
|
||||||
|
{{ dbDecryptProgress.message || (loading ? '解密中...' : '') }}
|
||||||
|
</div>
|
||||||
|
<div v-if="dbDecryptProgress.total > 0" class="text-sm font-mono text-[#000000e6]">
|
||||||
|
{{ dbDecryptProgress.current }} / {{ dbDecryptProgress.total }}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="w-full bg-gray-200 rounded-full h-2 overflow-hidden">
|
||||||
|
<div
|
||||||
|
class="h-full bg-[#07C160] transition-all duration-300"
|
||||||
|
:style="{ width: dbProgressPercent + '%' }"
|
||||||
|
></div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div v-if="dbDecryptProgress.current_file" class="mt-2 text-xs text-[#7F7F7F] truncate font-mono">
|
||||||
|
{{ dbDecryptProgress.current_file }}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div v-if="dbDecryptProgress.total > 0" class="mt-3 grid grid-cols-2 gap-4 text-center">
|
||||||
|
<div class="bg-gray-50 rounded-lg p-3">
|
||||||
|
<div class="text-lg font-bold text-[#07C160]">{{ dbDecryptProgress.success_count }}</div>
|
||||||
|
<div class="text-xs text-[#7F7F7F]">成功</div>
|
||||||
|
</div>
|
||||||
|
<div class="bg-gray-50 rounded-lg p-3">
|
||||||
|
<div class="text-lg font-bold text-[#FA5151]">{{ dbDecryptProgress.fail_count }}</div>
|
||||||
|
<div class="text-xs text-[#7F7F7F]">失败</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
</form>
|
</form>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -413,7 +447,7 @@
|
|||||||
</style>
|
</style>
|
||||||
|
|
||||||
<script setup>
|
<script setup>
|
||||||
import { ref, reactive, computed, onMounted } from 'vue'
|
import { ref, reactive, computed, onMounted, onBeforeUnmount } from 'vue'
|
||||||
import { useApi } from '~/composables/useApi'
|
import { useApi } from '~/composables/useApi'
|
||||||
|
|
||||||
const { decryptDatabase, saveMediaKeys, getSavedKeys, getDbKey, getImageKey, getWxStatus } = useApi()
|
const { decryptDatabase, saveMediaKeys, getSavedKeys, getDbKey, getImageKey, getWxStatus } = useApi()
|
||||||
@@ -625,6 +659,22 @@ const clearManualKeys = () => {
|
|||||||
const mediaDecryptResult = ref(null)
|
const mediaDecryptResult = ref(null)
|
||||||
const mediaDecrypting = ref(false)
|
const mediaDecrypting = ref(false)
|
||||||
|
|
||||||
|
// 数据库解密进度(SSE)
|
||||||
|
const dbDecryptProgress = reactive({
|
||||||
|
current: 0,
|
||||||
|
total: 0,
|
||||||
|
success_count: 0,
|
||||||
|
fail_count: 0,
|
||||||
|
current_file: '',
|
||||||
|
status: '',
|
||||||
|
message: ''
|
||||||
|
})
|
||||||
|
|
||||||
|
const dbProgressPercent = computed(() => {
|
||||||
|
if (dbDecryptProgress.total === 0) return 0
|
||||||
|
return Math.round((dbDecryptProgress.current / dbDecryptProgress.total) * 100)
|
||||||
|
})
|
||||||
|
|
||||||
// 实时解密进度
|
// 实时解密进度
|
||||||
const decryptProgress = reactive({
|
const decryptProgress = reactive({
|
||||||
current: 0,
|
current: 0,
|
||||||
@@ -673,6 +723,27 @@ const validateForm = () => {
|
|||||||
return isValid
|
return isValid
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let dbDecryptEventSource = null
|
||||||
|
onBeforeUnmount(() => {
|
||||||
|
try {
|
||||||
|
if (dbDecryptEventSource) dbDecryptEventSource.close()
|
||||||
|
} catch (e) {
|
||||||
|
// ignore
|
||||||
|
} finally {
|
||||||
|
dbDecryptEventSource = null
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
const resetDbDecryptProgress = () => {
|
||||||
|
dbDecryptProgress.current = 0
|
||||||
|
dbDecryptProgress.total = 0
|
||||||
|
dbDecryptProgress.success_count = 0
|
||||||
|
dbDecryptProgress.fail_count = 0
|
||||||
|
dbDecryptProgress.current_file = ''
|
||||||
|
dbDecryptProgress.status = ''
|
||||||
|
dbDecryptProgress.message = ''
|
||||||
|
}
|
||||||
|
|
||||||
// 处理解密
|
// 处理解密
|
||||||
const handleDecrypt = async () => {
|
const handleDecrypt = async () => {
|
||||||
if (!validateForm()) {
|
if (!validateForm()) {
|
||||||
@@ -682,43 +753,142 @@ const handleDecrypt = async () => {
|
|||||||
loading.value = true
|
loading.value = true
|
||||||
error.value = ''
|
error.value = ''
|
||||||
warning.value = ''
|
warning.value = ''
|
||||||
|
|
||||||
|
resetDbDecryptProgress()
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const result = await decryptDatabase({
|
const canSse = process.client && typeof window !== 'undefined' && typeof EventSource !== 'undefined'
|
||||||
key: formData.key,
|
|
||||||
db_storage_path: formData.db_storage_path
|
// Fallback: 如果环境不支持 SSE,则使用普通 POST(无进度)。
|
||||||
})
|
if (!canSse) {
|
||||||
|
const result = await decryptDatabase({
|
||||||
if (result.status === 'completed') {
|
key: formData.key,
|
||||||
// 解密成功,保存结果并进入下一步
|
db_storage_path: formData.db_storage_path
|
||||||
decryptResult.value = result
|
})
|
||||||
if (process.client && typeof window !== 'undefined') {
|
|
||||||
sessionStorage.setItem('decryptResult', JSON.stringify(result))
|
if (result.status === 'completed') {
|
||||||
}
|
decryptResult.value = result
|
||||||
// 记录当前账号(用于图片解密/密钥保存)
|
if (process.client && typeof window !== 'undefined') {
|
||||||
try {
|
sessionStorage.setItem('decryptResult', JSON.stringify(result))
|
||||||
const accounts = Object.keys(result.account_results || {})
|
}
|
||||||
if (accounts.length > 0) mediaAccount.value = accounts[0]
|
try {
|
||||||
} catch (e) {
|
const accounts = Object.keys(result.account_results || {})
|
||||||
// ignore
|
if (accounts.length > 0) mediaAccount.value = accounts[0]
|
||||||
|
} catch (e) {}
|
||||||
|
|
||||||
|
clearManualKeys()
|
||||||
|
currentStep.value = 1
|
||||||
|
await prefillKeysForAccount(mediaAccount.value)
|
||||||
|
} else if (result.status === 'failed') {
|
||||||
|
if (result.failure_count > 0 && result.success_count === 0) {
|
||||||
|
error.value = result.message || '所有文件解密失败'
|
||||||
|
} else {
|
||||||
|
error.value = '部分文件解密失败,请检查密钥是否正确'
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
error.value = result.message || '解密失败,请检查输入信息'
|
||||||
}
|
}
|
||||||
|
|
||||||
// 进入图片密钥填写步骤
|
loading.value = false
|
||||||
clearManualKeys()
|
return
|
||||||
currentStep.value = 1
|
}
|
||||||
await prefillKeysForAccount(mediaAccount.value)
|
|
||||||
} else if (result.status === 'failed') {
|
// SSE: 解密过程实时推送进度
|
||||||
if (result.failure_count > 0 && result.success_count === 0) {
|
if (dbDecryptEventSource) {
|
||||||
error.value = result.message || '所有文件解密失败'
|
try {
|
||||||
} else {
|
dbDecryptEventSource.close()
|
||||||
error.value = '部分文件解密失败,请检查密钥是否正确'
|
} catch (e) {}
|
||||||
|
dbDecryptEventSource = null
|
||||||
|
}
|
||||||
|
|
||||||
|
const params = new URLSearchParams()
|
||||||
|
params.set('key', formData.key)
|
||||||
|
params.set('db_storage_path', formData.db_storage_path)
|
||||||
|
const url = `http://localhost:8000/api/decrypt_stream?${params.toString()}`
|
||||||
|
|
||||||
|
dbDecryptProgress.message = '连接中...'
|
||||||
|
const eventSource = new EventSource(url)
|
||||||
|
dbDecryptEventSource = eventSource
|
||||||
|
|
||||||
|
eventSource.onmessage = async (event) => {
|
||||||
|
try {
|
||||||
|
const data = JSON.parse(event.data)
|
||||||
|
|
||||||
|
if (data.type === 'scanning') {
|
||||||
|
dbDecryptProgress.message = data.message || '正在扫描数据库文件...'
|
||||||
|
} else if (data.type === 'start') {
|
||||||
|
dbDecryptProgress.total = data.total || 0
|
||||||
|
dbDecryptProgress.message = data.message || '开始解密...'
|
||||||
|
} else if (data.type === 'progress') {
|
||||||
|
dbDecryptProgress.current = data.current || 0
|
||||||
|
dbDecryptProgress.total = data.total || 0
|
||||||
|
dbDecryptProgress.success_count = data.success_count || 0
|
||||||
|
dbDecryptProgress.fail_count = data.fail_count || 0
|
||||||
|
dbDecryptProgress.current_file = data.current_file || ''
|
||||||
|
dbDecryptProgress.status = data.status || ''
|
||||||
|
dbDecryptProgress.message = data.message || ''
|
||||||
|
} else if (data.type === 'phase') {
|
||||||
|
// e.g. building cache
|
||||||
|
dbDecryptProgress.message = data.message || ''
|
||||||
|
} else if (data.type === 'complete') {
|
||||||
|
dbDecryptProgress.status = 'complete'
|
||||||
|
dbDecryptProgress.current = data.total_databases || dbDecryptProgress.total
|
||||||
|
dbDecryptProgress.total = data.total_databases || dbDecryptProgress.total
|
||||||
|
dbDecryptProgress.success_count = data.success_count || 0
|
||||||
|
dbDecryptProgress.fail_count = data.failure_count || 0
|
||||||
|
dbDecryptProgress.message = data.message || '解密完成'
|
||||||
|
|
||||||
|
decryptResult.value = data
|
||||||
|
if (process.client && typeof window !== 'undefined') {
|
||||||
|
sessionStorage.setItem('decryptResult', JSON.stringify(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const accounts = Object.keys(data.account_results || {})
|
||||||
|
if (accounts.length > 0) mediaAccount.value = accounts[0]
|
||||||
|
} catch (e) {}
|
||||||
|
|
||||||
|
try {
|
||||||
|
eventSource.close()
|
||||||
|
} catch (e) {}
|
||||||
|
dbDecryptEventSource = null
|
||||||
|
loading.value = false
|
||||||
|
|
||||||
|
if (data.status === 'completed') {
|
||||||
|
clearManualKeys()
|
||||||
|
currentStep.value = 1
|
||||||
|
await prefillKeysForAccount(mediaAccount.value)
|
||||||
|
} else if (data.status === 'failed') {
|
||||||
|
error.value = data.message || '所有文件解密失败'
|
||||||
|
} else {
|
||||||
|
error.value = data.message || '解密失败,请检查输入信息'
|
||||||
|
}
|
||||||
|
} else if (data.type === 'error') {
|
||||||
|
error.value = data.message || '解密失败,请检查输入信息'
|
||||||
|
try {
|
||||||
|
eventSource.close()
|
||||||
|
} catch (e) {}
|
||||||
|
dbDecryptEventSource = null
|
||||||
|
loading.value = false
|
||||||
|
}
|
||||||
|
} catch (e) {
|
||||||
|
console.error('解析SSE消息失败:', e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
eventSource.onerror = (e) => {
|
||||||
|
console.error('SSE连接错误:', e)
|
||||||
|
try {
|
||||||
|
eventSource.close()
|
||||||
|
} catch (err) {}
|
||||||
|
dbDecryptEventSource = null
|
||||||
|
if (loading.value) {
|
||||||
|
error.value = 'SSE连接中断,请重试'
|
||||||
|
loading.value = false
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
error.value = result.message || '解密失败,请检查输入信息'
|
|
||||||
}
|
}
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
error.value = err.message || '解密过程中发生错误'
|
error.value = err.message || '解密过程中发生错误'
|
||||||
} finally {
|
|
||||||
loading.value = false
|
loading.value = false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,20 @@
|
|||||||
from fastapi import APIRouter, HTTPException
|
from __future__ import annotations
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from starlette.responses import StreamingResponse
|
||||||
|
|
||||||
|
from ..app_paths import get_output_databases_dir
|
||||||
from ..logging_config import get_logger
|
from ..logging_config import get_logger
|
||||||
from ..path_fix import PathFixRoute
|
from ..path_fix import PathFixRoute
|
||||||
from ..key_store import upsert_account_keys_in_store
|
from ..key_store import upsert_account_keys_in_store
|
||||||
from ..wechat_decrypt import decrypt_wechat_databases
|
from ..wechat_decrypt import WeChatDatabaseDecryptor, decrypt_wechat_databases
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -72,3 +82,273 @@ async def decrypt_databases(request: DecryptRequest):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"解密API异常: {str(e)}")
|
logger.error(f"解密API异常: {str(e)}")
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/api/decrypt_stream", summary="解密微信数据库(SSE实时进度)")
|
||||||
|
async def decrypt_databases_stream(
|
||||||
|
request: Request,
|
||||||
|
key: str | None = None,
|
||||||
|
db_storage_path: str | None = None,
|
||||||
|
):
|
||||||
|
"""通过SSE实时推送数据库解密进度。
|
||||||
|
|
||||||
|
注意:EventSource 只支持 GET,因此参数通过 querystring 传递。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _sse(payload: dict) -> str:
|
||||||
|
return f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
async def generate_progress():
|
||||||
|
# 1) Basic validation (keep 200 + SSE error event, avoid 422 breaking EventSource).
|
||||||
|
k = str(key or "").strip()
|
||||||
|
p = str(db_storage_path or "").strip()
|
||||||
|
|
||||||
|
if not k or len(k) != 64:
|
||||||
|
yield _sse({"type": "error", "message": "密钥格式无效,必须是64位十六进制字符串"})
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
bytes.fromhex(k)
|
||||||
|
except Exception:
|
||||||
|
yield _sse({"type": "error", "message": "密钥必须是有效的十六进制字符串"})
|
||||||
|
return
|
||||||
|
|
||||||
|
if not p:
|
||||||
|
yield _sse({"type": "error", "message": "请提供 db_storage_path 参数"})
|
||||||
|
return
|
||||||
|
|
||||||
|
storage_path = Path(p)
|
||||||
|
if not storage_path.exists():
|
||||||
|
yield _sse({"type": "error", "message": f"指定的数据库路径不存在: {p}"})
|
||||||
|
return
|
||||||
|
|
||||||
|
# 2) Scan databases.
|
||||||
|
yield _sse({"type": "scanning", "message": "正在扫描数据库文件..."})
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
account_name = "unknown_account"
|
||||||
|
path_parts = storage_path.parts
|
||||||
|
account_patterns = ["wxid_"]
|
||||||
|
for part in path_parts:
|
||||||
|
for pattern in account_patterns:
|
||||||
|
if part.startswith(pattern):
|
||||||
|
parts = part.split("_")
|
||||||
|
if len(parts) >= 3:
|
||||||
|
account_name = "_".join(parts[:-1])
|
||||||
|
else:
|
||||||
|
account_name = part
|
||||||
|
break
|
||||||
|
if account_name != "unknown_account":
|
||||||
|
break
|
||||||
|
|
||||||
|
if account_name == "unknown_account":
|
||||||
|
for part in reversed(path_parts):
|
||||||
|
if part != "db_storage" and len(part) > 3:
|
||||||
|
account_name = part
|
||||||
|
break
|
||||||
|
|
||||||
|
databases: list[dict] = []
|
||||||
|
for root, _dirs, files in os.walk(storage_path):
|
||||||
|
if "db_storage" not in str(root):
|
||||||
|
continue
|
||||||
|
for file_name in files:
|
||||||
|
if not file_name.endswith(".db"):
|
||||||
|
continue
|
||||||
|
if file_name in ["key_info.db"]:
|
||||||
|
continue
|
||||||
|
db_path = os.path.join(root, file_name)
|
||||||
|
databases.append({"path": db_path, "name": file_name, "account": account_name})
|
||||||
|
|
||||||
|
if not databases:
|
||||||
|
yield _sse({"type": "error", "message": "未找到微信数据库文件!请检查 db_storage_path 是否正确"})
|
||||||
|
return
|
||||||
|
|
||||||
|
account_databases = {account_name: databases}
|
||||||
|
total_databases = sum(len(dbs) for dbs in account_databases.values())
|
||||||
|
|
||||||
|
yield _sse({"type": "start", "total": total_databases, "message": f"开始解密 {total_databases} 个数据库"})
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
# 3) Init output dir & decryptor.
|
||||||
|
base_output_dir = get_output_databases_dir()
|
||||||
|
base_output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
decryptor = WeChatDatabaseDecryptor(k)
|
||||||
|
except ValueError as e:
|
||||||
|
yield _sse({"type": "error", "message": f"密钥错误: {e}"})
|
||||||
|
return
|
||||||
|
|
||||||
|
# 4) Decrypt per account, stream progress.
|
||||||
|
success_count = 0
|
||||||
|
fail_count = 0
|
||||||
|
processed_files: list[str] = []
|
||||||
|
failed_files: list[str] = []
|
||||||
|
account_results: dict = {}
|
||||||
|
overall_current = 0
|
||||||
|
|
||||||
|
for account, dbs in account_databases.items():
|
||||||
|
account_output_dir = base_output_dir / account
|
||||||
|
account_output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Save a hint for later UI (same as non-stream endpoint).
|
||||||
|
try:
|
||||||
|
source_db_storage_path = p
|
||||||
|
wxid_dir = ""
|
||||||
|
if storage_path.name.lower() == "db_storage":
|
||||||
|
wxid_dir = str(storage_path.parent)
|
||||||
|
else:
|
||||||
|
wxid_dir = str(storage_path)
|
||||||
|
(account_output_dir / "_source.json").write_text(
|
||||||
|
json.dumps({"db_storage_path": source_db_storage_path, "wxid_dir": wxid_dir}, ensure_ascii=False, indent=2),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
account_success = 0
|
||||||
|
account_processed: list[str] = []
|
||||||
|
account_failed: list[str] = []
|
||||||
|
|
||||||
|
for db_info in dbs:
|
||||||
|
if await request.is_disconnected():
|
||||||
|
return
|
||||||
|
|
||||||
|
overall_current += 1
|
||||||
|
db_path = str(db_info.get("path") or "")
|
||||||
|
db_name = str(db_info.get("name") or "")
|
||||||
|
current_file = f"{account}/{db_name}" if account else db_name
|
||||||
|
|
||||||
|
# Emit a "processing" event so UI updates immediately for large db files.
|
||||||
|
yield _sse(
|
||||||
|
{
|
||||||
|
"type": "progress",
|
||||||
|
"current": overall_current,
|
||||||
|
"total": total_databases,
|
||||||
|
"success_count": success_count,
|
||||||
|
"fail_count": fail_count,
|
||||||
|
"current_file": current_file,
|
||||||
|
"status": "processing",
|
||||||
|
"message": "解密中...",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
output_path = account_output_dir / db_name
|
||||||
|
task = asyncio.create_task(asyncio.to_thread(decryptor.decrypt_database, db_path, str(output_path)))
|
||||||
|
|
||||||
|
# Wait with heartbeat (can't yield while awaiting the thread directly).
|
||||||
|
last_heartbeat = time.time()
|
||||||
|
while not task.done():
|
||||||
|
if await request.is_disconnected():
|
||||||
|
return
|
||||||
|
now = time.time()
|
||||||
|
if now - last_heartbeat > 15:
|
||||||
|
last_heartbeat = now
|
||||||
|
# SSE comment heartbeat; browsers ignore but keeps proxies alive.
|
||||||
|
yield ": ping\n\n"
|
||||||
|
await asyncio.sleep(0.6)
|
||||||
|
try:
|
||||||
|
ok = bool(task.result())
|
||||||
|
except Exception:
|
||||||
|
ok = False
|
||||||
|
|
||||||
|
if ok:
|
||||||
|
account_success += 1
|
||||||
|
success_count += 1
|
||||||
|
account_processed.append(str(output_path))
|
||||||
|
processed_files.append(str(output_path))
|
||||||
|
status = "success"
|
||||||
|
msg = "解密成功"
|
||||||
|
else:
|
||||||
|
account_failed.append(db_path)
|
||||||
|
failed_files.append(db_path)
|
||||||
|
fail_count += 1
|
||||||
|
status = "fail"
|
||||||
|
msg = "解密失败"
|
||||||
|
|
||||||
|
yield _sse(
|
||||||
|
{
|
||||||
|
"type": "progress",
|
||||||
|
"current": overall_current,
|
||||||
|
"total": total_databases,
|
||||||
|
"success_count": success_count,
|
||||||
|
"fail_count": fail_count,
|
||||||
|
"current_file": current_file,
|
||||||
|
"status": status,
|
||||||
|
"message": msg,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if overall_current % 5 == 0:
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
account_results[account] = {
|
||||||
|
"total": len(dbs),
|
||||||
|
"success": account_success,
|
||||||
|
"failed": len(dbs) - account_success,
|
||||||
|
"output_dir": str(account_output_dir),
|
||||||
|
"processed_files": account_processed,
|
||||||
|
"failed_files": account_failed,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Build cache table (keep behavior consistent with the POST endpoint).
|
||||||
|
if os.environ.get("WECHAT_TOOL_BUILD_SESSION_LAST_MESSAGE", "1") != "0":
|
||||||
|
yield _sse(
|
||||||
|
{
|
||||||
|
"type": "phase",
|
||||||
|
"phase": "session_last_message",
|
||||||
|
"account": account,
|
||||||
|
"message": "正在构建会话缓存(最后一条消息)...",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from ..session_last_message import build_session_last_message_table
|
||||||
|
|
||||||
|
task = asyncio.create_task(
|
||||||
|
asyncio.to_thread(
|
||||||
|
build_session_last_message_table,
|
||||||
|
account_output_dir,
|
||||||
|
rebuild=True,
|
||||||
|
include_hidden=True,
|
||||||
|
include_official=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
last_heartbeat = time.time()
|
||||||
|
while not task.done():
|
||||||
|
if await request.is_disconnected():
|
||||||
|
return
|
||||||
|
now = time.time()
|
||||||
|
if now - last_heartbeat > 15:
|
||||||
|
last_heartbeat = now
|
||||||
|
yield ": ping\n\n"
|
||||||
|
await asyncio.sleep(0.6)
|
||||||
|
account_results[account]["session_last_message"] = task.result()
|
||||||
|
except Exception as e:
|
||||||
|
account_results[account]["session_last_message"] = {"status": "error", "message": str(e)}
|
||||||
|
|
||||||
|
status = "completed" if success_count > 0 else "failed"
|
||||||
|
result = {
|
||||||
|
"status": status,
|
||||||
|
"total_databases": total_databases,
|
||||||
|
"success_count": success_count,
|
||||||
|
"failure_count": total_databases - success_count,
|
||||||
|
"output_directory": str(base_output_dir.absolute()),
|
||||||
|
"message": f"解密完成: 成功 {success_count}/{total_databases}",
|
||||||
|
"processed_files": processed_files,
|
||||||
|
"failed_files": failed_files,
|
||||||
|
"account_results": account_results,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Save db key for frontend autofill.
|
||||||
|
try:
|
||||||
|
for account in (account_results or {}).keys():
|
||||||
|
upsert_account_keys_in_store(str(account), db_key=k)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
yield _sse({"type": "complete", **result})
|
||||||
|
|
||||||
|
headers = {"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"}
|
||||||
|
return StreamingResponse(generate_progress(), media_type="text/event-stream", headers=headers)
|
||||||
|
|||||||
91
tests/test_decrypt_stream_sse.py
Normal file
91
tests/test_decrypt_stream_sse.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
import importlib
|
||||||
|
from pathlib import Path
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
|
|
||||||
|
|
||||||
|
ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
sys.path.insert(0, str(ROOT / "src"))
|
||||||
|
|
||||||
|
|
||||||
|
class TestDecryptStreamSSE(unittest.TestCase):
|
||||||
|
def test_decrypt_stream_reports_progress(self):
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from wechat_decrypt_tool.wechat_decrypt import SQLITE_HEADER
|
||||||
|
|
||||||
|
with TemporaryDirectory() as td:
|
||||||
|
root = Path(td)
|
||||||
|
|
||||||
|
prev_data_dir = os.environ.get("WECHAT_TOOL_DATA_DIR")
|
||||||
|
prev_build_cache = os.environ.get("WECHAT_TOOL_BUILD_SESSION_LAST_MESSAGE")
|
||||||
|
try:
|
||||||
|
os.environ["WECHAT_TOOL_DATA_DIR"] = str(root)
|
||||||
|
os.environ["WECHAT_TOOL_BUILD_SESSION_LAST_MESSAGE"] = "0"
|
||||||
|
|
||||||
|
import wechat_decrypt_tool.app_paths as app_paths
|
||||||
|
import wechat_decrypt_tool.routers.decrypt as decrypt_router
|
||||||
|
|
||||||
|
importlib.reload(app_paths)
|
||||||
|
importlib.reload(decrypt_router)
|
||||||
|
|
||||||
|
db_storage = root / "xwechat_files" / "wxid_foo_bar" / "db_storage"
|
||||||
|
db_storage.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Fake a decrypted sqlite db (>= 4096 bytes) so decryptor falls back to copy.
|
||||||
|
(db_storage / "MSG0.db").write_bytes(SQLITE_HEADER + b"\x00" * (4096 - len(SQLITE_HEADER)))
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
app.include_router(decrypt_router.router)
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
events: list[dict] = []
|
||||||
|
with client.stream(
|
||||||
|
"GET",
|
||||||
|
"/api/decrypt_stream",
|
||||||
|
params={"key": "00" * 32, "db_storage_path": str(db_storage)},
|
||||||
|
) as resp:
|
||||||
|
self.assertEqual(resp.status_code, 200)
|
||||||
|
self.assertIn("text/event-stream", resp.headers.get("content-type", ""))
|
||||||
|
|
||||||
|
for line in resp.iter_lines():
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
if isinstance(line, bytes):
|
||||||
|
line = line.decode("utf-8", errors="ignore")
|
||||||
|
line = str(line)
|
||||||
|
|
||||||
|
if line.startswith(":"):
|
||||||
|
continue
|
||||||
|
if not line.startswith("data: "):
|
||||||
|
continue
|
||||||
|
payload = json.loads(line[len("data: ") :])
|
||||||
|
events.append(payload)
|
||||||
|
if payload.get("type") in {"complete", "error"}:
|
||||||
|
break
|
||||||
|
|
||||||
|
types = {e.get("type") for e in events}
|
||||||
|
self.assertIn("start", types)
|
||||||
|
self.assertIn("progress", types)
|
||||||
|
self.assertEqual(events[-1].get("type"), "complete")
|
||||||
|
|
||||||
|
out = root / "output" / "databases" / "wxid_foo" / "MSG0.db"
|
||||||
|
self.assertTrue(out.exists())
|
||||||
|
finally:
|
||||||
|
if prev_data_dir is None:
|
||||||
|
os.environ.pop("WECHAT_TOOL_DATA_DIR", None)
|
||||||
|
else:
|
||||||
|
os.environ["WECHAT_TOOL_DATA_DIR"] = prev_data_dir
|
||||||
|
if prev_build_cache is None:
|
||||||
|
os.environ.pop("WECHAT_TOOL_BUILD_SESSION_LAST_MESSAGE", None)
|
||||||
|
else:
|
||||||
|
os.environ["WECHAT_TOOL_BUILD_SESSION_LAST_MESSAGE"] = prev_build_cache
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
|
|
||||||
Reference in New Issue
Block a user