From 4f9acd3f691096665f847cd645a1e5c718e55c03 Mon Sep 17 00:00:00 2001 From: 2977094657 <2977094657@qq.com> Date: Sat, 11 Apr 2026 16:57:01 +0800 Subject: [PATCH] =?UTF-8?q?fix(decrypt):=20=E5=8A=A0=E5=BC=BA=E8=A7=A3?= =?UTF-8?q?=E5=AF=86=E7=BB=93=E6=9E=9C=E6=A0=A1=E9=AA=8C=E5=B9=B6=E8=BF=87?= =?UTF-8?q?=E6=BB=A4=E5=86=85=E9=83=A8=E6=95=B0=E6=8D=AE=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增可用 SQLite 校验,解密失败时返回更明确提示并清理无效输出 - 统一过滤 key_info、FTS 索引库和内部缓存库,修正数据库扫描与账号统计 - 补充解密流和数据库过滤相关测试 --- src/wechat_decrypt_tool/chat_helpers.py | 10 +- src/wechat_decrypt_tool/database_filters.py | 61 ++++++++++ src/wechat_decrypt_tool/media_helpers.py | 9 +- src/wechat_decrypt_tool/routers/chat.py | 3 +- src/wechat_decrypt_tool/routers/decrypt.py | 13 +- src/wechat_decrypt_tool/sqlite_diagnostics.py | 34 ++++++ src/wechat_decrypt_tool/wechat_decrypt.py | 80 ++++++++++++- src/wechat_decrypt_tool/wechat_detection.py | 8 +- tests/test_database_filters.py | 112 ++++++++++++++++++ tests/test_decrypt_stream_sse.py | 101 +++++++++++++++- tests/test_decrypted_account_validation.py | 61 ++++++++++ 11 files changed, 462 insertions(+), 30 deletions(-) create mode 100644 src/wechat_decrypt_tool/database_filters.py create mode 100644 tests/test_database_filters.py create mode 100644 tests/test_decrypted_account_validation.py diff --git a/src/wechat_decrypt_tool/chat_helpers.py b/src/wechat_decrypt_tool/chat_helpers.py index 40b0bad..90267aa 100644 --- a/src/wechat_decrypt_tool/chat_helpers.py +++ b/src/wechat_decrypt_tool/chat_helpers.py @@ -14,7 +14,7 @@ from fastapi import HTTPException from .app_paths import get_output_databases_dir from .logging_config import get_logger -from .sqlite_diagnostics import collect_sqlite_diagnostics, format_sqlite_diagnostics +from .sqlite_diagnostics import collect_sqlite_diagnostics, format_sqlite_diagnostics, is_usable_sqlite_db try: import zstandard as zstd # type: ignore @@ -29,13 +29,7 @@ _SQLITE_HEADER = b"SQLite format 3\x00" def _is_valid_decrypted_sqlite(path: Path) -> bool: - try: - if not path.exists() or (not path.is_file()): - return False - with path.open("rb") as f: - return f.read(len(_SQLITE_HEADER)) == _SQLITE_HEADER - except Exception: - return False + return is_usable_sqlite_db(path) def _list_decrypted_accounts() -> list[str]: diff --git a/src/wechat_decrypt_tool/database_filters.py b/src/wechat_decrypt_tool/database_filters.py new file mode 100644 index 0000000..e1a9099 --- /dev/null +++ b/src/wechat_decrypt_tool/database_filters.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from pathlib import Path + + +_IGNORED_SOURCE_DATABASE_NAMES = frozenset({"key_info.db"}) +_INDEX_DATABASE_NAMES = frozenset({"chat_search_index.db", "chat_search_index.tmp.db"}) +_INDEX_DATABASE_SUFFIXES = ("_fts.db",) +_INTERNAL_OUTPUT_DATABASE_NAMES = frozenset( + { + "chat_search_index.db", + "chat_search_index.tmp.db", + "session_last_message.db", + } +) + + +def normalize_database_file_name(file_name: str | Path) -> str: + return Path(str(file_name or "")).name.strip().lower() + + +def is_index_database_file(file_name: str | Path) -> bool: + lower_name = normalize_database_file_name(file_name) + if not lower_name: + return False + if lower_name in _INDEX_DATABASE_NAMES: + return True + return lower_name.endswith(_INDEX_DATABASE_SUFFIXES) + + +def should_skip_source_database(file_name: str | Path) -> bool: + lower_name = normalize_database_file_name(file_name) + if not lower_name: + return True + if lower_name in _IGNORED_SOURCE_DATABASE_NAMES: + return True + return is_index_database_file(lower_name) + + +def should_include_in_database_count(file_name: str | Path) -> bool: + lower_name = normalize_database_file_name(file_name) + if not lower_name.endswith(".db"): + return False + if should_skip_source_database(lower_name): + return False + if lower_name in _INTERNAL_OUTPUT_DATABASE_NAMES: + return False + return True + + +def list_countable_database_names(account_dir: Path) -> list[str]: + if not account_dir.exists(): + return [] + + db_files = [ + path.name + for path in account_dir.glob("*.db") + if path.is_file() and should_include_in_database_count(path.name) + ] + db_files.sort() + return db_files diff --git a/src/wechat_decrypt_tool/media_helpers.py b/src/wechat_decrypt_tool/media_helpers.py index f9ee3d2..cec1049 100644 --- a/src/wechat_decrypt_tool/media_helpers.py +++ b/src/wechat_decrypt_tool/media_helpers.py @@ -18,6 +18,7 @@ from fastapi import HTTPException from .app_paths import get_output_databases_dir from .logging_config import get_logger +from .sqlite_diagnostics import is_usable_sqlite_db logger = get_logger(__name__) @@ -28,13 +29,7 @@ _SQLITE_HEADER = b"SQLite format 3\x00" def _is_valid_decrypted_sqlite(path: Path) -> bool: - try: - if not path.exists() or (not path.is_file()): - return False - with path.open("rb") as f: - return f.read(len(_SQLITE_HEADER)) == _SQLITE_HEADER - except Exception: - return False + return is_usable_sqlite_db(path) def _list_decrypted_accounts() -> list[str]: diff --git a/src/wechat_decrypt_tool/routers/chat.py b/src/wechat_decrypt_tool/routers/chat.py index c4878fc..8ab0e92 100644 --- a/src/wechat_decrypt_tool/routers/chat.py +++ b/src/wechat_decrypt_tool/routers/chat.py @@ -70,6 +70,7 @@ from ..chat_helpers import ( from ..media_helpers import _resolve_account_db_storage_dir, _try_find_decrypted_resource from .. import chat_edit_store from ..app_paths import get_output_dir +from ..database_filters import list_countable_database_names from ..key_store import remove_account_keys_from_store from ..path_fix import PathFixRoute from ..session_last_message import ( @@ -3892,7 +3893,7 @@ async def list_chat_accounts(): @router.get("/api/chat/account_info", summary="获取当前账号信息") def get_chat_account_info(account: Optional[str] = None): account_dir = _resolve_account_dir(account) - db_files = sorted([p.name for p in account_dir.glob("*.db") if p.is_file()]) + db_files = list_countable_database_names(account_dir) session_db = account_dir / "session.db" session_updated_at = 0 diff --git a/src/wechat_decrypt_tool/routers/decrypt.py b/src/wechat_decrypt_tool/routers/decrypt.py index 14a9957..59f7958 100644 --- a/src/wechat_decrypt_tool/routers/decrypt.py +++ b/src/wechat_decrypt_tool/routers/decrypt.py @@ -16,7 +16,12 @@ from ..chat_realtime_autosync import CHAT_REALTIME_AUTOSYNC from ..logging_config import get_logger from ..path_fix import PathFixRoute from ..key_store import upsert_account_keys_in_store -from ..wechat_decrypt import WeChatDatabaseDecryptor, decrypt_wechat_databases, scan_account_databases_from_path +from ..wechat_decrypt import ( + WeChatDatabaseDecryptor, + build_decrypt_summary_message, + decrypt_wechat_databases, + scan_account_databases_from_path, +) logger = get_logger(__name__) @@ -463,7 +468,11 @@ async def decrypt_databases_stream( "success_count": success_count, "failure_count": total_databases - success_count, "output_directory": str(base_output_dir.absolute()), - "message": f"解密完成: 成功 {success_count}/{total_databases}", + "message": build_decrypt_summary_message( + success_count=success_count, + total_databases=total_databases, + diagnostic_warning_count=diagnostic_warning_count, + ), "processed_files": processed_files, "failed_files": failed_files, "account_results": account_results, diff --git a/src/wechat_decrypt_tool/sqlite_diagnostics.py b/src/wechat_decrypt_tool/sqlite_diagnostics.py index 2cfae67..1700596 100644 --- a/src/wechat_decrypt_tool/sqlite_diagnostics.py +++ b/src/wechat_decrypt_tool/sqlite_diagnostics.py @@ -123,6 +123,40 @@ def collect_sqlite_diagnostics( return diagnostics +def is_usable_sqlite_db(path: str | Path) -> bool: + db_path = Path(path) + if not db_path.exists() or (not db_path.is_file()): + return False + + try: + if int(db_path.stat().st_size) <= len(SQLITE_HEADER): + return False + except Exception: + return False + + try: + with db_path.open("rb") as f: + if f.read(len(SQLITE_HEADER)) != SQLITE_HEADER: + return False + except Exception: + return False + + conn: sqlite3.Connection | None = None + try: + conn = sqlite3.connect(str(db_path)) + conn.execute("PRAGMA schema_version").fetchone() + row = conn.execute("SELECT name FROM sqlite_master WHERE type='table' LIMIT 1").fetchone() + return row is not None + except Exception: + return False + finally: + if conn is not None: + try: + conn.close() + except Exception: + pass + + def sqlite_diagnostics_status(diagnostics: Mapping[str, Any]) -> str: if not diagnostics: return "not_run" diff --git a/src/wechat_decrypt_tool/wechat_decrypt.py b/src/wechat_decrypt_tool/wechat_decrypt.py index e714e7d..529f733 100644 --- a/src/wechat_decrypt_tool/wechat_decrypt.py +++ b/src/wechat_decrypt_tool/wechat_decrypt.py @@ -21,6 +21,7 @@ from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC from .app_paths import get_output_databases_dir +from .database_filters import should_skip_source_database from .sqlite_diagnostics import collect_sqlite_diagnostics, sqlite_diagnostics_status # 注意:不再支持默认密钥,所有密钥必须通过参数传入 @@ -64,6 +65,59 @@ def _derive_account_name_from_path(path: Path) -> str: return "unknown_account" +def _build_decrypt_failure_message(result: dict) -> str: + failed_pages = int(result.get("failed_pages") or 0) + successful_pages = int(result.get("successful_pages") or 0) + diagnostic_status = str(result.get("diagnostic_status") or "").strip() + diagnostics = dict(result.get("diagnostics") or {}) + + detail = ( + diagnostics.get("quick_check_error") + or diagnostics.get("connect_error") + or diagnostics.get("table_list_error") + or diagnostics.get("page_count_error") + or diagnostics.get("quick_check") + or diagnostic_status + ) + detail_text = " ".join(str(detail or "").split()).strip() + + if failed_pages > 0 and successful_pages == 0: + if detail_text: + return f"数据库校验未通过,密钥可能不匹配当前账号: {detail_text}" + return "数据库校验未通过,密钥可能不匹配当前账号" + + if diagnostic_status and diagnostic_status != "ok": + if detail_text: + return f"解密输出不是有效的 SQLite 数据库: {detail_text}" + return "解密输出不是有效的 SQLite 数据库" + + if failed_pages > 0: + return "解密输出包含页失败,结果不完整" + + return "" + + +def build_decrypt_summary_message(*, success_count: int, total_databases: int, diagnostic_warning_count: int) -> str: + success_count = int(success_count or 0) + total_databases = int(total_databases or 0) + diagnostic_warning_count = int(diagnostic_warning_count or 0) + + if total_databases <= 0: + return "未找到可解密的数据库" + + if success_count <= 0: + if diagnostic_warning_count > 0: + return "解密失败:数据库校验未通过,密钥可能不匹配当前账号。" + return "解密失败:未能成功解密任何数据库。" + + if success_count < total_databases: + if diagnostic_warning_count > 0: + return f"解密部分成功:成功 {success_count}/{total_databases},其余数据库校验未通过。" + return f"解密部分成功:成功 {success_count}/{total_databases}。" + + return f"解密完成: 成功 {success_count}/{total_databases}" + + def _resolve_db_storage_roots(storage_path: Path) -> list[Path]: try: target = storage_path.resolve() @@ -158,7 +212,7 @@ def scan_account_databases_from_path(db_storage_path: str) -> dict: for file_name in files: if not file_name.endswith(".db"): continue - if file_name in ["key_info.db"]: + if should_skip_source_database(file_name): continue db_path = os.path.join(root, file_name) databases.append( @@ -266,7 +320,8 @@ class WeChatDatabaseDecryptor: result["failed_page_samples"].append(item) def _finalize(success: bool, error: str = "") -> bool: - result["success"] = bool(success) + normalized_success = bool(success) + result["success"] = normalized_success if error: result["error"] = " ".join(str(error).split()).strip() @@ -281,6 +336,19 @@ class WeChatDatabaseDecryptor: result["diagnostics"] = diagnostics result["diagnostic_status"] = sqlite_diagnostics_status(diagnostics) + if normalized_success: + failure_message = _build_decrypt_failure_message(result) + if failure_message: + normalized_success = False + result["success"] = False + if not result["error"]: + result["error"] = failure_message + if output_file.exists(): + try: + output_file.unlink() + except Exception as exc: + logger.warning("删除无效解密输出失败: %s, 错误: %s", output_file, exc) + payload = { "db_name": result["db_name"], "db_path": result["db_path"], @@ -307,7 +375,7 @@ class WeChatDatabaseDecryptor: log_fn = logger.warning log_fn("[decrypt.diagnostic] %s", json.dumps(payload, ensure_ascii=False, sort_keys=True)) self.last_result = result - return bool(success) + return bool(result["success"]) logger.info(f"开始解密数据库: {db_path}") @@ -693,7 +761,11 @@ def decrypt_wechat_databases(db_storage_path: str = None, key: str = None) -> di # 返回结果 result = { "status": "success" if success_count > 0 else "error", - "message": f"解密完成: 成功 {success_count}/{total_databases}", + "message": build_decrypt_summary_message( + success_count=success_count, + total_databases=total_databases, + diagnostic_warning_count=diagnostic_warning_count, + ), "total_databases": total_databases, "successful_count": success_count, "failed_count": total_databases - success_count, diff --git a/src/wechat_decrypt_tool/wechat_detection.py b/src/wechat_decrypt_tool/wechat_detection.py index d00dc1d..2794579 100644 --- a/src/wechat_decrypt_tool/wechat_detection.py +++ b/src/wechat_decrypt_tool/wechat_detection.py @@ -13,6 +13,8 @@ from typing import List, Dict, Any, Union from ctypes import wintypes from datetime import datetime +from .database_filters import should_skip_source_database + def get_wx_db(msg_dir: str = None, db_types: Union[List[str], str] = None, @@ -59,8 +61,7 @@ def get_wx_db(msg_dir: str = None, for file_name in files: if not file_name.endswith(".db"): continue - # 排除不需要解密的数据库 - if file_name in ["key_info.db"]: + if should_skip_source_database(file_name): continue db_type = re.sub(r"\d*\.db$", "", file_name) if db_types and db_type not in db_types: # 如果指定db_type,则过滤掉其他db_type @@ -672,8 +673,7 @@ def collect_account_databases(data_dir: str, account_name: str) -> List[Dict[str if not file_name.endswith('.db'): continue - # 排除不需要解密的数据库 - if file_name in ["key_info.db"]: + if should_skip_source_database(file_name): continue db_path = os.path.join(root, file_name) diff --git a/tests/test_database_filters.py b/tests/test_database_filters.py new file mode 100644 index 0000000..fff02ae --- /dev/null +++ b/tests/test_database_filters.py @@ -0,0 +1,112 @@ +import importlib +import logging +import os +import sqlite3 +import sys +import unittest +from pathlib import Path +from tempfile import TemporaryDirectory + + +ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(ROOT / "src")) + + +def _close_logging_handlers() -> None: + for logger_name in ("", "uvicorn", "uvicorn.access", "uvicorn.error", "fastapi"): + lg = logging.getLogger(logger_name) + for handler in lg.handlers[:]: + try: + handler.close() + except Exception: + pass + try: + lg.removeHandler(handler) + except Exception: + pass + + +def _seed_sqlite(path: Path, table_name: str = "demo") -> None: + conn = sqlite3.connect(str(path)) + try: + conn.execute(f"CREATE TABLE {table_name}(id INTEGER PRIMARY KEY, value TEXT)") + conn.execute(f"INSERT INTO {table_name}(value) VALUES ('ok')") + conn.commit() + finally: + conn.close() + + +class TestDatabaseFilters(unittest.TestCase): + def test_scan_account_databases_skips_index_databases(self): + from wechat_decrypt_tool.wechat_decrypt import scan_account_databases_from_path + + with TemporaryDirectory() as td: + db_storage = Path(td) / "xwechat_files" / "wxid_demo_user" / "db_storage" + db_storage.mkdir(parents=True, exist_ok=True) + + _seed_sqlite(db_storage / "MSG0.db") + _seed_sqlite(db_storage / "contact_fts.db") + _seed_sqlite(db_storage / "favorite_fts.db") + _seed_sqlite(db_storage / "message_fts.db") + _seed_sqlite(db_storage / "key_info.db") + + result = scan_account_databases_from_path(str(db_storage)) + + self.assertEqual(result["status"], "success") + self.assertEqual(list(result["account_databases"].keys()), ["wxid_demo"]) + db_names = sorted(db["name"] for db in result["account_databases"]["wxid_demo"]) + self.assertEqual(db_names, ["MSG0.db"]) + + def test_collect_account_databases_skips_index_databases(self): + from wechat_decrypt_tool.wechat_detection import collect_account_databases + + with TemporaryDirectory() as td: + data_dir = Path(td) / "wxid_demo_user" + data_dir.mkdir(parents=True, exist_ok=True) + + _seed_sqlite(data_dir / "contact.db", "contact") + _seed_sqlite(data_dir / "contact_fts.db") + _seed_sqlite(data_dir / "favorite_fts.db") + _seed_sqlite(data_dir / "message_fts.db") + + databases = collect_account_databases(str(data_dir), "wxid_demo") + db_names = sorted(db["name"] for db in databases) + self.assertEqual(db_names, ["contact.db"]) + + def test_chat_account_info_hides_index_and_internal_databases(self): + with TemporaryDirectory() as td: + root = Path(td) + prev_data_dir = os.environ.get("WECHAT_TOOL_DATA_DIR") + try: + os.environ["WECHAT_TOOL_DATA_DIR"] = str(root) + + import wechat_decrypt_tool.app_paths as app_paths + import wechat_decrypt_tool.routers.chat as chat_router + + importlib.reload(app_paths) + importlib.reload(chat_router) + + account_dir = root / "output" / "databases" / "wxid_demo" + account_dir.mkdir(parents=True, exist_ok=True) + + _seed_sqlite(account_dir / "contact.db", "contact") + _seed_sqlite(account_dir / "session.db", "session_table") + _seed_sqlite(account_dir / "message_fts.db") + _seed_sqlite(account_dir / "chat_search_index.db") + _seed_sqlite(account_dir / "session_last_message.db") + + result = chat_router.get_chat_account_info("wxid_demo") + + self.assertEqual(result["status"], "success") + self.assertEqual(result["database_count"], 2) + self.assertEqual(result["databases"], ["contact.db", "session.db"]) + finally: + _close_logging_handlers() + if prev_data_dir is None: + os.environ.pop("WECHAT_TOOL_DATA_DIR", None) + else: + os.environ["WECHAT_TOOL_DATA_DIR"] = prev_data_dir + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_decrypt_stream_sse.py b/tests/test_decrypt_stream_sse.py index c041630..2dc0c66 100644 --- a/tests/test_decrypt_stream_sse.py +++ b/tests/test_decrypt_stream_sse.py @@ -1,5 +1,7 @@ import json +import logging import os +import sqlite3 import sys import unittest import importlib @@ -11,13 +13,25 @@ ROOT = Path(__file__).resolve().parents[1] sys.path.insert(0, str(ROOT / "src")) +def _close_logging_handlers() -> None: + for logger_name in ("", "uvicorn", "uvicorn.access", "uvicorn.error", "fastapi"): + lg = logging.getLogger(logger_name) + for handler in lg.handlers[:]: + try: + handler.close() + except Exception: + pass + try: + lg.removeHandler(handler) + except Exception: + pass + + class TestDecryptStreamSSE(unittest.TestCase): def test_decrypt_stream_reports_progress(self): from fastapi import FastAPI from fastapi.testclient import TestClient - from wechat_decrypt_tool.wechat_decrypt import SQLITE_HEADER - with TemporaryDirectory() as td: root = Path(td) @@ -36,8 +50,14 @@ class TestDecryptStreamSSE(unittest.TestCase): db_storage = root / "xwechat_files" / "wxid_foo_bar" / "db_storage" db_storage.mkdir(parents=True, exist_ok=True) - # Fake a decrypted sqlite db (>= 4096 bytes) so decryptor falls back to copy. - (db_storage / "MSG0.db").write_bytes(SQLITE_HEADER + b"\x00" * (4096 - len(SQLITE_HEADER))) + db_path = db_storage / "MSG0.db" + conn = sqlite3.connect(str(db_path)) + try: + conn.execute("CREATE TABLE demo(id INTEGER PRIMARY KEY, value TEXT)") + conn.execute("INSERT INTO demo(value) VALUES ('ok')") + conn.commit() + finally: + conn.close() app = FastAPI() app.include_router(decrypt_router.router) @@ -72,10 +92,83 @@ class TestDecryptStreamSSE(unittest.TestCase): self.assertIn("start", types) self.assertIn("progress", types) self.assertEqual(events[-1].get("type"), "complete") + self.assertEqual(events[-1].get("status"), "completed") out = root / "output" / "databases" / "wxid_foo" / "MSG0.db" self.assertTrue(out.exists()) finally: + _close_logging_handlers() + if prev_data_dir is None: + os.environ.pop("WECHAT_TOOL_DATA_DIR", None) + else: + os.environ["WECHAT_TOOL_DATA_DIR"] = prev_data_dir + if prev_build_cache is None: + os.environ.pop("WECHAT_TOOL_BUILD_SESSION_LAST_MESSAGE", None) + else: + os.environ["WECHAT_TOOL_BUILD_SESSION_LAST_MESSAGE"] = prev_build_cache + + def test_decrypt_stream_marks_invalid_output_as_failed(self): + from fastapi import FastAPI + from fastapi.testclient import TestClient + + with TemporaryDirectory() as td: + root = Path(td) + + prev_data_dir = os.environ.get("WECHAT_TOOL_DATA_DIR") + prev_build_cache = os.environ.get("WECHAT_TOOL_BUILD_SESSION_LAST_MESSAGE") + try: + os.environ["WECHAT_TOOL_DATA_DIR"] = str(root) + os.environ["WECHAT_TOOL_BUILD_SESSION_LAST_MESSAGE"] = "0" + + import wechat_decrypt_tool.app_paths as app_paths + import wechat_decrypt_tool.routers.decrypt as decrypt_router + + importlib.reload(app_paths) + importlib.reload(decrypt_router) + + db_storage = root / "xwechat_files" / "wxid_bad_case" / "db_storage" + db_storage.mkdir(parents=True, exist_ok=True) + (db_storage / "MSG0.db").write_bytes(b"\x01" * 4096) + + app = FastAPI() + app.include_router(decrypt_router.router) + client = TestClient(app) + + events: list[dict] = [] + with client.stream( + "GET", + "/api/decrypt_stream", + params={"key": "00" * 32, "db_storage_path": str(db_storage)}, + ) as resp: + self.assertEqual(resp.status_code, 200) + self.assertIn("text/event-stream", resp.headers.get("content-type", "")) + + for line in resp.iter_lines(): + if not line: + continue + if isinstance(line, bytes): + line = line.decode("utf-8", errors="ignore") + line = str(line) + + if line.startswith(":"): + continue + if not line.startswith("data: "): + continue + payload = json.loads(line[len("data: ") :]) + events.append(payload) + if payload.get("type") in {"complete", "error"}: + break + + self.assertEqual(events[-1].get("type"), "complete") + self.assertEqual(events[-1].get("status"), "failed") + self.assertEqual(events[-1].get("success_count"), 0) + self.assertEqual(events[-1].get("failure_count"), 1) + self.assertIn("密钥可能不匹配", str(events[-1].get("message") or "")) + + out = root / "output" / "databases" / "wxid_bad" / "MSG0.db" + self.assertFalse(out.exists()) + finally: + _close_logging_handlers() if prev_data_dir is None: os.environ.pop("WECHAT_TOOL_DATA_DIR", None) else: diff --git a/tests/test_decrypted_account_validation.py b/tests/test_decrypted_account_validation.py new file mode 100644 index 0000000..1bbb414 --- /dev/null +++ b/tests/test_decrypted_account_validation.py @@ -0,0 +1,61 @@ +import importlib +import os +import sqlite3 +import sys +import unittest +from pathlib import Path +from tempfile import TemporaryDirectory + + +ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(ROOT / "src")) + + +def _seed_sqlite(path: Path, table_name: str) -> None: + conn = sqlite3.connect(str(path)) + try: + conn.execute(f"CREATE TABLE {table_name}(id INTEGER PRIMARY KEY, value TEXT)") + conn.execute(f"INSERT INTO {table_name}(value) VALUES ('ok')") + conn.commit() + finally: + conn.close() + + +class TestDecryptedAccountValidation(unittest.TestCase): + def test_invalid_header_only_databases_are_ignored(self): + with TemporaryDirectory() as td: + root = Path(td) + prev_data_dir = os.environ.get("WECHAT_TOOL_DATA_DIR") + try: + os.environ["WECHAT_TOOL_DATA_DIR"] = str(root) + + import wechat_decrypt_tool.app_paths as app_paths + import wechat_decrypt_tool.chat_helpers as chat_helpers + import wechat_decrypt_tool.media_helpers as media_helpers + + importlib.reload(app_paths) + importlib.reload(chat_helpers) + importlib.reload(media_helpers) + + output_dir = root / "output" / "databases" + bad_dir = output_dir / "wxid_bad" + bad_dir.mkdir(parents=True, exist_ok=True) + (bad_dir / "session.db").write_bytes(b"SQLite format 3\x00") + (bad_dir / "contact.db").write_bytes(b"SQLite format 3\x00") + + good_dir = output_dir / "wxid_good" + good_dir.mkdir(parents=True, exist_ok=True) + _seed_sqlite(good_dir / "session.db", "SessionTable") + _seed_sqlite(good_dir / "contact.db", "contact") + + self.assertEqual(chat_helpers._list_decrypted_accounts(), ["wxid_good"]) + self.assertEqual(media_helpers._list_decrypted_accounts(), ["wxid_good"]) + finally: + if prev_data_dir is None: + os.environ.pop("WECHAT_TOOL_DATA_DIR", None) + else: + os.environ["WECHAT_TOOL_DATA_DIR"] = prev_data_dir + + +if __name__ == "__main__": + unittest.main()