improvement(chat): 增强 realtime 增量同步并补充消息搜索索引接口

- 新增后台 autosync:监听 db_storage 变化后触发 realtime -> decrypted 增量同步(去抖/限频)\n- 优化 WCDB realtime 关闭:支持锁超时,避免 busy 时强行 shutdown\n- 新增消息搜索索引相关接口(status/build/senders)\n- 前端关闭 realtime 前改为 sync_all,减少切回解密库后的列表/消息落后\n- 增加解密库消息表/索引创建相关单测
This commit is contained in:
2977094657
2026-02-03 16:31:31 +08:00
parent 625526ff3b
commit 3297f24f52
7 changed files with 975 additions and 71 deletions

View File

@@ -99,6 +99,7 @@ export const useApi = () => {
if (params && params.account) query.set('account', params.account) if (params && params.account) query.set('account', params.account)
if (params && params.username) query.set('username', params.username) if (params && params.username) query.set('username', params.username)
if (params && params.max_scan != null) query.set('max_scan', String(params.max_scan)) if (params && params.max_scan != null) query.set('max_scan', String(params.max_scan))
if (params && params.backfill_limit != null) query.set('backfill_limit', String(params.backfill_limit))
const url = '/chat/realtime/sync' + (query.toString() ? `?${query.toString()}` : '') const url = '/chat/realtime/sync' + (query.toString() ? `?${query.toString()}` : '')
return await request(url, { method: 'POST' }) return await request(url, { method: 'POST' })
} }

View File

