feat(decrypt): 解密支持 SSE 实时进度

- 新增 /api/decrypt_stream(GET + SSE):扫描 db_storage,逐库解密并推送 start/progress/complete/error

- 前端解密页优先使用 SSE 展示实时进度,不支持时回退到原 POST(无进度)

- 增加流式接口单测:验证事件序列与输出落盘
This commit is contained in:
2977094657
2026-02-18 16:54:25 +08:00
parent a14f8de6d0
commit 5d9fcede2f
3 changed files with 574 additions and 33 deletions

View File

@@ -125,6 +125,40 @@
</button> </button>
</div> </div>
</div> </div>
<!-- 解密进度 -->
<div v-if="loading || dbDecryptProgress.total > 0" class="mt-6">
<div class="flex items-center justify-between mb-2">
<div class="text-sm text-[#7F7F7F]">
{{ dbDecryptProgress.message || (loading ? '解密中...' : '') }}
</div>
<div v-if="dbDecryptProgress.total > 0" class="text-sm font-mono text-[#000000e6]">
{{ dbDecryptProgress.current }} / {{ dbDecryptProgress.total }}
</div>
</div>
<div class="w-full bg-gray-200 rounded-full h-2 overflow-hidden">
<div
class="h-full bg-[#07C160] transition-all duration-300"
:style="{ width: dbProgressPercent + '%' }"
></div>
</div>
<div v-if="dbDecryptProgress.current_file" class="mt-2 text-xs text-[#7F7F7F] truncate font-mono">
{{ dbDecryptProgress.current_file }}
</div>
<div v-if="dbDecryptProgress.total > 0" class="mt-3 grid grid-cols-2 gap-4 text-center">
<div class="bg-gray-50 rounded-lg p-3">
<div class="text-lg font-bold text-[#07C160]">{{ dbDecryptProgress.success_count }}</div>
<div class="text-xs text-[#7F7F7F]">成功</div>
</div>
<div class="bg-gray-50 rounded-lg p-3">
<div class="text-lg font-bold text-[#FA5151]">{{ dbDecryptProgress.fail_count }}</div>
<div class="text-xs text-[#7F7F7F]">失败</div>
</div>
</div>
</div>
</form> </form>
</div> </div>
</div> </div>
@@ -413,7 +447,7 @@
</style> </style>
<script setup> <script setup>
import { ref, reactive, computed, onMounted } from 'vue' import { ref, reactive, computed, onMounted, onBeforeUnmount } from 'vue'
import { useApi } from '~/composables/useApi' import { useApi } from '~/composables/useApi'
const { decryptDatabase, saveMediaKeys, getSavedKeys, getDbKey, getImageKey, getWxStatus } = useApi() const { decryptDatabase, saveMediaKeys, getSavedKeys, getDbKey, getImageKey, getWxStatus } = useApi()
@@ -625,6 +659,22 @@ const clearManualKeys = () => {
const mediaDecryptResult = ref(null) const mediaDecryptResult = ref(null)
const mediaDecrypting = ref(false) const mediaDecrypting = ref(false)
// 数据库解密进度SSE
const dbDecryptProgress = reactive({
current: 0,
total: 0,
success_count: 0,
fail_count: 0,
current_file: '',
status: '',
message: ''
})
const dbProgressPercent = computed(() => {
if (dbDecryptProgress.total === 0) return 0
return Math.round((dbDecryptProgress.current / dbDecryptProgress.total) * 100)
})
// 实时解密进度 // 实时解密进度
const decryptProgress = reactive({ const decryptProgress = reactive({
current: 0, current: 0,
@@ -673,6 +723,27 @@ const validateForm = () => {
return isValid return isValid
} }
let dbDecryptEventSource = null
onBeforeUnmount(() => {
try {
if (dbDecryptEventSource) dbDecryptEventSource.close()
} catch (e) {
// ignore
} finally {
dbDecryptEventSource = null
}
})
const resetDbDecryptProgress = () => {
dbDecryptProgress.current = 0
dbDecryptProgress.total = 0
dbDecryptProgress.success_count = 0
dbDecryptProgress.fail_count = 0
dbDecryptProgress.current_file = ''
dbDecryptProgress.status = ''
dbDecryptProgress.message = ''
}
// 处理解密 // 处理解密
const handleDecrypt = async () => { const handleDecrypt = async () => {
if (!validateForm()) { if (!validateForm()) {
@@ -682,43 +753,142 @@ const handleDecrypt = async () => {
loading.value = true loading.value = true
error.value = '' error.value = ''
warning.value = '' warning.value = ''
resetDbDecryptProgress()
try { try {
const result = await decryptDatabase({ const canSse = process.client && typeof window !== 'undefined' && typeof EventSource !== 'undefined'
key: formData.key,
db_storage_path: formData.db_storage_path // Fallback: 如果环境不支持 SSE则使用普通 POST无进度
}) if (!canSse) {
const result = await decryptDatabase({
if (result.status === 'completed') { key: formData.key,
// 解密成功,保存结果并进入下一步 db_storage_path: formData.db_storage_path
decryptResult.value = result })
if (process.client && typeof window !== 'undefined') {
sessionStorage.setItem('decryptResult', JSON.stringify(result)) if (result.status === 'completed') {
} decryptResult.value = result
// 记录当前账号(用于图片解密/密钥保存) if (process.client && typeof window !== 'undefined') {
try { sessionStorage.setItem('decryptResult', JSON.stringify(result))
const accounts = Object.keys(result.account_results || {}) }
if (accounts.length > 0) mediaAccount.value = accounts[0] try {
} catch (e) { const accounts = Object.keys(result.account_results || {})
// ignore if (accounts.length > 0) mediaAccount.value = accounts[0]
} catch (e) {}
clearManualKeys()
currentStep.value = 1
await prefillKeysForAccount(mediaAccount.value)
} else if (result.status === 'failed') {
if (result.failure_count > 0 && result.success_count === 0) {
error.value = result.message || '所有文件解密失败'
} else {
error.value = '部分文件解密失败,请检查密钥是否正确'
}
} else {
error.value = result.message || '解密失败,请检查输入信息'
} }
// 进入图片密钥填写步骤 loading.value = false
clearManualKeys() return
currentStep.value = 1 }
await prefillKeysForAccount(mediaAccount.value)
} else if (result.status === 'failed') { // SSE: 解密过程实时推送进度
if (result.failure_count > 0 && result.success_count === 0) { if (dbDecryptEventSource) {
error.value = result.message || '所有文件解密失败' try {
} else { dbDecryptEventSource.close()
error.value = '部分文件解密失败,请检查密钥是否正确' } catch (e) {}
dbDecryptEventSource = null
}
const params = new URLSearchParams()
params.set('key', formData.key)
params.set('db_storage_path', formData.db_storage_path)
const url = `http://localhost:8000/api/decrypt_stream?${params.toString()}`
dbDecryptProgress.message = '连接中...'
const eventSource = new EventSource(url)
dbDecryptEventSource = eventSource
eventSource.onmessage = async (event) => {
try {
const data = JSON.parse(event.data)
if (data.type === 'scanning') {
dbDecryptProgress.message = data.message || '正在扫描数据库文件...'
} else if (data.type === 'start') {
dbDecryptProgress.total = data.total || 0
dbDecryptProgress.message = data.message || '开始解密...'
} else if (data.type === 'progress') {
dbDecryptProgress.current = data.current || 0
dbDecryptProgress.total = data.total || 0
dbDecryptProgress.success_count = data.success_count || 0
dbDecryptProgress.fail_count = data.fail_count || 0
dbDecryptProgress.current_file = data.current_file || ''
dbDecryptProgress.status = data.status || ''
dbDecryptProgress.message = data.message || ''
} else if (data.type === 'phase') {
// e.g. building cache
dbDecryptProgress.message = data.message || ''
} else if (data.type === 'complete') {
dbDecryptProgress.status = 'complete'
dbDecryptProgress.current = data.total_databases || dbDecryptProgress.total
dbDecryptProgress.total = data.total_databases || dbDecryptProgress.total
dbDecryptProgress.success_count = data.success_count || 0
dbDecryptProgress.fail_count = data.failure_count || 0
dbDecryptProgress.message = data.message || '解密完成'
decryptResult.value = data
if (process.client && typeof window !== 'undefined') {
sessionStorage.setItem('decryptResult', JSON.stringify(data))
}
try {
const accounts = Object.keys(data.account_results || {})
if (accounts.length > 0) mediaAccount.value = accounts[0]
} catch (e) {}
try {
eventSource.close()
} catch (e) {}
dbDecryptEventSource = null
loading.value = false
if (data.status === 'completed') {
clearManualKeys()
currentStep.value = 1
await prefillKeysForAccount(mediaAccount.value)
} else if (data.status === 'failed') {
error.value = data.message || '所有文件解密失败'
} else {
error.value = data.message || '解密失败,请检查输入信息'
}
} else if (data.type === 'error') {
error.value = data.message || '解密失败,请检查输入信息'
try {
eventSource.close()
} catch (e) {}
dbDecryptEventSource = null
loading.value = false
}
} catch (e) {
console.error('解析SSE消息失败:', e)
}
}
eventSource.onerror = (e) => {
console.error('SSE连接错误:', e)
try {
eventSource.close()
} catch (err) {}
dbDecryptEventSource = null
if (loading.value) {
error.value = 'SSE连接中断请重试'
loading.value = false
} }
} else {
error.value = result.message || '解密失败,请检查输入信息'
} }
} catch (err) { } catch (err) {
error.value = err.message || '解密过程中发生错误' error.value = err.message || '解密过程中发生错误'
} finally {
loading.value = false loading.value = false
} }
} }

View File

@@ -1,10 +1,20 @@
from fastapi import APIRouter, HTTPException from __future__ import annotations
from pydantic import BaseModel, Field
import asyncio
import json
import os
import time
from pathlib import Path
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel, Field
from starlette.responses import StreamingResponse
from ..app_paths import get_output_databases_dir
from ..logging_config import get_logger from ..logging_config import get_logger
from ..path_fix import PathFixRoute from ..path_fix import PathFixRoute
from ..key_store import upsert_account_keys_in_store from ..key_store import upsert_account_keys_in_store
from ..wechat_decrypt import decrypt_wechat_databases from ..wechat_decrypt import WeChatDatabaseDecryptor, decrypt_wechat_databases
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -72,3 +82,273 @@ async def decrypt_databases(request: DecryptRequest):
except Exception as e: except Exception as e:
logger.error(f"解密API异常: {str(e)}") logger.error(f"解密API异常: {str(e)}")
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.get("/api/decrypt_stream", summary="解密微信数据库SSE实时进度")
async def decrypt_databases_stream(
request: Request,
key: str | None = None,
db_storage_path: str | None = None,
):
"""通过SSE实时推送数据库解密进度。
注意EventSource 只支持 GET因此参数通过 querystring 传递。
"""
def _sse(payload: dict) -> str:
return f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
async def generate_progress():
# 1) Basic validation (keep 200 + SSE error event, avoid 422 breaking EventSource).
k = str(key or "").strip()
p = str(db_storage_path or "").strip()
if not k or len(k) != 64:
yield _sse({"type": "error", "message": "密钥格式无效必须是64位十六进制字符串"})
return
try:
bytes.fromhex(k)
except Exception:
yield _sse({"type": "error", "message": "密钥必须是有效的十六进制字符串"})
return
if not p:
yield _sse({"type": "error", "message": "请提供 db_storage_path 参数"})
return
storage_path = Path(p)
if not storage_path.exists():
yield _sse({"type": "error", "message": f"指定的数据库路径不存在: {p}"})
return
# 2) Scan databases.
yield _sse({"type": "scanning", "message": "正在扫描数据库文件..."})
await asyncio.sleep(0)
account_name = "unknown_account"
path_parts = storage_path.parts
account_patterns = ["wxid_"]
for part in path_parts:
for pattern in account_patterns:
if part.startswith(pattern):
parts = part.split("_")
if len(parts) >= 3:
account_name = "_".join(parts[:-1])
else:
account_name = part
break
if account_name != "unknown_account":
break
if account_name == "unknown_account":
for part in reversed(path_parts):
if part != "db_storage" and len(part) > 3:
account_name = part
break
databases: list[dict] = []
for root, _dirs, files in os.walk(storage_path):
if "db_storage" not in str(root):
continue
for file_name in files:
if not file_name.endswith(".db"):
continue
if file_name in ["key_info.db"]:
continue
db_path = os.path.join(root, file_name)
databases.append({"path": db_path, "name": file_name, "account": account_name})
if not databases:
yield _sse({"type": "error", "message": "未找到微信数据库文件!请检查 db_storage_path 是否正确"})
return
account_databases = {account_name: databases}
total_databases = sum(len(dbs) for dbs in account_databases.values())
yield _sse({"type": "start", "total": total_databases, "message": f"开始解密 {total_databases} 个数据库"})
await asyncio.sleep(0)
# 3) Init output dir & decryptor.
base_output_dir = get_output_databases_dir()
base_output_dir.mkdir(parents=True, exist_ok=True)
try:
decryptor = WeChatDatabaseDecryptor(k)
except ValueError as e:
yield _sse({"type": "error", "message": f"密钥错误: {e}"})
return
# 4) Decrypt per account, stream progress.
success_count = 0
fail_count = 0
processed_files: list[str] = []
failed_files: list[str] = []
account_results: dict = {}
overall_current = 0
for account, dbs in account_databases.items():
account_output_dir = base_output_dir / account
account_output_dir.mkdir(parents=True, exist_ok=True)
# Save a hint for later UI (same as non-stream endpoint).
try:
source_db_storage_path = p
wxid_dir = ""
if storage_path.name.lower() == "db_storage":
wxid_dir = str(storage_path.parent)
else:
wxid_dir = str(storage_path)
(account_output_dir / "_source.json").write_text(
json.dumps({"db_storage_path": source_db_storage_path, "wxid_dir": wxid_dir}, ensure_ascii=False, indent=2),
encoding="utf-8",
)
except Exception:
pass
account_success = 0
account_processed: list[str] = []
account_failed: list[str] = []
for db_info in dbs:
if await request.is_disconnected():
return
overall_current += 1
db_path = str(db_info.get("path") or "")
db_name = str(db_info.get("name") or "")
current_file = f"{account}/{db_name}" if account else db_name
# Emit a "processing" event so UI updates immediately for large db files.
yield _sse(
{
"type": "progress",
"current": overall_current,
"total": total_databases,
"success_count": success_count,
"fail_count": fail_count,
"current_file": current_file,
"status": "processing",
"message": "解密中...",
}
)
output_path = account_output_dir / db_name
task = asyncio.create_task(asyncio.to_thread(decryptor.decrypt_database, db_path, str(output_path)))
# Wait with heartbeat (can't yield while awaiting the thread directly).
last_heartbeat = time.time()
while not task.done():
if await request.is_disconnected():
return
now = time.time()
if now - last_heartbeat > 15:
last_heartbeat = now
# SSE comment heartbeat; browsers ignore but keeps proxies alive.
yield ": ping\n\n"
await asyncio.sleep(0.6)
try:
ok = bool(task.result())
except Exception:
ok = False
if ok:
account_success += 1
success_count += 1
account_processed.append(str(output_path))
processed_files.append(str(output_path))
status = "success"
msg = "解密成功"
else:
account_failed.append(db_path)
failed_files.append(db_path)
fail_count += 1
status = "fail"
msg = "解密失败"
yield _sse(
{
"type": "progress",
"current": overall_current,
"total": total_databases,
"success_count": success_count,
"fail_count": fail_count,
"current_file": current_file,
"status": status,
"message": msg,
}
)
if overall_current % 5 == 0:
await asyncio.sleep(0)
account_results[account] = {
"total": len(dbs),
"success": account_success,
"failed": len(dbs) - account_success,
"output_dir": str(account_output_dir),
"processed_files": account_processed,
"failed_files": account_failed,
}
# Build cache table (keep behavior consistent with the POST endpoint).
if os.environ.get("WECHAT_TOOL_BUILD_SESSION_LAST_MESSAGE", "1") != "0":
yield _sse(
{
"type": "phase",
"phase": "session_last_message",
"account": account,
"message": "正在构建会话缓存(最后一条消息)...",
}
)
await asyncio.sleep(0)
try:
from ..session_last_message import build_session_last_message_table
task = asyncio.create_task(
asyncio.to_thread(
build_session_last_message_table,
account_output_dir,
rebuild=True,
include_hidden=True,
include_official=True,
)
)
last_heartbeat = time.time()
while not task.done():
if await request.is_disconnected():
return
now = time.time()
if now - last_heartbeat > 15:
last_heartbeat = now
yield ": ping\n\n"
await asyncio.sleep(0.6)
account_results[account]["session_last_message"] = task.result()
except Exception as e:
account_results[account]["session_last_message"] = {"status": "error", "message": str(e)}
status = "completed" if success_count > 0 else "failed"
result = {
"status": status,
"total_databases": total_databases,
"success_count": success_count,
"failure_count": total_databases - success_count,
"output_directory": str(base_output_dir.absolute()),
"message": f"解密完成: 成功 {success_count}/{total_databases}",
"processed_files": processed_files,
"failed_files": failed_files,
"account_results": account_results,
}
# Save db key for frontend autofill.
try:
for account in (account_results or {}).keys():
upsert_account_keys_in_store(str(account), db_key=k)
except Exception:
pass
yield _sse({"type": "complete", **result})
headers = {"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"}
return StreamingResponse(generate_progress(), media_type="text/event-stream", headers=headers)

View File

@@ -0,0 +1,91 @@
import json
import os
import sys
import unittest
import importlib
from pathlib import Path
from tempfile import TemporaryDirectory
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT / "src"))
class TestDecryptStreamSSE(unittest.TestCase):
def test_decrypt_stream_reports_progress(self):
from fastapi import FastAPI
from fastapi.testclient import TestClient
from wechat_decrypt_tool.wechat_decrypt import SQLITE_HEADER
with TemporaryDirectory() as td:
root = Path(td)
prev_data_dir = os.environ.get("WECHAT_TOOL_DATA_DIR")
prev_build_cache = os.environ.get("WECHAT_TOOL_BUILD_SESSION_LAST_MESSAGE")
try:
os.environ["WECHAT_TOOL_DATA_DIR"] = str(root)
os.environ["WECHAT_TOOL_BUILD_SESSION_LAST_MESSAGE"] = "0"
import wechat_decrypt_tool.app_paths as app_paths
import wechat_decrypt_tool.routers.decrypt as decrypt_router
importlib.reload(app_paths)
importlib.reload(decrypt_router)
db_storage = root / "xwechat_files" / "wxid_foo_bar" / "db_storage"
db_storage.mkdir(parents=True, exist_ok=True)
# Fake a decrypted sqlite db (>= 4096 bytes) so decryptor falls back to copy.
(db_storage / "MSG0.db").write_bytes(SQLITE_HEADER + b"\x00" * (4096 - len(SQLITE_HEADER)))
app = FastAPI()
app.include_router(decrypt_router.router)
client = TestClient(app)
events: list[dict] = []
with client.stream(
"GET",
"/api/decrypt_stream",
params={"key": "00" * 32, "db_storage_path": str(db_storage)},
) as resp:
self.assertEqual(resp.status_code, 200)
self.assertIn("text/event-stream", resp.headers.get("content-type", ""))
for line in resp.iter_lines():
if not line:
continue
if isinstance(line, bytes):
line = line.decode("utf-8", errors="ignore")
line = str(line)
if line.startswith(":"):
continue
if not line.startswith("data: "):
continue
payload = json.loads(line[len("data: ") :])
events.append(payload)
if payload.get("type") in {"complete", "error"}:
break
types = {e.get("type") for e in events}
self.assertIn("start", types)
self.assertIn("progress", types)
self.assertEqual(events[-1].get("type"), "complete")
out = root / "output" / "databases" / "wxid_foo" / "MSG0.db"
self.assertTrue(out.exists())
finally:
if prev_data_dir is None:
os.environ.pop("WECHAT_TOOL_DATA_DIR", None)
else:
os.environ["WECHAT_TOOL_DATA_DIR"] = prev_data_dir
if prev_build_cache is None:
os.environ.pop("WECHAT_TOOL_BUILD_SESSION_LAST_MESSAGE", None)
else:
os.environ["WECHAT_TOOL_BUILD_SESSION_LAST_MESSAGE"] = prev_build_cache
if __name__ == "__main__":
unittest.main()