diff --git a/frontend/pages/decrypt.vue b/frontend/pages/decrypt.vue index 2abede2..479bc39 100644 --- a/frontend/pages/decrypt.vue +++ b/frontend/pages/decrypt.vue @@ -272,7 +272,7 @@
@@ -366,6 +366,13 @@ {{ mediaDecrypting ? '解密中...' : (mediaDecryptResult ? '重新解密' : '开始解密图片') }} + @@ -53,9 +63,19 @@
- - - +

正在检测微信数据...

@@ -395,4 +415,4 @@ onMounted(() => { } startDetection() }) - \ No newline at end of file + diff --git a/src/wechat_decrypt_tool/routers/media.py b/src/wechat_decrypt_tool/routers/media.py index 6fb9350..cae8e99 100644 --- a/src/wechat_decrypt_tool/routers/media.py +++ b/src/wechat_decrypt_tool/routers/media.py @@ -2,7 +2,7 @@ import asyncio import json from typing import Optional -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, HTTPException, Request from fastapi.responses import Response, StreamingResponse from pydantic import BaseModel, Field @@ -226,6 +226,7 @@ async def get_decrypted_resource(md5: str, account: Optional[str] = None): @router.get("/api/media/decrypt_all_stream", summary="批量解密所有图片资源(SSE实时进度)") async def decrypt_all_media_stream( + request: Request, account: Optional[str] = None, xor_key: Optional[str] = None, aes_key: Optional[str] = None, @@ -252,8 +253,18 @@ async def decrypt_all_media_stream( - 解密后非有效图片格式 """ + async def is_client_disconnected() -> bool: + try: + return await request.is_disconnected() + except Exception: + return False + async def generate_progress(): try: + if await is_client_disconnected(): + logger.info("[SSE] 客户端已断开,取消图片解密任务") + return + account_dir = _resolve_account_dir(account) wxid_dir = _resolve_account_wxid_dir(account_dir) @@ -301,6 +312,10 @@ async def decrypt_all_media_stream( total_files = len(dat_files) logger.info(f"[SSE] 共发现 {total_files} 个.dat文件(仅图片)") + if await is_client_disconnected(): + logger.info("[SSE] 扫描完成后客户端已断开,停止图片解密任务") + return + if total_files == 0: yield f"data: {json.dumps({'type': 'complete', 'message': '未发现需要解密的图片文件', 'total': 0, 'success_count': 0, 'skip_count': 0, 'fail_count': 0})}\n\n" return @@ -319,6 +334,10 @@ async def decrypt_all_media_stream( resource_dir.mkdir(parents=True, exist_ok=True) for i, (dat_path, md5) in enumerate(dat_files): + if await is_client_disconnected(): + logger.info("[SSE] 客户端已断开,停止图片解密任务") + return + current = i + 1 file_name = dat_path.name diff --git a/src/wechat_decrypt_tool/wechat_detection.py b/src/wechat_decrypt_tool/wechat_detection.py index 2794579..e61d79d 100644 --- a/src/wechat_decrypt_tool/wechat_detection.py +++ b/src/wechat_decrypt_tool/wechat_detection.py @@ -16,6 +16,37 @@ from datetime import datetime from .database_filters import should_skip_source_database +COMMON_WECHAT_PATTERNS = [ + "WeChat Files", + "Weixin Files", + "wechat_files", + "xwechat_files", + "wechatMSG", + "WeChat", + "微信", + "Weixin", + "wechat", +] + +SYSTEM_SCAN_SKIP_NAMES = { + "$recycle.bin", + "$winreagent", + "config.msi", + "documents and settings", + "intel", + "onedrivetemp", + "perflogs", + "program files", + "program files (x86)", + "programdata", + "recovery", + "system volume information", + "windows", + "windows.old", + "windows.old(1)", +} + + def get_wx_db(msg_dir: str = None, db_types: Union[List[str], str] = None, wxids: Union[List[str], str] = None) -> List[dict]: @@ -285,6 +316,87 @@ def get_process_list(): return process_list +def _is_wechat_dir_candidate_name(name: str) -> bool: + normalized = str(name or "").strip().lower() + if not normalized: + return False + return any(pattern.lower() in normalized for pattern in COMMON_WECHAT_PATTERNS) + + +def _safe_iter_subdirs(directory: str) -> List[tuple[str, str]]: + items: List[tuple[str, str]] = [] + try: + with os.scandir(directory) as entries: + for entry in entries: + try: + if entry.is_dir(): + items.append((entry.name, entry.path)) + except OSError: + continue + except (PermissionError, OSError): + return [] + return items + + +def _append_detected_dir(detected_dirs: List[str], candidate: str) -> None: + if not candidate: + return + normalized = os.path.normpath(candidate) + if normalized not in detected_dirs: + detected_dirs.append(normalized) + + +def _build_auto_detect_scan_paths() -> List[str]: + scan_paths: List[str] = [] + seen_paths = set() + + def add(path_value: str | None) -> None: + raw = str(path_value or "").strip() + if not raw: + return + normalized = os.path.normpath(raw) + key = normalized.lower() + if key in seen_paths: + return + seen_paths.add(key) + scan_paths.append(normalized) + + home_dir = str(Path.home()) + add(home_dir) + add(os.path.join(home_dir, "Documents")) + add(os.path.join(home_dir, "Desktop")) + add(os.path.join(home_dir, "Downloads")) + + user_profile = str(os.environ.get("USERPROFILE") or "").strip() + if user_profile: + add(user_profile) + add(os.path.join(user_profile, "Documents")) + add(os.path.join(user_profile, "Desktop")) + add(os.path.join(user_profile, "Downloads")) + + for drive in ("C:", "D:", "E:", "F:"): + drive_root = drive + os.sep + if not os.path.exists(drive_root): + continue + + add(drive_root) + + for child_name, child_path in _safe_iter_subdirs(drive_root): + if child_name.strip().lower() in SYSTEM_SCAN_SKIP_NAMES: + continue + add(child_path) + + users_dir = os.path.join(drive_root, "Users") + add(users_dir) + for _user_name, user_dir in _safe_iter_subdirs(users_dir): + add(user_dir) + add(os.path.join(user_dir, "Documents")) + add(os.path.join(user_dir, "Desktop")) + add(os.path.join(user_dir, "Downloads")) + + return scan_paths + + def auto_detect_wechat_data_dirs(): """ 自动检测微信数据目录 - 多策略组合检测 @@ -292,52 +404,27 @@ def auto_detect_wechat_data_dirs(): """ detected_dirs = [] - # 策略1:注册表检测已移除 - - # 策略2和策略3:注册表相关检测已移除 - - # 策略1:常见驱动器扫描微信相关目录 - common_wechat_patterns = [ - "WeChat Files", "wechat_files", "xwechat_files", "wechatMSG", - "WeChat", "微信", "Weixin", "wechat" - ] - - # 扫描常见驱动器 - drives = ['C:', 'D:', 'E:', 'F:'] - for drive in drives: - if not os.path.exists(drive): + # 策略1:常见驱动器 / 用户目录 / 自定义目录的浅层扫描。 + # 这里既检查扫描根目录本身,也检查其直接子目录,兼容: + # - C:\Users\\Documents\WeChat Files + # - D:\wechatMSG\xwechat_files + # - D:\abc\wechatMSG\xwechat_files + for scan_path in _build_auto_detect_scan_paths(): + if not os.path.exists(scan_path): continue - try: - # 扫描驱动器根目录和常见目录 - scan_paths = [ - drive + os.sep, - os.path.join(drive + os.sep, "Users"), - ] + scan_name = os.path.basename(os.path.normpath(scan_path)) + if _is_wechat_dir_candidate_name(scan_name) and has_wxid_directories(scan_path): + _append_detected_dir(detected_dirs, scan_path) + print(f"[DEBUG] 目录扫描检测成功: {scan_path}") - for scan_path in scan_paths: - if not os.path.exists(scan_path): - continue - - try: - for item in os.listdir(scan_path): - item_path = os.path.join(scan_path, item) - if not os.path.isdir(item_path): - continue - - # 检查是否匹配微信目录模式 - for pattern in common_wechat_patterns: - if pattern.lower() in item.lower(): - # 检查是否包含wxid目录 - if has_wxid_directories(item_path): - if item_path not in detected_dirs: - detected_dirs.append(item_path) - print(f"[DEBUG] 目录扫描检测成功: {item_path}") - break - except (PermissionError, OSError): - continue - except (PermissionError, OSError): - continue + for item_name, item_path in _safe_iter_subdirs(scan_path): + if not _is_wechat_dir_candidate_name(item_name): + continue + if not has_wxid_directories(item_path): + continue + _append_detected_dir(detected_dirs, item_path) + print(f"[DEBUG] 目录扫描检测成功: {item_path}") # 策略2:进程内存分析(简化版) try: @@ -361,12 +448,11 @@ def auto_detect_wechat_data_dirs(): break for parent_dir in parent_dirs: - for pattern in common_wechat_patterns: + for pattern in COMMON_WECHAT_PATTERNS: potential_dir = os.path.join(parent_dir, pattern) if os.path.exists(potential_dir) and has_wxid_directories(potential_dir): - if potential_dir not in detected_dirs: - detected_dirs.append(potential_dir) - print(f"[DEBUG] 进程分析检测成功: {potential_dir}") + _append_detected_dir(detected_dirs, potential_dir) + print(f"[DEBUG] 进程分析检测成功: {potential_dir}") except: pass except: diff --git a/tests/test_media_decrypt_stream_cancel.py b/tests/test_media_decrypt_stream_cancel.py new file mode 100644 index 0000000..0324b55 --- /dev/null +++ b/tests/test_media_decrypt_stream_cancel.py @@ -0,0 +1,74 @@ +import asyncio +import json +import sys +import unittest +from pathlib import Path +from tempfile import TemporaryDirectory +from unittest import mock + + +ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(ROOT / "src")) + + +from wechat_decrypt_tool.routers import media as media_router # noqa: E402 pylint: disable=wrong-import-position + + +class _FakeDisconnectingRequest: + def __init__(self, disconnect_after: int): + self._disconnect_after = disconnect_after + self._calls = 0 + + async def is_disconnected(self): + self._calls += 1 + return self._calls >= self._disconnect_after + + +class TestMediaDecryptStreamCancel(unittest.TestCase): + def test_stream_stops_processing_when_client_disconnects(self): + with TemporaryDirectory() as td: + root = Path(td) + account_dir = root / "account" + wxid_dir = root / "wxid" + dat_path = wxid_dir / "image.dat" + resource_dir = account_dir / "resource" + wxid_dir.mkdir(parents=True, exist_ok=True) + dat_path.write_bytes(b"encrypted") + + request = _FakeDisconnectingRequest(disconnect_after=3) + decrypt_mock = mock.Mock(return_value=(True, "ok")) + + with mock.patch.object(media_router, "_resolve_account_dir", return_value=account_dir): + with mock.patch.object(media_router, "_resolve_account_wxid_dir", return_value=wxid_dir): + with mock.patch.object(media_router, "_load_media_keys", return_value={"xor": 0xA5, "aes": ""}): + with mock.patch.object(media_router, "_collect_all_dat_files", return_value=[(dat_path, "abc123")]): + with mock.patch.object(media_router, "_get_resource_dir", return_value=resource_dir): + with mock.patch.object(media_router, "_try_find_decrypted_resource", return_value=None): + with mock.patch.object(media_router, "_decrypt_and_save_resource", decrypt_mock): + response = asyncio.run( + media_router.decrypt_all_media_stream( + request=request, + account="wxid_demo", + ) + ) + + async def _read_chunks(): + chunks = [] + async for chunk in response.body_iterator: + chunks.append(chunk.decode("utf-8") if isinstance(chunk, bytes) else str(chunk)) + return chunks + + chunks = asyncio.run(_read_chunks()) + + events = [] + for chunk in chunks: + for line in chunk.splitlines(): + if line.startswith("data: "): + events.append(json.loads(line[len("data: ") :])) + + self.assertEqual([event.get("type") for event in events], ["scanning", "start"]) + decrypt_mock.assert_not_called() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_wechat_detection_auto_detect.py b/tests/test_wechat_detection_auto_detect.py new file mode 100644 index 0000000..18cf66a --- /dev/null +++ b/tests/test_wechat_detection_auto_detect.py @@ -0,0 +1,44 @@ +import sys +import unittest +from pathlib import Path +from tempfile import TemporaryDirectory +from unittest.mock import patch + + +ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(ROOT / "src")) + + +class TestWechatDetectionAutoDetect(unittest.TestCase): + def test_detect_wechat_installation_finds_nested_custom_data_root(self): + from wechat_decrypt_tool import wechat_detection as wd + + with TemporaryDirectory() as td: + nested_scan_root = Path(td) / "abc" + wechat_parent = nested_scan_root / "wechatMSG" + xwechat_root = wechat_parent / "xwechat_files" + + login_dir = xwechat_root / "all_users" / "login" / "wxid_demo" + login_dir.mkdir(parents=True, exist_ok=True) + (login_dir / "key_info.db").write_bytes(b"demo") + + account_dir = xwechat_root / "wxid_demo_nested" + account_dir.mkdir(parents=True, exist_ok=True) + (account_dir / "contact.db").write_bytes(b"demo") + + with ( + patch.object(wd, "_build_auto_detect_scan_paths", return_value=[str(nested_scan_root)]), + patch.object(wd, "get_process_list", return_value=[]), + ): + detected_dirs = wd.auto_detect_wechat_data_dirs() + result = wd.detect_wechat_installation() + + self.assertEqual(detected_dirs, [str(wechat_parent)]) + self.assertEqual(result["total_accounts"], 1) + self.assertEqual(result["accounts"][0]["account_name"], "wxid_demo") + self.assertEqual(result["accounts"][0]["data_dir"], str(account_dir)) + self.assertEqual(result["total_databases"], 1) + + +if __name__ == "__main__": + unittest.main()