@@ -60,6 +60,35 @@
</div> </div>
</div> </div>
<!-- 年度总结图标 -->
<div
class="w-full h-[var(--sidebar-rail-step)] flex items-center justify-center cursor-pointer group"
title="年度总结"
@click="goWrapped"
>
<div
class="w-[var(--sidebar-rail-btn)] h-[var(--sidebar-rail-btn)] rounded-md flex items-center justify-center transition-colors bg-transparent group-hover:bg-[#E1E1E1]"
>
<div class="w-[var(--sidebar-rail-icon)] h-[var(--sidebar-rail-icon)]" :class="isWrappedRoute ? 'text-[#07b75b]' : 'text-[#5d5d5d]'">
<svg
class="w-full h-full"
viewBox="0 0 24 24"
fill="none"
stroke="currentColor"
stroke-width="1.5"
stroke-linecap="round"
stroke-linejoin="round"
aria-hidden="true"
>
<rect x="4" y="4" width="16" height="16" rx="2" />
<path d="M8 16v-5" />
<path d="M12 16v-8" />
<path d="M16 16v-3" />
</svg>
</div>
</div>
</div>
<!-- 隐私模式按钮 --> <!-- 隐私模式按钮 -->
<div <div
class="w-full h-[var(--sidebar-rail-step)] flex items-center justify-center cursor-pointer group" class="w-full h-[var(--sidebar-rail-step)] flex items-center justify-center cursor-pointer group"
@@ -177,7 +206,7 @@
<!-- 联系人头像 --> <!-- 联系人头像 -->
<div class="w-[calc(45px/var(--dpr))] h-[calc(45px/var(--dpr))] rounded-md overflow-hidden bg-gray-300 flex-shrink-0" :class="{ 'privacy-blur': privacyMode }"> <div class="w-[calc(45px/var(--dpr))] h-[calc(45px/var(--dpr))] rounded-md overflow-hidden bg-gray-300 flex-shrink-0" :class="{ 'privacy-blur': privacyMode }">
<div v-if="contact.avatar" class="w-full h-full"> <div v-if="contact.avatar" class="w-full h-full">
<img :src="contact.avatar" :alt="contact.name" class="w-full h-full object-cover"> <img :src="contact.avatar" :alt="contact.name" class="w-full h-full object-cover" referrerpolicy="no-referrer" @error="onAvatarError($event, contact)">
</div> </div>
<div v-else class="w-full h-full flex items-center justify-center text-white text-xs font-bold" <div v-else class="w-full h-full flex items-center justify-center text-white text-xs font-bold"
:style="{ backgroundColor: contact.avatarColor || '#4B5563' }"> :style="{ backgroundColor: contact.avatarColor || '#4B5563' }">
@@ -358,7 +387,7 @@
:alt="message.sender + '的头像'" :alt="message.sender + '的头像'"
class="w-full h-full object-cover" class="w-full h-full object-cover"
referrerpolicy="no-referrer" referrerpolicy="no-referrer"
@error="onMessageAvatarError($event, message)" @error="onAvatarError($event, message)"
> >
</div> </div>
<div v-else class="w-full h-full flex items-center justify-center text-white text-xs font-bold" <div v-else class="w-full h-full flex items-center justify-center text-white text-xs font-bold"
@@ -1371,7 +1400,7 @@
> >
<input type="checkbox" :value="c.username" v-model="exportSelectedUsernames" /> <input type="checkbox" :value="c.username" v-model="exportSelectedUsernames" />
<div class="w-9 h-9 rounded-md overflow-hidden bg-gray-200 flex-shrink-0" :class="{ 'privacy-blur': privacyMode }"> <div class="w-9 h-9 rounded-md overflow-hidden bg-gray-200 flex-shrink-0" :class="{ 'privacy-blur': privacyMode }">
<img v-if="c.avatar" :src="c.avatar" :alt="c.name + '头像'" class="w-full h-full object-cover" /> <img v-if="c.avatar" :src="c.avatar" :alt="c.name + '头像'" class="w-full h-full object-cover" referrerpolicy="no-referrer" @error="onAvatarError($event, c)" />
<div v-else class="w-full h-full flex items-center justify-center text-xs font-bold text-gray-600"> <div v-else class="w-full h-full flex items-center justify-center text-xs font-bold text-gray-600">
{{ (c.name || c.username || '?').charAt(0) }} {{ (c.name || c.username || '?').charAt(0) }}
</div> </div>
@@ -1779,6 +1808,7 @@ useHead({
const route = useRoute() const route = useRoute()
const isSnsRoute = computed(() => route.path?.startsWith('/sns')) const isSnsRoute = computed(() => route.path?.startsWith('/sns'))
const isWrappedRoute = computed(() => route.path?.startsWith('/wrapped'))
const routeUsername = computed(() => { const routeUsername = computed(() => {
const raw = route.params.username const raw = route.params.username
@@ -2098,6 +2128,10 @@ const goSns = async () => {
await navigateTo('/sns') await navigateTo('/sns')
} }
const goWrapped = async () => {
await navigateTo('/wrapped')
}
// 实时更新WCDB DLL + db_storage watcher // 实时更新WCDB DLL + db_storage watcher
const realtimeEnabled = ref(false) const realtimeEnabled = ref(false)
const realtimeAvailable = ref(false) const realtimeAvailable = ref(false)
@@ -4319,10 +4353,10 @@ const normalizeMessage = (msg) => {
} }
} }
const onMessageAvatarError = (e, message) => { const onAvatarError = (e, target) => {
// Make sure we fall back to the initial avatar if the URL 404s/blocks. // Make sure we fall back to the initial avatar if the URL 404s/blocks.
try { e?.target && (e.target.style.display = 'none') } catch {} try { e?.target && (e.target.style.display = 'none') } catch {}
try { if (message) message.avatar = null } catch {} try { if (target) target.avatar = null } catch {}
} }
const shouldShowEmojiDownload = (message) => { const shouldShowEmojiDownload = (message) => {
@@ -5118,14 +5152,16 @@ const toggleRealtime = async (opts = {}) => {
try { try {
const api = useApi() const api = useApi()
const u = String(selectedContact.value?.username || '').trim() const u = String(selectedContact.value?.username || '').trim()
if (u) { // Sync all sessions once before falling back to the decrypted snapshot.
// Use a larger scan window on shutdown to reduce the chance of missing a backlog. // This keeps the sidebar session list consistent (e.g. new friends) after a refresh.
await api.syncChatRealtimeMessages({ await api.syncChatRealtimeAll({
account: selectedAccount.value, account: selectedAccount.value,
username: u, max_scan: 200,
max_scan: 5000 priority_username: u,
priority_max_scan: 5000,
include_hidden: true,
include_official: true
}) })
}
} catch {} } catch {}
await refreshSessionsForSelectedAccount({ sourceOverride: '' }) await refreshSessionsForSelectedAccount({ sourceOverride: '' })
if (selectedContact.value?.username) { if (selectedContact.value?.username) {

View File

@@ -11,6 +11,7 @@ from starlette.staticfiles import StaticFiles
from .logging_config import setup_logging, get_logger from .logging_config import setup_logging, get_logger
from .path_fix import PathFixRoute from .path_fix import PathFixRoute
from .chat_realtime_autosync import CHAT_REALTIME_AUTOSYNC
from .routers.chat import router as _chat_router from .routers.chat import router as _chat_router
from .routers.chat_export import router as _chat_export_router from .routers.chat_export import router as _chat_export_router
from .routers.chat_media import router as _chat_media_router from .routers.chat_media import router as _chat_media_router
@@ -121,16 +122,41 @@ def _maybe_mount_frontend() -> None:
_maybe_mount_frontend() _maybe_mount_frontend()
@app.on_event("startup")
async def _startup_background_jobs() -> None:
try:
CHAT_REALTIME_AUTOSYNC.start()
except Exception:
logger.exception("Failed to start realtime autosync service")
@app.on_event("shutdown") @app.on_event("shutdown")
async def _shutdown_wcdb_realtime() -> None: async def _shutdown_wcdb_realtime() -> None:
try: try:
WCDB_REALTIME.close_all() CHAT_REALTIME_AUTOSYNC.stop()
except Exception: except Exception:
pass pass
close_ok = False
lock_timeout_s: float | None = 0.2
try:
raw = str(os.environ.get("WECHAT_TOOL_WCDB_SHUTDOWN_LOCK_TIMEOUT_S", "0.2") or "").strip()
lock_timeout_s = float(raw) if raw else 0.2
if lock_timeout_s <= 0:
lock_timeout_s = None
except Exception:
lock_timeout_s = 0.2
try:
close_ok = WCDB_REALTIME.close_all(lock_timeout_s=lock_timeout_s)
except Exception:
close_ok = False
if close_ok:
try: try:
_wcdb_shutdown() _wcdb_shutdown()
except Exception: except Exception:
pass pass
else:
# If some conn locks were busy, other threads may still be running WCDB calls; avoid shutting down the lib.
logger.warning("[wcdb] close_all not fully completed; skip wcdb_shutdown")
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -0,0 +1,331 @@
"""Background auto-sync from WCDB realtime (db_storage) into decrypted sqlite.
Why:
- The UI can read "latest" messages from WCDB realtime (`source=realtime`), but most APIs default to the
decrypted sqlite snapshot (`source=decrypted`).
- Previously we only synced realtime -> decrypted when the UI toggled realtime off, which caused `/api/chat/messages`
to lag behind while realtime was enabled.
This module runs a lightweight background poller that watches db_storage mtime changes and triggers an incremental
sync_all into decrypted sqlite. It is intentionally conservative (debounced + rate-limited) to avoid hammering the
backend or the sqlite files.
"""
from __future__ import annotations
import os
import threading
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional
from fastapi import HTTPException
from .chat_helpers import _list_decrypted_accounts, _resolve_account_dir
from .logging_config import get_logger
from .wcdb_realtime import WCDB_REALTIME
logger = get_logger(__name__)
def _env_bool(name: str, default: bool) -> bool:
raw = str(os.environ.get(name, "") or "").strip().lower()
if not raw:
return default
return raw not in {"0", "false", "no", "off"}
def _env_int(name: str, default: int, *, min_v: int, max_v: int) -> int:
raw = str(os.environ.get(name, "") or "").strip()
try:
v = int(raw)
except Exception:
v = int(default)
if v < min_v:
v = min_v
if v > max_v:
v = max_v
return v
def _scan_db_storage_mtime_ns(db_storage_dir: Path) -> int:
"""Best-effort scan of db_storage for a "latest mtime" signal.
We intentionally restrict to common database buckets to reduce walk cost.
"""
try:
base = str(db_storage_dir)
except Exception:
return 0
max_ns = 0
try:
for root, dirs, files in os.walk(base):
if root == base:
allow = {"message", "session", "contact", "head_image", "bizchat", "sns", "general", "favorite"}
dirs[:] = [d for d in dirs if str(d or "").lower() in allow]
for fn in files:
name = str(fn or "").lower()
if not name.endswith((".db", ".db-wal", ".db-shm")):
continue
if not (
("message" in name)
or ("session" in name)
or ("contact" in name)
or ("name2id" in name)
or ("head_image" in name)
):
continue
try:
st = os.stat(os.path.join(root, fn))
m_ns = int(getattr(st, "st_mtime_ns", 0) or 0)
if m_ns <= 0:
m_ns = int(float(getattr(st, "st_mtime", 0.0) or 0.0) * 1_000_000_000)
if m_ns > max_ns:
max_ns = m_ns
except Exception:
continue
except Exception:
return 0
return max_ns
@dataclass
class _AccountState:
last_mtime_ns: int = 0
due_at: float = 0.0
last_sync_end_at: float = 0.0
thread: Optional[threading.Thread] = None
class ChatRealtimeAutoSyncService:
def __init__(self) -> None:
self._enabled = _env_bool("WECHAT_TOOL_REALTIME_AUTOSYNC", True)
self._interval_ms = _env_int("WECHAT_TOOL_REALTIME_AUTOSYNC_INTERVAL_MS", 1000, min_v=200, max_v=10_000)
self._debounce_ms = _env_int("WECHAT_TOOL_REALTIME_AUTOSYNC_DEBOUNCE_MS", 600, min_v=0, max_v=10_000)
self._min_sync_interval_ms = _env_int(
"WECHAT_TOOL_REALTIME_AUTOSYNC_MIN_SYNC_INTERVAL_MS", 800, min_v=0, max_v=60_000
)
self._workers = _env_int("WECHAT_TOOL_REALTIME_AUTOSYNC_WORKERS", 1, min_v=1, max_v=4)
# Sync strategy defaults: cheap incremental write into decrypted sqlite.
self._sync_max_scan = _env_int("WECHAT_TOOL_REALTIME_AUTOSYNC_MAX_SCAN", 200, min_v=20, max_v=5000)
self._priority_max_scan = _env_int("WECHAT_TOOL_REALTIME_AUTOSYNC_PRIORITY_MAX_SCAN", 600, min_v=20, max_v=5000)
self._backfill_limit = _env_int("WECHAT_TOOL_REALTIME_AUTOSYNC_BACKFILL_LIMIT", 0, min_v=0, max_v=5000)
# Default to the same conservative filtering as the chat UI sidebar (avoid hammering gh_/hidden sessions).
self._include_hidden = _env_bool("WECHAT_TOOL_REALTIME_AUTOSYNC_INCLUDE_HIDDEN", False)
self._include_official = _env_bool("WECHAT_TOOL_REALTIME_AUTOSYNC_INCLUDE_OFFICIAL", False)
self._mu = threading.Lock()
self._states: dict[str, _AccountState] = {}
self._stop = threading.Event()
self._thread: Optional[threading.Thread] = None
def start(self) -> None:
if not self._enabled:
logger.info("[realtime-autosync] disabled by env WECHAT_TOOL_REALTIME_AUTOSYNC=0")
return
with self._mu:
if self._thread is not None and self._thread.is_alive():
return
self._stop.clear()
self._thread = threading.Thread(target=self._run, name="realtime-autosync", daemon=True)
self._thread.start()
logger.info(
"[realtime-autosync] started interval_ms=%s debounce_ms=%s min_sync_interval_ms=%s max_scan=%s backfill_limit=%s workers=%s",
int(self._interval_ms),
int(self._debounce_ms),
int(self._min_sync_interval_ms),
int(self._sync_max_scan),
int(self._backfill_limit),
int(self._workers),
)
def stop(self) -> None:
with self._mu:
th = self._thread
self._thread = None
if th is None:
return
self._stop.set()
try:
th.join(timeout=5.0)
except Exception:
pass
logger.info("[realtime-autosync] stopped")
def _run(self) -> None:
while not self._stop.is_set():
tick_t0 = time.perf_counter()
try:
self._tick()
except Exception:
logger.exception("[realtime-autosync] tick failed")
# Avoid busy looping on exceptions; keep a minimum sleep.
elapsed_ms = (time.perf_counter() - tick_t0) * 1000.0
sleep_ms = max(100.0, float(self._interval_ms) - elapsed_ms)
self._stop.wait(timeout=sleep_ms / 1000.0)
def _tick(self) -> None:
accounts = _list_decrypted_accounts()
now = time.time()
if not accounts:
return
for acc in accounts:
if self._stop.is_set():
break
try:
account_dir = _resolve_account_dir(acc)
except HTTPException:
continue
except Exception:
continue
info = WCDB_REALTIME.get_status(account_dir)
available = bool(info.get("dll_present") and info.get("key_present") and info.get("db_storage_dir"))
if not available:
continue
db_storage_dir = Path(str(info.get("db_storage_dir") or "").strip())
if not db_storage_dir.exists() or not db_storage_dir.is_dir():
continue
scan_t0 = time.perf_counter()
mtime_ns = _scan_db_storage_mtime_ns(db_storage_dir)
scan_ms = (time.perf_counter() - scan_t0) * 1000.0
if scan_ms > 2000:
logger.warning("[realtime-autosync] scan slow account=%s ms=%.1f", acc, scan_ms)
with self._mu:
st = self._states.setdefault(acc, _AccountState())
if mtime_ns and mtime_ns != st.last_mtime_ns:
st.last_mtime_ns = int(mtime_ns)
st.due_at = now + (float(self._debounce_ms) / 1000.0)
# Schedule daemon threads. (Important: do NOT use ThreadPoolExecutor here; its threads are non-daemon on
# Windows/Python 3.12 and can prevent Ctrl+C from stopping the process.)
to_start: list[threading.Thread] = []
with self._mu:
# Drop state for removed accounts to keep memory bounded.
keep = set(accounts)
for acc in list(self._states.keys()):
if acc not in keep:
self._states.pop(acc, None)
# Clean up finished threads and compute current concurrency.
running = 0
for st in self._states.values():
th = st.thread
if th is not None and th.is_alive():
running += 1
elif th is not None and (not th.is_alive()):
st.thread = None
for acc, st in self._states.items():
if running >= int(self._workers):
break
if st.due_at <= 0 or st.due_at > now:
continue
if st.thread is not None and st.thread.is_alive():
continue
since = now - float(st.last_sync_end_at or 0.0)
min_interval = float(self._min_sync_interval_ms) / 1000.0
if min_interval > 0 and since < min_interval:
st.due_at = now + (min_interval - since)
continue
st.due_at = 0.0
th = threading.Thread(
target=self._sync_account_runner,
args=(acc,),
name=f"realtime-autosync-{acc}",
daemon=True,
)
st.thread = th
to_start.append(th)
running += 1
for th in to_start:
if self._stop.is_set():
break
try:
th.start()
except Exception:
# Best-effort: if a thread fails to start, clear the state so we can retry later.
with self._mu:
for acc, st in self._states.items():
if st.thread is th:
st.thread = None
break
def _sync_account_runner(self, account: str) -> None:
account = str(account or "").strip()
try:
if self._stop.is_set() or (not account):
return
res = self._sync_account(account)
inserted = int((res or {}).get("inserted_total") or (res or {}).get("insertedTotal") or 0)
synced = int((res or {}).get("synced") or (res or {}).get("sessionsSynced") or 0)
logger.info("[realtime-autosync] sync done account=%s synced=%s inserted=%s", account, synced, inserted)
except Exception:
logger.exception("[realtime-autosync] sync failed account=%s", account)
finally:
with self._mu:
st = self._states.get(account)
if st is not None:
st.thread = None
st.last_sync_end_at = time.time()
def _sync_account(self, account: str) -> dict[str, Any]:
"""Run a cheap incremental sync_all for one account."""
account = str(account or "").strip()
if not account:
return {"status": "skipped", "reason": "missing account"}
try:
account_dir = _resolve_account_dir(account)
except Exception as e:
return {"status": "skipped", "reason": f"resolve account failed: {e}"}
info = WCDB_REALTIME.get_status(account_dir)
available = bool(info.get("dll_present") and info.get("key_present") and info.get("db_storage_dir"))
if not available:
return {"status": "skipped", "reason": "realtime not available"}
# Import lazily to avoid any startup import ordering issues.
from .routers.chat import sync_chat_realtime_messages_all
try:
return sync_chat_realtime_messages_all(
request=None, # not used by the handler logic; we run it as an internal job
account=account,
max_scan=int(self._sync_max_scan),
priority_username=None,
priority_max_scan=int(self._priority_max_scan),
include_hidden=bool(self._include_hidden),
include_official=bool(self._include_official),
backfill_limit=int(self._backfill_limit),
)
except HTTPException as e:
return {"status": "error", "error": str(e.detail or "")}
except Exception as e:
return {"status": "error", "error": str(e)}
CHAT_REALTIME_AUTOSYNC = ChatRealtimeAutoSyncService()

View File

@@ -291,6 +291,130 @@ def _resolve_decrypted_message_table(account_dir: Path, username: str) -> Option
return None return None
def _pick_message_db_for_new_table(account_dir: Path, username: str) -> Optional[Path]:
"""Pick a target decrypted sqlite db to place a new Msg_<md5> table.
Some accounts have both `message_*.db` and `biz_message_*.db`. For normal users we prefer
`message*.db`; for official accounts (`gh_`) we prefer `biz_message*.db`.
"""
db_paths = _iter_message_db_paths(account_dir)
if not db_paths:
return None
uname = str(username or "").strip()
want_biz = bool(uname and uname.startswith("gh_"))
msg_paths: list[Path] = []
biz_paths: list[Path] = []
other_paths: list[Path] = []
for p in db_paths:
ln = p.name.lower()
if re.match(r"^message(_\d+)?\.db$", ln):
msg_paths.append(p)
elif re.match(r"^biz_message(_\d+)?\.db$", ln):
biz_paths.append(p)
else:
other_paths.append(p)
if want_biz and biz_paths:
return biz_paths[0]
if msg_paths:
return msg_paths[0]
if biz_paths:
return biz_paths[0]
return other_paths[0] if other_paths else db_paths[0]
def _ensure_decrypted_message_table(account_dir: Path, username: str) -> tuple[Path, str]:
"""Ensure the decrypted sqlite has a Msg_<md5(username)> table for this conversation.
Why:
- The decrypted snapshot can miss newly created sessions, so WCDB realtime can show messages
while the decrypted message_*.db has no table -> `/api/chat/messages` returns empty.
- Realtime sync should be able to create the missing conversation table and then insert rows.
"""
uname = str(username or "").strip()
if not uname:
raise HTTPException(status_code=400, detail="Missing username.")
resolved = _resolve_decrypted_message_table(account_dir, uname)
if resolved:
return resolved
target_db = _pick_message_db_for_new_table(account_dir, uname)
if target_db is None:
raise HTTPException(status_code=404, detail="No message databases found for this account.")
# Use the conventional WeChat naming (`Msg_<md5>`). Resolution is case-insensitive.
import hashlib
md5_hex = hashlib.md5(uname.encode("utf-8")).hexdigest()
table_name = f"Msg_{md5_hex}"
quoted_table = _quote_ident(table_name)
conn = sqlite3.connect(str(target_db))
try:
conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {quoted_table}(
local_id INTEGER PRIMARY KEY AUTOINCREMENT,
server_id INTEGER,
local_type INTEGER,
sort_seq INTEGER,
real_sender_id INTEGER,
create_time INTEGER,
status INTEGER,
upload_status INTEGER,
download_status INTEGER,
server_seq INTEGER,
origin_source INTEGER,
source TEXT,
message_content TEXT,
compress_content TEXT,
packed_info_data BLOB,
WCDB_CT_message_content INTEGER DEFAULT NULL,
WCDB_CT_source INTEGER DEFAULT NULL
)
"""
)
# Match the common indexes we observe on existing Msg_* tables for query performance.
idx_sender = _quote_ident(f"{table_name}_SENDERID")
idx_server = _quote_ident(f"{table_name}_SERVERID")
idx_sort = _quote_ident(f"{table_name}_SORTSEQ")
idx_type_seq = _quote_ident(f"{table_name}_TYPE_SEQ")
conn.execute(f"CREATE INDEX IF NOT EXISTS {idx_sender} ON {quoted_table}(real_sender_id)")
conn.execute(f"CREATE INDEX IF NOT EXISTS {idx_server} ON {quoted_table}(server_id)")
conn.execute(f"CREATE INDEX IF NOT EXISTS {idx_sort} ON {quoted_table}(sort_seq)")
conn.execute(f"CREATE INDEX IF NOT EXISTS {idx_type_seq} ON {quoted_table}(local_type, sort_seq)")
conn.commit()
finally:
conn.close()
return target_db, table_name
def _ensure_decrypted_message_tables(
account_dir: Path, usernames: list[str]
) -> dict[str, tuple[Path, str]]:
"""Bulk resolver that also creates missing Msg_<md5> tables when needed."""
table_map = _resolve_decrypted_message_tables(account_dir, usernames)
for u in usernames:
uname = str(u or "").strip()
if not uname or uname in table_map:
continue
try:
table_map[uname] = _ensure_decrypted_message_table(account_dir, uname)
except Exception:
# Best-effort: if we can't create the table, keep it missing and let callers skip.
continue
return table_map
def _resolve_decrypted_message_tables( def _resolve_decrypted_message_tables(
account_dir: Path, usernames: list[str] account_dir: Path, usernames: list[str]
) -> dict[str, tuple[Path, str]]: ) -> dict[str, tuple[Path, str]]:
@@ -358,18 +482,160 @@ def _ensure_session_last_message_table(conn: sqlite3.Connection) -> None:
) )
def _get_session_table_columns(conn: sqlite3.Connection) -> set[str]:
try:
rows = conn.execute("PRAGMA table_info(SessionTable)").fetchall()
# PRAGMA table_info returns: cid, name, type, notnull, dflt_value, pk
cols = {str(r[1]) for r in rows if r and r[1]}
return cols
except Exception:
return set()
def _upsert_session_table_rows(conn: sqlite3.Connection, rows: list[dict[str, Any]]) -> None:
"""Best-effort upsert of WCDB Session rows into decrypted session.db::SessionTable.
Why:
- WCDB realtime can observe newly created sessions (e.g., new friends) immediately.
- The decrypted snapshot's session.db can become stale and miss those sessions entirely, causing
the left sidebar list to differ after a refresh (when the UI falls back to decrypted).
This upsert intentionally avoids depending on message tables; it only keeps SessionTable fresh.
"""
if not rows:
return
# Ensure SessionTable exists; if not, silently skip (older/partial accounts).
try:
conn.execute("SELECT 1 FROM SessionTable LIMIT 1").fetchone()
except Exception:
return
cols = _get_session_table_columns(conn)
if "username" not in cols:
return
uniq_usernames: list[str] = []
for r in rows:
u = str((r or {}).get("username") or "").strip()
if not u:
continue
uniq_usernames.append(u)
uniq_usernames = list(dict.fromkeys(uniq_usernames))
if not uniq_usernames:
return
# Insert missing rows first so UPDATE always has a target.
try:
conn.executemany(
"INSERT OR IGNORE INTO SessionTable(username) VALUES (?)",
[(u,) for u in uniq_usernames],
)
except Exception:
# If the schema is unusual, don't fail the whole sync.
return
# Only update columns that exist in this account's schema.
# Keep the order stable so executemany parameters line up.
desired_cols = [
"unread_count",
"is_hidden",
"summary",
"draft",
"last_timestamp",
"sort_timestamp",
"last_msg_type",
"last_msg_sub_type",
]
update_cols = [c for c in desired_cols if c in cols]
if not update_cols:
return
def _int(v: Any) -> int:
try:
return int(v or 0)
except Exception:
return 0
def _text(v: Any) -> str:
try:
return str(v or "")
except Exception:
return ""
params: list[tuple[Any, ...]] = []
for r in rows:
u = str((r or {}).get("username") or "").strip()
if not u:
continue
values: list[Any] = []
for c in update_cols:
if c in {"unread_count", "is_hidden", "last_timestamp", "sort_timestamp", "last_msg_type", "last_msg_sub_type"}:
values.append(_int((r or {}).get(c)))
else:
values.append(_text((r or {}).get(c)))
values.append(u)
params.append(tuple(values))
if not params:
return
set_expr = ", ".join([f"{c} = ?" for c in update_cols])
conn.executemany(f"UPDATE SessionTable SET {set_expr} WHERE username = ?", params)
def _load_session_last_message_times(conn: sqlite3.Connection, usernames: list[str]) -> dict[str, int]:
"""Load last synced message create_time per conversation from session.db::session_last_message.
Note: This is used as the *sync watermark* for realtime -> decrypted, because SessionTable timestamps may be
updated from WCDB session rows for UI consistency.
"""
uniq = list(dict.fromkeys([str(u or "").strip() for u in usernames if str(u or "").strip()]))
if not uniq:
return {}
out: dict[str, int] = {}
chunk_size = 900
for i in range(0, len(uniq), chunk_size):
chunk = uniq[i : i + chunk_size]
placeholders = ",".join(["?"] * len(chunk))
try:
rows = conn.execute(
f"SELECT username, create_time FROM session_last_message WHERE username IN ({placeholders})",
chunk,
).fetchall()
except Exception:
continue
for r in rows:
u = str((r["username"] if isinstance(r, sqlite3.Row) else r[0]) or "").strip()
if not u:
continue
try:
ts = int((r["create_time"] if isinstance(r, sqlite3.Row) else r[1]) or 0)
except Exception:
ts = 0
out[u] = int(ts or 0)
return out
@router.post("/api/chat/realtime/sync", summary="实时消息同步到解密库(按会话增量)") @router.post("/api/chat/realtime/sync", summary="实时消息同步到解密库(按会话增量)")
def sync_chat_realtime_messages( def sync_chat_realtime_messages(
request: Request, request: Request,
username: str, username: str,
account: Optional[str] = None, account: Optional[str] = None,
max_scan: int = 600, max_scan: int = 600,
backfill_limit: int = 200,
): ):
""" """
设计目的:实时模式只用来“同步增量”到 output/databases 下的解密库,前端始终从解密库读取显示, 设计目的:实时模式只用来“同步增量”到 output/databases 下的解密库,前端始终从解密库读取显示,
避免 WCDB realtime 返回格式差异(如 compress_content/message_content 的 hex 编码)直接影响渲染。 避免 WCDB realtime 返回格式差异(如 compress_content/message_content 的 hex 编码)直接影响渲染。
同步策略:从 WCDB 获取最新消息(从新到旧),直到遇到解密库中已存在的最大 local_id 为止。 同步策略:从 WCDB 获取最新消息(从新到旧),直到遇到解密库中已存在的最大 local_id 为止。
backfill_limit同步过程中额外“回填”旧消息的 packed_info_data 的最大行数(用于修复旧库缺失字段)。
- 设为 0 可显著降低每次同步的扫描/写入开销(更适合前端实时轮询/推送触发的高频增量同步)。
""" """
if not username: if not username:
raise HTTPException(status_code=400, detail="Missing username.") raise HTTPException(status_code=400, detail="Missing username.")
@@ -377,6 +643,10 @@ def sync_chat_realtime_messages(
max_scan = 50 max_scan = 50
if max_scan > 5000: if max_scan > 5000:
max_scan = 5000 max_scan = 5000
if backfill_limit < 0:
backfill_limit = 0
if backfill_limit > 5000:
backfill_limit = 5000
account_dir = _resolve_account_dir(account) account_dir = _resolve_account_dir(account)
trace_id = f"rt-sync-{int(time.time() * 1000)}-{threading.get_ident()}" trace_id = f"rt-sync-{int(time.time() * 1000)}-{threading.get_ident()}"
@@ -399,10 +669,9 @@ def sync_chat_realtime_messages(
except WCDBRealtimeError as e: except WCDBRealtimeError as e:
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
resolved = _resolve_decrypted_message_table(account_dir, username) # Some sessions may not exist in the decrypted snapshot yet; create the missing Msg_<md5> table
if not resolved: # so we can insert the realtime rows and make `/api/chat/messages` work after switching off realtime.
raise HTTPException(status_code=404, detail="Conversation table not found in decrypted databases.") msg_db_path, table_name = _ensure_decrypted_message_table(account_dir, username)
msg_db_path, table_name = resolved
logger.info( logger.info(
"[%s] resolved decrypted table account=%s username=%s db=%s table=%s", "[%s] resolved decrypted table account=%s username=%s db=%s table=%s",
trace_id, trace_id,
@@ -493,7 +762,6 @@ def sync_chat_realtime_messages(
offset = 0 offset = 0
new_rows: list[dict[str, Any]] = [] new_rows: list[dict[str, Any]] = []
backfill_rows: list[dict[str, Any]] = [] backfill_rows: list[dict[str, Any]] = []
backfill_limit = min(200, int(max_scan))
reached_existing = False reached_existing = False
stop = False stop = False
@@ -545,8 +813,11 @@ def sync_chat_realtime_messages(
continue continue
reached_existing = True reached_existing = True
if int(backfill_limit) <= 0:
stop = True
break
backfill_rows.append(norm) backfill_rows.append(norm)
if len(backfill_rows) >= backfill_limit: if len(backfill_rows) >= int(backfill_limit):
stop = True stop = True
break break
@@ -772,11 +1043,18 @@ def _sync_chat_realtime_messages_for_table(
msg_db_path: Path, msg_db_path: Path,
table_name: str, table_name: str,
max_scan: int, max_scan: int,
backfill_limit: int = 200,
) -> dict[str, Any]: ) -> dict[str, Any]:
if max_scan < 50: if max_scan < 50:
max_scan = 50 max_scan = 50
if max_scan > 5000: if max_scan > 5000:
max_scan = 5000 max_scan = 5000
if backfill_limit < 0:
backfill_limit = 0
if backfill_limit > 5000:
backfill_limit = 5000
if backfill_limit > max_scan:
backfill_limit = max_scan
msg_conn = sqlite3.connect(str(msg_db_path)) msg_conn = sqlite3.connect(str(msg_db_path))
msg_conn.row_factory = sqlite3.Row msg_conn.row_factory = sqlite3.Row
@@ -858,13 +1136,12 @@ def _sync_chat_realtime_messages_for_table(
offset = 0 offset = 0
new_rows: list[dict[str, Any]] = [] new_rows: list[dict[str, Any]] = []
backfill_rows: list[dict[str, Any]] = [] backfill_rows: list[dict[str, Any]] = []
backfill_limit = min(200, int(max_scan))
reached_existing = False reached_existing = False
stop = False stop = False
while scanned < int(max_scan): while scanned < int(max_scan):
take = min(batch_size, int(max_scan) - scanned) take = min(batch_size, int(max_scan) - scanned)
logger.info( logger.debug(
"[realtime] wcdb_get_messages account=%s username=%s take=%s offset=%s", "[realtime] wcdb_get_messages account=%s username=%s take=%s offset=%s",
account_dir.name, account_dir.name,
username, username,
@@ -875,7 +1152,7 @@ def _sync_chat_realtime_messages_for_table(
with rt_conn.lock: with rt_conn.lock:
raw_rows = _wcdb_get_messages(rt_conn.handle, username, limit=take, offset=offset) raw_rows = _wcdb_get_messages(rt_conn.handle, username, limit=take, offset=offset)
wcdb_ms = (time.perf_counter() - wcdb_t0) * 1000.0 wcdb_ms = (time.perf_counter() - wcdb_t0) * 1000.0
logger.info( logger.debug(
"[realtime] wcdb_get_messages done account=%s username=%s rows=%s ms=%.1f", "[realtime] wcdb_get_messages done account=%s username=%s rows=%s ms=%.1f",
account_dir.name, account_dir.name,
username, username,
@@ -907,8 +1184,11 @@ def _sync_chat_realtime_messages_for_table(
continue continue
reached_existing = True reached_existing = True
if int(backfill_limit) <= 0:
stop = True
break
backfill_rows.append(norm) backfill_rows.append(norm)
if len(backfill_rows) >= backfill_limit: if len(backfill_rows) >= int(backfill_limit):
stop = True stop = True
break break
@@ -1115,6 +1395,7 @@ def sync_chat_realtime_messages_all(
priority_max_scan: int = 600, priority_max_scan: int = 600,
include_hidden: bool = True, include_hidden: bool = True,
include_official: bool = True, include_official: bool = True,
backfill_limit: int = 200,
): ):
""" """
全量会话同步(增量):遍历会话列表,对每个会话调用与 /realtime/sync 相同的“遇到已同步 local_id 即停止”逻辑。 全量会话同步(增量):遍历会话列表,对每个会话调用与 /realtime/sync 相同的“遇到已同步 local_id 即停止”逻辑。
@@ -1141,6 +1422,12 @@ def sync_chat_realtime_messages_all(
priority_max_scan = max_scan priority_max_scan = max_scan
if priority_max_scan > 5000: if priority_max_scan > 5000:
priority_max_scan = 5000 priority_max_scan = 5000
if backfill_limit < 0:
backfill_limit = 0
if backfill_limit > 5000:
backfill_limit = 5000
if backfill_limit > max_scan:
backfill_limit = max_scan
priority = str(priority_username or "").strip() priority = str(priority_username or "").strip()
started = time.time() started = time.time()
@@ -1172,6 +1459,7 @@ def sync_chat_realtime_messages_all(
raw_sessions = [] raw_sessions = []
sessions: list[tuple[int, str]] = [] sessions: list[tuple[int, str]] = []
realtime_rows_by_user: dict[str, dict[str, Any]] = {}
for item in raw_sessions: for item in raw_sessions:
if not isinstance(item, dict): if not isinstance(item, dict):
continue continue
@@ -1198,6 +1486,34 @@ def sync_chat_realtime_messages_all(
break break
sessions.append((ts, uname)) sessions.append((ts, uname))
# Keep a normalized SessionTable row for upserting into decrypted session.db.
norm_row = {
"username": uname,
"unread_count": item.get("unread_count", item.get("unreadCount", 0)),
"is_hidden": item.get("is_hidden", item.get("isHidden", 0)),
"summary": item.get("summary", ""),
"draft": item.get("draft", ""),
"last_timestamp": item.get("last_timestamp", item.get("lastTimestamp", 0)),
"sort_timestamp": item.get(
"sort_timestamp",
item.get("sortTimestamp", item.get("last_timestamp", item.get("lastTimestamp", 0))),
),
"last_msg_type": item.get("last_msg_type", item.get("lastMsgType", 0)),
"last_msg_sub_type": item.get("last_msg_sub_type", item.get("lastMsgSubType", 0)),
}
# Prefer the row with the newer sort timestamp for the same username.
prev = realtime_rows_by_user.get(uname)
try:
prev_sort = int((prev or {}).get("sort_timestamp") or 0)
except Exception:
prev_sort = 0
try:
cur_sort = int(norm_row.get("sort_timestamp") or 0)
except Exception:
cur_sort = 0
if prev is None or cur_sort >= prev_sort:
realtime_rows_by_user[uname] = norm_row
def _dedupe(items: list[tuple[int, str]]) -> list[tuple[int, str]]: def _dedupe(items: list[tuple[int, str]]) -> list[tuple[int, str]]:
seen = set() seen = set()
out: list[tuple[int, str]] = [] out: list[tuple[int, str]] = []
@@ -1219,7 +1535,8 @@ def sync_chat_realtime_messages_all(
len(all_usernames), len(all_usernames),
) )
# Skip sessions whose decrypted session.db already has a newer/equal sort_timestamp. # Keep SessionTable fresh for UI consistency, and use session_last_message.create_time as the
# "sync watermark" (instead of SessionTable timestamps) to decide whether a session needs syncing.
decrypted_ts_by_user: dict[str, int] = {} decrypted_ts_by_user: dict[str, int] = {}
if all_usernames: if all_usernames:
try: try:
@@ -1227,45 +1544,49 @@ def sync_chat_realtime_messages_all(
sconn = sqlite3.connect(str(session_db_path)) sconn = sqlite3.connect(str(session_db_path))
sconn.row_factory = sqlite3.Row sconn.row_factory = sqlite3.Row
try: try:
uniq = list(dict.fromkeys([u for u in all_usernames if u])) _ensure_session_last_message_table(sconn)
chunk_size = 900
for i in range(0, len(uniq), chunk_size): # If the cache table exists but is empty (older accounts), attempt a one-time build so we
chunk = uniq[i : i + chunk_size] # don't keep treating every session as "needs_sync".
placeholders = ",".join(["?"] * len(chunk))
try: try:
rows = sconn.execute( cnt = int(sconn.execute("SELECT COUNT(1) FROM session_last_message").fetchone()[0] or 0)
f"SELECT username, sort_timestamp, last_timestamp FROM SessionTable WHERE username IN ({placeholders})",
chunk,
).fetchall()
for r in rows:
u = str(r["username"] or "").strip()
if not u:
continue
try:
ts = int(r["sort_timestamp"] or 0)
except Exception: except Exception:
ts = 0 cnt = 0
if ts <= 0: if cnt <= 0:
try: try:
ts = int(r["last_timestamp"] or 0)
except Exception:
ts = 0
decrypted_ts_by_user[u] = int(ts or 0)
except sqlite3.OperationalError:
rows = sconn.execute(
f"SELECT username, last_timestamp FROM SessionTable WHERE username IN ({placeholders})",
chunk,
).fetchall()
for r in rows:
u = str(r["username"] or "").strip()
if not u:
continue
try:
decrypted_ts_by_user[u] = int(r["last_timestamp"] or 0)
except Exception:
decrypted_ts_by_user[u] = 0
finally:
sconn.close() sconn.close()
except Exception:
pass
try:
build_session_last_message_table(
account_dir,
rebuild=False,
include_hidden=True,
include_official=True,
)
except Exception:
pass
sconn = sqlite3.connect(str(session_db_path))
sconn.row_factory = sqlite3.Row
_ensure_session_last_message_table(sconn)
# Upsert latest WCDB sessions into decrypted SessionTable so the sidebar list remains stable
# after switching off realtime (or refreshing the page).
try:
_upsert_session_table_rows(sconn, list(realtime_rows_by_user.values()))
sconn.commit()
except Exception:
try:
sconn.rollback()
except Exception:
pass
decrypted_ts_by_user = _load_session_last_message_times(sconn, all_usernames)
finally:
try:
sconn.close()
except Exception:
pass
except Exception: except Exception:
decrypted_ts_by_user = {} decrypted_ts_by_user = {}
@@ -1291,7 +1612,7 @@ def sync_chat_realtime_messages_all(
if priority and priority in sync_usernames: if priority and priority in sync_usernames:
sync_usernames = [priority] + [u for u in sync_usernames if u != priority] sync_usernames = [priority] + [u for u in sync_usernames if u != priority]
table_map = _resolve_decrypted_message_tables(account_dir, sync_usernames) table_map = _ensure_decrypted_message_tables(account_dir, sync_usernames)
logger.info( logger.info(
"[%s] resolved decrypted tables account=%s resolved=%s need_sync=%s", "[%s] resolved decrypted tables account=%s resolved=%s need_sync=%s",
trace_id, trace_id,
@@ -1324,6 +1645,7 @@ def sync_chat_realtime_messages_all(
msg_db_path=msg_db_path, msg_db_path=msg_db_path,
table_name=table_name, table_name=table_name,
max_scan=int(cur_scan), max_scan=int(cur_scan),
backfill_limit=int(backfill_limit),
) )
synced += 1 synced += 1
scanned_total += int(result.get("scanned") or 0) scanned_total += int(result.get("scanned") or 0)
@@ -2595,6 +2917,11 @@ def list_chat_sessions(
row = contact_rows.get(u) row = contact_rows.get(u)
if _pick_display_name(row, u) == u: if _pick_display_name(row, u) == u:
need_display.append(u) need_display.append(u)
if source_norm == "realtime":
# In realtime mode, prefer WCDB-resolved avatar URLs (contact.db can be stale).
if u not in local_avatar_usernames:
need_avatar.append(u)
else:
if (not _pick_avatar_url(row)) and (u not in local_avatar_usernames): if (not _pick_avatar_url(row)) and (u not in local_avatar_usernames):
need_avatar.append(u) need_avatar.append(u)
@@ -2655,13 +2982,17 @@ def list_chat_sessions(
if wd and wd != username: if wd and wd != username:
display_name = wd display_name = wd
avatar_url = _pick_avatar_url(c_row) # Prefer local head_image avatars when available: decrypted contact.db URLs can be stale
if not avatar_url and username in local_avatar_usernames: # (or hotlink-protected for browsers). WCDB realtime (when available) is the next best.
avatar_url = ""
if username in local_avatar_usernames:
avatar_url = base_url + _build_avatar_url(account_dir.name, username) avatar_url = base_url + _build_avatar_url(account_dir.name, username)
if not avatar_url: if not avatar_url:
wa = str(wcdb_avatar_urls.get(username) or "").strip() wa = str(wcdb_avatar_urls.get(username) or "").strip()
if wa.lower().startswith(("http://", "https://")): if wa.lower().startswith(("http://", "https://")):
avatar_url = wa avatar_url = wa
if not avatar_url:
avatar_url = _pick_avatar_url(c_row) or ""
last_message = "" last_message = ""
if preview_mode == "session": if preview_mode == "session":
@@ -3426,6 +3757,63 @@ def list_chat_messages(
break break
scan_take = next_take scan_take = next_take
# Self-heal (default source only): if the decrypted snapshot has no conversation table yet (new session),
# do a one-shot realtime->decrypted sync and re-query once. This avoids "暂无聊天记录" after turning off realtime.
if (
source_norm != "realtime"
and (source is None or not str(source).strip())
and (not merged)
and int(offset) == 0
):
missing_table = False
try:
missing_table = _resolve_decrypted_message_table(account_dir, username) is None
except Exception:
missing_table = True
if missing_table:
rt_conn2 = None
try:
rt_conn2 = WCDB_REALTIME.ensure_connected(account_dir)
except WCDBRealtimeError:
rt_conn2 = None
except Exception:
rt_conn2 = None
if rt_conn2 is not None:
try:
with _realtime_sync_lock(account_dir.name, username):
msg_db_path2, table_name2 = _ensure_decrypted_message_table(account_dir, username)
_sync_chat_realtime_messages_for_table(
account_dir=account_dir,
rt_conn=rt_conn2,
username=username,
msg_db_path=msg_db_path2,
table_name=table_name2,
max_scan=max(200, int(limit) + 50),
backfill_limit=0,
)
except Exception:
pass
(
merged,
has_more_any,
sender_usernames,
quote_usernames,
pat_usernames,
) = _collect_chat_messages(
username=username,
account_dir=account_dir,
db_paths=db_paths,
resource_conn=resource_conn,
resource_chat_id=resource_chat_id,
take=scan_take,
want_types=want_types,
)
if want_types is not None:
merged = [m for m in merged if _normalize_render_type_key(m.get("renderType")) in want_types]
r""" r"""
take = int(limit) + int(offset) take = int(limit) + int(offset)
take_probe = take + 1 take_probe = take + 1

View File

@@ -615,16 +615,36 @@ class WCDBRealtimeManager:
except Exception: except Exception:
pass pass
def close_all(self) -> None: def close_all(self, *, lock_timeout_s: float | None = None) -> bool:
"""Close all known WCDB realtime connections.
When `lock_timeout_s` is None, this waits indefinitely for per-connection locks.
When provided, this will skip busy connections after the timeout and return False.
"""
with self._mu: with self._mu:
conns = list(self._conns.values()) conns = list(self._conns.values())
self._conns.clear() self._conns.clear()
ok = True
for conn in conns: for conn in conns:
try: try:
if lock_timeout_s is None:
with conn.lock: with conn.lock:
close_account(conn.handle) close_account(conn.handle)
except Exception:
continue continue
acquired = conn.lock.acquire(timeout=float(lock_timeout_s))
if not acquired:
ok = False
logger.warning("[wcdb] close_all skip busy conn account=%s", conn.account)
continue
try:
close_account(conn.handle)
finally:
conn.lock.release()
except Exception:
ok = False
continue
return ok
WCDB_REALTIME = WCDBRealtimeManager() WCDB_REALTIME = WCDBRealtimeManager()

View File

@@ -0,0 +1,102 @@
import hashlib
import sqlite3
import sys
import unittest
from pathlib import Path
from tempfile import TemporaryDirectory
# Ensure "src/" is importable when running tests from repo root.
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT / "src"))
class TestRealtimeSyncTableCreation(unittest.TestCase):
def _touch_sqlite(self, path: Path) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(path))
try:
# Ensure a valid sqlite file is created.
conn.execute("PRAGMA user_version = 1")
conn.commit()
finally:
conn.close()
def test_ensure_creates_msg_table_and_indexes_in_message_db(self):
from wechat_decrypt_tool.routers import chat as chat_router
with TemporaryDirectory() as td:
account_dir = Path(td)
self._touch_sqlite(account_dir / "message_0.db")
username = "wxid_foo"
md5_hex = hashlib.md5(username.encode("utf-8")).hexdigest()
expected_table = f"Msg_{md5_hex}"
db_path, table_name = chat_router._ensure_decrypted_message_table(account_dir, username)
self.assertEqual(table_name, expected_table)
self.assertEqual(db_path.name, "message_0.db")
conn = sqlite3.connect(str(db_path))
try:
r = conn.execute(
"SELECT 1 FROM sqlite_master WHERE type='table' AND lower(name)=lower(?)",
(expected_table,),
).fetchone()
self.assertIsNotNone(r, "Msg_<md5> table should be created")
idx_names = [
f"{expected_table}_SENDERID",
f"{expected_table}_SERVERID",
f"{expected_table}_SORTSEQ",
f"{expected_table}_TYPE_SEQ",
]
for idx in idx_names:
r = conn.execute(
"SELECT 1 FROM sqlite_master WHERE type='index' AND lower(name)=lower(?)",
(idx,),
).fetchone()
self.assertIsNotNone(r, f"Index {idx} should be created")
finally:
conn.close()
def test_ensure_prefers_biz_message_for_official_accounts(self):
from wechat_decrypt_tool.routers import chat as chat_router
with TemporaryDirectory() as td:
account_dir = Path(td)
self._touch_sqlite(account_dir / "message_0.db")
self._touch_sqlite(account_dir / "biz_message_0.db")
username = "gh_12345"
db_path, _ = chat_router._ensure_decrypted_message_table(account_dir, username)
self.assertEqual(db_path.name, "biz_message_0.db")
def test_bulk_ensure_creates_missing_tables(self):
from wechat_decrypt_tool.routers import chat as chat_router
with TemporaryDirectory() as td:
account_dir = Path(td)
self._touch_sqlite(account_dir / "message_0.db")
usernames = ["wxid_a", "wxid_b"]
table_map = chat_router._ensure_decrypted_message_tables(account_dir, usernames)
self.assertEqual(set(table_map.keys()), set(usernames))
conn = sqlite3.connect(str(account_dir / "message_0.db"))
try:
for u in usernames:
md5_hex = hashlib.md5(u.encode("utf-8")).hexdigest()
expected_table = f"Msg_{md5_hex}"
r = conn.execute(
"SELECT 1 FROM sqlite_master WHERE type='table' AND lower(name)=lower(?)",
(expected_table,),
).fetchone()
self.assertIsNotNone(r, f"{expected_table} should be created for {u}")
finally:
conn.close()
if __name__ == "__main__":
unittest.main()