mirror of
https://github.com/LifeArchiveProject/WeChatDataAnalysis.git
synced 2026-06-18 15:54:08 +08:00
fix(decrypt): 加强解密结果校验并过滤内部数据库
- 新增可用 SQLite 校验,解密失败时返回更明确提示并清理无效输出 - 统一过滤 key_info、FTS 索引库和内部缓存库,修正数据库扫描与账号统计 - 补充解密流和数据库过滤相关测试
This commit is contained in:
@@ -14,7 +14,7 @@ from fastapi import HTTPException
|
||||
|
||||
from .app_paths import get_output_databases_dir
|
||||
from .logging_config import get_logger
|
||||
from .sqlite_diagnostics import collect_sqlite_diagnostics, format_sqlite_diagnostics
|
||||
from .sqlite_diagnostics import collect_sqlite_diagnostics, format_sqlite_diagnostics, is_usable_sqlite_db
|
||||
|
||||
try:
|
||||
import zstandard as zstd # type: ignore
|
||||
@@ -29,13 +29,7 @@ _SQLITE_HEADER = b"SQLite format 3\x00"
|
||||
|
||||
|
||||
def _is_valid_decrypted_sqlite(path: Path) -> bool:
|
||||
try:
|
||||
if not path.exists() or (not path.is_file()):
|
||||
return False
|
||||
with path.open("rb") as f:
|
||||
return f.read(len(_SQLITE_HEADER)) == _SQLITE_HEADER
|
||||
except Exception:
|
||||
return False
|
||||
return is_usable_sqlite_db(path)
|
||||
|
||||
|
||||
def _list_decrypted_accounts() -> list[str]:
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
_IGNORED_SOURCE_DATABASE_NAMES = frozenset({"key_info.db"})
|
||||
_INDEX_DATABASE_NAMES = frozenset({"chat_search_index.db", "chat_search_index.tmp.db"})
|
||||
_INDEX_DATABASE_SUFFIXES = ("_fts.db",)
|
||||
_INTERNAL_OUTPUT_DATABASE_NAMES = frozenset(
|
||||
{
|
||||
"chat_search_index.db",
|
||||
"chat_search_index.tmp.db",
|
||||
"session_last_message.db",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def normalize_database_file_name(file_name: str | Path) -> str:
|
||||
return Path(str(file_name or "")).name.strip().lower()
|
||||
|
||||
|
||||
def is_index_database_file(file_name: str | Path) -> bool:
|
||||
lower_name = normalize_database_file_name(file_name)
|
||||
if not lower_name:
|
||||
return False
|
||||
if lower_name in _INDEX_DATABASE_NAMES:
|
||||
return True
|
||||
return lower_name.endswith(_INDEX_DATABASE_SUFFIXES)
|
||||
|
||||
|
||||
def should_skip_source_database(file_name: str | Path) -> bool:
|
||||
lower_name = normalize_database_file_name(file_name)
|
||||
if not lower_name:
|
||||
return True
|
||||
if lower_name in _IGNORED_SOURCE_DATABASE_NAMES:
|
||||
return True
|
||||
return is_index_database_file(lower_name)
|
||||
|
||||
|
||||
def should_include_in_database_count(file_name: str | Path) -> bool:
|
||||
lower_name = normalize_database_file_name(file_name)
|
||||
if not lower_name.endswith(".db"):
|
||||
return False
|
||||
if should_skip_source_database(lower_name):
|
||||
return False
|
||||
if lower_name in _INTERNAL_OUTPUT_DATABASE_NAMES:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def list_countable_database_names(account_dir: Path) -> list[str]:
|
||||
if not account_dir.exists():
|
||||
return []
|
||||
|
||||
db_files = [
|
||||
path.name
|
||||
for path in account_dir.glob("*.db")
|
||||
if path.is_file() and should_include_in_database_count(path.name)
|
||||
]
|
||||
db_files.sort()
|
||||
return db_files
|
||||
@@ -18,6 +18,7 @@ from fastapi import HTTPException
|
||||
|
||||
from .app_paths import get_output_databases_dir
|
||||
from .logging_config import get_logger
|
||||
from .sqlite_diagnostics import is_usable_sqlite_db
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -28,13 +29,7 @@ _SQLITE_HEADER = b"SQLite format 3\x00"
|
||||
|
||||
|
||||
def _is_valid_decrypted_sqlite(path: Path) -> bool:
|
||||
try:
|
||||
if not path.exists() or (not path.is_file()):
|
||||
return False
|
||||
with path.open("rb") as f:
|
||||
return f.read(len(_SQLITE_HEADER)) == _SQLITE_HEADER
|
||||
except Exception:
|
||||
return False
|
||||
return is_usable_sqlite_db(path)
|
||||
|
||||
|
||||
def _list_decrypted_accounts() -> list[str]:
|
||||
|
||||
@@ -70,6 +70,7 @@ from ..chat_helpers import (
|
||||
from ..media_helpers import _resolve_account_db_storage_dir, _try_find_decrypted_resource
|
||||
from .. import chat_edit_store
|
||||
from ..app_paths import get_output_dir
|
||||
from ..database_filters import list_countable_database_names
|
||||
from ..key_store import remove_account_keys_from_store
|
||||
from ..path_fix import PathFixRoute
|
||||
from ..session_last_message import (
|
||||
@@ -3892,7 +3893,7 @@ async def list_chat_accounts():
|
||||
@router.get("/api/chat/account_info", summary="获取当前账号信息")
|
||||
def get_chat_account_info(account: Optional[str] = None):
|
||||
account_dir = _resolve_account_dir(account)
|
||||
db_files = sorted([p.name for p in account_dir.glob("*.db") if p.is_file()])
|
||||
db_files = list_countable_database_names(account_dir)
|
||||
|
||||
session_db = account_dir / "session.db"
|
||||
session_updated_at = 0
|
||||
|
||||
@@ -16,7 +16,12 @@ from ..chat_realtime_autosync import CHAT_REALTIME_AUTOSYNC
|
||||
from ..logging_config import get_logger
|
||||
from ..path_fix import PathFixRoute
|
||||
from ..key_store import upsert_account_keys_in_store
|
||||
from ..wechat_decrypt import WeChatDatabaseDecryptor, decrypt_wechat_databases, scan_account_databases_from_path
|
||||
from ..wechat_decrypt import (
|
||||
WeChatDatabaseDecryptor,
|
||||
build_decrypt_summary_message,
|
||||
decrypt_wechat_databases,
|
||||
scan_account_databases_from_path,
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -463,7 +468,11 @@ async def decrypt_databases_stream(
|
||||
"success_count": success_count,
|
||||
"failure_count": total_databases - success_count,
|
||||
"output_directory": str(base_output_dir.absolute()),
|
||||
"message": f"解密完成: 成功 {success_count}/{total_databases}",
|
||||
"message": build_decrypt_summary_message(
|
||||
success_count=success_count,
|
||||
total_databases=total_databases,
|
||||
diagnostic_warning_count=diagnostic_warning_count,
|
||||
),
|
||||
"processed_files": processed_files,
|
||||
"failed_files": failed_files,
|
||||
"account_results": account_results,
|
||||
|
||||
@@ -123,6 +123,40 @@ def collect_sqlite_diagnostics(
|
||||
return diagnostics
|
||||
|
||||
|
||||
def is_usable_sqlite_db(path: str | Path) -> bool:
|
||||
db_path = Path(path)
|
||||
if not db_path.exists() or (not db_path.is_file()):
|
||||
return False
|
||||
|
||||
try:
|
||||
if int(db_path.stat().st_size) <= len(SQLITE_HEADER):
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
try:
|
||||
with db_path.open("rb") as f:
|
||||
if f.read(len(SQLITE_HEADER)) != SQLITE_HEADER:
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
conn: sqlite3.Connection | None = None
|
||||
try:
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.execute("PRAGMA schema_version").fetchone()
|
||||
row = conn.execute("SELECT name FROM sqlite_master WHERE type='table' LIMIT 1").fetchone()
|
||||
return row is not None
|
||||
except Exception:
|
||||
return False
|
||||
finally:
|
||||
if conn is not None:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def sqlite_diagnostics_status(diagnostics: Mapping[str, Any]) -> str:
|
||||
if not diagnostics:
|
||||
return "not_run"
|
||||
|
||||
@@ -21,6 +21,7 @@ from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
|
||||
from .app_paths import get_output_databases_dir
|
||||
from .database_filters import should_skip_source_database
|
||||
from .sqlite_diagnostics import collect_sqlite_diagnostics, sqlite_diagnostics_status
|
||||
|
||||
# 注意:不再支持默认密钥,所有密钥必须通过参数传入
|
||||
@@ -64,6 +65,59 @@ def _derive_account_name_from_path(path: Path) -> str:
|
||||
return "unknown_account"
|
||||
|
||||
|
||||
def _build_decrypt_failure_message(result: dict) -> str:
|
||||
failed_pages = int(result.get("failed_pages") or 0)
|
||||
successful_pages = int(result.get("successful_pages") or 0)
|
||||
diagnostic_status = str(result.get("diagnostic_status") or "").strip()
|
||||
diagnostics = dict(result.get("diagnostics") or {})
|
||||
|
||||
detail = (
|
||||
diagnostics.get("quick_check_error")
|
||||
or diagnostics.get("connect_error")
|
||||
or diagnostics.get("table_list_error")
|
||||
or diagnostics.get("page_count_error")
|
||||
or diagnostics.get("quick_check")
|
||||
or diagnostic_status
|
||||
)
|
||||
detail_text = " ".join(str(detail or "").split()).strip()
|
||||
|
||||
if failed_pages > 0 and successful_pages == 0:
|
||||
if detail_text:
|
||||
return f"数据库校验未通过,密钥可能不匹配当前账号: {detail_text}"
|
||||
return "数据库校验未通过,密钥可能不匹配当前账号"
|
||||
|
||||
if diagnostic_status and diagnostic_status != "ok":
|
||||
if detail_text:
|
||||
return f"解密输出不是有效的 SQLite 数据库: {detail_text}"
|
||||
return "解密输出不是有效的 SQLite 数据库"
|
||||
|
||||
if failed_pages > 0:
|
||||
return "解密输出包含页失败,结果不完整"
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def build_decrypt_summary_message(*, success_count: int, total_databases: int, diagnostic_warning_count: int) -> str:
|
||||
success_count = int(success_count or 0)
|
||||
total_databases = int(total_databases or 0)
|
||||
diagnostic_warning_count = int(diagnostic_warning_count or 0)
|
||||
|
||||
if total_databases <= 0:
|
||||
return "未找到可解密的数据库"
|
||||
|
||||
if success_count <= 0:
|
||||
if diagnostic_warning_count > 0:
|
||||
return "解密失败:数据库校验未通过,密钥可能不匹配当前账号。"
|
||||
return "解密失败:未能成功解密任何数据库。"
|
||||
|
||||
if success_count < total_databases:
|
||||
if diagnostic_warning_count > 0:
|
||||
return f"解密部分成功:成功 {success_count}/{total_databases},其余数据库校验未通过。"
|
||||
return f"解密部分成功:成功 {success_count}/{total_databases}。"
|
||||
|
||||
return f"解密完成: 成功 {success_count}/{total_databases}"
|
||||
|
||||
|
||||
def _resolve_db_storage_roots(storage_path: Path) -> list[Path]:
|
||||
try:
|
||||
target = storage_path.resolve()
|
||||
@@ -158,7 +212,7 @@ def scan_account_databases_from_path(db_storage_path: str) -> dict:
|
||||
for file_name in files:
|
||||
if not file_name.endswith(".db"):
|
||||
continue
|
||||
if file_name in ["key_info.db"]:
|
||||
if should_skip_source_database(file_name):
|
||||
continue
|
||||
db_path = os.path.join(root, file_name)
|
||||
databases.append(
|
||||
@@ -266,7 +320,8 @@ class WeChatDatabaseDecryptor:
|
||||
result["failed_page_samples"].append(item)
|
||||
|
||||
def _finalize(success: bool, error: str = "") -> bool:
|
||||
result["success"] = bool(success)
|
||||
normalized_success = bool(success)
|
||||
result["success"] = normalized_success
|
||||
if error:
|
||||
result["error"] = " ".join(str(error).split()).strip()
|
||||
|
||||
@@ -281,6 +336,19 @@ class WeChatDatabaseDecryptor:
|
||||
result["diagnostics"] = diagnostics
|
||||
result["diagnostic_status"] = sqlite_diagnostics_status(diagnostics)
|
||||
|
||||
if normalized_success:
|
||||
failure_message = _build_decrypt_failure_message(result)
|
||||
if failure_message:
|
||||
normalized_success = False
|
||||
result["success"] = False
|
||||
if not result["error"]:
|
||||
result["error"] = failure_message
|
||||
if output_file.exists():
|
||||
try:
|
||||
output_file.unlink()
|
||||
except Exception as exc:
|
||||
logger.warning("删除无效解密输出失败: %s, 错误: %s", output_file, exc)
|
||||
|
||||
payload = {
|
||||
"db_name": result["db_name"],
|
||||
"db_path": result["db_path"],
|
||||
@@ -307,7 +375,7 @@ class WeChatDatabaseDecryptor:
|
||||
log_fn = logger.warning
|
||||
log_fn("[decrypt.diagnostic] %s", json.dumps(payload, ensure_ascii=False, sort_keys=True))
|
||||
self.last_result = result
|
||||
return bool(success)
|
||||
return bool(result["success"])
|
||||
|
||||
logger.info(f"开始解密数据库: {db_path}")
|
||||
|
||||
@@ -693,7 +761,11 @@ def decrypt_wechat_databases(db_storage_path: str = None, key: str = None) -> di
|
||||
# 返回结果
|
||||
result = {
|
||||
"status": "success" if success_count > 0 else "error",
|
||||
"message": f"解密完成: 成功 {success_count}/{total_databases}",
|
||||
"message": build_decrypt_summary_message(
|
||||
success_count=success_count,
|
||||
total_databases=total_databases,
|
||||
diagnostic_warning_count=diagnostic_warning_count,
|
||||
),
|
||||
"total_databases": total_databases,
|
||||
"successful_count": success_count,
|
||||
"failed_count": total_databases - success_count,
|
||||
|
||||
@@ -13,6 +13,8 @@ from typing import List, Dict, Any, Union
|
||||
from ctypes import wintypes
|
||||
from datetime import datetime
|
||||
|
||||
from .database_filters import should_skip_source_database
|
||||
|
||||
|
||||
def get_wx_db(msg_dir: str = None,
|
||||
db_types: Union[List[str], str] = None,
|
||||
@@ -59,8 +61,7 @@ def get_wx_db(msg_dir: str = None,
|
||||
for file_name in files:
|
||||
if not file_name.endswith(".db"):
|
||||
continue
|
||||
# 排除不需要解密的数据库
|
||||
if file_name in ["key_info.db"]:
|
||||
if should_skip_source_database(file_name):
|
||||
continue
|
||||
db_type = re.sub(r"\d*\.db$", "", file_name)
|
||||
if db_types and db_type not in db_types: # 如果指定db_type,则过滤掉其他db_type
|
||||
@@ -672,8 +673,7 @@ def collect_account_databases(data_dir: str, account_name: str) -> List[Dict[str
|
||||
if not file_name.endswith('.db'):
|
||||
continue
|
||||
|
||||
# 排除不需要解密的数据库
|
||||
if file_name in ["key_info.db"]:
|
||||
if should_skip_source_database(file_name):
|
||||
continue
|
||||
|
||||
db_path = os.path.join(root, file_name)
|
||||
|
||||
@@ -0,0 +1,112 @@
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(ROOT / "src"))
|
||||
|
||||
|
||||
def _close_logging_handlers() -> None:
|
||||
for logger_name in ("", "uvicorn", "uvicorn.access", "uvicorn.error", "fastapi"):
|
||||
lg = logging.getLogger(logger_name)
|
||||
for handler in lg.handlers[:]:
|
||||
try:
|
||||
handler.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
lg.removeHandler(handler)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _seed_sqlite(path: Path, table_name: str = "demo") -> None:
|
||||
conn = sqlite3.connect(str(path))
|
||||
try:
|
||||
conn.execute(f"CREATE TABLE {table_name}(id INTEGER PRIMARY KEY, value TEXT)")
|
||||
conn.execute(f"INSERT INTO {table_name}(value) VALUES ('ok')")
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
class TestDatabaseFilters(unittest.TestCase):
|
||||
def test_scan_account_databases_skips_index_databases(self):
|
||||
from wechat_decrypt_tool.wechat_decrypt import scan_account_databases_from_path
|
||||
|
||||
with TemporaryDirectory() as td:
|
||||
db_storage = Path(td) / "xwechat_files" / "wxid_demo_user" / "db_storage"
|
||||
db_storage.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
_seed_sqlite(db_storage / "MSG0.db")
|
||||
_seed_sqlite(db_storage / "contact_fts.db")
|
||||
_seed_sqlite(db_storage / "favorite_fts.db")
|
||||
_seed_sqlite(db_storage / "message_fts.db")
|
||||
_seed_sqlite(db_storage / "key_info.db")
|
||||
|
||||
result = scan_account_databases_from_path(str(db_storage))
|
||||
|
||||
self.assertEqual(result["status"], "success")
|
||||
self.assertEqual(list(result["account_databases"].keys()), ["wxid_demo"])
|
||||
db_names = sorted(db["name"] for db in result["account_databases"]["wxid_demo"])
|
||||
self.assertEqual(db_names, ["MSG0.db"])
|
||||
|
||||
def test_collect_account_databases_skips_index_databases(self):
|
||||
from wechat_decrypt_tool.wechat_detection import collect_account_databases
|
||||
|
||||
with TemporaryDirectory() as td:
|
||||
data_dir = Path(td) / "wxid_demo_user"
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
_seed_sqlite(data_dir / "contact.db", "contact")
|
||||
_seed_sqlite(data_dir / "contact_fts.db")
|
||||
_seed_sqlite(data_dir / "favorite_fts.db")
|
||||
_seed_sqlite(data_dir / "message_fts.db")
|
||||
|
||||
databases = collect_account_databases(str(data_dir), "wxid_demo")
|
||||
db_names = sorted(db["name"] for db in databases)
|
||||
self.assertEqual(db_names, ["contact.db"])
|
||||
|
||||
def test_chat_account_info_hides_index_and_internal_databases(self):
|
||||
with TemporaryDirectory() as td:
|
||||
root = Path(td)
|
||||
prev_data_dir = os.environ.get("WECHAT_TOOL_DATA_DIR")
|
||||
try:
|
||||
os.environ["WECHAT_TOOL_DATA_DIR"] = str(root)
|
||||
|
||||
import wechat_decrypt_tool.app_paths as app_paths
|
||||
import wechat_decrypt_tool.routers.chat as chat_router
|
||||
|
||||
importlib.reload(app_paths)
|
||||
importlib.reload(chat_router)
|
||||
|
||||
account_dir = root / "output" / "databases" / "wxid_demo"
|
||||
account_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
_seed_sqlite(account_dir / "contact.db", "contact")
|
||||
_seed_sqlite(account_dir / "session.db", "session_table")
|
||||
_seed_sqlite(account_dir / "message_fts.db")
|
||||
_seed_sqlite(account_dir / "chat_search_index.db")
|
||||
_seed_sqlite(account_dir / "session_last_message.db")
|
||||
|
||||
result = chat_router.get_chat_account_info("wxid_demo")
|
||||
|
||||
self.assertEqual(result["status"], "success")
|
||||
self.assertEqual(result["database_count"], 2)
|
||||
self.assertEqual(result["databases"], ["contact.db", "session.db"])
|
||||
finally:
|
||||
_close_logging_handlers()
|
||||
if prev_data_dir is None:
|
||||
os.environ.pop("WECHAT_TOOL_DATA_DIR", None)
|
||||
else:
|
||||
os.environ["WECHAT_TOOL_DATA_DIR"] = prev_data_dir
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,5 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import sys
|
||||
import unittest
|
||||
import importlib
|
||||
@@ -11,13 +13,25 @@ ROOT = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(ROOT / "src"))
|
||||
|
||||
|
||||
def _close_logging_handlers() -> None:
|
||||
for logger_name in ("", "uvicorn", "uvicorn.access", "uvicorn.error", "fastapi"):
|
||||
lg = logging.getLogger(logger_name)
|
||||
for handler in lg.handlers[:]:
|
||||
try:
|
||||
handler.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
lg.removeHandler(handler)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class TestDecryptStreamSSE(unittest.TestCase):
|
||||
def test_decrypt_stream_reports_progress(self):
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from wechat_decrypt_tool.wechat_decrypt import SQLITE_HEADER
|
||||
|
||||
with TemporaryDirectory() as td:
|
||||
root = Path(td)
|
||||
|
||||
@@ -36,8 +50,14 @@ class TestDecryptStreamSSE(unittest.TestCase):
|
||||
db_storage = root / "xwechat_files" / "wxid_foo_bar" / "db_storage"
|
||||
db_storage.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Fake a decrypted sqlite db (>= 4096 bytes) so decryptor falls back to copy.
|
||||
(db_storage / "MSG0.db").write_bytes(SQLITE_HEADER + b"\x00" * (4096 - len(SQLITE_HEADER)))
|
||||
db_path = db_storage / "MSG0.db"
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
try:
|
||||
conn.execute("CREATE TABLE demo(id INTEGER PRIMARY KEY, value TEXT)")
|
||||
conn.execute("INSERT INTO demo(value) VALUES ('ok')")
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(decrypt_router.router)
|
||||
@@ -72,10 +92,83 @@ class TestDecryptStreamSSE(unittest.TestCase):
|
||||
self.assertIn("start", types)
|
||||
self.assertIn("progress", types)
|
||||
self.assertEqual(events[-1].get("type"), "complete")
|
||||
self.assertEqual(events[-1].get("status"), "completed")
|
||||
|
||||
out = root / "output" / "databases" / "wxid_foo" / "MSG0.db"
|
||||
self.assertTrue(out.exists())
|
||||
finally:
|
||||
_close_logging_handlers()
|
||||
if prev_data_dir is None:
|
||||
os.environ.pop("WECHAT_TOOL_DATA_DIR", None)
|
||||
else:
|
||||
os.environ["WECHAT_TOOL_DATA_DIR"] = prev_data_dir
|
||||
if prev_build_cache is None:
|
||||
os.environ.pop("WECHAT_TOOL_BUILD_SESSION_LAST_MESSAGE", None)
|
||||
else:
|
||||
os.environ["WECHAT_TOOL_BUILD_SESSION_LAST_MESSAGE"] = prev_build_cache
|
||||
|
||||
def test_decrypt_stream_marks_invalid_output_as_failed(self):
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
with TemporaryDirectory() as td:
|
||||
root = Path(td)
|
||||
|
||||
prev_data_dir = os.environ.get("WECHAT_TOOL_DATA_DIR")
|
||||
prev_build_cache = os.environ.get("WECHAT_TOOL_BUILD_SESSION_LAST_MESSAGE")
|
||||
try:
|
||||
os.environ["WECHAT_TOOL_DATA_DIR"] = str(root)
|
||||
os.environ["WECHAT_TOOL_BUILD_SESSION_LAST_MESSAGE"] = "0"
|
||||
|
||||
import wechat_decrypt_tool.app_paths as app_paths
|
||||
import wechat_decrypt_tool.routers.decrypt as decrypt_router
|
||||
|
||||
importlib.reload(app_paths)
|
||||
importlib.reload(decrypt_router)
|
||||
|
||||
db_storage = root / "xwechat_files" / "wxid_bad_case" / "db_storage"
|
||||
db_storage.mkdir(parents=True, exist_ok=True)
|
||||
(db_storage / "MSG0.db").write_bytes(b"\x01" * 4096)
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(decrypt_router.router)
|
||||
client = TestClient(app)
|
||||
|
||||
events: list[dict] = []
|
||||
with client.stream(
|
||||
"GET",
|
||||
"/api/decrypt_stream",
|
||||
params={"key": "00" * 32, "db_storage_path": str(db_storage)},
|
||||
) as resp:
|
||||
self.assertEqual(resp.status_code, 200)
|
||||
self.assertIn("text/event-stream", resp.headers.get("content-type", ""))
|
||||
|
||||
for line in resp.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
if isinstance(line, bytes):
|
||||
line = line.decode("utf-8", errors="ignore")
|
||||
line = str(line)
|
||||
|
||||
if line.startswith(":"):
|
||||
continue
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
payload = json.loads(line[len("data: ") :])
|
||||
events.append(payload)
|
||||
if payload.get("type") in {"complete", "error"}:
|
||||
break
|
||||
|
||||
self.assertEqual(events[-1].get("type"), "complete")
|
||||
self.assertEqual(events[-1].get("status"), "failed")
|
||||
self.assertEqual(events[-1].get("success_count"), 0)
|
||||
self.assertEqual(events[-1].get("failure_count"), 1)
|
||||
self.assertIn("密钥可能不匹配", str(events[-1].get("message") or ""))
|
||||
|
||||
out = root / "output" / "databases" / "wxid_bad" / "MSG0.db"
|
||||
self.assertFalse(out.exists())
|
||||
finally:
|
||||
_close_logging_handlers()
|
||||
if prev_data_dir is None:
|
||||
os.environ.pop("WECHAT_TOOL_DATA_DIR", None)
|
||||
else:
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
import importlib
|
||||
import os
|
||||
import sqlite3
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(ROOT / "src"))
|
||||
|
||||
|
||||
def _seed_sqlite(path: Path, table_name: str) -> None:
|
||||
conn = sqlite3.connect(str(path))
|
||||
try:
|
||||
conn.execute(f"CREATE TABLE {table_name}(id INTEGER PRIMARY KEY, value TEXT)")
|
||||
conn.execute(f"INSERT INTO {table_name}(value) VALUES ('ok')")
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
class TestDecryptedAccountValidation(unittest.TestCase):
|
||||
def test_invalid_header_only_databases_are_ignored(self):
|
||||
with TemporaryDirectory() as td:
|
||||
root = Path(td)
|
||||
prev_data_dir = os.environ.get("WECHAT_TOOL_DATA_DIR")
|
||||
try:
|
||||
os.environ["WECHAT_TOOL_DATA_DIR"] = str(root)
|
||||
|
||||
import wechat_decrypt_tool.app_paths as app_paths
|
||||
import wechat_decrypt_tool.chat_helpers as chat_helpers
|
||||
import wechat_decrypt_tool.media_helpers as media_helpers
|
||||
|
||||
importlib.reload(app_paths)
|
||||
importlib.reload(chat_helpers)
|
||||
importlib.reload(media_helpers)
|
||||
|
||||
output_dir = root / "output" / "databases"
|
||||
bad_dir = output_dir / "wxid_bad"
|
||||
bad_dir.mkdir(parents=True, exist_ok=True)
|
||||
(bad_dir / "session.db").write_bytes(b"SQLite format 3\x00")
|
||||
(bad_dir / "contact.db").write_bytes(b"SQLite format 3\x00")
|
||||
|
||||
good_dir = output_dir / "wxid_good"
|
||||
good_dir.mkdir(parents=True, exist_ok=True)
|
||||
_seed_sqlite(good_dir / "session.db", "SessionTable")
|
||||
_seed_sqlite(good_dir / "contact.db", "contact")
|
||||
|
||||
self.assertEqual(chat_helpers._list_decrypted_accounts(), ["wxid_good"])
|
||||
self.assertEqual(media_helpers._list_decrypted_accounts(), ["wxid_good"])
|
||||
finally:
|
||||
if prev_data_dir is None:
|
||||
os.environ.pop("WECHAT_TOOL_DATA_DIR", None)
|
||||
else:
|
||||
os.environ["WECHAT_TOOL_DATA_DIR"] = prev_data_dir
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user