feat(media): 新增表情批量下载步骤并支持并发配置

- 解密页新增表情下载步骤,支持开始/停止、进度展示和结果统计\n- 图片解密与表情下载接口支持并发配置,补充 SSE 进度与结果信息\n- 增加表情目录聚合、缓存校验与媒体下载相关测试
This commit is contained in:
2977094657
2026-04-15 01:26:44 +08:00
Unverified
parent 1fc937ed7c
commit 23932bf89c
6 changed files with 3528 additions and 500 deletions
+81 -21
View File
@@ -24,7 +24,80 @@ class _FakeDisconnectingRequest:
return self._calls >= self._disconnect_after
async def _read_sse_events(response) -> list[dict]:
chunks = []
async for chunk in response.body_iterator:
chunks.append(chunk.decode("utf-8") if isinstance(chunk, bytes) else str(chunk))
events = []
for chunk in chunks:
for line in chunk.splitlines():
if line.startswith("data: "):
events.append(json.loads(line[len("data: ") :]))
return events
class TestMediaDecryptStreamCancel(unittest.TestCase):
def test_stream_uses_default_concurrency(self):
with TemporaryDirectory() as td:
root = Path(td)
account_dir = root / "account"
wxid_dir = root / "wxid"
dat_path = wxid_dir / "image.dat"
resource_dir = account_dir / "resource"
wxid_dir.mkdir(parents=True, exist_ok=True)
dat_path.write_bytes(b"encrypted")
with mock.patch.object(media_router, "_resolve_account_dir", return_value=account_dir):
with mock.patch.object(media_router, "_resolve_account_wxid_dir", return_value=wxid_dir):
with mock.patch.object(media_router, "_load_media_keys", return_value={"xor": 0xA5, "aes": ""}):
with mock.patch.object(media_router, "_collect_all_dat_files", return_value=[(dat_path, "abc123")]):
with mock.patch.object(media_router, "_get_resource_dir", return_value=resource_dir):
with mock.patch.object(media_router, "_try_find_decrypted_resource", return_value=None):
with mock.patch.object(media_router, "_decrypt_and_save_resource", return_value=(True, "ok")):
response = asyncio.run(
media_router.decrypt_all_media_stream(
request=_FakeDisconnectingRequest(disconnect_after=999),
account="wxid_demo",
)
)
events = asyncio.run(_read_sse_events(response))
self.assertEqual([event.get("type") for event in events], ["scanning", "start", "progress", "complete"])
self.assertEqual(events[1].get("concurrency"), 10)
self.assertEqual(events[2].get("concurrency"), 10)
self.assertEqual(events[3].get("concurrency"), 10)
def test_stream_uses_requested_concurrency(self):
with TemporaryDirectory() as td:
root = Path(td)
account_dir = root / "account"
wxid_dir = root / "wxid"
dat_path = wxid_dir / "image.dat"
resource_dir = account_dir / "resource"
wxid_dir.mkdir(parents=True, exist_ok=True)
dat_path.write_bytes(b"encrypted")
with mock.patch.object(media_router, "_resolve_account_dir", return_value=account_dir):
with mock.patch.object(media_router, "_resolve_account_wxid_dir", return_value=wxid_dir):
with mock.patch.object(media_router, "_load_media_keys", return_value={"xor": 0xA5, "aes": ""}):
with mock.patch.object(media_router, "_collect_all_dat_files", return_value=[(dat_path, "abc123")]):
with mock.patch.object(media_router, "_get_resource_dir", return_value=resource_dir):
with mock.patch.object(media_router, "_try_find_decrypted_resource", return_value=None):
with mock.patch.object(media_router, "_decrypt_and_save_resource", return_value=(True, "ok")):
response = asyncio.run(
media_router.decrypt_all_media_stream(
request=_FakeDisconnectingRequest(disconnect_after=999),
account="wxid_demo",
concurrency=7,
)
)
events = asyncio.run(_read_sse_events(response))
self.assertEqual(events[1].get("concurrency"), 7)
self.assertEqual(events[2].get("concurrency"), 7)
self.assertEqual(events[3].get("concurrency"), 7)
def test_stream_stops_processing_when_client_disconnects(self):
with TemporaryDirectory() as td:
root = Path(td)
@@ -43,28 +116,15 @@ class TestMediaDecryptStreamCancel(unittest.TestCase):
with mock.patch.object(media_router, "_load_media_keys", return_value={"xor": 0xA5, "aes": ""}):
with mock.patch.object(media_router, "_collect_all_dat_files", return_value=[(dat_path, "abc123")]):
with mock.patch.object(media_router, "_get_resource_dir", return_value=resource_dir):
with mock.patch.object(media_router, "_try_find_decrypted_resource", return_value=None):
with mock.patch.object(media_router, "_decrypt_and_save_resource", decrypt_mock):
response = asyncio.run(
media_router.decrypt_all_media_stream(
request=request,
account="wxid_demo",
with mock.patch.object(media_router, "_try_find_decrypted_resource", return_value=None):
with mock.patch.object(media_router, "_decrypt_and_save_resource", decrypt_mock):
response = asyncio.run(
media_router.decrypt_all_media_stream(
request=request,
account="wxid_demo",
)
)
)
async def _read_chunks():
chunks = []
async for chunk in response.body_iterator:
chunks.append(chunk.decode("utf-8") if isinstance(chunk, bytes) else str(chunk))
return chunks
chunks = asyncio.run(_read_chunks())
events = []
for chunk in chunks:
for line in chunk.splitlines():
if line.startswith("data: "):
events.append(json.loads(line[len("data: ") :]))
events = asyncio.run(_read_sse_events(response))
self.assertEqual([event.get("type") for event in events], ["scanning", "start"])
decrypt_mock.assert_not_called()
+190
View File
@@ -0,0 +1,190 @@
import asyncio
import json
import sys
import unittest
from pathlib import Path
from tempfile import TemporaryDirectory
from unittest import mock
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT / "src"))
from wechat_decrypt_tool.routers import media as media_router # noqa: E402 pylint: disable=wrong-import-position
PNG_1X1 = bytes.fromhex(
"89504E470D0A1A0A"
"0000000D49484452000000010000000108060000001F15C489"
"0000000D49444154789C6360606060000000050001A5F64540"
"0000000049454E44AE426082"
)
class _FakeRequest:
async def is_disconnected(self):
return False
class _FakeDisconnectingRequest:
def __init__(self, disconnect_after: int):
self._disconnect_after = disconnect_after
self._calls = 0
async def is_disconnected(self):
self._calls += 1
return self._calls >= self._disconnect_after
def _emoji_catalog(md5: str):
return (
{
md5: {
"md5": md5,
"urls": [f"https://example.com/{md5}.png"],
"aes_keys": [],
"sources": ["message_xml"],
}
},
{
"total_candidates": 1,
"total_candidates_with_url": 1,
"source_counts": {"message_xml": 1},
},
)
async def _read_sse_events(response) -> list[dict]:
chunks = []
async for chunk in response.body_iterator:
chunks.append(chunk.decode("utf-8") if isinstance(chunk, bytes) else str(chunk))
events = []
for chunk in chunks:
for line in chunk.splitlines():
if line.startswith("data: "):
events.append(json.loads(line[len("data: ") :]))
return events
class TestMediaEmojiDownloadStream(unittest.TestCase):
def test_stream_downloads_missing_emoji_and_saves_resource(self):
with TemporaryDirectory() as td:
account_dir = Path(td) / "account"
account_dir.mkdir(parents=True, exist_ok=True)
md5 = "a" * 32
with mock.patch.object(media_router, "_resolve_account_dir", return_value=account_dir):
with mock.patch.object(
media_router,
"_collect_emoticon_download_catalog",
return_value=_emoji_catalog(md5),
):
with mock.patch.object(
media_router,
"_try_fetch_emoticon_from_remote",
return_value=(PNG_1X1, "image/png"),
) as fetch_mock:
response = asyncio.run(
media_router.download_all_emojis_stream(
request=_FakeRequest(),
account="wxid_demo",
)
)
events = asyncio.run(_read_sse_events(response))
self.assertEqual([event.get("type") for event in events], ["scanning", "start", "progress", "complete"])
self.assertEqual(events[2].get("status"), "success")
self.assertEqual(events[3].get("success_count"), 1)
self.assertEqual(events[1].get("concurrency"), 20)
self.assertTrue((account_dir / "resource" / md5[:2] / f"{md5}.png").exists())
fetch_mock.assert_called_once()
def test_stream_uses_requested_concurrency(self):
with TemporaryDirectory() as td:
account_dir = Path(td) / "account"
account_dir.mkdir(parents=True, exist_ok=True)
md5 = "d" * 32
with mock.patch.object(media_router, "_resolve_account_dir", return_value=account_dir):
with mock.patch.object(
media_router,
"_collect_emoticon_download_catalog",
return_value=_emoji_catalog(md5),
):
with mock.patch.object(
media_router,
"_try_fetch_emoticon_from_remote",
return_value=(PNG_1X1, "image/png"),
):
response = asyncio.run(
media_router.download_all_emojis_stream(
request=_FakeRequest(),
account="wxid_demo",
concurrency=7,
)
)
events = asyncio.run(_read_sse_events(response))
self.assertEqual(events[1].get("concurrency"), 7)
self.assertEqual(events[2].get("concurrency"), 7)
self.assertEqual(events[3].get("concurrency"), 7)
def test_stream_skips_existing_downloaded_emoji(self):
with TemporaryDirectory() as td:
account_dir = Path(td) / "account"
md5 = "b" * 32
resource_dir = account_dir / "resource" / md5[:2]
account_dir.mkdir(parents=True, exist_ok=True)
resource_dir.mkdir(parents=True, exist_ok=True)
cached = resource_dir / f"{md5}.png"
cached.write_bytes(PNG_1X1)
with mock.patch.object(media_router, "_resolve_account_dir", return_value=account_dir):
with mock.patch.object(
media_router,
"_collect_emoticon_download_catalog",
return_value=_emoji_catalog(md5),
):
with mock.patch.object(media_router, "_try_fetch_emoticon_from_remote") as fetch_mock:
response = asyncio.run(
media_router.download_all_emojis_stream(
request=_FakeRequest(),
account="wxid_demo",
)
)
events = asyncio.run(_read_sse_events(response))
self.assertEqual([event.get("type") for event in events], ["scanning", "start", "progress", "complete"])
self.assertEqual(events[2].get("status"), "skip")
self.assertEqual(events[3].get("skip_count"), 1)
fetch_mock.assert_not_called()
def test_stream_stops_before_processing_when_client_disconnects(self):
with TemporaryDirectory() as td:
account_dir = Path(td) / "account"
account_dir.mkdir(parents=True, exist_ok=True)
md5 = "c" * 32
with mock.patch.object(media_router, "_resolve_account_dir", return_value=account_dir):
with mock.patch.object(
media_router,
"_collect_emoticon_download_catalog",
return_value=_emoji_catalog(md5),
):
with mock.patch.object(media_router, "_try_fetch_emoticon_from_remote") as fetch_mock:
response = asyncio.run(
media_router.download_all_emojis_stream(
request=_FakeDisconnectingRequest(disconnect_after=3),
account="wxid_demo",
)
)
events = asyncio.run(_read_sse_events(response))
self.assertEqual([event.get("type") for event in events], ["scanning", "start"])
fetch_mock.assert_not_called()
if __name__ == "__main__":
unittest.main()
+109
View File
@@ -0,0 +1,109 @@
import sqlite3
import sys
import unittest
from pathlib import Path
from tempfile import TemporaryDirectory
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT / "src"))
from wechat_decrypt_tool.media_helpers import ( # noqa: E402 pylint: disable=wrong-import-position
_collect_emoticon_download_catalog,
_lookup_emoticon_info,
)
class TestMediaEmoticonCatalog(unittest.TestCase):
def test_catalog_merges_emoticon_db_extern_md5_and_message_xml(self):
with TemporaryDirectory() as td:
account_dir = Path(td) / "account"
account_dir.mkdir(parents=True, exist_ok=True)
primary_md5 = "a" * 32
extern_md5 = "b" * 32
message_md5 = "c" * 32
no_url_md5 = "d" * 32
message_extern_md5 = "e" * 32
aes_key = "1" * 32
conn = sqlite3.connect(str(account_dir / "emoticon.db"))
conn.execute(
"CREATE TABLE kNonStoreEmoticonTable ("
"md5 TEXT, extern_md5 TEXT, aes_key TEXT, cdn_url TEXT, encrypt_url TEXT, "
"extern_url TEXT, thumb_url TEXT, tp_url TEXT)"
)
conn.execute(
"INSERT INTO kNonStoreEmoticonTable VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
(
primary_md5,
extern_md5,
aes_key,
f"https://example.com/{primary_md5}.gif",
"",
"",
"",
"",
),
)
conn.commit()
conn.close()
conn = sqlite3.connect(str(account_dir / "message_0.db"))
conn.execute(
"CREATE TABLE Msg_demo ("
"local_type INTEGER, compress_content BLOB, message_content BLOB, packed_info_data BLOB)"
)
conn.executemany(
"INSERT INTO Msg_demo VALUES (?, ?, ?, ?)",
[
(
47,
None,
(
f'<msg><emoji md5="{message_md5}" externmd5="{message_extern_md5}" '
f'aeskey="{aes_key}" cdnurl="https://example.com/{message_md5}.png" /></msg>'
),
bytes([0x10, 0x45]),
),
(
47,
None,
f'<msg><emoji md5="{primary_md5}" cdnurl="https://example.com/{primary_md5}-2.png" /></msg>',
bytes([0x10, 0x45]),
),
(
47,
None,
f'<msg><emoji md5="{no_url_md5}" /></msg>',
bytes([0x10, 0x45]),
),
],
)
conn.commit()
conn.close()
catalog, stats = _collect_emoticon_download_catalog(account_dir)
self.assertEqual(set(catalog), {primary_md5, extern_md5, message_md5})
self.assertIn("emoticon_db_md5", catalog[primary_md5]["sources"])
self.assertIn("message_xml", catalog[primary_md5]["sources"])
self.assertIn("emoticon_db_extern_md5", catalog[extern_md5]["sources"])
self.assertIn("message_xml", catalog[message_md5]["sources"])
self.assertNotIn(no_url_md5, catalog)
self.assertEqual(stats["emoticon_db_md5"], 1)
self.assertEqual(stats["emoticon_db_extern_md5"], 1)
self.assertEqual(stats["message_xml_rows"], 3)
self.assertEqual(stats["message_xml_md5"], 3)
self.assertEqual(stats["message_xml_md5_with_url"], 2)
self.assertEqual(stats["message_xml_extern_md5"], 1)
self.assertEqual(stats["message_builtin_expr_ids"], 1)
self.assertEqual(stats["source_counts"]["message_xml"], 2)
info = _lookup_emoticon_info(str(account_dir), extern_md5)
self.assertEqual(info["md5"], primary_md5)
self.assertEqual(info["extern_md5"], extern_md5)
if __name__ == "__main__":
unittest.main()