mirror of
https://github.com/LifeArchiveProject/WeChatDataAnalysis.git
synced 2026-06-18 15:54:08 +08:00
fix(key): 支持手动指定微信安装目录并校验 db key 来源
- /api/get_keys 支持传入 wechat_install_path,兼容安装目录与 Weixin.exe / WeChat.exe - 解密完成后保存 db key 的来源路径与别名,避免历史密钥被错误账号复用 - 解密页按 account + db_storage_path 回填已保存密钥,并补充相关测试覆盖
This commit is contained in:
@@ -397,6 +397,7 @@ export const useApi = () => {
|
||||
const getSavedKeys = async (params = {}) => {
|
||||
const query = new URLSearchParams()
|
||||
if (params && params.account) query.set('account', params.account)
|
||||
if (params && params.db_storage_path) query.set('db_storage_path', params.db_storage_path)
|
||||
const url = '/keys' + (query.toString() ? `?${query.toString()}` : '')
|
||||
return await request(url)
|
||||
}
|
||||
@@ -547,8 +548,11 @@ export const useApi = () => {
|
||||
}
|
||||
|
||||
// 获取数据库密钥
|
||||
const getKeys = async () => {
|
||||
return await request('/get_keys')
|
||||
const getKeys = async (params = {}) => {
|
||||
const query = new URLSearchParams()
|
||||
if (params && params.wechat_install_path) query.set('wechat_install_path', params.wechat_install_path)
|
||||
const url = '/get_keys' + (query.toString() ? `?${query.toString()}` : '')
|
||||
return await request(url)
|
||||
}
|
||||
|
||||
// 获取图片密钥
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
export const WECHAT_INSTALL_PATH_STORAGE_KEY = 'decrypt.wechatInstallPath'
|
||||
|
||||
export const normalizeWechatInstallPath = (value) => String(value || '').trim()
|
||||
|
||||
export const readStoredWechatInstallPath = () => {
|
||||
if (!process.client || typeof window === 'undefined') return ''
|
||||
try {
|
||||
return normalizeWechatInstallPath(window.localStorage.getItem(WECHAT_INSTALL_PATH_STORAGE_KEY) || '')
|
||||
} catch {
|
||||
return ''
|
||||
}
|
||||
}
|
||||
|
||||
export const writeStoredWechatInstallPath = (value) => {
|
||||
if (!process.client || typeof window === 'undefined') return
|
||||
try {
|
||||
const normalized = normalizeWechatInstallPath(value)
|
||||
if (normalized) {
|
||||
window.localStorage.setItem(WECHAT_INSTALL_PATH_STORAGE_KEY, normalized)
|
||||
} else {
|
||||
window.localStorage.removeItem(WECHAT_INSTALL_PATH_STORAGE_KEY)
|
||||
}
|
||||
} catch {}
|
||||
}
|
||||
@@ -73,6 +73,12 @@
|
||||
</svg>
|
||||
点击按钮将自动获取【数据库解密密钥】。您也可以手动输入已知的64位密钥。
|
||||
</p>
|
||||
<p v-if="formData.wechat_install_path" class="mt-2 text-xs text-[#7F7F7F] flex items-start">
|
||||
<svg class="w-4 h-4 mr-1 mt-0.5 text-[#10AEEF]" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z"/>
|
||||
</svg>
|
||||
<span>当前将使用第一步检测时保存的微信安装目录:<span class="font-mono break-all">{{ formData.wechat_install_path }}</span>。</span>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- 数据库路径输入 -->
|
||||
@@ -655,6 +661,7 @@
|
||||
<script setup>
|
||||
import { ref, reactive, computed, onMounted, onBeforeUnmount } from 'vue'
|
||||
import { useApi } from '~/composables/useApi'
|
||||
import { normalizeWechatInstallPath, readStoredWechatInstallPath } from '~/lib/wechat-install-path'
|
||||
|
||||
const { decryptDatabase, saveMediaKeys, getSavedKeys, getKeys, getImageKey, getWxStatus } = useApi()
|
||||
|
||||
@@ -677,7 +684,8 @@ const steps = [
|
||||
// 表单数据
|
||||
const formData = reactive({
|
||||
key: '',
|
||||
db_storage_path: ''
|
||||
db_storage_path: '',
|
||||
wechat_install_path: ''
|
||||
})
|
||||
|
||||
// 表单错误
|
||||
@@ -764,7 +772,10 @@ const prefillKeysForAccount = async (account) => {
|
||||
if (!acc) return
|
||||
logDecryptDebug('prefill:start', { account: acc })
|
||||
try {
|
||||
const resp = await getSavedKeys({ account: acc })
|
||||
const resp = await getSavedKeys({
|
||||
account: acc,
|
||||
db_storage_path: String(formData.db_storage_path || '').trim()
|
||||
})
|
||||
if (!resp || resp.status !== 'success') return
|
||||
const keys = resp.keys || {}
|
||||
|
||||
@@ -786,6 +797,9 @@ const prefillKeysForAccount = async (account) => {
|
||||
request_account: acc,
|
||||
response_account: String(resp.account || '').trim(),
|
||||
db_key_present: !!dbKey,
|
||||
db_key_store_account: String(keys.db_key_store_account || '').trim(),
|
||||
db_key_source_wxid_dir: String(keys.db_key_source_wxid_dir || '').trim(),
|
||||
db_key_blocked_reason: String(keys.db_key_blocked_reason || '').trim(),
|
||||
...summarizeKeyStateForLog(
|
||||
String(keys.image_xor_key || '').trim(),
|
||||
String(keys.image_aes_key || '').trim()
|
||||
@@ -873,6 +887,8 @@ const handleGetDbKey = async () => {
|
||||
formErrors.key = ''
|
||||
|
||||
try {
|
||||
const wechatInstallPath = normalizeWechatInstallPath(formData.wechat_install_path || readStoredWechatInstallPath())
|
||||
formData.wechat_install_path = wechatInstallPath
|
||||
const statusRes = await getWxStatus()
|
||||
const wxStatus = statusRes?.wx_status
|
||||
|
||||
@@ -883,7 +899,9 @@ const handleGetDbKey = async () => {
|
||||
|
||||
warning.value = '正在启动微信,请确保微信未开启“自动登录”,并在弹窗中正常登录。'
|
||||
|
||||
const res = await getKeys()
|
||||
const res = await getKeys({
|
||||
wechat_install_path: wechatInstallPath
|
||||
})
|
||||
|
||||
if (res && res.status === 0) {
|
||||
if (res.data?.db_key) {
|
||||
@@ -1617,6 +1635,7 @@ const skipToChat = async () => {
|
||||
// 页面加载时检查是否有选中的账户
|
||||
onMounted(async () => {
|
||||
if (process.client && typeof window !== 'undefined') {
|
||||
formData.wechat_install_path = readStoredWechatInstallPath()
|
||||
const selectedAccount = sessionStorage.getItem('selectedAccount')
|
||||
logDecryptDebug('mounted:selected-account-raw', { raw: selectedAccount || '' })
|
||||
if (selectedAccount) {
|
||||
|
||||
@@ -35,6 +35,30 @@
|
||||
<span v-if="customPath">当前指定检测路径:<span class="font-mono bg-gray-50 px-1 rounded text-[#000000e6]">{{ customPath }}</span></span>
|
||||
<span v-else>如果自动检测漏了,您可以手动指定微信数据根目录 (通常名为 xwechat_files) 让系统重新扫描。</span>
|
||||
</p>
|
||||
<div class="mt-3">
|
||||
<label for="wechatInstallPath" class="block text-xs font-medium text-[#000000e6]">微信安装目录(可选)</label>
|
||||
<div class="mt-2 flex flex-col lg:flex-row gap-3">
|
||||
<input
|
||||
id="wechatInstallPath"
|
||||
v-model="wechatInstallPath"
|
||||
type="text"
|
||||
placeholder="例如: D:\Program Files\Tencent\WeChat 或 D:\Program Files\Tencent\WeChat\Weixin.exe"
|
||||
class="flex-1 px-3 py-2 bg-white border border-[#EDEDED] rounded-lg font-mono text-xs focus:outline-none focus:ring-2 focus:ring-[#07C160] focus:border-transparent transition-all duration-200"
|
||||
@blur="persistWechatInstallPath"
|
||||
/>
|
||||
<button
|
||||
type="button"
|
||||
@click="pickWechatInstallDirectory"
|
||||
:disabled="isPickingWechatInstallPath"
|
||||
class="shrink-0 px-4 py-2 bg-white border border-[#EDEDED] text-[#000000e6] rounded-lg text-xs font-medium hover:bg-gray-50 disabled:opacity-50 disabled:cursor-wait transition-all duration-200"
|
||||
>
|
||||
{{ isPickingWechatInstallPath ? '选择中...' : '选择微信目录' }}
|
||||
</button>
|
||||
</div>
|
||||
<p class="text-xs text-[#7F7F7F] mt-2">
|
||||
一键获取数据库密钥会优先使用这里填写的路径。支持安装目录或 Weixin.exe / WeChat.exe 路径。
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<button @click="handlePickDirectory" :disabled="loading"
|
||||
class="shrink-0 px-5 py-2.5 bg-[#07C160] text-white rounded-xl text-sm font-medium hover:bg-[#06AD56] focus:ring-2 focus:ring-[#07C160] focus:ring-offset-1 disabled:opacity-50 transition-all duration-200 flex items-center justify-center">
|
||||
@@ -238,6 +262,7 @@
|
||||
<script setup>
|
||||
import {computed, onMounted, ref} from 'vue'
|
||||
import {useApi} from '~/composables/useApi'
|
||||
import {normalizeWechatInstallPath, readStoredWechatInstallPath, writeStoredWechatInstallPath} from '~/lib/wechat-install-path'
|
||||
import {useAppStore} from '~/stores/app'
|
||||
|
||||
const { detectWechat, pickSystemDirectory } = useApi()
|
||||
@@ -245,6 +270,8 @@ const appStore = useAppStore()
|
||||
const loading = ref(false)
|
||||
const detectionResult = ref(null)
|
||||
const customPath = ref('')
|
||||
const wechatInstallPath = ref('')
|
||||
const isPickingWechatInstallPath = ref(false)
|
||||
const STORAGE_KEY = 'wechat_data_root_path'
|
||||
|
||||
const isDesktopShell = () => {
|
||||
@@ -289,6 +316,52 @@ const handlePickDirectory = async () => {
|
||||
}
|
||||
|
||||
// 计算属性:将当前登录账号排在第一位
|
||||
const persistWechatInstallPath = () => {
|
||||
const normalized = normalizeWechatInstallPath(wechatInstallPath.value)
|
||||
wechatInstallPath.value = normalized
|
||||
writeStoredWechatInstallPath(normalized)
|
||||
}
|
||||
|
||||
const pickWechatInstallDirectory = async () => {
|
||||
if (isPickingWechatInstallPath.value) return
|
||||
isPickingWechatInstallPath.value = true
|
||||
|
||||
try {
|
||||
let path = ''
|
||||
|
||||
if (isDesktopShell()) {
|
||||
const res = await window.wechatDesktop.chooseDirectory({
|
||||
title: '请选择微信安装目录'
|
||||
})
|
||||
if (!res || res.canceled || !res.filePaths?.length) return
|
||||
path = res.filePaths[0]
|
||||
} else {
|
||||
try {
|
||||
const res = await pickSystemDirectory({
|
||||
title: '请选择微信安装目录',
|
||||
initial_dir: normalizeWechatInstallPath(wechatInstallPath.value)
|
||||
})
|
||||
if (!res || !res.path) return
|
||||
path = res.path
|
||||
} catch (e) {
|
||||
console.error('通过API唤起微信安装目录选择器失败:', e)
|
||||
path = window.prompt('无法直接唤起窗口,请输入微信安装目录或 Weixin.exe / WeChat.exe 的绝对路径:')
|
||||
if (!path) return
|
||||
}
|
||||
}
|
||||
|
||||
const normalized = normalizeWechatInstallPath(path)
|
||||
if (!normalized) return
|
||||
wechatInstallPath.value = normalized
|
||||
persistWechatInstallPath()
|
||||
} catch (e) {
|
||||
console.error('选择微信安装目录失败:', e)
|
||||
} finally {
|
||||
isPickingWechatInstallPath.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// ?????????????????
|
||||
const sortedAccounts = computed(() => {
|
||||
if (!detectionResult.value?.data?.accounts) return []
|
||||
const accounts = [...detectionResult.value.data.accounts]
|
||||
@@ -384,6 +457,8 @@ const startDetection = async () => {
|
||||
|
||||
// 跳转到解密页面并传递账户信息
|
||||
const goToDecrypt = (account) => {
|
||||
persistWechatInstallPath()
|
||||
|
||||
if (process.client && typeof window !== 'undefined') {
|
||||
sessionStorage.setItem('selectedAccount', JSON.stringify({
|
||||
account_name: account.account_name,
|
||||
@@ -412,6 +487,7 @@ onMounted(() => {
|
||||
const saved = String(localStorage.getItem(STORAGE_KEY) || '').trim()
|
||||
if (saved) customPath.value = saved
|
||||
} catch {}
|
||||
wechatInstallPath.value = readStoredWechatInstallPath()
|
||||
}
|
||||
startDetection()
|
||||
})
|
||||
|
||||
@@ -29,6 +29,8 @@ from .media_helpers import _resolve_account_dir, _resolve_account_wxid_dir
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
WECHAT_EXECUTABLE_NAMES = ("Weixin.exe", "WeChat.exe")
|
||||
|
||||
|
||||
def _summarize_aes_key(value: Any) -> str:
|
||||
raw = str(value or "").strip()
|
||||
@@ -109,19 +111,72 @@ def _resolve_wxid_dir_for_image_key(
|
||||
raise FileNotFoundError("无法定位该账号的 wxid_dir,请传入有效的 db_storage_path 或先完成数据库解密")
|
||||
|
||||
|
||||
def _normalize_user_path(value: Any) -> str:
|
||||
raw = str(value or "").strip().strip('"').strip("'")
|
||||
if not raw:
|
||||
return ""
|
||||
try:
|
||||
return os.path.normpath(os.path.expandvars(raw))
|
||||
except Exception:
|
||||
return raw
|
||||
|
||||
|
||||
def _read_wechat_version_from_exe(exe_path: str) -> str:
|
||||
normalized = _normalize_user_path(exe_path)
|
||||
if not normalized:
|
||||
return ""
|
||||
try:
|
||||
import win32api
|
||||
|
||||
version_info = win32api.GetFileVersionInfo(normalized, "\\")
|
||||
return (
|
||||
f"{version_info['FileVersionMS'] >> 16}."
|
||||
f"{version_info['FileVersionMS'] & 0xFFFF}."
|
||||
f"{version_info['FileVersionLS'] >> 16}."
|
||||
f"{version_info['FileVersionLS'] & 0xFFFF}"
|
||||
)
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def _resolve_manual_wechat_exe_path(wechat_install_path: Optional[str] = None) -> str:
|
||||
normalized = _normalize_user_path(wechat_install_path)
|
||||
if not normalized:
|
||||
return ""
|
||||
|
||||
candidate = Path(normalized).expanduser()
|
||||
executable_names = {name.lower() for name in WECHAT_EXECUTABLE_NAMES}
|
||||
if candidate.is_file():
|
||||
if candidate.name.lower() not in executable_names:
|
||||
raise RuntimeError("手动路径必须指向微信安装目录,或直接指向 Weixin.exe / WeChat.exe")
|
||||
return str(candidate)
|
||||
|
||||
if candidate.is_dir():
|
||||
for exe_name in WECHAT_EXECUTABLE_NAMES:
|
||||
exe_path = candidate / exe_name
|
||||
if exe_path.is_file():
|
||||
return str(exe_path)
|
||||
raise RuntimeError("手动指定的微信安装目录中未找到 Weixin.exe 或 WeChat.exe")
|
||||
|
||||
raise RuntimeError(f"手动指定的微信安装目录不存在: {candidate}")
|
||||
|
||||
|
||||
# ====================== 以下是hook逻辑 ======================================
|
||||
|
||||
class WeChatKeyFetcher:
|
||||
def __init__(self):
|
||||
self.process_name = "Weixin.exe"
|
||||
self.process_names = {name.lower() for name in WECHAT_EXECUTABLE_NAMES}
|
||||
self.timeout_seconds = 60
|
||||
|
||||
def _is_wechat_process(self, name: Any) -> bool:
|
||||
return str(name or "").strip().lower() in self.process_names
|
||||
|
||||
def kill_wechat(self):
|
||||
"""检测并查杀微信进程"""
|
||||
killed = False
|
||||
for proc in psutil.process_iter(['pid', 'name']):
|
||||
try:
|
||||
if proc.info['name'] == self.process_name:
|
||||
if self._is_wechat_process(proc.info['name']):
|
||||
logger.info(f"Killing WeChat process: {proc.info['pid']}")
|
||||
proc.terminate()
|
||||
killed = True
|
||||
@@ -134,11 +189,14 @@ class WeChatKeyFetcher:
|
||||
def launch_wechat(self, exe_path: str) -> int:
|
||||
"""启动微信并返回 PID"""
|
||||
try:
|
||||
process = subprocess.Popen(exe_path)
|
||||
normalized_exe_path = _normalize_user_path(exe_path)
|
||||
process = subprocess.Popen(normalized_exe_path)
|
||||
time.sleep(2)
|
||||
candidates = []
|
||||
target_process_name = Path(normalized_exe_path).name.lower()
|
||||
for proc in psutil.process_iter(['pid', 'name', 'create_time']):
|
||||
if proc.info['name'] == self.process_name:
|
||||
proc_name = str(proc.info.get('name') or "").strip().lower()
|
||||
if proc_name == target_process_name or self._is_wechat_process(proc_name):
|
||||
candidates.append(proc)
|
||||
|
||||
if candidates:
|
||||
@@ -152,19 +210,32 @@ class WeChatKeyFetcher:
|
||||
logger.error(f"启动微信失败: {e}")
|
||||
raise RuntimeError(f"无法启动微信: {e}")
|
||||
|
||||
def fetch_db_key(self) -> dict:
|
||||
def fetch_db_key(self, wechat_install_path: Optional[str] = None) -> dict:
|
||||
"""调用 wx_key 仅获取数据库密钥 (Hook 模式)"""
|
||||
if wx_key is None:
|
||||
raise RuntimeError("wx_key 模块未安装或加载失败")
|
||||
|
||||
install_info = detect_wechat_installation()
|
||||
exe_path = install_info.get('wechat_exe_path')
|
||||
version = install_info.get('wechat_version')
|
||||
manual_path = _normalize_user_path(wechat_install_path)
|
||||
if manual_path:
|
||||
exe_path = _resolve_manual_wechat_exe_path(manual_path)
|
||||
version = _read_wechat_version_from_exe(exe_path)
|
||||
logger.info(
|
||||
"[db_key] 使用手动指定的微信安装路径: input=%s exe_path=%s version=%s",
|
||||
manual_path,
|
||||
exe_path,
|
||||
version or "unknown",
|
||||
)
|
||||
else:
|
||||
install_info = detect_wechat_installation()
|
||||
exe_path = _normalize_user_path(install_info.get('wechat_exe_path'))
|
||||
version = str(install_info.get('wechat_version') or "").strip()
|
||||
|
||||
if not exe_path or not version:
|
||||
raise RuntimeError("无法自动定位微信安装路径或版本")
|
||||
if not exe_path:
|
||||
raise RuntimeError("无法自动定位微信安装路径,请手动填写微信安装目录")
|
||||
if not Path(exe_path).is_file():
|
||||
raise RuntimeError(f"微信可执行文件不存在: {exe_path}")
|
||||
|
||||
logger.info(f"Detect WeChat: {version} at {exe_path}")
|
||||
logger.info(f"Detect WeChat: {version or 'unknown'} at {exe_path}")
|
||||
|
||||
self.kill_wechat()
|
||||
pid = self.launch_wechat(exe_path)
|
||||
@@ -204,9 +275,9 @@ class WeChatKeyFetcher:
|
||||
"db_key": found_db_key
|
||||
}
|
||||
|
||||
def get_db_key_workflow():
|
||||
def get_db_key_workflow(wechat_install_path: Optional[str] = None):
|
||||
fetcher = WeChatKeyFetcher()
|
||||
return fetcher.fetch_db_key()
|
||||
return fetcher.fetch_db_key(wechat_install_path=wechat_install_path)
|
||||
|
||||
|
||||
# ============================== 以下是图片密钥逻辑 =====================================
|
||||
|
||||
@@ -1,13 +1,41 @@
|
||||
import datetime
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Iterable, Optional
|
||||
|
||||
from .app_paths import get_account_keys_path
|
||||
|
||||
_KEY_STORE_PATH = get_account_keys_path()
|
||||
|
||||
|
||||
def normalize_key_store_path(path_value: Optional[str]) -> str:
|
||||
raw = str(path_value or "").strip()
|
||||
if not raw:
|
||||
return ""
|
||||
|
||||
try:
|
||||
return str(Path(raw).expanduser().resolve())
|
||||
except Exception:
|
||||
try:
|
||||
return str(Path(raw).expanduser())
|
||||
except Exception:
|
||||
return raw
|
||||
|
||||
|
||||
def _normalize_account_aliases(*values: Optional[str], aliases: Optional[Iterable[str]] = None) -> list[str]:
|
||||
out: list[str] = []
|
||||
seen: set[str] = set()
|
||||
|
||||
for value in [*values, *(list(aliases or []))]:
|
||||
key = str(value or "").strip()
|
||||
if (not key) or (key in seen):
|
||||
continue
|
||||
seen.add(key)
|
||||
out.append(key)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _atomic_write_json(path: Path, payload: Any) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = path.with_suffix(path.suffix + ".tmp")
|
||||
@@ -40,25 +68,36 @@ def upsert_account_keys_in_store(
|
||||
db_key: Optional[str] = None,
|
||||
image_xor_key: Optional[str] = None,
|
||||
image_aes_key: Optional[str] = None,
|
||||
aliases: Optional[Iterable[str]] = None,
|
||||
db_key_source_wxid_dir: Optional[str] = None,
|
||||
db_key_source_db_storage_path: Optional[str] = None,
|
||||
) -> dict[str, Any]:
|
||||
account = str(account or "").strip()
|
||||
if not account:
|
||||
return {}
|
||||
|
||||
store = load_account_keys_store()
|
||||
item = store.get(account, {})
|
||||
if not isinstance(item, dict):
|
||||
item = {}
|
||||
target_accounts = _normalize_account_aliases(account, aliases=aliases)
|
||||
|
||||
item: dict[str, Any] = {}
|
||||
for target_account in target_accounts:
|
||||
existing = store.get(target_account, {})
|
||||
if isinstance(existing, dict) and existing:
|
||||
item = dict(existing)
|
||||
break
|
||||
|
||||
if db_key is not None:
|
||||
item["db_key"] = str(db_key)
|
||||
item["db_key_source_wxid_dir"] = normalize_key_store_path(db_key_source_wxid_dir)
|
||||
item["db_key_source_db_storage_path"] = normalize_key_store_path(db_key_source_db_storage_path)
|
||||
if image_xor_key is not None:
|
||||
item["image_xor_key"] = str(image_xor_key)
|
||||
if image_aes_key is not None:
|
||||
item["image_aes_key"] = str(image_aes_key)
|
||||
|
||||
item["updated_at"] = datetime.datetime.now().isoformat(timespec="seconds")
|
||||
store[account] = item
|
||||
for target_account in target_accounts:
|
||||
store[target_account] = dict(item)
|
||||
|
||||
try:
|
||||
_atomic_write_json(_KEY_STORE_PATH, store)
|
||||
|
||||
@@ -118,6 +118,38 @@ def _acquire_decrypt_account_guards(accounts: Any, *, reason: str) -> list[tuple
|
||||
return guards
|
||||
|
||||
|
||||
def _save_db_key_for_account(account: str, key: str, account_result: dict[str, Any] | None) -> None:
|
||||
payload = dict(account_result or {})
|
||||
success_count = int(payload.get("success") or 0)
|
||||
if success_count <= 0:
|
||||
logger.info("[decrypt] skip saving db key for failed account=%s success=%s", account, success_count)
|
||||
return
|
||||
|
||||
source_wxid_dir = str(payload.get("source_wxid_dir") or "").strip()
|
||||
source_db_storage_path = str(payload.get("source_db_storage_path") or "").strip()
|
||||
aliases: list[str] = []
|
||||
|
||||
if source_wxid_dir:
|
||||
wxid_dir_name = str(Path(source_wxid_dir).name or "").strip()
|
||||
if wxid_dir_name and wxid_dir_name != str(account or "").strip():
|
||||
aliases.append(wxid_dir_name)
|
||||
|
||||
upsert_account_keys_in_store(
|
||||
str(account),
|
||||
db_key=key,
|
||||
aliases=aliases,
|
||||
db_key_source_wxid_dir=source_wxid_dir or None,
|
||||
db_key_source_db_storage_path=source_db_storage_path or None,
|
||||
)
|
||||
logger.info(
|
||||
"[decrypt] saved db key account=%s aliases=%s source_wxid_dir=%s source_db_storage_path=%s",
|
||||
str(account),
|
||||
aliases,
|
||||
source_wxid_dir,
|
||||
source_db_storage_path,
|
||||
)
|
||||
|
||||
|
||||
class DecryptRequest(BaseModel):
|
||||
"""解密请求模型"""
|
||||
|
||||
@@ -170,8 +202,8 @@ async def decrypt_databases(request: DecryptRequest):
|
||||
|
||||
# 成功解密后,按账号保存数据库密钥(用于前端自动回填)
|
||||
try:
|
||||
for account_name in (results.get("account_results") or {}).keys():
|
||||
upsert_account_keys_in_store(str(account_name), db_key=request.key)
|
||||
for account_name, account_result in (results.get("account_results") or {}).items():
|
||||
_save_db_key_for_account(str(account_name), request.key, account_result)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -417,6 +449,8 @@ async def decrypt_databases_stream(
|
||||
"success": account_success,
|
||||
"failed": len(dbs) - account_success,
|
||||
"output_dir": str(account_output_dir),
|
||||
"source_db_storage_path": str(source_db_storage_path),
|
||||
"source_wxid_dir": str(wxid_dir),
|
||||
"processed_files": account_processed,
|
||||
"failed_files": account_failed,
|
||||
"db_diagnostics": account_db_diagnostics,
|
||||
@@ -481,8 +515,8 @@ async def decrypt_databases_stream(
|
||||
|
||||
# Save db key for frontend autofill.
|
||||
try:
|
||||
for account in (account_results or {}).keys():
|
||||
upsert_account_keys_in_store(str(account), db_key=k)
|
||||
for account, account_result in (account_results or {}).items():
|
||||
_save_db_key_for_account(str(account), k, account_result)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from ..logging_config import get_logger
|
||||
from ..key_store import get_account_keys_from_store
|
||||
from ..key_store import get_account_keys_from_store, normalize_key_store_path
|
||||
from ..key_service import get_db_key_workflow, get_image_key_integrated_workflow
|
||||
from ..media_helpers import _load_media_keys, _resolve_account_dir
|
||||
from ..path_fix import PathFixRoute
|
||||
@@ -21,8 +22,106 @@ def _summarize_aes_key(value: str) -> str:
|
||||
return f"{raw[:4]}...{raw[-4:]}(len={len(raw)})"
|
||||
|
||||
|
||||
def _resolve_requested_wxid_dir(*, db_storage_path: Optional[str] = None, wxid_dir: Optional[str] = None) -> str:
|
||||
explicit_wxid_dir = str(wxid_dir or "").strip()
|
||||
if explicit_wxid_dir:
|
||||
return normalize_key_store_path(explicit_wxid_dir)
|
||||
|
||||
raw_db_storage_path = str(db_storage_path or "").strip()
|
||||
if not raw_db_storage_path:
|
||||
return ""
|
||||
|
||||
candidate = Path(raw_db_storage_path).expanduser()
|
||||
try:
|
||||
if str(candidate.name or "").lower() == "db_storage":
|
||||
return normalize_key_store_path(str(candidate.parent))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
if str((candidate / "db_storage").name or "").lower() == "db_storage":
|
||||
return normalize_key_store_path(str(candidate))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def _build_saved_key_candidates(account_name: Optional[str], request_account: Optional[str], request_wxid_dir: str) -> list[str]:
|
||||
out: list[str] = []
|
||||
seen: set[str] = set()
|
||||
|
||||
for value in [
|
||||
Path(request_wxid_dir).name if request_wxid_dir else "",
|
||||
str(account_name or "").strip(),
|
||||
str(request_account or "").strip(),
|
||||
]:
|
||||
key = str(value or "").strip()
|
||||
if (not key) or (key in seen):
|
||||
continue
|
||||
seen.add(key)
|
||||
out.append(key)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _evaluate_db_key_candidate(
|
||||
*,
|
||||
store_account: str,
|
||||
keys: dict,
|
||||
account_name: Optional[str],
|
||||
request_wxid_dir: str,
|
||||
request_db_storage_path: str,
|
||||
) -> tuple[bool, int, str]:
|
||||
db_key = str(keys.get("db_key") or "").strip()
|
||||
if not db_key:
|
||||
return False, -1, ""
|
||||
|
||||
source_wxid_dir = normalize_key_store_path(keys.get("db_key_source_wxid_dir"))
|
||||
source_db_storage_path = normalize_key_store_path(keys.get("db_key_source_db_storage_path"))
|
||||
request_wxid_dir_name = Path(request_wxid_dir).name if request_wxid_dir else ""
|
||||
source_wxid_dir_name = Path(source_wxid_dir).name if source_wxid_dir else ""
|
||||
|
||||
if request_db_storage_path and source_db_storage_path:
|
||||
if source_db_storage_path == request_db_storage_path:
|
||||
return True, 400, ""
|
||||
return (
|
||||
False,
|
||||
0,
|
||||
f"Saved db key source does not match current db_storage_path. request={request_db_storage_path} stored={source_db_storage_path}",
|
||||
)
|
||||
|
||||
if request_wxid_dir and source_wxid_dir:
|
||||
if (source_wxid_dir == request_wxid_dir) or (
|
||||
source_wxid_dir_name and source_wxid_dir_name == request_wxid_dir_name
|
||||
):
|
||||
return True, 300, ""
|
||||
return (
|
||||
False,
|
||||
0,
|
||||
f"Saved db key source does not match current wxid_dir. request={request_wxid_dir_name} stored={source_wxid_dir_name or source_wxid_dir}",
|
||||
)
|
||||
|
||||
if request_wxid_dir_name:
|
||||
if store_account == request_wxid_dir_name:
|
||||
return True, 200, ""
|
||||
if account_name and request_wxid_dir_name == str(account_name or "").strip():
|
||||
return True, 100, ""
|
||||
return (
|
||||
False,
|
||||
0,
|
||||
f"Legacy saved db key is ambiguous for current wxid_dir={request_wxid_dir_name}. Please fetch a fresh db key.",
|
||||
)
|
||||
|
||||
return True, 50, ""
|
||||
|
||||
|
||||
@router.get("/api/keys", summary="获取账号已保存的密钥")
|
||||
async def get_saved_keys(account: Optional[str] = None):
|
||||
async def get_saved_keys(
|
||||
account: Optional[str] = None,
|
||||
db_storage_path: Optional[str] = None,
|
||||
wxid_dir: Optional[str] = None,
|
||||
):
|
||||
"""获取账号的数据库密钥与图片密钥(用于前端自动回填)"""
|
||||
account_name: Optional[str] = None
|
||||
account_dir = None
|
||||
@@ -34,16 +133,56 @@ async def get_saved_keys(account: Optional[str] = None):
|
||||
# 账号可能尚未解密;仍允许从全局 store 读取(如果传入了 account)
|
||||
account_name = str(account or "").strip() or None
|
||||
|
||||
request_db_storage_path = normalize_key_store_path(db_storage_path)
|
||||
request_wxid_dir = _resolve_requested_wxid_dir(db_storage_path=db_storage_path, wxid_dir=wxid_dir)
|
||||
candidate_accounts = _build_saved_key_candidates(account_name, account, request_wxid_dir)
|
||||
|
||||
logger.info(
|
||||
"[keys] get_saved_keys start: request_account=%s resolved_account=%s account_dir=%s",
|
||||
"[keys] get_saved_keys start: request_account=%s resolved_account=%s account_dir=%s db_storage_path=%s wxid_dir=%s candidates=%s",
|
||||
str(account or "").strip(),
|
||||
str(account_name or ""),
|
||||
str(account_dir) if account_dir else "",
|
||||
request_db_storage_path,
|
||||
request_wxid_dir,
|
||||
candidate_accounts,
|
||||
)
|
||||
|
||||
keys: dict = {}
|
||||
if account_name:
|
||||
keys = get_account_keys_from_store(account_name)
|
||||
selected_db_key_account = ""
|
||||
selected_db_key_score = -1
|
||||
db_key_blocked_reason = ""
|
||||
db_key_source_wxid_dir = ""
|
||||
db_key_source_db_storage_path = ""
|
||||
|
||||
for candidate_account in candidate_accounts:
|
||||
candidate_keys = get_account_keys_from_store(candidate_account)
|
||||
if not isinstance(candidate_keys, dict) or not candidate_keys:
|
||||
continue
|
||||
|
||||
if not str(keys.get("image_xor_key") or "").strip():
|
||||
keys["image_xor_key"] = str(candidate_keys.get("image_xor_key") or "").strip()
|
||||
if not str(keys.get("image_aes_key") or "").strip():
|
||||
keys["image_aes_key"] = str(candidate_keys.get("image_aes_key") or "").strip()
|
||||
if not str(keys.get("updated_at") or "").strip():
|
||||
keys["updated_at"] = str(candidate_keys.get("updated_at") or "").strip()
|
||||
|
||||
ok, score, blocked_reason = _evaluate_db_key_candidate(
|
||||
store_account=candidate_account,
|
||||
keys=candidate_keys,
|
||||
account_name=account_name,
|
||||
request_wxid_dir=request_wxid_dir,
|
||||
request_db_storage_path=request_db_storage_path,
|
||||
)
|
||||
if ok and score > selected_db_key_score:
|
||||
selected_db_key_score = score
|
||||
selected_db_key_account = candidate_account
|
||||
keys["db_key"] = str(candidate_keys.get("db_key") or "").strip()
|
||||
db_key_source_wxid_dir = normalize_key_store_path(candidate_keys.get("db_key_source_wxid_dir"))
|
||||
db_key_source_db_storage_path = normalize_key_store_path(candidate_keys.get("db_key_source_db_storage_path"))
|
||||
if str(candidate_keys.get("updated_at") or "").strip():
|
||||
keys["updated_at"] = str(candidate_keys.get("updated_at") or "").strip()
|
||||
elif (not ok) and blocked_reason and (not db_key_blocked_reason):
|
||||
db_key_blocked_reason = blocked_reason
|
||||
|
||||
# 兼容:如果 store 里没有图片密钥,尝试从账号目录的 _media_keys.json 读取
|
||||
if account_dir and isinstance(keys, dict):
|
||||
@@ -62,11 +201,18 @@ async def get_saved_keys(account: Optional[str] = None):
|
||||
"image_xor_key": str(keys.get("image_xor_key") or "").strip(),
|
||||
"image_aes_key": str(keys.get("image_aes_key") or "").strip(),
|
||||
"updated_at": str(keys.get("updated_at") or "").strip(),
|
||||
"db_key_source_wxid_dir": db_key_source_wxid_dir,
|
||||
"db_key_source_db_storage_path": db_key_source_db_storage_path,
|
||||
"db_key_store_account": selected_db_key_account,
|
||||
"db_key_blocked_reason": db_key_blocked_reason,
|
||||
}
|
||||
logger.info(
|
||||
"[keys] get_saved_keys done: account=%s db_key_present=%s xor_key=%s aes_key=%s updated_at=%s",
|
||||
"[keys] get_saved_keys done: account=%s db_key_present=%s db_key_store_account=%s db_key_source_wxid_dir=%s blocked_reason=%s xor_key=%s aes_key=%s updated_at=%s",
|
||||
str(account_name or ""),
|
||||
bool(result["db_key"]),
|
||||
result["db_key_store_account"],
|
||||
result["db_key_source_wxid_dir"],
|
||||
result["db_key_blocked_reason"],
|
||||
result["image_xor_key"],
|
||||
_summarize_aes_key(result["image_aes_key"]),
|
||||
result["updated_at"],
|
||||
@@ -80,7 +226,7 @@ async def get_saved_keys(account: Optional[str] = None):
|
||||
|
||||
|
||||
@router.get("/api/get_keys", summary="自动获取微信数据库与图片密钥")
|
||||
async def get_wechat_db_key():
|
||||
async def get_wechat_db_key(wechat_install_path: Optional[str] = None):
|
||||
"""
|
||||
自动流程:
|
||||
1. 结束微信进程
|
||||
@@ -89,7 +235,11 @@ async def get_wechat_db_key():
|
||||
4. 抓取 DB 与 图片密钥(AES + XOR)并返回
|
||||
"""
|
||||
try:
|
||||
keys_data = get_db_key_workflow()
|
||||
logger.info(
|
||||
"[keys] get_wechat_db_key start: wechat_install_path=%s",
|
||||
str(wechat_install_path or "").strip(),
|
||||
)
|
||||
keys_data = get_db_key_workflow(wechat_install_path=wechat_install_path)
|
||||
|
||||
return {
|
||||
"status": 0,
|
||||
|
||||
@@ -731,6 +731,8 @@ def decrypt_wechat_databases(db_storage_path: str = None, key: str = None) -> di
|
||||
"success": account_success,
|
||||
"failed": len(databases) - account_success,
|
||||
"output_dir": str(account_output_dir),
|
||||
"source_db_storage_path": str(source_db_storage_path),
|
||||
"source_wxid_dir": str(wxid_dir),
|
||||
"processed_files": account_processed,
|
||||
"failed_files": account_failed,
|
||||
"db_diagnostics": account_db_diagnostics,
|
||||
|
||||
@@ -7,6 +7,7 @@ import unittest
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from unittest import mock
|
||||
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
@@ -64,35 +65,43 @@ class TestDecryptStreamSSE(unittest.TestCase):
|
||||
client = TestClient(app)
|
||||
|
||||
events: list[dict] = []
|
||||
with client.stream(
|
||||
"GET",
|
||||
"/api/decrypt_stream",
|
||||
params={"key": "00" * 32, "db_storage_path": str(db_storage)},
|
||||
) as resp:
|
||||
self.assertEqual(resp.status_code, 200)
|
||||
self.assertIn("text/event-stream", resp.headers.get("content-type", ""))
|
||||
with mock.patch.object(decrypt_router, "upsert_account_keys_in_store") as upsert_mock:
|
||||
with client.stream(
|
||||
"GET",
|
||||
"/api/decrypt_stream",
|
||||
params={"key": "00" * 32, "db_storage_path": str(db_storage)},
|
||||
) as resp:
|
||||
self.assertEqual(resp.status_code, 200)
|
||||
self.assertIn("text/event-stream", resp.headers.get("content-type", ""))
|
||||
|
||||
for line in resp.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
if isinstance(line, bytes):
|
||||
line = line.decode("utf-8", errors="ignore")
|
||||
line = str(line)
|
||||
for line in resp.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
if isinstance(line, bytes):
|
||||
line = line.decode("utf-8", errors="ignore")
|
||||
line = str(line)
|
||||
|
||||
if line.startswith(":"):
|
||||
continue
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
payload = json.loads(line[len("data: ") :])
|
||||
events.append(payload)
|
||||
if payload.get("type") in {"complete", "error"}:
|
||||
break
|
||||
if line.startswith(":"):
|
||||
continue
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
payload = json.loads(line[len("data: ") :])
|
||||
events.append(payload)
|
||||
if payload.get("type") in {"complete", "error"}:
|
||||
break
|
||||
|
||||
types = {e.get("type") for e in events}
|
||||
self.assertIn("start", types)
|
||||
self.assertIn("progress", types)
|
||||
self.assertEqual(events[-1].get("type"), "complete")
|
||||
self.assertEqual(events[-1].get("status"), "completed")
|
||||
upsert_mock.assert_called_once_with(
|
||||
"wxid_foo",
|
||||
db_key="00" * 32,
|
||||
aliases=["wxid_foo_bar"],
|
||||
db_key_source_wxid_dir=str(db_storage.parent),
|
||||
db_key_source_db_storage_path=str(db_storage),
|
||||
)
|
||||
|
||||
out = root / "output" / "databases" / "wxid_foo" / "MSG0.db"
|
||||
self.assertTrue(out.exists())
|
||||
@@ -135,35 +144,37 @@ class TestDecryptStreamSSE(unittest.TestCase):
|
||||
client = TestClient(app)
|
||||
|
||||
events: list[dict] = []
|
||||
with client.stream(
|
||||
"GET",
|
||||
"/api/decrypt_stream",
|
||||
params={"key": "00" * 32, "db_storage_path": str(db_storage)},
|
||||
) as resp:
|
||||
self.assertEqual(resp.status_code, 200)
|
||||
self.assertIn("text/event-stream", resp.headers.get("content-type", ""))
|
||||
with mock.patch.object(decrypt_router, "upsert_account_keys_in_store") as upsert_mock:
|
||||
with client.stream(
|
||||
"GET",
|
||||
"/api/decrypt_stream",
|
||||
params={"key": "00" * 32, "db_storage_path": str(db_storage)},
|
||||
) as resp:
|
||||
self.assertEqual(resp.status_code, 200)
|
||||
self.assertIn("text/event-stream", resp.headers.get("content-type", ""))
|
||||
|
||||
for line in resp.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
if isinstance(line, bytes):
|
||||
line = line.decode("utf-8", errors="ignore")
|
||||
line = str(line)
|
||||
for line in resp.iter_lines():
|
||||
if not line:
|
||||
continue
|
||||
if isinstance(line, bytes):
|
||||
line = line.decode("utf-8", errors="ignore")
|
||||
line = str(line)
|
||||
|
||||
if line.startswith(":"):
|
||||
continue
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
payload = json.loads(line[len("data: ") :])
|
||||
events.append(payload)
|
||||
if payload.get("type") in {"complete", "error"}:
|
||||
break
|
||||
if line.startswith(":"):
|
||||
continue
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
payload = json.loads(line[len("data: ") :])
|
||||
events.append(payload)
|
||||
if payload.get("type") in {"complete", "error"}:
|
||||
break
|
||||
|
||||
self.assertEqual(events[-1].get("type"), "complete")
|
||||
self.assertEqual(events[-1].get("status"), "failed")
|
||||
self.assertEqual(events[-1].get("success_count"), 0)
|
||||
self.assertEqual(events[-1].get("failure_count"), 1)
|
||||
self.assertIn("密钥可能不匹配", str(events[-1].get("message") or ""))
|
||||
upsert_mock.assert_not_called()
|
||||
|
||||
out = root / "output" / "databases" / "wxid_bad" / "MSG0.db"
|
||||
self.assertFalse(out.exists())
|
||||
|
||||
@@ -0,0 +1,115 @@
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from unittest import mock
|
||||
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(ROOT / "src"))
|
||||
|
||||
|
||||
import wechat_decrypt_tool.key_service as key_service
|
||||
|
||||
|
||||
class _FakeWxKey:
|
||||
def __init__(self, key: str) -> None:
|
||||
self.key = key
|
||||
self.initialize_calls: list[int] = []
|
||||
self.cleanup_calls = 0
|
||||
|
||||
def initialize_hook(self, pid: int) -> bool:
|
||||
self.initialize_calls.append(pid)
|
||||
return True
|
||||
|
||||
def get_last_error_msg(self) -> str:
|
||||
return ""
|
||||
|
||||
def poll_key_data(self):
|
||||
return {"key": self.key}
|
||||
|
||||
def get_status_message(self):
|
||||
return None, None
|
||||
|
||||
def cleanup_hook(self) -> None:
|
||||
self.cleanup_calls += 1
|
||||
|
||||
|
||||
class TestKeyServiceManualWechatInstallPath(unittest.TestCase):
|
||||
def test_get_db_key_workflow_can_use_manual_install_directory(self) -> None:
|
||||
fake_wx_key = _FakeWxKey("a" * 64)
|
||||
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
install_dir = Path(temp_dir)
|
||||
exe_path = install_dir / "WeChat.exe"
|
||||
exe_path.write_bytes(b"")
|
||||
|
||||
with mock.patch.object(
|
||||
key_service,
|
||||
"wx_key",
|
||||
fake_wx_key,
|
||||
), mock.patch.object(
|
||||
key_service,
|
||||
"detect_wechat_installation",
|
||||
side_effect=AssertionError("should not auto-detect when manual path is provided"),
|
||||
), mock.patch.object(
|
||||
key_service,
|
||||
"_read_wechat_version_from_exe",
|
||||
return_value="",
|
||||
), mock.patch.object(
|
||||
key_service.WeChatKeyFetcher,
|
||||
"kill_wechat",
|
||||
autospec=True,
|
||||
) as kill_mock, mock.patch.object(
|
||||
key_service.WeChatKeyFetcher,
|
||||
"launch_wechat",
|
||||
autospec=True,
|
||||
return_value=4321,
|
||||
) as launch_mock:
|
||||
result = key_service.get_db_key_workflow(wechat_install_path=str(install_dir))
|
||||
|
||||
self.assertEqual(result["db_key"], "a" * 64)
|
||||
kill_mock.assert_called_once()
|
||||
launch_mock.assert_called_once()
|
||||
_, used_exe_path = launch_mock.call_args.args
|
||||
self.assertEqual(used_exe_path, str(exe_path))
|
||||
self.assertEqual(fake_wx_key.initialize_calls, [4321])
|
||||
self.assertEqual(fake_wx_key.cleanup_calls, 1)
|
||||
|
||||
def test_get_db_key_workflow_does_not_require_detected_version(self) -> None:
|
||||
fake_wx_key = _FakeWxKey("b" * 64)
|
||||
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
exe_path = Path(temp_dir) / "Weixin.exe"
|
||||
exe_path.write_bytes(b"")
|
||||
|
||||
with mock.patch.object(
|
||||
key_service,
|
||||
"wx_key",
|
||||
fake_wx_key,
|
||||
), mock.patch.object(
|
||||
key_service,
|
||||
"detect_wechat_installation",
|
||||
return_value={
|
||||
"wechat_exe_path": str(exe_path),
|
||||
"wechat_version": "",
|
||||
},
|
||||
), mock.patch.object(
|
||||
key_service.WeChatKeyFetcher,
|
||||
"kill_wechat",
|
||||
autospec=True,
|
||||
), mock.patch.object(
|
||||
key_service.WeChatKeyFetcher,
|
||||
"launch_wechat",
|
||||
autospec=True,
|
||||
return_value=2468,
|
||||
):
|
||||
result = key_service.get_db_key_workflow()
|
||||
|
||||
self.assertEqual(result["db_key"], "b" * 64)
|
||||
self.assertEqual(fake_wx_key.initialize_calls, [2468])
|
||||
self.assertEqual(fake_wx_key.cleanup_calls, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -0,0 +1,107 @@
|
||||
import asyncio
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(ROOT / "src"))
|
||||
|
||||
|
||||
def _close_logging_handlers() -> None:
|
||||
for logger_name in ("", "uvicorn", "uvicorn.access", "uvicorn.error", "fastapi"):
|
||||
lg = logging.getLogger(logger_name)
|
||||
for handler in lg.handlers[:]:
|
||||
try:
|
||||
handler.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
lg.removeHandler(handler)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class TestSavedDbKeySourceValidation(unittest.TestCase):
|
||||
def test_get_saved_keys_blocks_legacy_db_key_for_suffixed_wxid_dir(self) -> None:
|
||||
with TemporaryDirectory() as td:
|
||||
root = Path(td)
|
||||
db_storage = root / "xwechat_files" / "wxid_demo_abcd" / "db_storage"
|
||||
db_storage.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
prev_data_dir = os.environ.get("WECHAT_TOOL_DATA_DIR")
|
||||
try:
|
||||
os.environ["WECHAT_TOOL_DATA_DIR"] = str(root)
|
||||
|
||||
import wechat_decrypt_tool.app_paths as app_paths
|
||||
import wechat_decrypt_tool.key_store as key_store
|
||||
import wechat_decrypt_tool.routers.keys as keys_router
|
||||
|
||||
importlib.reload(app_paths)
|
||||
importlib.reload(key_store)
|
||||
importlib.reload(keys_router)
|
||||
|
||||
key_store.upsert_account_keys_in_store("wxid_demo", db_key="A" * 64)
|
||||
result = asyncio.run(
|
||||
keys_router.get_saved_keys(account="wxid_demo", db_storage_path=str(db_storage))
|
||||
)
|
||||
|
||||
self.assertEqual(result["status"], "success")
|
||||
self.assertEqual(result["keys"]["db_key"], "")
|
||||
self.assertIn("Legacy saved db key is ambiguous", result["keys"]["db_key_blocked_reason"])
|
||||
finally:
|
||||
_close_logging_handlers()
|
||||
if prev_data_dir is None:
|
||||
os.environ.pop("WECHAT_TOOL_DATA_DIR", None)
|
||||
else:
|
||||
os.environ["WECHAT_TOOL_DATA_DIR"] = prev_data_dir
|
||||
|
||||
def test_get_saved_keys_accepts_source_matched_db_key(self) -> None:
|
||||
with TemporaryDirectory() as td:
|
||||
root = Path(td)
|
||||
db_storage = root / "xwechat_files" / "wxid_demo_abcd" / "db_storage"
|
||||
db_storage.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
prev_data_dir = os.environ.get("WECHAT_TOOL_DATA_DIR")
|
||||
try:
|
||||
os.environ["WECHAT_TOOL_DATA_DIR"] = str(root)
|
||||
|
||||
import wechat_decrypt_tool.app_paths as app_paths
|
||||
import wechat_decrypt_tool.key_store as key_store
|
||||
import wechat_decrypt_tool.routers.keys as keys_router
|
||||
|
||||
importlib.reload(app_paths)
|
||||
importlib.reload(key_store)
|
||||
importlib.reload(keys_router)
|
||||
|
||||
key_store.upsert_account_keys_in_store(
|
||||
"wxid_demo",
|
||||
db_key="B" * 64,
|
||||
aliases=["wxid_demo_abcd"],
|
||||
db_key_source_wxid_dir=str(db_storage.parent),
|
||||
db_key_source_db_storage_path=str(db_storage),
|
||||
)
|
||||
result = asyncio.run(
|
||||
keys_router.get_saved_keys(account="wxid_demo", db_storage_path=str(db_storage))
|
||||
)
|
||||
|
||||
self.assertEqual(result["status"], "success")
|
||||
self.assertEqual(result["keys"]["db_key"], "B" * 64)
|
||||
self.assertEqual(result["keys"]["db_key_store_account"], "wxid_demo_abcd")
|
||||
self.assertEqual(result["keys"]["db_key_source_wxid_dir"], str(db_storage.parent))
|
||||
self.assertEqual(result["keys"]["db_key_source_db_storage_path"], str(db_storage))
|
||||
self.assertEqual(result["keys"]["db_key_blocked_reason"], "")
|
||||
finally:
|
||||
_close_logging_handlers()
|
||||
if prev_data_dir is None:
|
||||
os.environ.pop("WECHAT_TOOL_DATA_DIR", None)
|
||||
else:
|
||||
os.environ["WECHAT_TOOL_DATA_DIR"] = prev_data_dir
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user