mirror of
https://github.com/LifeArchiveProject/WeChatDataAnalysis.git
synced 2026-02-19 06:10:52 +08:00
feat(decrypt): 解密支持 SSE 实时进度
- 新增 /api/decrypt_stream(GET + SSE):扫描 db_storage,逐库解密并推送 start/progress/complete/error - 前端解密页优先使用 SSE 展示实时进度,不支持时回退到原 POST(无进度) - 增加流式接口单测:验证事件序列与输出落盘
This commit is contained in:
@@ -125,6 +125,40 @@
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 解密进度 -->
|
||||
<div v-if="loading || dbDecryptProgress.total > 0" class="mt-6">
|
||||
<div class="flex items-center justify-between mb-2">
|
||||
<div class="text-sm text-[#7F7F7F]">
|
||||
{{ dbDecryptProgress.message || (loading ? '解密中...' : '') }}
|
||||
</div>
|
||||
<div v-if="dbDecryptProgress.total > 0" class="text-sm font-mono text-[#000000e6]">
|
||||
{{ dbDecryptProgress.current }} / {{ dbDecryptProgress.total }}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="w-full bg-gray-200 rounded-full h-2 overflow-hidden">
|
||||
<div
|
||||
class="h-full bg-[#07C160] transition-all duration-300"
|
||||
:style="{ width: dbProgressPercent + '%' }"
|
||||
></div>
|
||||
</div>
|
||||
|
||||
<div v-if="dbDecryptProgress.current_file" class="mt-2 text-xs text-[#7F7F7F] truncate font-mono">
|
||||
{{ dbDecryptProgress.current_file }}
|
||||
</div>
|
||||
|
||||
<div v-if="dbDecryptProgress.total > 0" class="mt-3 grid grid-cols-2 gap-4 text-center">
|
||||
<div class="bg-gray-50 rounded-lg p-3">
|
||||
<div class="text-lg font-bold text-[#07C160]">{{ dbDecryptProgress.success_count }}</div>
|
||||
<div class="text-xs text-[#7F7F7F]">成功</div>
|
||||
</div>
|
||||
<div class="bg-gray-50 rounded-lg p-3">
|
||||
<div class="text-lg font-bold text-[#FA5151]">{{ dbDecryptProgress.fail_count }}</div>
|
||||
<div class="text-xs text-[#7F7F7F]">失败</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</form>
|
||||
</div>
|
||||
</div>
|
||||
@@ -413,7 +447,7 @@
|
||||
</style>
|
||||
|
||||
<script setup>
|
||||
import { ref, reactive, computed, onMounted } from 'vue'
|
||||
import { ref, reactive, computed, onMounted, onBeforeUnmount } from 'vue'
|
||||
import { useApi } from '~/composables/useApi'
|
||||
|
||||
const { decryptDatabase, saveMediaKeys, getSavedKeys, getDbKey, getImageKey, getWxStatus } = useApi()
|
||||
@@ -625,6 +659,22 @@ const clearManualKeys = () => {
|
||||
const mediaDecryptResult = ref(null)
|
||||
const mediaDecrypting = ref(false)
|
||||
|
||||
// 数据库解密进度(SSE)
|
||||
const dbDecryptProgress = reactive({
|
||||
current: 0,
|
||||
total: 0,
|
||||
success_count: 0,
|
||||
fail_count: 0,
|
||||
current_file: '',
|
||||
status: '',
|
||||
message: ''
|
||||
})
|
||||
|
||||
const dbProgressPercent = computed(() => {
|
||||
if (dbDecryptProgress.total === 0) return 0
|
||||
return Math.round((dbDecryptProgress.current / dbDecryptProgress.total) * 100)
|
||||
})
|
||||
|
||||
// 实时解密进度
|
||||
const decryptProgress = reactive({
|
||||
current: 0,
|
||||
@@ -673,6 +723,27 @@ const validateForm = () => {
|
||||
return isValid
|
||||
}
|
||||
|
||||
let dbDecryptEventSource = null
|
||||
onBeforeUnmount(() => {
|
||||
try {
|
||||
if (dbDecryptEventSource) dbDecryptEventSource.close()
|
||||
} catch (e) {
|
||||
// ignore
|
||||
} finally {
|
||||
dbDecryptEventSource = null
|
||||
}
|
||||
})
|
||||
|
||||
const resetDbDecryptProgress = () => {
|
||||
dbDecryptProgress.current = 0
|
||||
dbDecryptProgress.total = 0
|
||||
dbDecryptProgress.success_count = 0
|
||||
dbDecryptProgress.fail_count = 0
|
||||
dbDecryptProgress.current_file = ''
|
||||
dbDecryptProgress.status = ''
|
||||
dbDecryptProgress.message = ''
|
||||
}
|
||||
|
||||
// 处理解密
|
||||
const handleDecrypt = async () => {
|
||||
if (!validateForm()) {
|
||||
@@ -682,43 +753,142 @@ const handleDecrypt = async () => {
|
||||
loading.value = true
|
||||
error.value = ''
|
||||
warning.value = ''
|
||||
|
||||
resetDbDecryptProgress()
|
||||
|
||||
try {
|
||||
const result = await decryptDatabase({
|
||||
key: formData.key,
|
||||
db_storage_path: formData.db_storage_path
|
||||
})
|
||||
|
||||
if (result.status === 'completed') {
|
||||
// 解密成功,保存结果并进入下一步
|
||||
decryptResult.value = result
|
||||
if (process.client && typeof window !== 'undefined') {
|
||||
sessionStorage.setItem('decryptResult', JSON.stringify(result))
|
||||
}
|
||||
// 记录当前账号(用于图片解密/密钥保存)
|
||||
try {
|
||||
const accounts = Object.keys(result.account_results || {})
|
||||
if (accounts.length > 0) mediaAccount.value = accounts[0]
|
||||
} catch (e) {
|
||||
// ignore
|
||||
const canSse = process.client && typeof window !== 'undefined' && typeof EventSource !== 'undefined'
|
||||
|
||||
// Fallback: 如果环境不支持 SSE,则使用普通 POST(无进度)。
|
||||
if (!canSse) {
|
||||
const result = await decryptDatabase({
|
||||
key: formData.key,
|
||||
db_storage_path: formData.db_storage_path
|
||||
})
|
||||
|
||||
if (result.status === 'completed') {
|
||||
decryptResult.value = result
|
||||
if (process.client && typeof window !== 'undefined') {
|
||||
sessionStorage.setItem('decryptResult', JSON.stringify(result))
|
||||
}
|
||||
try {
|
||||
const accounts = Object.keys(result.account_results || {})
|
||||
if (accounts.length > 0) mediaAccount.value = accounts[0]
|
||||
} catch (e) {}
|
||||
|
||||
clearManualKeys()
|
||||
currentStep.value = 1
|
||||
await prefillKeysForAccount(mediaAccount.value)
|
||||
} else if (result.status === 'failed') {
|
||||
if (result.failure_count > 0 && result.success_count === 0) {
|
||||
error.value = result.message || '所有文件解密失败'
|
||||
} else {
|
||||
error.value = '部分文件解密失败,请检查密钥是否正确'
|
||||
}
|
||||
} else {
|
||||
error.value = result.message || '解密失败,请检查输入信息'
|
||||
}
|
||||
|
||||
// 进入图片密钥填写步骤
|
||||
clearManualKeys()
|
||||
currentStep.value = 1
|
||||
await prefillKeysForAccount(mediaAccount.value)
|
||||
} else if (result.status === 'failed') {
|
||||
if (result.failure_count > 0 && result.success_count === 0) {
|
||||
error.value = result.message || '所有文件解密失败'
|
||||
} else {
|
||||
error.value = '部分文件解密失败,请检查密钥是否正确'
|
||||
loading.value = false
|
||||
return
|
||||
}
|
||||
|
||||
// SSE: 解密过程实时推送进度
|
||||
if (dbDecryptEventSource) {
|
||||
try {
|
||||
dbDecryptEventSource.close()
|
||||
} catch (e) {}
|
||||
dbDecryptEventSource = null
|
||||
}
|
||||
|
||||
const params = new URLSearchParams()
|
||||
params.set('key', formData.key)
|
||||
params.set('db_storage_path', formData.db_storage_path)
|
||||
const url = `http://localhost:8000/api/decrypt_stream?${params.toString()}`
|
||||
|
||||
dbDecryptProgress.message = '连接中...'
|
||||
const eventSource = new EventSource(url)
|
||||
dbDecryptEventSource = eventSource
|
||||
|
||||
eventSource.onmessage = async (event) => {
|
||||
try {
|
||||
const data = JSON.parse(event.data)
|
||||
|
||||
if (data.type === 'scanning') {
|
||||
dbDecryptProgress.message = data.message || '正在扫描数据库文件...'
|
||||
} else if (data.type === 'start') {
|
||||
dbDecryptProgress.total = data.total || 0
|
||||
dbDecryptProgress.message = data.message || '开始解密...'
|
||||
} else if (data.type === 'progress') {
|
||||
dbDecryptProgress.current = data.current || 0
|
||||
dbDecryptProgress.total = data.total || 0
|
||||
dbDecryptProgress.success_count = data.success_count || 0
|
||||
dbDecryptProgress.fail_count = data.fail_count || 0
|
||||
dbDecryptProgress.current_file = data.current_file || ''
|
||||
dbDecryptProgress.status = data.status || ''
|
||||
dbDecryptProgress.message = data.message || ''
|
||||
} else if (data.type === 'phase') {
|
||||
// e.g. building cache
|
||||
dbDecryptProgress.message = data.message || ''
|
||||
} else if (data.type === 'complete') {
|
||||
dbDecryptProgress.status = 'complete'
|
||||
dbDecryptProgress.current = data.total_databases || dbDecryptProgress.total
|
||||
dbDecryptProgress.total = data.total_databases || dbDecryptProgress.total
|
||||
dbDecryptProgress.success_count = data.success_count || 0
|
||||
dbDecryptProgress.fail_count = data.failure_count || 0
|
||||
dbDecryptProgress.message = data.message || '解密完成'
|
||||
|
||||
decryptResult.value = data
|
||||
if (process.client && typeof window !== 'undefined') {
|
||||
sessionStorage.setItem('decryptResult', JSON.stringify(data))
|
||||
}
|
||||
|
||||
try {
|
||||
const accounts = Object.keys(data.account_results || {})
|
||||
if (accounts.length > 0) mediaAccount.value = accounts[0]
|
||||
} catch (e) {}
|
||||
|
||||
try {
|
||||
eventSource.close()
|
||||
} catch (e) {}
|
||||
dbDecryptEventSource = null
|
||||
loading.value = false
|
||||
|
||||
if (data.status === 'completed') {
|
||||
clearManualKeys()
|
||||
currentStep.value = 1
|
||||
await prefillKeysForAccount(mediaAccount.value)
|
||||
} else if (data.status === 'failed') {
|
||||
error.value = data.message || '所有文件解密失败'
|
||||
} else {
|
||||
error.value = data.message || '解密失败,请检查输入信息'
|
||||
}
|
||||
} else if (data.type === 'error') {
|
||||
error.value = data.message || '解密失败,请检查输入信息'
|
||||
try {
|
||||
eventSource.close()
|
||||
} catch (e) {}
|
||||
dbDecryptEventSource = null
|
||||
loading.value = false
|
||||
}
|
||||
} catch (e) {
|
||||
console.error('解析SSE消息失败:', e)
|
||||
}
|
||||
}
|
||||
|
||||
eventSource.onerror = (e) => {
|
||||
console.error('SSE连接错误:', e)
|
||||
try {
|
||||
eventSource.close()
|
||||
} catch (err) {}
|
||||
dbDecryptEventSource = null
|
||||
if (loading.value) {
|
||||
error.value = 'SSE连接中断,请重试'
|
||||
loading.value = false
|
||||
}
|
||||
} else {
|
||||
error.value = result.message || '解密失败,请检查输入信息'
|
||||
}
|
||||
} catch (err) {
|
||||
error.value = err.message || '解密过程中发生错误'
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,20 @@
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel, Field
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from ..app_paths import get_output_databases_dir
|
||||
from ..logging_config import get_logger
|
||||
from ..path_fix import PathFixRoute
|
||||
from ..key_store import upsert_account_keys_in_store
|
||||
from ..wechat_decrypt import decrypt_wechat_databases
|
||||
from ..wechat_decrypt import WeChatDatabaseDecryptor, decrypt_wechat_databases
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -72,3 +82,273 @@ async def decrypt_databases(request: DecryptRequest):
|
||||
except Exception as e:
|
||||
logger.error(f"解密API异常: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/api/decrypt_stream", summary="解密微信数据库(SSE实时进度)")
|
||||
async def decrypt_databases_stream(
|
||||
request: Request,
|
||||
key: str | None = None,
|
||||
db_storage_path: str | None = None,
|
||||
):
|
||||
"""通过SSE实时推送数据库解密进度。
|
||||
|
||||
注意:EventSource 只支持 GET,因此参数通过 querystring 传递。
|
||||
"""
|
||||
|
||||
def _sse(payload: dict) -> str:
|
||||
return f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||
|
||||
async def generate_progress():
|
||||
# 1) Basic validation (keep 200 + SSE error event, avoid 422 breaking EventSource).
|
||||
k = str(key or "").strip()
|
||||
p = str(db_storage_path or "").strip()
|
||||
|
||||
if not k or len(k) != 64:
|
||||
yield _sse({"type": "error", "message": "密钥格式无效,必须是64位十六进制字符串"})
|
||||
return
|
||||
|
||||
try:
|
||||
bytes.fromhex(k)
|
||||
except Exception:
|
||||
yield _sse({"type": "error", "message": "密钥必须是有效的十六进制字符串"})
|
||||
return
|
||||
|
||||
if not p:
|
||||
yield _sse({"type": "error", "message": "请提供 db_storage_path 参数"})
|
||||
return
|
||||
|
||||
storage_path = Path(p)
|
||||
if not storage_path.exists():
|
||||
yield _sse({"type": "error", "message": f"指定的数据库路径不存在: {p}"})
|
||||
return
|
||||
|
||||
# 2) Scan databases.
|
||||
yield _sse({"type": "scanning", "message": "正在扫描数据库文件..."})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
account_name = "unknown_account"
|
||||
path_parts = storage_path.parts
|
||||
account_patterns = ["wxid_"]
|
||||
for part in path_parts:
|
||||
for pattern in account_patterns:
|
||||
if part.startswith(pattern):
|
||||
parts = part.split("_")
|
||||
if len(parts) >= 3:
|
||||
account_name = "_".join(parts[:-1])
|
||||
else:
|
||||
account_name = part
|
||||
break
|
||||
if account_name != "unknown_account":
|
||||
break
|
||||
|
||||
if account_name == "unknown_account":
|
||||
for part in reversed(path_parts):
|
||||
if part != "db_storage" and len(part) > 3:
|
||||
account_name = part
|
||||
break
|
||||
|
||||
databases: list[dict] = []
|
||||
for root, _dirs, files in os.walk(storage_path):
|
||||
if "db_storage" not in str(root):
|
||||
continue
|
||||
for file_name in files:
|
||||
if not file_name.endswith(".db"):
|
||||
continue
|
||||
if file_name in ["key_info.db"]:
|
||||
continue
|
||||
db_path = os.path.join(root, file_name)
|
||||
databases.append({"path": db_path, "name": file_name, "account": account_name})
|
||||
|
||||
if not databases:
|
||||
yield _sse({"type": "error", "message": "未找到微信数据库文件!请检查 db_storage_path 是否正确"})
|
||||
return
|
||||
|
||||
account_databases = {account_name: databases}
|
||||
total_databases = sum(len(dbs) for dbs in account_databases.values())
|
||||
|
||||
yield _sse({"type": "start", "total": total_databases, "message": f"开始解密 {total_databases} 个数据库"})
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# 3) Init output dir & decryptor.
|
||||
base_output_dir = get_output_databases_dir()
|
||||
base_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
decryptor = WeChatDatabaseDecryptor(k)
|
||||
except ValueError as e:
|
||||
yield _sse({"type": "error", "message": f"密钥错误: {e}"})
|
||||
return
|
||||
|
||||
# 4) Decrypt per account, stream progress.
|
||||
success_count = 0
|
||||
fail_count = 0
|
||||
processed_files: list[str] = []
|
||||
failed_files: list[str] = []
|
||||
account_results: dict = {}
|
||||
overall_current = 0
|
||||
|
||||
for account, dbs in account_databases.items():
|
||||
account_output_dir = base_output_dir / account
|
||||
account_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save a hint for later UI (same as non-stream endpoint).
|
||||
try:
|
||||
source_db_storage_path = p
|
||||
wxid_dir = ""
|
||||
if storage_path.name.lower() == "db_storage":
|
||||
wxid_dir = str(storage_path.parent)
|
||||
else:
|
||||
wxid_dir = str(storage_path)
|
||||
(account_output_dir / "_source.json").write_text(
|
||||
json.dumps({"db_storage_path": source_db_storage_path, "wxid_dir": wxid_dir}, ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
account_success = 0
|
||||
account_processed: list[str] = []
|
||||
account_failed: list[str] = []
|
||||
|
||||
for db_info in dbs:
|
||||
if await request.is_disconnected():
|
||||
return
|
||||
|
||||
overall_current += 1
|
||||
db_path = str(db_info.get("path") or "")
|
||||
db_name = str(db_info.get("name") or "")
|
||||
current_file = f"{account}/{db_name}" if account else db_name
|
||||
|
||||
# Emit a "processing" event so UI updates immediately for large db files.
|
||||
yield _sse(
|
||||
{
|
||||
"type": "progress",
|
||||
"current": overall_current,
|
||||
"total": total_databases,
|
||||
"success_count": success_count,
|
||||
"fail_count": fail_count,
|
||||
"current_file": current_file,
|
||||
"status": "processing",
|
||||
"message": "解密中...",
|
||||
}
|
||||
)
|
||||
|
||||
output_path = account_output_dir / db_name
|
||||
task = asyncio.create_task(asyncio.to_thread(decryptor.decrypt_database, db_path, str(output_path)))
|
||||
|
||||
# Wait with heartbeat (can't yield while awaiting the thread directly).
|
||||
last_heartbeat = time.time()
|
||||
while not task.done():
|
||||
if await request.is_disconnected():
|
||||
return
|
||||
now = time.time()
|
||||
if now - last_heartbeat > 15:
|
||||
last_heartbeat = now
|
||||
# SSE comment heartbeat; browsers ignore but keeps proxies alive.
|
||||
yield ": ping\n\n"
|
||||
await asyncio.sleep(0.6)
|
||||
try:
|
||||
ok = bool(task.result())
|
||||
except Exception:
|
||||
ok = False
|
||||
|
||||
if ok:
|
||||
account_success += 1
|
||||
success_count += 1
|
||||
account_processed.append(str(output_path))
|
||||
processed_files.append(str(output_path))
|
||||
status = "success"
|
||||
msg = "解密成功"
|
||||
else:
|
||||
account_failed.append(db_path)
|
||||
failed_files.append(db_path)
|
||||
fail_count += 1
|
||||
status = "fail"
|
||||
msg = "解密失败"
|
||||
|
||||
yield _sse(
|
||||
{
|
||||
"type": "progress",
|
||||
"current": overall_current,
|
||||
"total": total_databases,
|
||||
"success_count": success_count,
|
||||
"fail_count": fail_count,
|
||||
"current_file": current_file,
|
||||
"status": status,
|
||||
"message": msg,
|
||||
}
|
||||
)
|
||||
|
||||
if overall_current % 5 == 0:
|
||||
await asyncio.sleep(0)
|
||||
|
||||
account_results[account] = {
|
||||
"total": len(dbs),
|
||||
"success": account_success,
|
||||
"failed": len(dbs) - account_success,
|
||||
"output_dir": str(account_output_dir),
|
||||
"processed_files": account_processed,
|
||||
"failed_files": account_failed,
|
||||
}
|
||||
|
||||
# Build cache table (keep behavior consistent with the POST endpoint).
|
||||
if os.environ.get("WECHAT_TOOL_BUILD_SESSION_LAST_MESSAGE", "1") != "0":
|
||||
yield _sse(
|
||||
{
|
||||
"type": "phase",
|
||||
"phase": "session_last_message",
|
||||
"account": account,
|
||||
"message": "正在构建会话缓存(最后一条消息)...",
|
||||
}
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
try:
|
||||
from ..session_last_message import build_session_last_message_table
|
||||
|
||||
task = asyncio.create_task(
|
||||
asyncio.to_thread(
|
||||
build_session_last_message_table,
|
||||
account_output_dir,
|
||||
rebuild=True,
|
||||
include_hidden=True,
|
||||
include_official=True,
|
||||
)
|
||||
)
|
||||
last_heartbeat = time.time()
|
||||
while not task.done():
|
||||
if await request.is_disconnected():
|
||||
return
|
||||
now = time.time()
|
||||
if now - last_heartbeat > 15:
|
||||
last_heartbeat = now
|
||||
yield ": ping\n\n"
|
||||
await asyncio.sleep(0.6)
|
||||
account_results[account]["session_last_message"] = task.result()
|
||||
except Exception as e:
|
||||
account_results[account]["session_last_message"] = {"status": "error", "message": str(e)}
|
||||
|
||||
status = "completed" if success_count > 0 else "failed"
|
||||
result = {
|
||||
"status": status,
|
||||
"total_databases": total_databases,
|
||||
"success_count": success_count,
|
||||
"failure_count": total_databases - success_count,
|
||||
"output_directory": str(base_output_dir.absolute()),
|
||||
"message": f"解密完成: 成功 {success_count}/{total_databases}",
|
||||
"processed_files": processed_files,
|
||||
"failed_files": failed_files,
|
||||
"account_results": account_results,
|
||||
}
|
||||
|
||||
# Save db key for frontend autofill.
|
||||
try:
|
||||
for account in (account_results or {}).keys():
|
||||
upsert_account_keys_in_store(str(account), db_key=k)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
yield _sse({"type": "complete", **result})
|
||||
|
||||
headers = {"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"}
|
||||
return StreamingResponse(generate_progress(), media_type="text/event-stream", headers=headers)
|
||||
|
||||
91
tests/test_decrypt_stream_sse.py
Normal file
91
tests/test_decrypt_stream_sse.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(ROOT / "src"))
|
||||
|
||||
|
||||
class TestDecryptStreamSSE(unittest.TestCase):
|
||||
def test_decrypt_stream_reports_progress(self):
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from wechat_decrypt_tool.wechat_decrypt import SQLITE_HEADER
|
||||
|
||||
with TemporaryDirectory() as td:
|
||||
root = Path(td)
|
||||
|
||||
prev_data_dir = os.environ.get("WECHAT_TOOL_DATA_DIR")
|
||||
prev_build_cache = os.environ.get("WECHAT_TOOL_BUILD_SESSION_LAST_MESSAGE")
|
||||
try:
|
||||
os.environ["WECHAT_TOOL_DATA_DIR"] = str(root)
|
||||
os.environ["WECHAT_TOOL_BUILD_SESSION_LAST_MESSAGE"] = "0"
|
||||
|
||||
import wechat_decrypt_tool.app_paths as app_paths
|
||||
import wechat_decrypt_tool.routers.decrypt as decrypt_router
|
||||
|
||||
importlib.reload(app_paths)
|
||||
importlib.reload(decrypt_router)
|
||||
|
||||
db_storage = root / "xwechat_files" / "wxid_foo_bar" / "db_storage"
|
||||
db_storage.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Fake a decrypted sqlite db (>= 4096 bytes) so decryptor falls back to copy.
|
||||
(db_storage / "MSG0.db").write_bytes(SQLITE_HEADER + b"\x00" * (4096 - len(SQLITE_HEADER)))
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(decrypt_router.router)
|
||||
client = TestClient(app)
|
||||
|
||||
events: list[dict] = []
|
||||
with client.stream(
|
||||
"GET",
|
||||
"/api/decrypt_stream",
|
||||
params={"key": "00" * 32, "db_storage_path": str(db_storage)},
|
||||
) as resp:
|
||||
self.assertEqual(resp.status_code, 200)
|
||||
self.assertIn("text/event-stream", resp.headers.get("content-type", ""))
|
||||
|
||||
for line in resp.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
if isinstance(line, bytes):
|
||||
line = line.decode("utf-8", errors="ignore")
|
||||
line = str(line)
|
||||
|
||||
if line.startswith(":"):
|
||||
continue
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
payload = json.loads(line[len("data: ") :])
|
||||
events.append(payload)
|
||||
if payload.get("type") in {"complete", "error"}:
|
||||
break
|
||||
|
||||
types = {e.get("type") for e in events}
|
||||
self.assertIn("start", types)
|
||||
self.assertIn("progress", types)
|
||||
self.assertEqual(events[-1].get("type"), "complete")
|
||||
|
||||
out = root / "output" / "databases" / "wxid_foo" / "MSG0.db"
|
||||
self.assertTrue(out.exists())
|
||||
finally:
|
||||
if prev_data_dir is None:
|
||||
os.environ.pop("WECHAT_TOOL_DATA_DIR", None)
|
||||
else:
|
||||
os.environ["WECHAT_TOOL_DATA_DIR"] = prev_data_dir
|
||||
if prev_build_cache is None:
|
||||
os.environ.pop("WECHAT_TOOL_BUILD_SESSION_LAST_MESSAGE", None)
|
||||
else:
|
||||
os.environ["WECHAT_TOOL_BUILD_SESSION_LAST_MESSAGE"] = prev_build_cache
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user