fix(decrypt): 加强解密结果校验并过滤内部数据库

- 新增可用 SQLite 校验,解密失败时返回更明确提示并清理无效输出

- 统一过滤 key_info、FTS 索引库和内部缓存库,修正数据库扫描与账号统计

- 补充解密流和数据库过滤相关测试
This commit is contained in:
2977094657
2026-04-11 16:57:01 +08:00
Unverified
parent 604e01eb50
commit 4f9acd3f69
11 changed files with 462 additions and 30 deletions
+2 -8
View File
@@ -14,7 +14,7 @@ from fastapi import HTTPException
from .app_paths import get_output_databases_dir
from .logging_config import get_logger
from .sqlite_diagnostics import collect_sqlite_diagnostics, format_sqlite_diagnostics
from .sqlite_diagnostics import collect_sqlite_diagnostics, format_sqlite_diagnostics, is_usable_sqlite_db
try:
import zstandard as zstd # type: ignore
@@ -29,13 +29,7 @@ _SQLITE_HEADER = b"SQLite format 3\x00"
def _is_valid_decrypted_sqlite(path: Path) -> bool:
try:
if not path.exists() or (not path.is_file()):
return False
with path.open("rb") as f:
return f.read(len(_SQLITE_HEADER)) == _SQLITE_HEADER
except Exception:
return False
return is_usable_sqlite_db(path)
def _list_decrypted_accounts() -> list[str]:
@@ -0,0 +1,61 @@
from __future__ import annotations
from pathlib import Path
_IGNORED_SOURCE_DATABASE_NAMES = frozenset({"key_info.db"})
_INDEX_DATABASE_NAMES = frozenset({"chat_search_index.db", "chat_search_index.tmp.db"})
_INDEX_DATABASE_SUFFIXES = ("_fts.db",)
_INTERNAL_OUTPUT_DATABASE_NAMES = frozenset(
{
"chat_search_index.db",
"chat_search_index.tmp.db",
"session_last_message.db",
}
)
def normalize_database_file_name(file_name: str | Path) -> str:
return Path(str(file_name or "")).name.strip().lower()
def is_index_database_file(file_name: str | Path) -> bool:
lower_name = normalize_database_file_name(file_name)
if not lower_name:
return False
if lower_name in _INDEX_DATABASE_NAMES:
return True
return lower_name.endswith(_INDEX_DATABASE_SUFFIXES)
def should_skip_source_database(file_name: str | Path) -> bool:
lower_name = normalize_database_file_name(file_name)
if not lower_name:
return True
if lower_name in _IGNORED_SOURCE_DATABASE_NAMES:
return True
return is_index_database_file(lower_name)
def should_include_in_database_count(file_name: str | Path) -> bool:
lower_name = normalize_database_file_name(file_name)
if not lower_name.endswith(".db"):
return False
if should_skip_source_database(lower_name):
return False
if lower_name in _INTERNAL_OUTPUT_DATABASE_NAMES:
return False
return True
def list_countable_database_names(account_dir: Path) -> list[str]:
if not account_dir.exists():
return []
db_files = [
path.name
for path in account_dir.glob("*.db")
if path.is_file() and should_include_in_database_count(path.name)
]
db_files.sort()
return db_files
+2 -7
View File
@@ -18,6 +18,7 @@ from fastapi import HTTPException
from .app_paths import get_output_databases_dir
from .logging_config import get_logger
from .sqlite_diagnostics import is_usable_sqlite_db
logger = get_logger(__name__)
@@ -28,13 +29,7 @@ _SQLITE_HEADER = b"SQLite format 3\x00"
def _is_valid_decrypted_sqlite(path: Path) -> bool:
try:
if not path.exists() or (not path.is_file()):
return False
with path.open("rb") as f:
return f.read(len(_SQLITE_HEADER)) == _SQLITE_HEADER
except Exception:
return False
return is_usable_sqlite_db(path)
def _list_decrypted_accounts() -> list[str]:
+2 -1
View File
@@ -70,6 +70,7 @@ from ..chat_helpers import (
from ..media_helpers import _resolve_account_db_storage_dir, _try_find_decrypted_resource
from .. import chat_edit_store
from ..app_paths import get_output_dir
from ..database_filters import list_countable_database_names
from ..key_store import remove_account_keys_from_store
from ..path_fix import PathFixRoute
from ..session_last_message import (
@@ -3892,7 +3893,7 @@ async def list_chat_accounts():
@router.get("/api/chat/account_info", summary="获取当前账号信息")
def get_chat_account_info(account: Optional[str] = None):
account_dir = _resolve_account_dir(account)
db_files = sorted([p.name for p in account_dir.glob("*.db") if p.is_file()])
db_files = list_countable_database_names(account_dir)
session_db = account_dir / "session.db"
session_updated_at = 0
+11 -2
View File
@@ -16,7 +16,12 @@ from ..chat_realtime_autosync import CHAT_REALTIME_AUTOSYNC
from ..logging_config import get_logger
from ..path_fix import PathFixRoute
from ..key_store import upsert_account_keys_in_store
from ..wechat_decrypt import WeChatDatabaseDecryptor, decrypt_wechat_databases, scan_account_databases_from_path
from ..wechat_decrypt import (
WeChatDatabaseDecryptor,
build_decrypt_summary_message,
decrypt_wechat_databases,
scan_account_databases_from_path,
)
logger = get_logger(__name__)
@@ -463,7 +468,11 @@ async def decrypt_databases_stream(
"success_count": success_count,
"failure_count": total_databases - success_count,
"output_directory": str(base_output_dir.absolute()),
"message": f"解密完成: 成功 {success_count}/{total_databases}",
"message": build_decrypt_summary_message(
success_count=success_count,
total_databases=total_databases,
diagnostic_warning_count=diagnostic_warning_count,
),
"processed_files": processed_files,
"failed_files": failed_files,
"account_results": account_results,
@@ -123,6 +123,40 @@ def collect_sqlite_diagnostics(
return diagnostics
def is_usable_sqlite_db(path: str | Path) -> bool:
db_path = Path(path)
if not db_path.exists() or (not db_path.is_file()):
return False
try:
if int(db_path.stat().st_size) <= len(SQLITE_HEADER):
return False
except Exception:
return False
try:
with db_path.open("rb") as f:
if f.read(len(SQLITE_HEADER)) != SQLITE_HEADER:
return False
except Exception:
return False
conn: sqlite3.Connection | None = None
try:
conn = sqlite3.connect(str(db_path))
conn.execute("PRAGMA schema_version").fetchone()
row = conn.execute("SELECT name FROM sqlite_master WHERE type='table' LIMIT 1").fetchone()
return row is not None
except Exception:
return False
finally:
if conn is not None:
try:
conn.close()
except Exception:
pass
def sqlite_diagnostics_status(diagnostics: Mapping[str, Any]) -> str:
if not diagnostics:
return "not_run"
+76 -4
View File
@@ -21,6 +21,7 @@ from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
from .app_paths import get_output_databases_dir
from .database_filters import should_skip_source_database
from .sqlite_diagnostics import collect_sqlite_diagnostics, sqlite_diagnostics_status
# 注意:不再支持默认密钥,所有密钥必须通过参数传入
@@ -64,6 +65,59 @@ def _derive_account_name_from_path(path: Path) -> str:
return "unknown_account"
def _build_decrypt_failure_message(result: dict) -> str:
failed_pages = int(result.get("failed_pages") or 0)
successful_pages = int(result.get("successful_pages") or 0)
diagnostic_status = str(result.get("diagnostic_status") or "").strip()
diagnostics = dict(result.get("diagnostics") or {})
detail = (
diagnostics.get("quick_check_error")
or diagnostics.get("connect_error")
or diagnostics.get("table_list_error")
or diagnostics.get("page_count_error")
or diagnostics.get("quick_check")
or diagnostic_status
)
detail_text = " ".join(str(detail or "").split()).strip()
if failed_pages > 0 and successful_pages == 0:
if detail_text:
return f"数据库校验未通过,密钥可能不匹配当前账号: {detail_text}"
return "数据库校验未通过,密钥可能不匹配当前账号"
if diagnostic_status and diagnostic_status != "ok":
if detail_text:
return f"解密输出不是有效的 SQLite 数据库: {detail_text}"
return "解密输出不是有效的 SQLite 数据库"
if failed_pages > 0:
return "解密输出包含页失败,结果不完整"
return ""
def build_decrypt_summary_message(*, success_count: int, total_databases: int, diagnostic_warning_count: int) -> str:
success_count = int(success_count or 0)
total_databases = int(total_databases or 0)
diagnostic_warning_count = int(diagnostic_warning_count or 0)
if total_databases <= 0:
return "未找到可解密的数据库"
if success_count <= 0:
if diagnostic_warning_count > 0:
return "解密失败:数据库校验未通过,密钥可能不匹配当前账号。"
return "解密失败:未能成功解密任何数据库。"
if success_count < total_databases:
if diagnostic_warning_count > 0:
return f"解密部分成功:成功 {success_count}/{total_databases},其余数据库校验未通过。"
return f"解密部分成功:成功 {success_count}/{total_databases}"
return f"解密完成: 成功 {success_count}/{total_databases}"
def _resolve_db_storage_roots(storage_path: Path) -> list[Path]:
try:
target = storage_path.resolve()
@@ -158,7 +212,7 @@ def scan_account_databases_from_path(db_storage_path: str) -> dict:
for file_name in files:
if not file_name.endswith(".db"):
continue
if file_name in ["key_info.db"]:
if should_skip_source_database(file_name):
continue
db_path = os.path.join(root, file_name)
databases.append(
@@ -266,7 +320,8 @@ class WeChatDatabaseDecryptor:
result["failed_page_samples"].append(item)
def _finalize(success: bool, error: str = "") -> bool:
result["success"] = bool(success)
normalized_success = bool(success)
result["success"] = normalized_success
if error:
result["error"] = " ".join(str(error).split()).strip()
@@ -281,6 +336,19 @@ class WeChatDatabaseDecryptor:
result["diagnostics"] = diagnostics
result["diagnostic_status"] = sqlite_diagnostics_status(diagnostics)
if normalized_success:
failure_message = _build_decrypt_failure_message(result)
if failure_message:
normalized_success = False
result["success"] = False
if not result["error"]:
result["error"] = failure_message
if output_file.exists():
try:
output_file.unlink()
except Exception as exc:
logger.warning("删除无效解密输出失败: %s, 错误: %s", output_file, exc)
payload = {
"db_name": result["db_name"],
"db_path": result["db_path"],
@@ -307,7 +375,7 @@ class WeChatDatabaseDecryptor:
log_fn = logger.warning
log_fn("[decrypt.diagnostic] %s", json.dumps(payload, ensure_ascii=False, sort_keys=True))
self.last_result = result
return bool(success)
return bool(result["success"])
logger.info(f"开始解密数据库: {db_path}")
@@ -693,7 +761,11 @@ def decrypt_wechat_databases(db_storage_path: str = None, key: str = None) -> di
# 返回结果
result = {
"status": "success" if success_count > 0 else "error",
"message": f"解密完成: 成功 {success_count}/{total_databases}",
"message": build_decrypt_summary_message(
success_count=success_count,
total_databases=total_databases,
diagnostic_warning_count=diagnostic_warning_count,
),
"total_databases": total_databases,
"successful_count": success_count,
"failed_count": total_databases - success_count,
+4 -4
View File
@@ -13,6 +13,8 @@ from typing import List, Dict, Any, Union
from ctypes import wintypes
from datetime import datetime
from .database_filters import should_skip_source_database
def get_wx_db(msg_dir: str = None,
db_types: Union[List[str], str] = None,
@@ -59,8 +61,7 @@ def get_wx_db(msg_dir: str = None,
for file_name in files:
if not file_name.endswith(".db"):
continue
# 排除不需要解密的数据库
if file_name in ["key_info.db"]:
if should_skip_source_database(file_name):
continue
db_type = re.sub(r"\d*\.db$", "", file_name)
if db_types and db_type not in db_types: # 如果指定db_type,则过滤掉其他db_type
@@ -672,8 +673,7 @@ def collect_account_databases(data_dir: str, account_name: str) -> List[Dict[str
if not file_name.endswith('.db'):
continue
# 排除不需要解密的数据库
if file_name in ["key_info.db"]:
if should_skip_source_database(file_name):
continue
db_path = os.path.join(root, file_name)
+112
View File
@@ -0,0 +1,112 @@
import importlib
import logging
import os
import sqlite3
import sys
import unittest
from pathlib import Path
from tempfile import TemporaryDirectory
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT / "src"))
def _close_logging_handlers() -> None:
for logger_name in ("", "uvicorn", "uvicorn.access", "uvicorn.error", "fastapi"):
lg = logging.getLogger(logger_name)
for handler in lg.handlers[:]:
try:
handler.close()
except Exception:
pass
try:
lg.removeHandler(handler)
except Exception:
pass
def _seed_sqlite(path: Path, table_name: str = "demo") -> None:
conn = sqlite3.connect(str(path))
try:
conn.execute(f"CREATE TABLE {table_name}(id INTEGER PRIMARY KEY, value TEXT)")
conn.execute(f"INSERT INTO {table_name}(value) VALUES ('ok')")
conn.commit()
finally:
conn.close()
class TestDatabaseFilters(unittest.TestCase):
def test_scan_account_databases_skips_index_databases(self):
from wechat_decrypt_tool.wechat_decrypt import scan_account_databases_from_path
with TemporaryDirectory() as td:
db_storage = Path(td) / "xwechat_files" / "wxid_demo_user" / "db_storage"
db_storage.mkdir(parents=True, exist_ok=True)
_seed_sqlite(db_storage / "MSG0.db")
_seed_sqlite(db_storage / "contact_fts.db")
_seed_sqlite(db_storage / "favorite_fts.db")
_seed_sqlite(db_storage / "message_fts.db")
_seed_sqlite(db_storage / "key_info.db")
result = scan_account_databases_from_path(str(db_storage))
self.assertEqual(result["status"], "success")
self.assertEqual(list(result["account_databases"].keys()), ["wxid_demo"])
db_names = sorted(db["name"] for db in result["account_databases"]["wxid_demo"])
self.assertEqual(db_names, ["MSG0.db"])
def test_collect_account_databases_skips_index_databases(self):
from wechat_decrypt_tool.wechat_detection import collect_account_databases
with TemporaryDirectory() as td:
data_dir = Path(td) / "wxid_demo_user"
data_dir.mkdir(parents=True, exist_ok=True)
_seed_sqlite(data_dir / "contact.db", "contact")
_seed_sqlite(data_dir / "contact_fts.db")
_seed_sqlite(data_dir / "favorite_fts.db")
_seed_sqlite(data_dir / "message_fts.db")
databases = collect_account_databases(str(data_dir), "wxid_demo")
db_names = sorted(db["name"] for db in databases)
self.assertEqual(db_names, ["contact.db"])
def test_chat_account_info_hides_index_and_internal_databases(self):
with TemporaryDirectory() as td:
root = Path(td)
prev_data_dir = os.environ.get("WECHAT_TOOL_DATA_DIR")
try:
os.environ["WECHAT_TOOL_DATA_DIR"] = str(root)
import wechat_decrypt_tool.app_paths as app_paths
import wechat_decrypt_tool.routers.chat as chat_router
importlib.reload(app_paths)
importlib.reload(chat_router)
account_dir = root / "output" / "databases" / "wxid_demo"
account_dir.mkdir(parents=True, exist_ok=True)
_seed_sqlite(account_dir / "contact.db", "contact")
_seed_sqlite(account_dir / "session.db", "session_table")
_seed_sqlite(account_dir / "message_fts.db")
_seed_sqlite(account_dir / "chat_search_index.db")
_seed_sqlite(account_dir / "session_last_message.db")
result = chat_router.get_chat_account_info("wxid_demo")
self.assertEqual(result["status"], "success")
self.assertEqual(result["database_count"], 2)
self.assertEqual(result["databases"], ["contact.db", "session.db"])
finally:
_close_logging_handlers()
if prev_data_dir is None:
os.environ.pop("WECHAT_TOOL_DATA_DIR", None)
else:
os.environ["WECHAT_TOOL_DATA_DIR"] = prev_data_dir
if __name__ == "__main__":
unittest.main()
+97 -4
View File
@@ -1,5 +1,7 @@
import json
import logging
import os
import sqlite3
import sys
import unittest
import importlib
@@ -11,13 +13,25 @@ ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT / "src"))
def _close_logging_handlers() -> None:
for logger_name in ("", "uvicorn", "uvicorn.access", "uvicorn.error", "fastapi"):
lg = logging.getLogger(logger_name)
for handler in lg.handlers[:]:
try:
handler.close()
except Exception:
pass
try:
lg.removeHandler(handler)
except Exception:
pass
class TestDecryptStreamSSE(unittest.TestCase):
def test_decrypt_stream_reports_progress(self):
from fastapi import FastAPI
from fastapi.testclient import TestClient
from wechat_decrypt_tool.wechat_decrypt import SQLITE_HEADER
with TemporaryDirectory() as td:
root = Path(td)
@@ -36,8 +50,14 @@ class TestDecryptStreamSSE(unittest.TestCase):
db_storage = root / "xwechat_files" / "wxid_foo_bar" / "db_storage"
db_storage.mkdir(parents=True, exist_ok=True)
# Fake a decrypted sqlite db (>= 4096 bytes) so decryptor falls back to copy.
(db_storage / "MSG0.db").write_bytes(SQLITE_HEADER + b"\x00" * (4096 - len(SQLITE_HEADER)))
db_path = db_storage / "MSG0.db"
conn = sqlite3.connect(str(db_path))
try:
conn.execute("CREATE TABLE demo(id INTEGER PRIMARY KEY, value TEXT)")
conn.execute("INSERT INTO demo(value) VALUES ('ok')")
conn.commit()
finally:
conn.close()
app = FastAPI()
app.include_router(decrypt_router.router)
@@ -72,10 +92,83 @@ class TestDecryptStreamSSE(unittest.TestCase):
self.assertIn("start", types)
self.assertIn("progress", types)
self.assertEqual(events[-1].get("type"), "complete")
self.assertEqual(events[-1].get("status"), "completed")
out = root / "output" / "databases" / "wxid_foo" / "MSG0.db"
self.assertTrue(out.exists())
finally:
_close_logging_handlers()
if prev_data_dir is None:
os.environ.pop("WECHAT_TOOL_DATA_DIR", None)
else:
os.environ["WECHAT_TOOL_DATA_DIR"] = prev_data_dir
if prev_build_cache is None:
os.environ.pop("WECHAT_TOOL_BUILD_SESSION_LAST_MESSAGE", None)
else:
os.environ["WECHAT_TOOL_BUILD_SESSION_LAST_MESSAGE"] = prev_build_cache
def test_decrypt_stream_marks_invalid_output_as_failed(self):
from fastapi import FastAPI
from fastapi.testclient import TestClient
with TemporaryDirectory() as td:
root = Path(td)
prev_data_dir = os.environ.get("WECHAT_TOOL_DATA_DIR")
prev_build_cache = os.environ.get("WECHAT_TOOL_BUILD_SESSION_LAST_MESSAGE")
try:
os.environ["WECHAT_TOOL_DATA_DIR"] = str(root)
os.environ["WECHAT_TOOL_BUILD_SESSION_LAST_MESSAGE"] = "0"
import wechat_decrypt_tool.app_paths as app_paths
import wechat_decrypt_tool.routers.decrypt as decrypt_router
importlib.reload(app_paths)
importlib.reload(decrypt_router)
db_storage = root / "xwechat_files" / "wxid_bad_case" / "db_storage"
db_storage.mkdir(parents=True, exist_ok=True)
(db_storage / "MSG0.db").write_bytes(b"\x01" * 4096)
app = FastAPI()
app.include_router(decrypt_router.router)
client = TestClient(app)
events: list[dict] = []
with client.stream(
"GET",
"/api/decrypt_stream",
params={"key": "00" * 32, "db_storage_path": str(db_storage)},
) as resp:
self.assertEqual(resp.status_code, 200)
self.assertIn("text/event-stream", resp.headers.get("content-type", ""))
for line in resp.iter_lines():
if not line:
continue
if isinstance(line, bytes):
line = line.decode("utf-8", errors="ignore")
line = str(line)
if line.startswith(":"):
continue
if not line.startswith("data: "):
continue
payload = json.loads(line[len("data: ") :])
events.append(payload)
if payload.get("type") in {"complete", "error"}:
break
self.assertEqual(events[-1].get("type"), "complete")
self.assertEqual(events[-1].get("status"), "failed")
self.assertEqual(events[-1].get("success_count"), 0)
self.assertEqual(events[-1].get("failure_count"), 1)
self.assertIn("密钥可能不匹配", str(events[-1].get("message") or ""))
out = root / "output" / "databases" / "wxid_bad" / "MSG0.db"
self.assertFalse(out.exists())
finally:
_close_logging_handlers()
if prev_data_dir is None:
os.environ.pop("WECHAT_TOOL_DATA_DIR", None)
else:
@@ -0,0 +1,61 @@
import importlib
import os
import sqlite3
import sys
import unittest
from pathlib import Path
from tempfile import TemporaryDirectory
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT / "src"))
def _seed_sqlite(path: Path, table_name: str) -> None:
conn = sqlite3.connect(str(path))
try:
conn.execute(f"CREATE TABLE {table_name}(id INTEGER PRIMARY KEY, value TEXT)")
conn.execute(f"INSERT INTO {table_name}(value) VALUES ('ok')")
conn.commit()
finally:
conn.close()
class TestDecryptedAccountValidation(unittest.TestCase):
def test_invalid_header_only_databases_are_ignored(self):
with TemporaryDirectory() as td:
root = Path(td)
prev_data_dir = os.environ.get("WECHAT_TOOL_DATA_DIR")
try:
os.environ["WECHAT_TOOL_DATA_DIR"] = str(root)
import wechat_decrypt_tool.app_paths as app_paths
import wechat_decrypt_tool.chat_helpers as chat_helpers
import wechat_decrypt_tool.media_helpers as media_helpers
importlib.reload(app_paths)
importlib.reload(chat_helpers)
importlib.reload(media_helpers)
output_dir = root / "output" / "databases"
bad_dir = output_dir / "wxid_bad"
bad_dir.mkdir(parents=True, exist_ok=True)
(bad_dir / "session.db").write_bytes(b"SQLite format 3\x00")
(bad_dir / "contact.db").write_bytes(b"SQLite format 3\x00")
good_dir = output_dir / "wxid_good"
good_dir.mkdir(parents=True, exist_ok=True)
_seed_sqlite(good_dir / "session.db", "SessionTable")
_seed_sqlite(good_dir / "contact.db", "contact")
self.assertEqual(chat_helpers._list_decrypted_accounts(), ["wxid_good"])
self.assertEqual(media_helpers._list_decrypted_accounts(), ["wxid_good"])
finally:
if prev_data_dir is None:
os.environ.pop("WECHAT_TOOL_DATA_DIR", None)
else:
os.environ["WECHAT_TOOL_DATA_DIR"] = prev_data_dir
if __name__ == "__main__":
unittest.main()