diff --git a/src/wechat_decrypt_tool/avatar_cache.py b/src/wechat_decrypt_tool/avatar_cache.py new file mode 100644 index 0000000..c37eaee --- /dev/null +++ b/src/wechat_decrypt_tool/avatar_cache.py @@ -0,0 +1,454 @@ +from __future__ import annotations + +import hashlib +import os +import re +import sqlite3 +import time +from email.utils import formatdate +from pathlib import Path +from typing import Any, Optional +from urllib.parse import urlsplit, urlunsplit + +from .app_paths import get_output_dir +from .logging_config import get_logger + +logger = get_logger(__name__) + +AVATAR_CACHE_TTL_SECONDS = 7 * 24 * 60 * 60 + + +def is_avatar_cache_enabled() -> bool: + v = str(os.environ.get("WECHAT_TOOL_AVATAR_CACHE_ENABLED", "1") or "").strip().lower() + return v not in {"", "0", "false", "off", "no"} + + +def get_avatar_cache_root_dir() -> Path: + return get_output_dir() / "avatar_cache" + + +def _safe_segment(value: str) -> str: + cleaned = re.sub(r"[^0-9A-Za-z._-]+", "_", str(value or "").strip()) + cleaned = cleaned.strip("._-") + return cleaned or "default" + + +def _account_layout(account: str) -> tuple[Path, Path, Path, Path]: + account_dir = get_avatar_cache_root_dir() / _safe_segment(account) + files_dir = account_dir / "files" + tmp_dir = account_dir / "tmp" + db_path = account_dir / "avatar_cache.db" + return account_dir, files_dir, tmp_dir, db_path + + +def _ensure_account_layout(account: str) -> tuple[Path, Path, Path, Path]: + account_dir, files_dir, tmp_dir, db_path = _account_layout(account) + account_dir.mkdir(parents=True, exist_ok=True) + files_dir.mkdir(parents=True, exist_ok=True) + tmp_dir.mkdir(parents=True, exist_ok=True) + return account_dir, files_dir, tmp_dir, db_path + + +def _connect(account: str) -> sqlite3.Connection: + _, _, _, db_path = _ensure_account_layout(account) + conn = sqlite3.connect(str(db_path), timeout=5) + conn.row_factory = sqlite3.Row + _ensure_schema(conn) + return conn + + +def _ensure_schema(conn: sqlite3.Connection) -> None: + conn.execute( + """ + CREATE TABLE IF NOT EXISTS avatar_cache_entries ( + account TEXT NOT NULL, + cache_key TEXT NOT NULL, + source_kind TEXT NOT NULL, + username TEXT NOT NULL DEFAULT '', + source_url TEXT NOT NULL DEFAULT '', + source_md5 TEXT NOT NULL DEFAULT '', + source_update_time INTEGER NOT NULL DEFAULT 0, + rel_path TEXT NOT NULL DEFAULT '', + media_type TEXT NOT NULL DEFAULT 'application/octet-stream', + size_bytes INTEGER NOT NULL DEFAULT 0, + etag TEXT NOT NULL DEFAULT '', + last_modified TEXT NOT NULL DEFAULT '', + fetched_at INTEGER NOT NULL DEFAULT 0, + checked_at INTEGER NOT NULL DEFAULT 0, + expires_at INTEGER NOT NULL DEFAULT 0, + PRIMARY KEY (account, cache_key) + ) + """ + ) + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_avatar_cache_entries_account_username ON avatar_cache_entries(account, username)" + ) + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_avatar_cache_entries_account_source ON avatar_cache_entries(account, source_kind, source_url)" + ) + conn.commit() + + +def _row_to_dict(row: Optional[sqlite3.Row]) -> Optional[dict[str, Any]]: + if row is None: + return None + out: dict[str, Any] = {} + for k in row.keys(): + out[str(k)] = row[k] + return out + + +def normalize_avatar_source_url(url: str) -> str: + raw = str(url or "").strip() + if not raw: + return "" + try: + p = urlsplit(raw) + except Exception: + return raw + scheme = str(p.scheme or "").lower() + host = str(p.hostname or "").lower() + if not scheme or not host: + return raw + netloc = host + if p.port: + netloc = f"{host}:{int(p.port)}" + path = p.path or "/" + return urlunsplit((scheme, netloc, path, p.query or "", "")) + + +def cache_key_for_avatar_user(username: str) -> str: + u = str(username or "").strip() + return hashlib.sha1(f"user:{u}".encode("utf-8", errors="ignore")).hexdigest() + + +def cache_key_for_avatar_url(url: str) -> str: + u = normalize_avatar_source_url(url) + return hashlib.sha1(f"url:{u}".encode("utf-8", errors="ignore")).hexdigest() + + +def get_avatar_cache_entry(account: str, cache_key: str) -> Optional[dict[str, Any]]: + if (not is_avatar_cache_enabled()) or (not cache_key): + return None + try: + conn = _connect(account) + except Exception: + return None + try: + row = conn.execute( + "SELECT * FROM avatar_cache_entries WHERE account = ? AND cache_key = ? LIMIT 1", + (str(account or ""), str(cache_key or "")), + ).fetchone() + return _row_to_dict(row) + except Exception: + return None + finally: + try: + conn.close() + except Exception: + pass + + +def get_avatar_cache_user_entry(account: str, username: str) -> Optional[dict[str, Any]]: + if not username: + return None + return get_avatar_cache_entry(account, cache_key_for_avatar_user(username)) + + +def get_avatar_cache_url_entry(account: str, source_url: str) -> Optional[dict[str, Any]]: + if not source_url: + return None + return get_avatar_cache_entry(account, cache_key_for_avatar_url(source_url)) + + +def resolve_avatar_cache_entry_path(account: str, entry: Optional[dict[str, Any]]) -> Optional[Path]: + if not entry: + return None + rel = str(entry.get("rel_path") or "").strip().replace("\\", "/") + if not rel: + return None + account_dir, _, _, _ = _account_layout(account) + p = account_dir / rel + try: + account_dir_resolved = account_dir.resolve() + p_resolved = p.resolve() + if p_resolved != account_dir_resolved and account_dir_resolved not in p_resolved.parents: + return None + return p_resolved + except Exception: + return p + + +def avatar_cache_entry_file_exists(account: str, entry: Optional[dict[str, Any]]) -> Optional[Path]: + p = resolve_avatar_cache_entry_path(account, entry) + if not p: + return None + try: + if p.exists() and p.is_file(): + return p + except Exception: + return None + return None + + +def avatar_cache_entry_is_fresh(entry: Optional[dict[str, Any]], now_ts: Optional[int] = None) -> bool: + if not entry: + return False + try: + expires = int(entry.get("expires_at") or 0) + except Exception: + expires = 0 + if expires <= 0: + return False + now0 = int(now_ts or time.time()) + return expires > now0 + + +def _guess_ext(media_type: str) -> str: + mt = str(media_type or "").strip().lower() + if mt == "image/jpeg": + return "jpg" + if mt == "image/png": + return "png" + if mt == "image/gif": + return "gif" + if mt == "image/webp": + return "webp" + if mt == "image/bmp": + return "bmp" + if mt == "image/svg+xml": + return "svg" + if mt == "image/avif": + return "avif" + if mt.startswith("image/"): + return mt.split("/", 1)[1].split("+", 1)[0].split(";", 1)[0] or "img" + return "dat" + + +def _http_date_from_ts(ts: Optional[int]) -> str: + try: + t = int(ts or 0) + except Exception: + t = 0 + if t <= 0: + return "" + try: + return formatdate(timeval=float(t), usegmt=True) + except Exception: + return "" + + +def upsert_avatar_cache_entry( + account: str, + *, + cache_key: str, + source_kind: str, + username: str = "", + source_url: str = "", + source_md5: str = "", + source_update_time: int = 0, + rel_path: str = "", + media_type: str = "application/octet-stream", + size_bytes: int = 0, + etag: str = "", + last_modified: str = "", + fetched_at: Optional[int] = None, + checked_at: Optional[int] = None, + expires_at: Optional[int] = None, +) -> Optional[dict[str, Any]]: + if (not is_avatar_cache_enabled()) or (not cache_key): + return None + + acct = str(account or "").strip() + ck = str(cache_key or "").strip() + sk = str(source_kind or "").strip().lower() + if not acct or not ck or not sk: + return None + + source_url_norm = normalize_avatar_source_url(source_url) if source_url else "" + + now_ts = int(time.time()) + fetched = int(fetched_at if fetched_at is not None else now_ts) + checked = int(checked_at if checked_at is not None else now_ts) + expire_ts = int(expires_at if expires_at is not None else (checked + AVATAR_CACHE_TTL_SECONDS)) + + try: + conn = _connect(acct) + except Exception as e: + logger.warning(f"[avatar_cache_error] open db failed account={acct} err={e}") + return None + try: + conn.execute( + """ + INSERT INTO avatar_cache_entries ( + account, cache_key, source_kind, username, source_url, + source_md5, source_update_time, rel_path, media_type, size_bytes, + etag, last_modified, fetched_at, checked_at, expires_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(account, cache_key) DO UPDATE SET + source_kind=excluded.source_kind, + username=excluded.username, + source_url=excluded.source_url, + source_md5=excluded.source_md5, + source_update_time=excluded.source_update_time, + rel_path=excluded.rel_path, + media_type=excluded.media_type, + size_bytes=excluded.size_bytes, + etag=excluded.etag, + last_modified=excluded.last_modified, + fetched_at=excluded.fetched_at, + checked_at=excluded.checked_at, + expires_at=excluded.expires_at + """, + ( + acct, + ck, + sk, + str(username or "").strip(), + source_url_norm, + str(source_md5 or "").strip().lower(), + int(source_update_time or 0), + str(rel_path or "").strip().replace("\\", "/"), + str(media_type or "application/octet-stream").strip() or "application/octet-stream", + int(size_bytes or 0), + str(etag or "").strip(), + str(last_modified or "").strip(), + fetched, + checked, + expire_ts, + ), + ) + conn.commit() + row = conn.execute( + "SELECT * FROM avatar_cache_entries WHERE account = ? AND cache_key = ? LIMIT 1", + (acct, ck), + ).fetchone() + return _row_to_dict(row) + except Exception as e: + logger.warning(f"[avatar_cache_error] upsert failed account={acct} cache_key={ck} err={e}") + return None + finally: + try: + conn.close() + except Exception: + pass + + +def touch_avatar_cache_entry(account: str, cache_key: str, *, ttl_seconds: int = AVATAR_CACHE_TTL_SECONDS) -> bool: + if (not is_avatar_cache_enabled()) or (not cache_key): + return False + now_ts = int(time.time()) + try: + conn = _connect(account) + except Exception: + return False + try: + conn.execute( + "UPDATE avatar_cache_entries SET checked_at = ?, expires_at = ? WHERE account = ? AND cache_key = ?", + (now_ts, now_ts + max(60, int(ttl_seconds or AVATAR_CACHE_TTL_SECONDS)), str(account or ""), str(cache_key or "")), + ) + conn.commit() + return True + except Exception: + return False + finally: + try: + conn.close() + except Exception: + pass + + +def write_avatar_cache_payload( + account: str, + *, + source_kind: str, + username: str = "", + source_url: str = "", + payload: bytes, + media_type: str, + source_md5: str = "", + source_update_time: int = 0, + etag: str = "", + last_modified: str = "", + ttl_seconds: int = AVATAR_CACHE_TTL_SECONDS, +) -> tuple[Optional[dict[str, Any]], Optional[Path]]: + if (not is_avatar_cache_enabled()) or (not payload): + return None, None + + acct = str(account or "").strip() + sk = str(source_kind or "").strip().lower() + if not acct or sk not in {"user", "url"}: + return None, None + + source_url_norm = normalize_avatar_source_url(source_url) if source_url else "" + if sk == "user": + cache_key = cache_key_for_avatar_user(username) + else: + cache_key = cache_key_for_avatar_url(source_url_norm) + + digest = hashlib.sha1(bytes(payload)).hexdigest() + ext = _guess_ext(media_type) + rel_path = f"files/{digest[:2]}/{digest}.{ext}" + + try: + account_dir, _, tmp_dir, _ = _ensure_account_layout(acct) + except Exception as e: + logger.warning(f"[avatar_cache_error] ensure dirs failed account={acct} err={e}") + return None, None + + abs_path = account_dir / rel_path + try: + abs_path.parent.mkdir(parents=True, exist_ok=True) + if (not abs_path.exists()) or (int(abs_path.stat().st_size) != len(payload)): + tmp_path = tmp_dir / f"{digest}.{time.time_ns()}.tmp" + tmp_path.write_bytes(payload) + os.replace(str(tmp_path), str(abs_path)) + except Exception as e: + logger.warning(f"[avatar_cache_error] write file failed account={acct} path={abs_path} err={e}") + return None, None + + if (not etag) and digest: + etag = f'"{digest}"' + if (not last_modified) and source_update_time: + last_modified = _http_date_from_ts(source_update_time) + if not last_modified: + last_modified = _http_date_from_ts(int(time.time())) + + entry = upsert_avatar_cache_entry( + acct, + cache_key=cache_key, + source_kind=sk, + username=username, + source_url=source_url_norm, + source_md5=source_md5, + source_update_time=int(source_update_time or 0), + rel_path=rel_path, + media_type=media_type, + size_bytes=len(payload), + etag=etag, + last_modified=last_modified, + fetched_at=int(time.time()), + checked_at=int(time.time()), + expires_at=int(time.time()) + max(60, int(ttl_seconds or AVATAR_CACHE_TTL_SECONDS)), + ) + if not entry: + return None, None + return entry, abs_path + + +def build_avatar_cache_response_headers( + entry: Optional[dict[str, Any]], *, max_age: int = AVATAR_CACHE_TTL_SECONDS +) -> dict[str, str]: + headers: dict[str, str] = { + "Cache-Control": f"public, max-age={int(max_age)}", + } + if not entry: + return headers + etag = str(entry.get("etag") or "").strip() + last_modified = str(entry.get("last_modified") or "").strip() + if etag: + headers["ETag"] = etag + if last_modified: + headers["Last-Modified"] = last_modified + return headers + diff --git a/src/wechat_decrypt_tool/routers/chat.py b/src/wechat_decrypt_tool/routers/chat.py index 7c02107..27a67ab 100644 --- a/src/wechat_decrypt_tool/routers/chat.py +++ b/src/wechat_decrypt_tool/routers/chat.py @@ -45,7 +45,6 @@ from ..chat_helpers import ( _normalize_xml_url, _parse_app_message, _parse_pat_message, - _pick_avatar_url, _pick_display_name, _query_head_image_usernames, _quote_ident, @@ -85,6 +84,19 @@ _REALTIME_SYNC_LOCKS: dict[tuple[str, str], threading.Lock] = {} _REALTIME_SYNC_ALL_LOCKS: dict[str, threading.Lock] = {} +def _avatar_url_unified( + *, + account_dir: Path, + username: str, + local_avatar_usernames: set[str] | None = None, +) -> str: + u = str(username or "").strip() + if not u: + return "" + # Unified avatar entrypoint: backend decides local db vs remote fallback + cache. + return _build_avatar_url(str(account_dir.name or ""), u) + + def _realtime_sync_lock(account: str, username: str) -> threading.Lock: key = (str(account or "").strip(), str(username or "").strip()) with _REALTIME_SYNC_MU: @@ -1946,9 +1958,11 @@ async def chat_search_index_senders( continue cnt = int(r["c"] or 0) row = contact_rows.get(su) - avatar_url = _pick_avatar_url(row) - if (not avatar_url) and (su in local_sender_avatars): - avatar_url = _build_avatar_url(account_dir.name, su) + avatar_url = _avatar_url_unified( + account_dir=account_dir, + username=su, + local_avatar_usernames=local_sender_avatars, + ) senders.append( { "username": su, @@ -2568,7 +2582,7 @@ def _postprocess_full_messages( row = sender_contact_rows.get(u) if _pick_display_name(row, u) == u: need_display.append(u) - if (not _pick_avatar_url(row)) and (u not in local_sender_avatars): + if u not in local_sender_avatars: need_avatar.append(u) need_display = list(dict.fromkeys(need_display)) @@ -2606,13 +2620,11 @@ def _postprocess_full_messages( if wd and wd != su: display_name = wd m["senderDisplayName"] = display_name - avatar_url = _pick_avatar_url(row) - if not avatar_url and su in local_sender_avatars: - avatar_url = base_url + _build_avatar_url(account_dir.name, su) - if not avatar_url: - wa = str(wcdb_avatar_urls.get(su) or "").strip() - if wa.lower().startswith(("http://", "https://")): - avatar_url = wa + avatar_url = base_url + _avatar_url_unified( + account_dir=account_dir, + username=su, + local_avatar_usernames=local_sender_avatars, + ) m["senderAvatar"] = avatar_url qu = str(m.get("quoteUsername") or "").strip() @@ -2922,7 +2934,7 @@ def list_chat_sessions( if u not in local_avatar_usernames: need_avatar.append(u) else: - if (not _pick_avatar_url(row)) and (u not in local_avatar_usernames): + if u not in local_avatar_usernames: need_avatar.append(u) need_display = list(dict.fromkeys(need_display)) @@ -2984,15 +2996,11 @@ def list_chat_sessions( # Prefer local head_image avatars when available: decrypted contact.db URLs can be stale # (or hotlink-protected for browsers). WCDB realtime (when available) is the next best. - avatar_url = "" - if username in local_avatar_usernames: - avatar_url = base_url + _build_avatar_url(account_dir.name, username) - if not avatar_url: - wa = str(wcdb_avatar_urls.get(username) or "").strip() - if wa.lower().startswith(("http://", "https://")): - avatar_url = wa - if not avatar_url: - avatar_url = _pick_avatar_url(c_row) or "" + avatar_url = base_url + _avatar_url_unified( + account_dir=account_dir, + username=username, + local_avatar_usernames=local_avatar_usernames, + ) last_message = "" if preview_mode == "session": @@ -4388,7 +4396,7 @@ def list_chat_messages( row = sender_contact_rows.get(u) if _pick_display_name(row, u) == u: need_display.append(u) - if (not _pick_avatar_url(row)) and (u not in local_sender_avatars): + if u not in local_sender_avatars: need_avatar.append(u) need_display = list(dict.fromkeys(need_display)) @@ -4426,13 +4434,11 @@ def list_chat_messages( if wd and wd != su: display_name = wd m["senderDisplayName"] = display_name - avatar_url = _pick_avatar_url(row) - if not avatar_url and su in local_sender_avatars: - avatar_url = base_url + _build_avatar_url(account_dir.name, su) - if not avatar_url: - wa = str(wcdb_avatar_urls.get(su) or "").strip() - if wa.lower().startswith(("http://", "https://")): - avatar_url = wa + avatar_url = base_url + _avatar_url_unified( + account_dir=account_dir, + username=su, + local_avatar_usernames=local_sender_avatars, + ) m["senderAvatar"] = avatar_url qu = str(m.get("quoteUsername") or "").strip() @@ -4897,7 +4903,7 @@ async def _search_chat_messages_via_fts( row = contact_rows.get(uu) if _pick_display_name(row, uu) == uu: need_display.append(uu) - if (not _pick_avatar_url(row)) and (uu not in local_avatar_usernames): + if uu not in local_avatar_usernames: need_avatar.append(uu) need_display = list(dict.fromkeys(need_display)) @@ -4919,13 +4925,11 @@ async def _search_chat_messages_via_fts( wd = str(wcdb_display_names.get(username) or "").strip() if wd and wd != username: conv_name = wd - conv_avatar = _pick_avatar_url(conv_row) - if (not conv_avatar) and (username in local_avatar_usernames): - conv_avatar = base_url + _build_avatar_url(account_dir.name, username) - if not conv_avatar: - wa = str(wcdb_avatar_urls.get(username) or "").strip() - if wa.lower().startswith(("http://", "https://")): - conv_avatar = wa + conv_avatar = base_url + _avatar_url_unified( + account_dir=account_dir, + username=username, + local_avatar_usernames=local_avatar_usernames, + ) for h in hits: su = str(h.get("senderUsername") or "").strip() @@ -4939,13 +4943,11 @@ async def _search_chat_messages_via_fts( if wd and wd != su: display_name = wd h["senderDisplayName"] = display_name - avatar_url = _pick_avatar_url(row) - if (not avatar_url) and (su in local_avatar_usernames): - avatar_url = base_url + _build_avatar_url(account_dir.name, su) - if not avatar_url: - wa = str(wcdb_avatar_urls.get(su) or "").strip() - if wa.lower().startswith(("http://", "https://")): - avatar_url = wa + avatar_url = base_url + _avatar_url_unified( + account_dir=account_dir, + username=su, + local_avatar_usernames=local_avatar_usernames, + ) h["senderAvatar"] = avatar_url else: uniq_contacts = list( @@ -4968,7 +4970,7 @@ async def _search_chat_messages_via_fts( row = contact_rows.get(uu) if _pick_display_name(row, uu) == uu: need_display.append(uu) - if (not _pick_avatar_url(row)) and (uu not in local_avatar_usernames): + if uu not in local_avatar_usernames: need_avatar.append(uu) need_display = list(dict.fromkeys(need_display)) @@ -4994,13 +4996,11 @@ async def _search_chat_messages_via_fts( if wd and wd != cu: conv_name = wd h["conversationName"] = conv_name or cu - conv_avatar = _pick_avatar_url(crow) - if (not conv_avatar) and cu and (cu in local_avatar_usernames): - conv_avatar = base_url + _build_avatar_url(account_dir.name, cu) - if not conv_avatar and cu: - wa = str(wcdb_avatar_urls.get(cu) or "").strip() - if wa.lower().startswith(("http://", "https://")): - conv_avatar = wa + conv_avatar = base_url + _avatar_url_unified( + account_dir=account_dir, + username=cu, + local_avatar_usernames=local_avatar_usernames, + ) h["conversationAvatar"] = conv_avatar if su: row = contact_rows.get(su) @@ -5010,13 +5010,11 @@ async def _search_chat_messages_via_fts( if wd and wd != su: display_name = wd h["senderDisplayName"] = display_name - avatar_url = _pick_avatar_url(row) - if (not avatar_url) and (su in local_avatar_usernames): - avatar_url = base_url + _build_avatar_url(account_dir.name, su) - if not avatar_url: - wa = str(wcdb_avatar_urls.get(su) or "").strip() - if wa.lower().startswith(("http://", "https://")): - avatar_url = wa + avatar_url = base_url + _avatar_url_unified( + account_dir=account_dir, + username=su, + local_avatar_usernames=local_avatar_usernames, + ) h["senderAvatar"] = avatar_url return { diff --git a/src/wechat_decrypt_tool/routers/chat_media.py b/src/wechat_decrypt_tool/routers/chat_media.py index d90b224..fcf4e31 100644 --- a/src/wechat_decrypt_tool/routers/chat_media.py +++ b/src/wechat_decrypt_tool/routers/chat_media.py @@ -8,7 +8,7 @@ import os import sqlite3 import subprocess from pathlib import Path -from typing import Optional +from typing import Any, Optional from urllib.parse import urlparse import requests @@ -16,6 +16,21 @@ from fastapi import APIRouter, HTTPException from fastapi.responses import FileResponse, Response from pydantic import BaseModel, Field +from ..avatar_cache import ( + AVATAR_CACHE_TTL_SECONDS, + avatar_cache_entry_file_exists, + avatar_cache_entry_is_fresh, + build_avatar_cache_response_headers, + cache_key_for_avatar_user, + cache_key_for_avatar_url, + get_avatar_cache_url_entry, + get_avatar_cache_user_entry, + is_avatar_cache_enabled, + normalize_avatar_source_url, + touch_avatar_cache_entry, + upsert_avatar_cache_entry, + write_avatar_cache_payload, +) from ..logging_config import get_logger from ..media_helpers import ( _convert_silk_to_wav, @@ -43,14 +58,56 @@ from ..media_helpers import ( _try_find_decrypted_resource, _try_strip_media_prefix, ) -from ..chat_helpers import _extract_md5_from_packed_info +from ..chat_helpers import _extract_md5_from_packed_info, _load_contact_rows, _pick_avatar_url from ..path_fix import PathFixRoute +from ..wcdb_realtime import WCDB_REALTIME, get_avatar_urls as _wcdb_get_avatar_urls logger = get_logger(__name__) router = APIRouter(route_class=PathFixRoute) +def _resolve_avatar_remote_url(*, account_dir: Path, username: str) -> str: + u = str(username or "").strip() + if not u: + return "" + + # 1) contact.db first (cheap local lookup) + try: + rows = _load_contact_rows(account_dir / "contact.db", [u]) + row = rows.get(u) + raw = str(_pick_avatar_url(row) or "").strip() + if raw.lower().startswith(("http://", "https://")): + return normalize_avatar_source_url(raw) + except Exception: + pass + + # 2) WCDB fallback (more complete on enterprise/openim IDs) + try: + wcdb_conn = WCDB_REALTIME.ensure_connected(account_dir) + with wcdb_conn.lock: + mp = _wcdb_get_avatar_urls(wcdb_conn.handle, [u]) + wa = str(mp.get(u) or "").strip() + if wa.lower().startswith(("http://", "https://")): + return normalize_avatar_source_url(wa) + except Exception: + pass + + return "" + + +def _parse_304_headers(headers: Any) -> tuple[str, str]: + try: + etag = str((headers or {}).get("ETag") or "").strip() + except Exception: + etag = "" + try: + last_modified = str((headers or {}).get("Last-Modified") or "").strip() + except Exception: + last_modified = "" + return etag, last_modified + + @lru_cache(maxsize=4096) def _fast_probe_image_path_in_chat_attach( *, @@ -267,27 +324,309 @@ async def get_chat_avatar(username: str, account: Optional[str] = None): if not username: raise HTTPException(status_code=400, detail="Missing username.") account_dir = _resolve_account_dir(account) + account_name = str(account_dir.name or "").strip() + user_key = str(username or "").strip() + + # 1) Try on-disk cache first (fast path) + user_entry = None + cached_file = None + if is_avatar_cache_enabled() and account_name and user_key: + try: + user_entry = get_avatar_cache_user_entry(account_name, user_key) + cached_file = avatar_cache_entry_file_exists(account_name, user_entry) + if cached_file is not None: + logger.info(f"[avatar_cache_hit] kind=user account={account_name} username={user_key}") + except Exception as e: + logger.warning(f"[avatar_cache_error] read user cache failed account={account_name} username={user_key} err={e}") + head_image_db_path = account_dir / "head_image.db" if not head_image_db_path.exists(): + # No local head_image.db: allow fallback from cached/remote URL path. + if cached_file is not None and user_entry: + headers = build_avatar_cache_response_headers(user_entry) + return FileResponse( + str(cached_file), + media_type=str(user_entry.get("media_type") or "application/octet-stream"), + headers=headers, + ) raise HTTPException(status_code=404, detail="head_image.db not found.") conn = sqlite3.connect(str(head_image_db_path)) try: - row = conn.execute( - "SELECT image_buffer FROM head_image WHERE username = ? ORDER BY update_time DESC LIMIT 1", + meta = conn.execute( + "SELECT md5, update_time FROM head_image WHERE username = ? ORDER BY update_time DESC LIMIT 1", (username,), ).fetchone() + if meta and meta[0] is not None: + db_md5 = str(meta[0] or "").strip().lower() + try: + db_update_time = int(meta[1] or 0) + except Exception: + db_update_time = 0 + + # Cache still valid against head_image metadata. + if cached_file is not None and user_entry: + cached_md5 = str(user_entry.get("source_md5") or "").strip().lower() + try: + cached_update = int(user_entry.get("source_update_time") or 0) + except Exception: + cached_update = 0 + if cached_md5 == db_md5 and cached_update == db_update_time: + touch_avatar_cache_entry(account_name, str(user_entry.get("cache_key") or "")) + headers = build_avatar_cache_response_headers(user_entry) + return FileResponse( + str(cached_file), + media_type=str(user_entry.get("media_type") or "application/octet-stream"), + headers=headers, + ) + + # Refresh from blob (changed or first-load) + row = conn.execute( + "SELECT image_buffer FROM head_image WHERE username = ? ORDER BY update_time DESC LIMIT 1", + (username,), + ).fetchone() + if row and row[0] is not None: + data = bytes(row[0]) if isinstance(row[0], (memoryview, bytearray)) else row[0] + if not isinstance(data, (bytes, bytearray)): + data = bytes(data) + if data: + media_type = _detect_image_media_type(data) + media_type = media_type if media_type.startswith("image/") else "application/octet-stream" + entry, out_path = write_avatar_cache_payload( + account_name, + source_kind="user", + username=user_key, + payload=bytes(data), + media_type=media_type, + source_md5=db_md5, + source_update_time=db_update_time, + ttl_seconds=AVATAR_CACHE_TTL_SECONDS, + ) + if entry and out_path: + logger.info( + f"[avatar_cache_download] kind=user account={account_name} username={user_key} src=head_image" + ) + headers = build_avatar_cache_response_headers(entry) + return FileResponse(str(out_path), media_type=media_type, headers=headers) + + # cache write failed: fallback to response bytes + logger.warning( + f"[avatar_cache_error] kind=user account={account_name} username={user_key} action=write_fallback" + ) + return Response(content=bytes(data), media_type=media_type) + + # meta not found (no local avatar blob) + row = None finally: conn.close() - if not row or row[0] is None: - raise HTTPException(status_code=404, detail="Avatar not found.") + # 2) Fallback: remote avatar URL (contact/WCDB), cache by URL. + remote_url = _resolve_avatar_remote_url(account_dir=account_dir, username=user_key) + if remote_url and is_avatar_cache_enabled(): + url_entry = get_avatar_cache_url_entry(account_name, remote_url) + url_file = avatar_cache_entry_file_exists(account_name, url_entry) + if url_entry and url_file and avatar_cache_entry_is_fresh(url_entry): + logger.info(f"[avatar_cache_hit] kind=url account={account_name} username={user_key}") + touch_avatar_cache_entry(account_name, str(url_entry.get("cache_key") or "")) + # Keep user-key mapping aligned, so next user lookup is direct. + try: + upsert_avatar_cache_entry( + account_name, + cache_key=cache_key_for_avatar_user(user_key), + source_kind="user", + username=user_key, + source_url=remote_url, + source_md5=str(url_entry.get("source_md5") or ""), + source_update_time=int(url_entry.get("source_update_time") or 0), + rel_path=str(url_entry.get("rel_path") or ""), + media_type=str(url_entry.get("media_type") or "application/octet-stream"), + size_bytes=int(url_entry.get("size_bytes") or 0), + etag=str(url_entry.get("etag") or ""), + last_modified=str(url_entry.get("last_modified") or ""), + fetched_at=int(url_entry.get("fetched_at") or 0), + checked_at=int(url_entry.get("checked_at") or 0), + expires_at=int(url_entry.get("expires_at") or 0), + ) + except Exception: + pass + headers = build_avatar_cache_response_headers(url_entry) + return FileResponse( + str(url_file), + media_type=str(url_entry.get("media_type") or "application/octet-stream"), + headers=headers, + ) - data = bytes(row[0]) if isinstance(row[0], (memoryview, bytearray)) else row[0] - if not isinstance(data, (bytes, bytearray)): - data = bytes(data) - media_type = _detect_image_media_type(data) - return Response(content=data, media_type=media_type) + # Revalidate / download remote avatar + def _download_remote_avatar( + source_url: str, + *, + etag: str, + last_modified: str, + ) -> tuple[bytes, str, str, str, bool]: + base_headers = { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120 Safari/537.36", + "Accept": "image/avif,image/webp,image/apng,image/*,*/*;q=0.8", + } + + header_variants = [ + { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/107.0.0.0 Safari/537.36 MicroMessenger/7.0.20.1781(0x6700143B) WindowsWechat(0x63090719) XWEB/8351", + "Accept": "image/avif,image/webp,image/apng,image/svg+xml,image/*,*/*;q=0.8", + "Accept-Language": "zh-CN,zh;q=0.9", + "Referer": "https://servicewechat.com/", + "Origin": "https://servicewechat.com", + "Range": "bytes=0-", + }, + {"Referer": "https://wx.qq.com/", "Origin": "https://wx.qq.com"}, + {"Referer": "https://mp.weixin.qq.com/", "Origin": "https://mp.weixin.qq.com"}, + {"Referer": "https://www.baidu.com/", "Origin": "https://www.baidu.com"}, + {}, + ] + + last_err: Exception | None = None + for extra in header_variants: + headers = dict(base_headers) + headers.update(extra) + if etag: + headers["If-None-Match"] = etag + if last_modified: + headers["If-Modified-Since"] = last_modified + + r = requests.get(source_url, headers=headers, timeout=20, stream=True) + try: + if r.status_code == 304: + e2, lm2 = _parse_304_headers(r.headers) + return b"", "", (e2 or etag), (lm2 or last_modified), True + r.raise_for_status() + content_type = str(r.headers.get("Content-Type") or "").strip() + e2, lm2 = _parse_304_headers(r.headers) + max_bytes = 10 * 1024 * 1024 + chunks: list[bytes] = [] + total = 0 + for ch in r.iter_content(chunk_size=64 * 1024): + if not ch: + continue + chunks.append(ch) + total += len(ch) + if total > max_bytes: + raise HTTPException(status_code=400, detail="Avatar too large (>10MB).") + return b"".join(chunks), content_type, e2, lm2, False + except HTTPException: + raise + except Exception as e: + last_err = e + finally: + try: + r.close() + except Exception: + pass + + raise last_err or RuntimeError("avatar remote download failed") + + etag0 = str((url_entry or {}).get("etag") or "").strip() + lm0 = str((url_entry or {}).get("last_modified") or "").strip() + try: + payload, ct, etag_new, lm_new, not_modified = await asyncio.to_thread( + _download_remote_avatar, + remote_url, + etag=etag0, + last_modified=lm0, + ) + except Exception as e: + logger.warning(f"[avatar_cache_error] kind=url account={account_name} username={user_key} err={e}") + if url_entry and url_file: + headers = build_avatar_cache_response_headers(url_entry) + return FileResponse( + str(url_file), + media_type=str(url_entry.get("media_type") or "application/octet-stream"), + headers=headers, + ) + raise HTTPException(status_code=404, detail="Avatar not found.") + + if not_modified and url_entry and url_file: + touch_avatar_cache_entry(account_name, cache_key_for_avatar_url(remote_url)) + if etag_new or lm_new: + try: + upsert_avatar_cache_entry( + account_name, + cache_key=cache_key_for_avatar_url(remote_url), + source_kind="url", + username=user_key, + source_url=remote_url, + source_md5=str(url_entry.get("source_md5") or ""), + source_update_time=int(url_entry.get("source_update_time") or 0), + rel_path=str(url_entry.get("rel_path") or ""), + media_type=str(url_entry.get("media_type") or "application/octet-stream"), + size_bytes=int(url_entry.get("size_bytes") or 0), + etag=etag_new or etag0, + last_modified=lm_new or lm0, + ) + except Exception: + pass + logger.info(f"[avatar_cache_revalidate] kind=url account={account_name} username={user_key} status=304") + headers = build_avatar_cache_response_headers(url_entry) + return FileResponse( + str(url_file), + media_type=str(url_entry.get("media_type") or "application/octet-stream"), + headers=headers, + ) + + if payload: + payload2, media_type, _ext = _detect_media_type_and_ext(payload) + if media_type == "application/octet-stream" and ct: + try: + mt = ct.split(";")[0].strip() + if mt.startswith("image/"): + media_type = mt + except Exception: + pass + if str(media_type or "").startswith("image/"): + entry, out_path = write_avatar_cache_payload( + account_name, + source_kind="url", + username=user_key, + source_url=remote_url, + payload=payload2, + media_type=media_type, + etag=etag_new, + last_modified=lm_new, + ttl_seconds=AVATAR_CACHE_TTL_SECONDS, + ) + if entry and out_path: + # bind user-key record to same file for quicker next access + try: + upsert_avatar_cache_entry( + account_name, + cache_key=cache_key_for_avatar_user(user_key), + source_kind="user", + username=user_key, + source_url=remote_url, + source_md5=str(entry.get("source_md5") or ""), + source_update_time=int(entry.get("source_update_time") or 0), + rel_path=str(entry.get("rel_path") or ""), + media_type=str(entry.get("media_type") or "application/octet-stream"), + size_bytes=int(entry.get("size_bytes") or 0), + etag=str(entry.get("etag") or ""), + last_modified=str(entry.get("last_modified") or ""), + fetched_at=int(entry.get("fetched_at") or 0), + checked_at=int(entry.get("checked_at") or 0), + expires_at=int(entry.get("expires_at") or 0), + ) + except Exception: + pass + logger.info(f"[avatar_cache_download] kind=url account={account_name} username={user_key}") + headers = build_avatar_cache_response_headers(entry) + return FileResponse(str(out_path), media_type=media_type, headers=headers) + + if cached_file is not None and user_entry: + headers = build_avatar_cache_response_headers(user_entry) + return FileResponse( + str(cached_file), + media_type=str(user_entry.get("media_type") or "application/octet-stream"), + headers=headers, + ) + + raise HTTPException(status_code=404, detail="Avatar not found.") class EmojiDownloadRequest(BaseModel): @@ -434,7 +773,25 @@ async def proxy_image(url: str): if not _is_allowed_proxy_image_host(host): raise HTTPException(status_code=400, detail="Unsupported url host for proxy_image.") - def _download_bytes() -> tuple[bytes, str]: + source_url = normalize_avatar_source_url(u) + proxy_account = "_proxy" + cache_entry = get_avatar_cache_url_entry(proxy_account, source_url) if is_avatar_cache_enabled() else None + cache_file = avatar_cache_entry_file_exists(proxy_account, cache_entry) + if cache_entry and cache_file and avatar_cache_entry_is_fresh(cache_entry): + logger.info(f"[avatar_cache_hit] kind=proxy_url account={proxy_account}") + touch_avatar_cache_entry(proxy_account, cache_key_for_avatar_url(source_url)) + headers = build_avatar_cache_response_headers(cache_entry) + return FileResponse( + str(cache_file), + media_type=str(cache_entry.get("media_type") or "application/octet-stream"), + headers=headers, + ) + + def _download_bytes( + *, + if_none_match: str = "", + if_modified_since: str = "", + ) -> tuple[bytes, str, str, str, bool]: base_headers = { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120 Safari/537.36", "Accept": "image/avif,image/webp,image/apng,image/*,*/*;q=0.8", @@ -464,10 +821,20 @@ async def proxy_image(url: str): for extra in header_variants: headers = dict(base_headers) headers.update(extra) + if if_none_match: + headers["If-None-Match"] = if_none_match + if if_modified_since: + headers["If-Modified-Since"] = if_modified_since r = requests.get(u, headers=headers, timeout=20, stream=True) try: + if r.status_code == 304: + etag0 = str(r.headers.get("ETag") or "").strip() + lm0 = str(r.headers.get("Last-Modified") or "").strip() + return b"", "", etag0, lm0, True r.raise_for_status() content_type = str(r.headers.get("Content-Type") or "").strip() + etag0 = str(r.headers.get("ETag") or "").strip() + lm0 = str(r.headers.get("Last-Modified") or "").strip() max_bytes = 10 * 1024 * 1024 chunks: list[bytes] = [] total = 0 @@ -478,7 +845,7 @@ async def proxy_image(url: str): total += len(ch) if total > max_bytes: raise HTTPException(status_code=400, detail="Proxy image too large (>10MB).") - return b"".join(chunks), content_type + return b"".join(chunks), content_type, etag0, lm0, False except HTTPException: # Hard failure, don't retry with another referer. raise @@ -493,14 +860,50 @@ async def proxy_image(url: str): # All variants failed. raise last_err or RuntimeError("proxy_image download failed") + etag0 = str((cache_entry or {}).get("etag") or "").strip() + lm0 = str((cache_entry or {}).get("last_modified") or "").strip() try: - data, ct = await asyncio.to_thread(_download_bytes) + data, ct, etag_new, lm_new, not_modified = await asyncio.to_thread( + _download_bytes, + if_none_match=etag0, + if_modified_since=lm0, + ) except HTTPException: raise except Exception as e: logger.warning(f"proxy_image failed: url={u} err={e}") + if cache_entry and cache_file: + headers = build_avatar_cache_response_headers(cache_entry) + return FileResponse( + str(cache_file), + media_type=str(cache_entry.get("media_type") or "application/octet-stream"), + headers=headers, + ) raise HTTPException(status_code=502, detail=f"Proxy image failed: {e}") + if not_modified and cache_entry and cache_file: + logger.info(f"[avatar_cache_revalidate] kind=proxy_url account={proxy_account} status=304") + upsert_avatar_cache_entry( + proxy_account, + cache_key=cache_key_for_avatar_url(source_url), + source_kind="url", + source_url=source_url, + username="", + source_md5=str(cache_entry.get("source_md5") or ""), + source_update_time=int(cache_entry.get("source_update_time") or 0), + rel_path=str(cache_entry.get("rel_path") or ""), + media_type=str(cache_entry.get("media_type") or "application/octet-stream"), + size_bytes=int(cache_entry.get("size_bytes") or 0), + etag=etag_new or etag0, + last_modified=lm_new or lm0, + ) + headers = build_avatar_cache_response_headers(cache_entry) + return FileResponse( + str(cache_file), + media_type=str(cache_entry.get("media_type") or "application/octet-stream"), + headers=headers, + ) + if not data: raise HTTPException(status_code=502, detail="Proxy returned empty body.") @@ -518,8 +921,24 @@ async def proxy_image(url: str): if not str(media_type or "").startswith("image/"): raise HTTPException(status_code=502, detail="Proxy did not return an image.") + if is_avatar_cache_enabled(): + entry, out_path = write_avatar_cache_payload( + proxy_account, + source_kind="url", + source_url=source_url, + payload=payload, + media_type=media_type, + etag=etag_new, + last_modified=lm_new, + ttl_seconds=AVATAR_CACHE_TTL_SECONDS, + ) + if entry and out_path: + logger.info(f"[avatar_cache_download] kind=proxy_url account={proxy_account}") + headers = build_avatar_cache_response_headers(entry) + return FileResponse(str(out_path), media_type=media_type, headers=headers) + resp = Response(content=payload, media_type=media_type) - resp.headers["Cache-Control"] = "public, max-age=86400" + resp.headers["Cache-Control"] = f"public, max-age={AVATAR_CACHE_TTL_SECONDS}" return resp diff --git a/src/wechat_decrypt_tool/wrapped/cards/card_00_global_overview.py b/src/wechat_decrypt_tool/wrapped/cards/card_00_global_overview.py index e3853c4..04c31a0 100644 --- a/src/wechat_decrypt_tool/wrapped/cards/card_00_global_overview.py +++ b/src/wechat_decrypt_tool/wrapped/cards/card_00_global_overview.py @@ -17,7 +17,6 @@ from ...chat_helpers import ( _decode_sqlite_text, _iter_message_db_paths, _load_contact_rows, - _pick_avatar_url, _pick_display_name, _quote_ident, _should_keep_session, @@ -701,7 +700,7 @@ def build_card_00_global_overview( u, cnt = stats.top_contact row = contact_rows.get(u) display = _pick_display_name(row, u) - avatar = _pick_avatar_url(row) or (_build_avatar_url(str(account_dir.name or ""), u) if u else "") + avatar = _build_avatar_url(str(account_dir.name or ""), u) if u else "" top_contact_obj = { "username": u, "displayName": display, @@ -716,7 +715,7 @@ def build_card_00_global_overview( u, cnt = stats.top_group row = contact_rows.get(u) display = _pick_display_name(row, u) - avatar = _pick_avatar_url(row) or (_build_avatar_url(str(account_dir.name or ""), u) if u else "") + avatar = _build_avatar_url(str(account_dir.name or ""), u) if u else "" top_group_obj = { "username": u, "displayName": display, diff --git a/src/wechat_decrypt_tool/wrapped/cards/card_01_cyber_schedule.py b/src/wechat_decrypt_tool/wrapped/cards/card_01_cyber_schedule.py index a4704f0..b7fff77 100644 --- a/src/wechat_decrypt_tool/wrapped/cards/card_01_cyber_schedule.py +++ b/src/wechat_decrypt_tool/wrapped/cards/card_01_cyber_schedule.py @@ -14,7 +14,6 @@ from ...chat_helpers import ( _build_avatar_url, _iter_message_db_paths, _load_contact_rows, - _pick_avatar_url, _pick_display_name, _quote_ident, _row_to_search_hit, @@ -713,7 +712,7 @@ def _fetch_message_moment_payload( contact_row = contact_rows.get(username) display = _pick_display_name(contact_row, username) - avatar = _pick_avatar_url(contact_row) or (_build_avatar_url(str(account_dir.name or ""), username) if username else "") + avatar = _build_avatar_url(str(account_dir.name or ""), username) if username else "" return { "timestamp": int(ref.ts), diff --git a/src/wechat_decrypt_tool/wrapped/cards/card_03_reply_speed.py b/src/wechat_decrypt_tool/wrapped/cards/card_03_reply_speed.py index 7498520..86b0c1c 100644 --- a/src/wechat_decrypt_tool/wrapped/cards/card_03_reply_speed.py +++ b/src/wechat_decrypt_tool/wrapped/cards/card_03_reply_speed.py @@ -12,7 +12,6 @@ from typing import Any, Optional from ...chat_helpers import ( _build_avatar_url, _load_contact_rows, - _pick_avatar_url, _pick_display_name, _should_keep_session, ) @@ -385,7 +384,7 @@ def compute_reply_speed_stats(*, account_dir: Path, year: int) -> dict[str, Any] def conv_to_obj(score: float | None, agg: _ConvAgg) -> dict[str, Any]: row = contact_rows.get(agg.username) display = _pick_display_name(row, agg.username) - avatar = _pick_avatar_url(row) or (_build_avatar_url(str(account_dir.name or ""), agg.username) if agg.username else "") + avatar = _build_avatar_url(str(account_dir.name or ""), agg.username) if agg.username else "" avg_s = agg.avg_gap() out: dict[str, Any] = { "username": agg.username, @@ -420,7 +419,7 @@ def compute_reply_speed_stats(*, account_dir: Path, year: int) -> dict[str, Any] else: row = contact_rows.get(global_fastest_u) display = _pick_display_name(row, global_fastest_u) - avatar = _pick_avatar_url(row) or (_build_avatar_url(str(account_dir.name or ""), global_fastest_u) if global_fastest_u else "") + avatar = _build_avatar_url(str(account_dir.name or ""), global_fastest_u) if global_fastest_u else "" fastest_obj = { "username": global_fastest_u, "displayName": display, @@ -440,7 +439,7 @@ def compute_reply_speed_stats(*, account_dir: Path, year: int) -> dict[str, Any] else: row = contact_rows.get(global_slowest_u) display = _pick_display_name(row, global_slowest_u) - avatar = _pick_avatar_url(row) or (_build_avatar_url(str(account_dir.name or ""), global_slowest_u) if global_slowest_u else "") + avatar = _build_avatar_url(str(account_dir.name or ""), global_slowest_u) if global_slowest_u else "" slowest_obj = { "username": global_slowest_u, "displayName": display, @@ -547,7 +546,7 @@ def compute_reply_speed_stats(*, account_dir: Path, year: int) -> dict[str, Any] row = contact_rows.get(u) display = _pick_display_name(row, u) - avatar = _pick_avatar_url(row) or (_build_avatar_url(str(account_dir.name or ""), u) if u else "") + avatar = _build_avatar_url(str(account_dir.name or ""), u) if u else "" series.append( { "username": u, @@ -595,7 +594,7 @@ def compute_reply_speed_stats(*, account_dir: Path, year: int) -> dict[str, Any] if not u: continue display = _pick_display_name(r, u) - avatar = _pick_avatar_url(r) or (_build_avatar_url(str(account_dir.name or ""), u) if u else "") + avatar = _build_avatar_url(str(account_dir.name or ""), u) if u else "" all_contacts_list.append({ "username": u, "displayName": display, diff --git a/tests/test_avatar_cache_chat_media.py b/tests/test_avatar_cache_chat_media.py new file mode 100644 index 0000000..4b8db11 --- /dev/null +++ b/tests/test_avatar_cache_chat_media.py @@ -0,0 +1,173 @@ +import os +import sqlite3 +import sys +import unittest +import importlib +from pathlib import Path +from tempfile import TemporaryDirectory + + +ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(ROOT / "src")) + + +class TestAvatarCacheChatMedia(unittest.TestCase): + def _seed_contact_db(self, path: Path, *, username: str = "wxid_friend") -> None: + conn = sqlite3.connect(str(path)) + try: + conn.execute( + """ + CREATE TABLE contact ( + username TEXT, + remark TEXT, + nick_name TEXT, + alias TEXT, + local_type INTEGER, + verify_flag INTEGER, + big_head_url TEXT, + small_head_url TEXT + ) + """ + ) + conn.execute( + """ + CREATE TABLE stranger ( + username TEXT, + remark TEXT, + nick_name TEXT, + alias TEXT, + local_type INTEGER, + verify_flag INTEGER, + big_head_url TEXT, + small_head_url TEXT + ) + """ + ) + conn.execute( + "INSERT INTO contact VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + ( + username, + "", + "测试好友", + "", + 1, + 0, + "https://wx.qlogo.cn/mmhead/ver_1/test_remote_avatar/132", + "", + ), + ) + conn.commit() + finally: + conn.close() + + def _seed_session_db(self, path: Path, *, username: str = "wxid_friend") -> None: + conn = sqlite3.connect(str(path)) + try: + conn.execute( + """ + CREATE TABLE SessionTable ( + username TEXT, + sort_timestamp INTEGER, + last_timestamp INTEGER + ) + """ + ) + conn.execute("INSERT INTO SessionTable VALUES (?, ?, ?)", (username, 200, 200)) + conn.commit() + finally: + conn.close() + + def _seed_head_image_db(self, path: Path, *, username: str = "wxid_friend") -> None: + # 1x1 PNG + png = bytes.fromhex( + "89504E470D0A1A0A" + "0000000D49484452000000010000000108060000001F15C489" + "0000000D49444154789C6360606060000000050001A5F64540" + "0000000049454E44AE426082" + ) + conn = sqlite3.connect(str(path)) + try: + conn.execute("CREATE TABLE head_image(username TEXT PRIMARY KEY, md5 TEXT, image_buffer BLOB, update_time INTEGER)") + conn.execute( + "INSERT INTO head_image VALUES (?, ?, ?, ?)", + (username, "0123456789abcdef0123456789abcdef", sqlite3.Binary(png), 1735689600), + ) + conn.commit() + finally: + conn.close() + + def test_chat_avatar_caches_to_output_avatar_cache(self): + from fastapi import FastAPI + from fastapi.testclient import TestClient + + with TemporaryDirectory() as td: + root = Path(td) + account = "wxid_test" + username = "wxid_friend" + account_dir = root / "output" / "databases" / account + account_dir.mkdir(parents=True, exist_ok=True) + + self._seed_contact_db(account_dir / "contact.db", username=username) + self._seed_session_db(account_dir / "session.db", username=username) + self._seed_head_image_db(account_dir / "head_image.db", username=username) + + prev_data = None + prev_cache = None + try: + prev_data = os.environ.get("WECHAT_TOOL_DATA_DIR") + prev_cache = os.environ.get("WECHAT_TOOL_AVATAR_CACHE_ENABLED") + os.environ["WECHAT_TOOL_DATA_DIR"] = str(root) + os.environ["WECHAT_TOOL_AVATAR_CACHE_ENABLED"] = "1" + + import wechat_decrypt_tool.app_paths as app_paths + import wechat_decrypt_tool.chat_helpers as chat_helpers + import wechat_decrypt_tool.avatar_cache as avatar_cache + import wechat_decrypt_tool.routers.chat_media as chat_media + + importlib.reload(app_paths) + importlib.reload(chat_helpers) + importlib.reload(avatar_cache) + importlib.reload(chat_media) + + app = FastAPI() + app.include_router(chat_media.router) + client = TestClient(app) + + resp = client.get("/api/chat/avatar", params={"account": account, "username": username}) + self.assertEqual(resp.status_code, 200) + self.assertTrue(resp.headers.get("content-type", "").startswith("image/")) + + cache_db = root / "output" / "avatar_cache" / account / "avatar_cache.db" + self.assertTrue(cache_db.exists()) + + conn = sqlite3.connect(str(cache_db)) + try: + row = conn.execute( + "SELECT cache_key, source_kind, username, rel_path, media_type FROM avatar_cache_entries WHERE source_kind = 'user' LIMIT 1" + ).fetchone() + self.assertIsNotNone(row) + rel_path = str(row[3] or "") + finally: + conn.close() + + self.assertTrue(rel_path) + cache_file = (root / "output" / "avatar_cache" / account / rel_path).resolve() + self.assertTrue(cache_file.exists()) + + resp2 = client.get("/api/chat/avatar", params={"account": account, "username": username}) + self.assertEqual(resp2.status_code, 200) + self.assertEqual(resp2.content, resp.content) + finally: + if prev_data is None: + os.environ.pop("WECHAT_TOOL_DATA_DIR", None) + else: + os.environ["WECHAT_TOOL_DATA_DIR"] = prev_data + if prev_cache is None: + os.environ.pop("WECHAT_TOOL_AVATAR_CACHE_ENABLED", None) + else: + os.environ["WECHAT_TOOL_AVATAR_CACHE_ENABLED"] = prev_cache + + +if __name__ == "__main__": + unittest.main() +