fix(key): 支持手动指定微信安装目录并校验 db key 来源

- /api/get_keys 支持传入 wechat_install_path,兼容安装目录与 Weixin.exe / WeChat.exe

- 解密完成后保存 db key 的来源路径与别名,避免历史密钥被错误账号复用

- 解密页按 account + db_storage_path 回填已保存密钥,并补充相关测试覆盖
This commit is contained in:
2977094657
2026-04-23 21:32:02 +08:00
Unverified
parent ec2a84af18
commit 0987167c4a
12 changed files with 729 additions and 77 deletions
+53 -42
View File
@@ -7,6 +7,7 @@ import unittest
import importlib
from pathlib import Path
from tempfile import TemporaryDirectory
from unittest import mock
ROOT = Path(__file__).resolve().parents[1]
@@ -64,35 +65,43 @@ class TestDecryptStreamSSE(unittest.TestCase):
client = TestClient(app)
events: list[dict] = []
with client.stream(
"GET",
"/api/decrypt_stream",
params={"key": "00" * 32, "db_storage_path": str(db_storage)},
) as resp:
self.assertEqual(resp.status_code, 200)
self.assertIn("text/event-stream", resp.headers.get("content-type", ""))
with mock.patch.object(decrypt_router, "upsert_account_keys_in_store") as upsert_mock:
with client.stream(
"GET",
"/api/decrypt_stream",
params={"key": "00" * 32, "db_storage_path": str(db_storage)},
) as resp:
self.assertEqual(resp.status_code, 200)
self.assertIn("text/event-stream", resp.headers.get("content-type", ""))
for line in resp.iter_lines():
if not line:
continue
if isinstance(line, bytes):
line = line.decode("utf-8", errors="ignore")
line = str(line)
for line in resp.iter_lines():
if not line:
continue
if isinstance(line, bytes):
line = line.decode("utf-8", errors="ignore")
line = str(line)
if line.startswith(":"):
continue
if not line.startswith("data: "):
continue
payload = json.loads(line[len("data: ") :])
events.append(payload)
if payload.get("type") in {"complete", "error"}:
break
if line.startswith(":"):
continue
if not line.startswith("data: "):
continue
payload = json.loads(line[len("data: ") :])
events.append(payload)
if payload.get("type") in {"complete", "error"}:
break
types = {e.get("type") for e in events}
self.assertIn("start", types)
self.assertIn("progress", types)
self.assertEqual(events[-1].get("type"), "complete")
self.assertEqual(events[-1].get("status"), "completed")
upsert_mock.assert_called_once_with(
"wxid_foo",
db_key="00" * 32,
aliases=["wxid_foo_bar"],
db_key_source_wxid_dir=str(db_storage.parent),
db_key_source_db_storage_path=str(db_storage),
)
out = root / "output" / "databases" / "wxid_foo" / "MSG0.db"
self.assertTrue(out.exists())
@@ -135,35 +144,37 @@ class TestDecryptStreamSSE(unittest.TestCase):
client = TestClient(app)
events: list[dict] = []
with client.stream(
"GET",
"/api/decrypt_stream",
params={"key": "00" * 32, "db_storage_path": str(db_storage)},
) as resp:
self.assertEqual(resp.status_code, 200)
self.assertIn("text/event-stream", resp.headers.get("content-type", ""))
with mock.patch.object(decrypt_router, "upsert_account_keys_in_store") as upsert_mock:
with client.stream(
"GET",
"/api/decrypt_stream",
params={"key": "00" * 32, "db_storage_path": str(db_storage)},
) as resp:
self.assertEqual(resp.status_code, 200)
self.assertIn("text/event-stream", resp.headers.get("content-type", ""))
for line in resp.iter_lines():
if not line:
continue
if isinstance(line, bytes):
line = line.decode("utf-8", errors="ignore")
line = str(line)
for line in resp.iter_lines():
if not line:
continue
if isinstance(line, bytes):
line = line.decode("utf-8", errors="ignore")
line = str(line)
if line.startswith(":"):
continue
if not line.startswith("data: "):
continue
payload = json.loads(line[len("data: ") :])
events.append(payload)
if payload.get("type") in {"complete", "error"}:
break
if line.startswith(":"):
continue
if not line.startswith("data: "):
continue
payload = json.loads(line[len("data: ") :])
events.append(payload)
if payload.get("type") in {"complete", "error"}:
break
self.assertEqual(events[-1].get("type"), "complete")
self.assertEqual(events[-1].get("status"), "failed")
self.assertEqual(events[-1].get("success_count"), 0)
self.assertEqual(events[-1].get("failure_count"), 1)
self.assertIn("密钥可能不匹配", str(events[-1].get("message") or ""))
upsert_mock.assert_not_called()
out = root / "output" / "databases" / "wxid_bad" / "MSG0.db"
self.assertFalse(out.exists())
@@ -0,0 +1,115 @@
import sys
import unittest
from pathlib import Path
from tempfile import TemporaryDirectory
from unittest import mock
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT / "src"))
import wechat_decrypt_tool.key_service as key_service
class _FakeWxKey:
def __init__(self, key: str) -> None:
self.key = key
self.initialize_calls: list[int] = []
self.cleanup_calls = 0
def initialize_hook(self, pid: int) -> bool:
self.initialize_calls.append(pid)
return True
def get_last_error_msg(self) -> str:
return ""
def poll_key_data(self):
return {"key": self.key}
def get_status_message(self):
return None, None
def cleanup_hook(self) -> None:
self.cleanup_calls += 1
class TestKeyServiceManualWechatInstallPath(unittest.TestCase):
def test_get_db_key_workflow_can_use_manual_install_directory(self) -> None:
fake_wx_key = _FakeWxKey("a" * 64)
with TemporaryDirectory() as temp_dir:
install_dir = Path(temp_dir)
exe_path = install_dir / "WeChat.exe"
exe_path.write_bytes(b"")
with mock.patch.object(
key_service,
"wx_key",
fake_wx_key,
), mock.patch.object(
key_service,
"detect_wechat_installation",
side_effect=AssertionError("should not auto-detect when manual path is provided"),
), mock.patch.object(
key_service,
"_read_wechat_version_from_exe",
return_value="",
), mock.patch.object(
key_service.WeChatKeyFetcher,
"kill_wechat",
autospec=True,
) as kill_mock, mock.patch.object(
key_service.WeChatKeyFetcher,
"launch_wechat",
autospec=True,
return_value=4321,
) as launch_mock:
result = key_service.get_db_key_workflow(wechat_install_path=str(install_dir))
self.assertEqual(result["db_key"], "a" * 64)
kill_mock.assert_called_once()
launch_mock.assert_called_once()
_, used_exe_path = launch_mock.call_args.args
self.assertEqual(used_exe_path, str(exe_path))
self.assertEqual(fake_wx_key.initialize_calls, [4321])
self.assertEqual(fake_wx_key.cleanup_calls, 1)
def test_get_db_key_workflow_does_not_require_detected_version(self) -> None:
fake_wx_key = _FakeWxKey("b" * 64)
with TemporaryDirectory() as temp_dir:
exe_path = Path(temp_dir) / "Weixin.exe"
exe_path.write_bytes(b"")
with mock.patch.object(
key_service,
"wx_key",
fake_wx_key,
), mock.patch.object(
key_service,
"detect_wechat_installation",
return_value={
"wechat_exe_path": str(exe_path),
"wechat_version": "",
},
), mock.patch.object(
key_service.WeChatKeyFetcher,
"kill_wechat",
autospec=True,
), mock.patch.object(
key_service.WeChatKeyFetcher,
"launch_wechat",
autospec=True,
return_value=2468,
):
result = key_service.get_db_key_workflow()
self.assertEqual(result["db_key"], "b" * 64)
self.assertEqual(fake_wx_key.initialize_calls, [2468])
self.assertEqual(fake_wx_key.cleanup_calls, 1)
if __name__ == "__main__":
unittest.main()
@@ -0,0 +1,107 @@
import asyncio
import importlib
import logging
import os
import sys
import unittest
from pathlib import Path
from tempfile import TemporaryDirectory
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT / "src"))
def _close_logging_handlers() -> None:
for logger_name in ("", "uvicorn", "uvicorn.access", "uvicorn.error", "fastapi"):
lg = logging.getLogger(logger_name)
for handler in lg.handlers[:]:
try:
handler.close()
except Exception:
pass
try:
lg.removeHandler(handler)
except Exception:
pass
class TestSavedDbKeySourceValidation(unittest.TestCase):
def test_get_saved_keys_blocks_legacy_db_key_for_suffixed_wxid_dir(self) -> None:
with TemporaryDirectory() as td:
root = Path(td)
db_storage = root / "xwechat_files" / "wxid_demo_abcd" / "db_storage"
db_storage.mkdir(parents=True, exist_ok=True)
prev_data_dir = os.environ.get("WECHAT_TOOL_DATA_DIR")
try:
os.environ["WECHAT_TOOL_DATA_DIR"] = str(root)
import wechat_decrypt_tool.app_paths as app_paths
import wechat_decrypt_tool.key_store as key_store
import wechat_decrypt_tool.routers.keys as keys_router
importlib.reload(app_paths)
importlib.reload(key_store)
importlib.reload(keys_router)
key_store.upsert_account_keys_in_store("wxid_demo", db_key="A" * 64)
result = asyncio.run(
keys_router.get_saved_keys(account="wxid_demo", db_storage_path=str(db_storage))
)
self.assertEqual(result["status"], "success")
self.assertEqual(result["keys"]["db_key"], "")
self.assertIn("Legacy saved db key is ambiguous", result["keys"]["db_key_blocked_reason"])
finally:
_close_logging_handlers()
if prev_data_dir is None:
os.environ.pop("WECHAT_TOOL_DATA_DIR", None)
else:
os.environ["WECHAT_TOOL_DATA_DIR"] = prev_data_dir
def test_get_saved_keys_accepts_source_matched_db_key(self) -> None:
with TemporaryDirectory() as td:
root = Path(td)
db_storage = root / "xwechat_files" / "wxid_demo_abcd" / "db_storage"
db_storage.mkdir(parents=True, exist_ok=True)
prev_data_dir = os.environ.get("WECHAT_TOOL_DATA_DIR")
try:
os.environ["WECHAT_TOOL_DATA_DIR"] = str(root)
import wechat_decrypt_tool.app_paths as app_paths
import wechat_decrypt_tool.key_store as key_store
import wechat_decrypt_tool.routers.keys as keys_router
importlib.reload(app_paths)
importlib.reload(key_store)
importlib.reload(keys_router)
key_store.upsert_account_keys_in_store(
"wxid_demo",
db_key="B" * 64,
aliases=["wxid_demo_abcd"],
db_key_source_wxid_dir=str(db_storage.parent),
db_key_source_db_storage_path=str(db_storage),
)
result = asyncio.run(
keys_router.get_saved_keys(account="wxid_demo", db_storage_path=str(db_storage))
)
self.assertEqual(result["status"], "success")
self.assertEqual(result["keys"]["db_key"], "B" * 64)
self.assertEqual(result["keys"]["db_key_store_account"], "wxid_demo_abcd")
self.assertEqual(result["keys"]["db_key_source_wxid_dir"], str(db_storage.parent))
self.assertEqual(result["keys"]["db_key_source_db_storage_path"], str(db_storage))
self.assertEqual(result["keys"]["db_key_blocked_reason"], "")
finally:
_close_logging_handlers()
if prev_data_dir is None:
os.environ.pop("WECHAT_TOOL_DATA_DIR", None)
else:
os.environ["WECHAT_TOOL_DATA_DIR"] = prev_data_dir
if __name__ == "__main__":
unittest.main()