diff --git a/frontend/pages/decrypt.vue b/frontend/pages/decrypt.vue
index 2abede2..479bc39 100644
--- a/frontend/pages/decrypt.vue
+++ b/frontend/pages/decrypt.vue
@@ -272,7 +272,7 @@
@@ -395,4 +415,4 @@ onMounted(() => {
}
startDetection()
})
-
\ No newline at end of file
+
diff --git a/src/wechat_decrypt_tool/routers/media.py b/src/wechat_decrypt_tool/routers/media.py
index 6fb9350..cae8e99 100644
--- a/src/wechat_decrypt_tool/routers/media.py
+++ b/src/wechat_decrypt_tool/routers/media.py
@@ -2,7 +2,7 @@ import asyncio
import json
from typing import Optional
-from fastapi import APIRouter, HTTPException
+from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import Response, StreamingResponse
from pydantic import BaseModel, Field
@@ -226,6 +226,7 @@ async def get_decrypted_resource(md5: str, account: Optional[str] = None):
@router.get("/api/media/decrypt_all_stream", summary="批量解密所有图片资源(SSE实时进度)")
async def decrypt_all_media_stream(
+ request: Request,
account: Optional[str] = None,
xor_key: Optional[str] = None,
aes_key: Optional[str] = None,
@@ -252,8 +253,18 @@ async def decrypt_all_media_stream(
- 解密后非有效图片格式
"""
+ async def is_client_disconnected() -> bool:
+ try:
+ return await request.is_disconnected()
+ except Exception:
+ return False
+
async def generate_progress():
try:
+ if await is_client_disconnected():
+ logger.info("[SSE] 客户端已断开,取消图片解密任务")
+ return
+
account_dir = _resolve_account_dir(account)
wxid_dir = _resolve_account_wxid_dir(account_dir)
@@ -301,6 +312,10 @@ async def decrypt_all_media_stream(
total_files = len(dat_files)
logger.info(f"[SSE] 共发现 {total_files} 个.dat文件(仅图片)")
+ if await is_client_disconnected():
+ logger.info("[SSE] 扫描完成后客户端已断开,停止图片解密任务")
+ return
+
if total_files == 0:
yield f"data: {json.dumps({'type': 'complete', 'message': '未发现需要解密的图片文件', 'total': 0, 'success_count': 0, 'skip_count': 0, 'fail_count': 0})}\n\n"
return
@@ -319,6 +334,10 @@ async def decrypt_all_media_stream(
resource_dir.mkdir(parents=True, exist_ok=True)
for i, (dat_path, md5) in enumerate(dat_files):
+ if await is_client_disconnected():
+ logger.info("[SSE] 客户端已断开,停止图片解密任务")
+ return
+
current = i + 1
file_name = dat_path.name
diff --git a/src/wechat_decrypt_tool/wechat_detection.py b/src/wechat_decrypt_tool/wechat_detection.py
index 2794579..e61d79d 100644
--- a/src/wechat_decrypt_tool/wechat_detection.py
+++ b/src/wechat_decrypt_tool/wechat_detection.py
@@ -16,6 +16,37 @@ from datetime import datetime
from .database_filters import should_skip_source_database
+COMMON_WECHAT_PATTERNS = [
+ "WeChat Files",
+ "Weixin Files",
+ "wechat_files",
+ "xwechat_files",
+ "wechatMSG",
+ "WeChat",
+ "微信",
+ "Weixin",
+ "wechat",
+]
+
+SYSTEM_SCAN_SKIP_NAMES = {
+ "$recycle.bin",
+ "$winreagent",
+ "config.msi",
+ "documents and settings",
+ "intel",
+ "onedrivetemp",
+ "perflogs",
+ "program files",
+ "program files (x86)",
+ "programdata",
+ "recovery",
+ "system volume information",
+ "windows",
+ "windows.old",
+ "windows.old(1)",
+}
+
+
def get_wx_db(msg_dir: str = None,
db_types: Union[List[str], str] = None,
wxids: Union[List[str], str] = None) -> List[dict]:
@@ -285,6 +316,87 @@ def get_process_list():
return process_list
+def _is_wechat_dir_candidate_name(name: str) -> bool:
+ normalized = str(name or "").strip().lower()
+ if not normalized:
+ return False
+ return any(pattern.lower() in normalized for pattern in COMMON_WECHAT_PATTERNS)
+
+
+def _safe_iter_subdirs(directory: str) -> List[tuple[str, str]]:
+ items: List[tuple[str, str]] = []
+ try:
+ with os.scandir(directory) as entries:
+ for entry in entries:
+ try:
+ if entry.is_dir():
+ items.append((entry.name, entry.path))
+ except OSError:
+ continue
+ except (PermissionError, OSError):
+ return []
+ return items
+
+
+def _append_detected_dir(detected_dirs: List[str], candidate: str) -> None:
+ if not candidate:
+ return
+ normalized = os.path.normpath(candidate)
+ if normalized not in detected_dirs:
+ detected_dirs.append(normalized)
+
+
+def _build_auto_detect_scan_paths() -> List[str]:
+ scan_paths: List[str] = []
+ seen_paths = set()
+
+ def add(path_value: str | None) -> None:
+ raw = str(path_value or "").strip()
+ if not raw:
+ return
+ normalized = os.path.normpath(raw)
+ key = normalized.lower()
+ if key in seen_paths:
+ return
+ seen_paths.add(key)
+ scan_paths.append(normalized)
+
+ home_dir = str(Path.home())
+ add(home_dir)
+ add(os.path.join(home_dir, "Documents"))
+ add(os.path.join(home_dir, "Desktop"))
+ add(os.path.join(home_dir, "Downloads"))
+
+ user_profile = str(os.environ.get("USERPROFILE") or "").strip()
+ if user_profile:
+ add(user_profile)
+ add(os.path.join(user_profile, "Documents"))
+ add(os.path.join(user_profile, "Desktop"))
+ add(os.path.join(user_profile, "Downloads"))
+
+ for drive in ("C:", "D:", "E:", "F:"):
+ drive_root = drive + os.sep
+ if not os.path.exists(drive_root):
+ continue
+
+ add(drive_root)
+
+ for child_name, child_path in _safe_iter_subdirs(drive_root):
+ if child_name.strip().lower() in SYSTEM_SCAN_SKIP_NAMES:
+ continue
+ add(child_path)
+
+ users_dir = os.path.join(drive_root, "Users")
+ add(users_dir)
+ for _user_name, user_dir in _safe_iter_subdirs(users_dir):
+ add(user_dir)
+ add(os.path.join(user_dir, "Documents"))
+ add(os.path.join(user_dir, "Desktop"))
+ add(os.path.join(user_dir, "Downloads"))
+
+ return scan_paths
+
+
def auto_detect_wechat_data_dirs():
"""
自动检测微信数据目录 - 多策略组合检测
@@ -292,52 +404,27 @@ def auto_detect_wechat_data_dirs():
"""
detected_dirs = []
- # 策略1:注册表检测已移除
-
- # 策略2和策略3:注册表相关检测已移除
-
- # 策略1:常见驱动器扫描微信相关目录
- common_wechat_patterns = [
- "WeChat Files", "wechat_files", "xwechat_files", "wechatMSG",
- "WeChat", "微信", "Weixin", "wechat"
- ]
-
- # 扫描常见驱动器
- drives = ['C:', 'D:', 'E:', 'F:']
- for drive in drives:
- if not os.path.exists(drive):
+ # 策略1:常见驱动器 / 用户目录 / 自定义目录的浅层扫描。
+ # 这里既检查扫描根目录本身,也检查其直接子目录,兼容:
+ # - C:\Users\
\Documents\WeChat Files
+ # - D:\wechatMSG\xwechat_files
+ # - D:\abc\wechatMSG\xwechat_files
+ for scan_path in _build_auto_detect_scan_paths():
+ if not os.path.exists(scan_path):
continue
- try:
- # 扫描驱动器根目录和常见目录
- scan_paths = [
- drive + os.sep,
- os.path.join(drive + os.sep, "Users"),
- ]
+ scan_name = os.path.basename(os.path.normpath(scan_path))
+ if _is_wechat_dir_candidate_name(scan_name) and has_wxid_directories(scan_path):
+ _append_detected_dir(detected_dirs, scan_path)
+ print(f"[DEBUG] 目录扫描检测成功: {scan_path}")
- for scan_path in scan_paths:
- if not os.path.exists(scan_path):
- continue
-
- try:
- for item in os.listdir(scan_path):
- item_path = os.path.join(scan_path, item)
- if not os.path.isdir(item_path):
- continue
-
- # 检查是否匹配微信目录模式
- for pattern in common_wechat_patterns:
- if pattern.lower() in item.lower():
- # 检查是否包含wxid目录
- if has_wxid_directories(item_path):
- if item_path not in detected_dirs:
- detected_dirs.append(item_path)
- print(f"[DEBUG] 目录扫描检测成功: {item_path}")
- break
- except (PermissionError, OSError):
- continue
- except (PermissionError, OSError):
- continue
+ for item_name, item_path in _safe_iter_subdirs(scan_path):
+ if not _is_wechat_dir_candidate_name(item_name):
+ continue
+ if not has_wxid_directories(item_path):
+ continue
+ _append_detected_dir(detected_dirs, item_path)
+ print(f"[DEBUG] 目录扫描检测成功: {item_path}")
# 策略2:进程内存分析(简化版)
try:
@@ -361,12 +448,11 @@ def auto_detect_wechat_data_dirs():
break
for parent_dir in parent_dirs:
- for pattern in common_wechat_patterns:
+ for pattern in COMMON_WECHAT_PATTERNS:
potential_dir = os.path.join(parent_dir, pattern)
if os.path.exists(potential_dir) and has_wxid_directories(potential_dir):
- if potential_dir not in detected_dirs:
- detected_dirs.append(potential_dir)
- print(f"[DEBUG] 进程分析检测成功: {potential_dir}")
+ _append_detected_dir(detected_dirs, potential_dir)
+ print(f"[DEBUG] 进程分析检测成功: {potential_dir}")
except:
pass
except:
diff --git a/tests/test_media_decrypt_stream_cancel.py b/tests/test_media_decrypt_stream_cancel.py
new file mode 100644
index 0000000..0324b55
--- /dev/null
+++ b/tests/test_media_decrypt_stream_cancel.py
@@ -0,0 +1,74 @@
+import asyncio
+import json
+import sys
+import unittest
+from pathlib import Path
+from tempfile import TemporaryDirectory
+from unittest import mock
+
+
+ROOT = Path(__file__).resolve().parents[1]
+sys.path.insert(0, str(ROOT / "src"))
+
+
+from wechat_decrypt_tool.routers import media as media_router # noqa: E402 pylint: disable=wrong-import-position
+
+
+class _FakeDisconnectingRequest:
+ def __init__(self, disconnect_after: int):
+ self._disconnect_after = disconnect_after
+ self._calls = 0
+
+ async def is_disconnected(self):
+ self._calls += 1
+ return self._calls >= self._disconnect_after
+
+
+class TestMediaDecryptStreamCancel(unittest.TestCase):
+ def test_stream_stops_processing_when_client_disconnects(self):
+ with TemporaryDirectory() as td:
+ root = Path(td)
+ account_dir = root / "account"
+ wxid_dir = root / "wxid"
+ dat_path = wxid_dir / "image.dat"
+ resource_dir = account_dir / "resource"
+ wxid_dir.mkdir(parents=True, exist_ok=True)
+ dat_path.write_bytes(b"encrypted")
+
+ request = _FakeDisconnectingRequest(disconnect_after=3)
+ decrypt_mock = mock.Mock(return_value=(True, "ok"))
+
+ with mock.patch.object(media_router, "_resolve_account_dir", return_value=account_dir):
+ with mock.patch.object(media_router, "_resolve_account_wxid_dir", return_value=wxid_dir):
+ with mock.patch.object(media_router, "_load_media_keys", return_value={"xor": 0xA5, "aes": ""}):
+ with mock.patch.object(media_router, "_collect_all_dat_files", return_value=[(dat_path, "abc123")]):
+ with mock.patch.object(media_router, "_get_resource_dir", return_value=resource_dir):
+ with mock.patch.object(media_router, "_try_find_decrypted_resource", return_value=None):
+ with mock.patch.object(media_router, "_decrypt_and_save_resource", decrypt_mock):
+ response = asyncio.run(
+ media_router.decrypt_all_media_stream(
+ request=request,
+ account="wxid_demo",
+ )
+ )
+
+ async def _read_chunks():
+ chunks = []
+ async for chunk in response.body_iterator:
+ chunks.append(chunk.decode("utf-8") if isinstance(chunk, bytes) else str(chunk))
+ return chunks
+
+ chunks = asyncio.run(_read_chunks())
+
+ events = []
+ for chunk in chunks:
+ for line in chunk.splitlines():
+ if line.startswith("data: "):
+ events.append(json.loads(line[len("data: ") :]))
+
+ self.assertEqual([event.get("type") for event in events], ["scanning", "start"])
+ decrypt_mock.assert_not_called()
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_wechat_detection_auto_detect.py b/tests/test_wechat_detection_auto_detect.py
new file mode 100644
index 0000000..18cf66a
--- /dev/null
+++ b/tests/test_wechat_detection_auto_detect.py
@@ -0,0 +1,44 @@
+import sys
+import unittest
+from pathlib import Path
+from tempfile import TemporaryDirectory
+from unittest.mock import patch
+
+
+ROOT = Path(__file__).resolve().parents[1]
+sys.path.insert(0, str(ROOT / "src"))
+
+
+class TestWechatDetectionAutoDetect(unittest.TestCase):
+ def test_detect_wechat_installation_finds_nested_custom_data_root(self):
+ from wechat_decrypt_tool import wechat_detection as wd
+
+ with TemporaryDirectory() as td:
+ nested_scan_root = Path(td) / "abc"
+ wechat_parent = nested_scan_root / "wechatMSG"
+ xwechat_root = wechat_parent / "xwechat_files"
+
+ login_dir = xwechat_root / "all_users" / "login" / "wxid_demo"
+ login_dir.mkdir(parents=True, exist_ok=True)
+ (login_dir / "key_info.db").write_bytes(b"demo")
+
+ account_dir = xwechat_root / "wxid_demo_nested"
+ account_dir.mkdir(parents=True, exist_ok=True)
+ (account_dir / "contact.db").write_bytes(b"demo")
+
+ with (
+ patch.object(wd, "_build_auto_detect_scan_paths", return_value=[str(nested_scan_root)]),
+ patch.object(wd, "get_process_list", return_value=[]),
+ ):
+ detected_dirs = wd.auto_detect_wechat_data_dirs()
+ result = wd.detect_wechat_installation()
+
+ self.assertEqual(detected_dirs, [str(wechat_parent)])
+ self.assertEqual(result["total_accounts"], 1)
+ self.assertEqual(result["accounts"][0]["account_name"], "wxid_demo")
+ self.assertEqual(result["accounts"][0]["data_dir"], str(account_dir))
+ self.assertEqual(result["total_databases"], 1)
+
+
+if __name__ == "__main__":
+ unittest.main()