feat(chat-media): 新增头像缓存并统一头像出口

This commit is contained in:
2977094657
2026-02-09 00:14:32 +08:00
parent 993593ca7f
commit 36d9af2b28
7 changed files with 1127 additions and 86 deletions

View File

@@ -0,0 +1,454 @@
from __future__ import annotations
import hashlib
import os
import re
import sqlite3
import time
from email.utils import formatdate
from pathlib import Path
from typing import Any, Optional
from urllib.parse import urlsplit, urlunsplit
from .app_paths import get_output_dir
from .logging_config import get_logger
logger = get_logger(__name__)
AVATAR_CACHE_TTL_SECONDS = 7 * 24 * 60 * 60
def is_avatar_cache_enabled() -> bool:
v = str(os.environ.get("WECHAT_TOOL_AVATAR_CACHE_ENABLED", "1") or "").strip().lower()
return v not in {"", "0", "false", "off", "no"}
def get_avatar_cache_root_dir() -> Path:
return get_output_dir() / "avatar_cache"
def _safe_segment(value: str) -> str:
cleaned = re.sub(r"[^0-9A-Za-z._-]+", "_", str(value or "").strip())
cleaned = cleaned.strip("._-")
return cleaned or "default"
def _account_layout(account: str) -> tuple[Path, Path, Path, Path]:
account_dir = get_avatar_cache_root_dir() / _safe_segment(account)
files_dir = account_dir / "files"
tmp_dir = account_dir / "tmp"
db_path = account_dir / "avatar_cache.db"
return account_dir, files_dir, tmp_dir, db_path
def _ensure_account_layout(account: str) -> tuple[Path, Path, Path, Path]:
account_dir, files_dir, tmp_dir, db_path = _account_layout(account)
account_dir.mkdir(parents=True, exist_ok=True)
files_dir.mkdir(parents=True, exist_ok=True)
tmp_dir.mkdir(parents=True, exist_ok=True)
return account_dir, files_dir, tmp_dir, db_path
def _connect(account: str) -> sqlite3.Connection:
_, _, _, db_path = _ensure_account_layout(account)
conn = sqlite3.connect(str(db_path), timeout=5)
conn.row_factory = sqlite3.Row
_ensure_schema(conn)
return conn
def _ensure_schema(conn: sqlite3.Connection) -> None:
conn.execute(
"""
CREATE TABLE IF NOT EXISTS avatar_cache_entries (
account TEXT NOT NULL,
cache_key TEXT NOT NULL,
source_kind TEXT NOT NULL,
username TEXT NOT NULL DEFAULT '',
source_url TEXT NOT NULL DEFAULT '',
source_md5 TEXT NOT NULL DEFAULT '',
source_update_time INTEGER NOT NULL DEFAULT 0,
rel_path TEXT NOT NULL DEFAULT '',
media_type TEXT NOT NULL DEFAULT 'application/octet-stream',
size_bytes INTEGER NOT NULL DEFAULT 0,
etag TEXT NOT NULL DEFAULT '',
last_modified TEXT NOT NULL DEFAULT '',
fetched_at INTEGER NOT NULL DEFAULT 0,
checked_at INTEGER NOT NULL DEFAULT 0,
expires_at INTEGER NOT NULL DEFAULT 0,
PRIMARY KEY (account, cache_key)
)
"""
)
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_avatar_cache_entries_account_username ON avatar_cache_entries(account, username)"
)
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_avatar_cache_entries_account_source ON avatar_cache_entries(account, source_kind, source_url)"
)
conn.commit()
def _row_to_dict(row: Optional[sqlite3.Row]) -> Optional[dict[str, Any]]:
if row is None:
return None
out: dict[str, Any] = {}
for k in row.keys():
out[str(k)] = row[k]
return out
def normalize_avatar_source_url(url: str) -> str:
raw = str(url or "").strip()
if not raw:
return ""
try:
p = urlsplit(raw)
except Exception:
return raw
scheme = str(p.scheme or "").lower()
host = str(p.hostname or "").lower()
if not scheme or not host:
return raw
netloc = host
if p.port:
netloc = f"{host}:{int(p.port)}"
path = p.path or "/"
return urlunsplit((scheme, netloc, path, p.query or "", ""))
def cache_key_for_avatar_user(username: str) -> str:
u = str(username or "").strip()
return hashlib.sha1(f"user:{u}".encode("utf-8", errors="ignore")).hexdigest()
def cache_key_for_avatar_url(url: str) -> str:
u = normalize_avatar_source_url(url)
return hashlib.sha1(f"url:{u}".encode("utf-8", errors="ignore")).hexdigest()
def get_avatar_cache_entry(account: str, cache_key: str) -> Optional[dict[str, Any]]:
if (not is_avatar_cache_enabled()) or (not cache_key):
return None
try:
conn = _connect(account)
except Exception:
return None
try:
row = conn.execute(
"SELECT * FROM avatar_cache_entries WHERE account = ? AND cache_key = ? LIMIT 1",
(str(account or ""), str(cache_key or "")),
).fetchone()
return _row_to_dict(row)
except Exception:
return None
finally:
try:
conn.close()
except Exception:
pass
def get_avatar_cache_user_entry(account: str, username: str) -> Optional[dict[str, Any]]:
if not username:
return None
return get_avatar_cache_entry(account, cache_key_for_avatar_user(username))
def get_avatar_cache_url_entry(account: str, source_url: str) -> Optional[dict[str, Any]]:
if not source_url:
return None
return get_avatar_cache_entry(account, cache_key_for_avatar_url(source_url))
def resolve_avatar_cache_entry_path(account: str, entry: Optional[dict[str, Any]]) -> Optional[Path]:
if not entry:
return None
rel = str(entry.get("rel_path") or "").strip().replace("\\", "/")
if not rel:
return None
account_dir, _, _, _ = _account_layout(account)
p = account_dir / rel
try:
account_dir_resolved = account_dir.resolve()
p_resolved = p.resolve()
if p_resolved != account_dir_resolved and account_dir_resolved not in p_resolved.parents:
return None
return p_resolved
except Exception:
return p
def avatar_cache_entry_file_exists(account: str, entry: Optional[dict[str, Any]]) -> Optional[Path]:
p = resolve_avatar_cache_entry_path(account, entry)
if not p:
return None
try:
if p.exists() and p.is_file():
return p
except Exception:
return None
return None
def avatar_cache_entry_is_fresh(entry: Optional[dict[str, Any]], now_ts: Optional[int] = None) -> bool:
if not entry:
return False
try:
expires = int(entry.get("expires_at") or 0)
except Exception:
expires = 0
if expires <= 0:
return False
now0 = int(now_ts or time.time())
return expires > now0
def _guess_ext(media_type: str) -> str:
mt = str(media_type or "").strip().lower()
if mt == "image/jpeg":
return "jpg"
if mt == "image/png":
return "png"
if mt == "image/gif":
return "gif"
if mt == "image/webp":
return "webp"
if mt == "image/bmp":
return "bmp"
if mt == "image/svg+xml":
return "svg"
if mt == "image/avif":
return "avif"
if mt.startswith("image/"):
return mt.split("/", 1)[1].split("+", 1)[0].split(";", 1)[0] or "img"
return "dat"
def _http_date_from_ts(ts: Optional[int]) -> str:
try:
t = int(ts or 0)
except Exception:
t = 0
if t <= 0:
return ""
try:
return formatdate(timeval=float(t), usegmt=True)
except Exception:
return ""
def upsert_avatar_cache_entry(
account: str,
*,
cache_key: str,
source_kind: str,
username: str = "",
source_url: str = "",
source_md5: str = "",
source_update_time: int = 0,
rel_path: str = "",
media_type: str = "application/octet-stream",
size_bytes: int = 0,
etag: str = "",
last_modified: str = "",
fetched_at: Optional[int] = None,
checked_at: Optional[int] = None,
expires_at: Optional[int] = None,
) -> Optional[dict[str, Any]]:
if (not is_avatar_cache_enabled()) or (not cache_key):
return None
acct = str(account or "").strip()
ck = str(cache_key or "").strip()
sk = str(source_kind or "").strip().lower()
if not acct or not ck or not sk:
return None
source_url_norm = normalize_avatar_source_url(source_url) if source_url else ""
now_ts = int(time.time())
fetched = int(fetched_at if fetched_at is not None else now_ts)
checked = int(checked_at if checked_at is not None else now_ts)
expire_ts = int(expires_at if expires_at is not None else (checked + AVATAR_CACHE_TTL_SECONDS))
try:
conn = _connect(acct)
except Exception as e:
logger.warning(f"[avatar_cache_error] open db failed account={acct} err={e}")
return None
try:
conn.execute(
"""
INSERT INTO avatar_cache_entries (
account, cache_key, source_kind, username, source_url,
source_md5, source_update_time, rel_path, media_type, size_bytes,
etag, last_modified, fetched_at, checked_at, expires_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(account, cache_key) DO UPDATE SET
source_kind=excluded.source_kind,
username=excluded.username,
source_url=excluded.source_url,
source_md5=excluded.source_md5,
source_update_time=excluded.source_update_time,
rel_path=excluded.rel_path,
media_type=excluded.media_type,
size_bytes=excluded.size_bytes,
etag=excluded.etag,
last_modified=excluded.last_modified,
fetched_at=excluded.fetched_at,
checked_at=excluded.checked_at,
expires_at=excluded.expires_at
""",
(
acct,
ck,
sk,
str(username or "").strip(),
source_url_norm,
str(source_md5 or "").strip().lower(),
int(source_update_time or 0),
str(rel_path or "").strip().replace("\\", "/"),
str(media_type or "application/octet-stream").strip() or "application/octet-stream",
int(size_bytes or 0),
str(etag or "").strip(),
str(last_modified or "").strip(),
fetched,
checked,
expire_ts,
),
)
conn.commit()
row = conn.execute(
"SELECT * FROM avatar_cache_entries WHERE account = ? AND cache_key = ? LIMIT 1",
(acct, ck),
).fetchone()
return _row_to_dict(row)
except Exception as e:
logger.warning(f"[avatar_cache_error] upsert failed account={acct} cache_key={ck} err={e}")
return None
finally:
try:
conn.close()
except Exception:
pass
def touch_avatar_cache_entry(account: str, cache_key: str, *, ttl_seconds: int = AVATAR_CACHE_TTL_SECONDS) -> bool:
if (not is_avatar_cache_enabled()) or (not cache_key):
return False
now_ts = int(time.time())
try:
conn = _connect(account)
except Exception:
return False
try:
conn.execute(
"UPDATE avatar_cache_entries SET checked_at = ?, expires_at = ? WHERE account = ? AND cache_key = ?",
(now_ts, now_ts + max(60, int(ttl_seconds or AVATAR_CACHE_TTL_SECONDS)), str(account or ""), str(cache_key or "")),
)
conn.commit()
return True
except Exception:
return False
finally:
try:
conn.close()
except Exception:
pass
def write_avatar_cache_payload(
account: str,
*,
source_kind: str,
username: str = "",
source_url: str = "",
payload: bytes,
media_type: str,
source_md5: str = "",
source_update_time: int = 0,
etag: str = "",
last_modified: str = "",
ttl_seconds: int = AVATAR_CACHE_TTL_SECONDS,
) -> tuple[Optional[dict[str, Any]], Optional[Path]]:
if (not is_avatar_cache_enabled()) or (not payload):
return None, None
acct = str(account or "").strip()
sk = str(source_kind or "").strip().lower()
if not acct or sk not in {"user", "url"}:
return None, None
source_url_norm = normalize_avatar_source_url(source_url) if source_url else ""
if sk == "user":
cache_key = cache_key_for_avatar_user(username)
else:
cache_key = cache_key_for_avatar_url(source_url_norm)
digest = hashlib.sha1(bytes(payload)).hexdigest()
ext = _guess_ext(media_type)
rel_path = f"files/{digest[:2]}/{digest}.{ext}"
try:
account_dir, _, tmp_dir, _ = _ensure_account_layout(acct)
except Exception as e:
logger.warning(f"[avatar_cache_error] ensure dirs failed account={acct} err={e}")
return None, None
abs_path = account_dir / rel_path
try:
abs_path.parent.mkdir(parents=True, exist_ok=True)
if (not abs_path.exists()) or (int(abs_path.stat().st_size) != len(payload)):
tmp_path = tmp_dir / f"{digest}.{time.time_ns()}.tmp"
tmp_path.write_bytes(payload)
os.replace(str(tmp_path), str(abs_path))
except Exception as e:
logger.warning(f"[avatar_cache_error] write file failed account={acct} path={abs_path} err={e}")
return None, None
if (not etag) and digest:
etag = f'"{digest}"'
if (not last_modified) and source_update_time:
last_modified = _http_date_from_ts(source_update_time)
if not last_modified:
last_modified = _http_date_from_ts(int(time.time()))
entry = upsert_avatar_cache_entry(
acct,
cache_key=cache_key,
source_kind=sk,
username=username,
source_url=source_url_norm,
source_md5=source_md5,
source_update_time=int(source_update_time or 0),
rel_path=rel_path,
media_type=media_type,
size_bytes=len(payload),
etag=etag,
last_modified=last_modified,
fetched_at=int(time.time()),
checked_at=int(time.time()),
expires_at=int(time.time()) + max(60, int(ttl_seconds or AVATAR_CACHE_TTL_SECONDS)),
)
if not entry:
return None, None
return entry, abs_path
def build_avatar_cache_response_headers(
entry: Optional[dict[str, Any]], *, max_age: int = AVATAR_CACHE_TTL_SECONDS
) -> dict[str, str]:
headers: dict[str, str] = {
"Cache-Control": f"public, max-age={int(max_age)}",
}
if not entry:
return headers
etag = str(entry.get("etag") or "").strip()
last_modified = str(entry.get("last_modified") or "").strip()
if etag:
headers["ETag"] = etag
if last_modified:
headers["Last-Modified"] = last_modified
return headers

View File

@@ -45,7 +45,6 @@ from ..chat_helpers import (
_normalize_xml_url, _normalize_xml_url,
_parse_app_message, _parse_app_message,
_parse_pat_message, _parse_pat_message,
_pick_avatar_url,
_pick_display_name, _pick_display_name,
_query_head_image_usernames, _query_head_image_usernames,
_quote_ident, _quote_ident,
@@ -85,6 +84,19 @@ _REALTIME_SYNC_LOCKS: dict[tuple[str, str], threading.Lock] = {}
_REALTIME_SYNC_ALL_LOCKS: dict[str, threading.Lock] = {} _REALTIME_SYNC_ALL_LOCKS: dict[str, threading.Lock] = {}
def _avatar_url_unified(
*,
account_dir: Path,
username: str,
local_avatar_usernames: set[str] | None = None,
) -> str:
u = str(username or "").strip()
if not u:
return ""
# Unified avatar entrypoint: backend decides local db vs remote fallback + cache.
return _build_avatar_url(str(account_dir.name or ""), u)
def _realtime_sync_lock(account: str, username: str) -> threading.Lock: def _realtime_sync_lock(account: str, username: str) -> threading.Lock:
key = (str(account or "").strip(), str(username or "").strip()) key = (str(account or "").strip(), str(username or "").strip())
with _REALTIME_SYNC_MU: with _REALTIME_SYNC_MU:
@@ -1946,9 +1958,11 @@ async def chat_search_index_senders(
continue continue
cnt = int(r["c"] or 0) cnt = int(r["c"] or 0)
row = contact_rows.get(su) row = contact_rows.get(su)
avatar_url = _pick_avatar_url(row) avatar_url = _avatar_url_unified(
if (not avatar_url) and (su in local_sender_avatars): account_dir=account_dir,
avatar_url = _build_avatar_url(account_dir.name, su) username=su,
local_avatar_usernames=local_sender_avatars,
)
senders.append( senders.append(
{ {
"username": su, "username": su,
@@ -2568,7 +2582,7 @@ def _postprocess_full_messages(
row = sender_contact_rows.get(u) row = sender_contact_rows.get(u)
if _pick_display_name(row, u) == u: if _pick_display_name(row, u) == u:
need_display.append(u) need_display.append(u)
if (not _pick_avatar_url(row)) and (u not in local_sender_avatars): if u not in local_sender_avatars:
need_avatar.append(u) need_avatar.append(u)
need_display = list(dict.fromkeys(need_display)) need_display = list(dict.fromkeys(need_display))
@@ -2606,13 +2620,11 @@ def _postprocess_full_messages(
if wd and wd != su: if wd and wd != su:
display_name = wd display_name = wd
m["senderDisplayName"] = display_name m["senderDisplayName"] = display_name
avatar_url = _pick_avatar_url(row) avatar_url = base_url + _avatar_url_unified(
if not avatar_url and su in local_sender_avatars: account_dir=account_dir,
avatar_url = base_url + _build_avatar_url(account_dir.name, su) username=su,
if not avatar_url: local_avatar_usernames=local_sender_avatars,
wa = str(wcdb_avatar_urls.get(su) or "").strip() )
if wa.lower().startswith(("http://", "https://")):
avatar_url = wa
m["senderAvatar"] = avatar_url m["senderAvatar"] = avatar_url
qu = str(m.get("quoteUsername") or "").strip() qu = str(m.get("quoteUsername") or "").strip()
@@ -2922,7 +2934,7 @@ def list_chat_sessions(
if u not in local_avatar_usernames: if u not in local_avatar_usernames:
need_avatar.append(u) need_avatar.append(u)
else: else:
if (not _pick_avatar_url(row)) and (u not in local_avatar_usernames): if u not in local_avatar_usernames:
need_avatar.append(u) need_avatar.append(u)
need_display = list(dict.fromkeys(need_display)) need_display = list(dict.fromkeys(need_display))
@@ -2984,15 +2996,11 @@ def list_chat_sessions(
# Prefer local head_image avatars when available: decrypted contact.db URLs can be stale # Prefer local head_image avatars when available: decrypted contact.db URLs can be stale
# (or hotlink-protected for browsers). WCDB realtime (when available) is the next best. # (or hotlink-protected for browsers). WCDB realtime (when available) is the next best.
avatar_url = "" avatar_url = base_url + _avatar_url_unified(
if username in local_avatar_usernames: account_dir=account_dir,
avatar_url = base_url + _build_avatar_url(account_dir.name, username) username=username,
if not avatar_url: local_avatar_usernames=local_avatar_usernames,
wa = str(wcdb_avatar_urls.get(username) or "").strip() )
if wa.lower().startswith(("http://", "https://")):
avatar_url = wa
if not avatar_url:
avatar_url = _pick_avatar_url(c_row) or ""
last_message = "" last_message = ""
if preview_mode == "session": if preview_mode == "session":
@@ -4388,7 +4396,7 @@ def list_chat_messages(
row = sender_contact_rows.get(u) row = sender_contact_rows.get(u)
if _pick_display_name(row, u) == u: if _pick_display_name(row, u) == u:
need_display.append(u) need_display.append(u)
if (not _pick_avatar_url(row)) and (u not in local_sender_avatars): if u not in local_sender_avatars:
need_avatar.append(u) need_avatar.append(u)
need_display = list(dict.fromkeys(need_display)) need_display = list(dict.fromkeys(need_display))
@@ -4426,13 +4434,11 @@ def list_chat_messages(
if wd and wd != su: if wd and wd != su:
display_name = wd display_name = wd
m["senderDisplayName"] = display_name m["senderDisplayName"] = display_name
avatar_url = _pick_avatar_url(row) avatar_url = base_url + _avatar_url_unified(
if not avatar_url and su in local_sender_avatars: account_dir=account_dir,
avatar_url = base_url + _build_avatar_url(account_dir.name, su) username=su,
if not avatar_url: local_avatar_usernames=local_sender_avatars,
wa = str(wcdb_avatar_urls.get(su) or "").strip() )
if wa.lower().startswith(("http://", "https://")):
avatar_url = wa
m["senderAvatar"] = avatar_url m["senderAvatar"] = avatar_url
qu = str(m.get("quoteUsername") or "").strip() qu = str(m.get("quoteUsername") or "").strip()
@@ -4897,7 +4903,7 @@ async def _search_chat_messages_via_fts(
row = contact_rows.get(uu) row = contact_rows.get(uu)
if _pick_display_name(row, uu) == uu: if _pick_display_name(row, uu) == uu:
need_display.append(uu) need_display.append(uu)
if (not _pick_avatar_url(row)) and (uu not in local_avatar_usernames): if uu not in local_avatar_usernames:
need_avatar.append(uu) need_avatar.append(uu)
need_display = list(dict.fromkeys(need_display)) need_display = list(dict.fromkeys(need_display))
@@ -4919,13 +4925,11 @@ async def _search_chat_messages_via_fts(
wd = str(wcdb_display_names.get(username) or "").strip() wd = str(wcdb_display_names.get(username) or "").strip()
if wd and wd != username: if wd and wd != username:
conv_name = wd conv_name = wd
conv_avatar = _pick_avatar_url(conv_row) conv_avatar = base_url + _avatar_url_unified(
if (not conv_avatar) and (username in local_avatar_usernames): account_dir=account_dir,
conv_avatar = base_url + _build_avatar_url(account_dir.name, username) username=username,
if not conv_avatar: local_avatar_usernames=local_avatar_usernames,
wa = str(wcdb_avatar_urls.get(username) or "").strip() )
if wa.lower().startswith(("http://", "https://")):
conv_avatar = wa
for h in hits: for h in hits:
su = str(h.get("senderUsername") or "").strip() su = str(h.get("senderUsername") or "").strip()
@@ -4939,13 +4943,11 @@ async def _search_chat_messages_via_fts(
if wd and wd != su: if wd and wd != su:
display_name = wd display_name = wd
h["senderDisplayName"] = display_name h["senderDisplayName"] = display_name
avatar_url = _pick_avatar_url(row) avatar_url = base_url + _avatar_url_unified(
if (not avatar_url) and (su in local_avatar_usernames): account_dir=account_dir,
avatar_url = base_url + _build_avatar_url(account_dir.name, su) username=su,
if not avatar_url: local_avatar_usernames=local_avatar_usernames,
wa = str(wcdb_avatar_urls.get(su) or "").strip() )
if wa.lower().startswith(("http://", "https://")):
avatar_url = wa
h["senderAvatar"] = avatar_url h["senderAvatar"] = avatar_url
else: else:
uniq_contacts = list( uniq_contacts = list(
@@ -4968,7 +4970,7 @@ async def _search_chat_messages_via_fts(
row = contact_rows.get(uu) row = contact_rows.get(uu)
if _pick_display_name(row, uu) == uu: if _pick_display_name(row, uu) == uu:
need_display.append(uu) need_display.append(uu)
if (not _pick_avatar_url(row)) and (uu not in local_avatar_usernames): if uu not in local_avatar_usernames:
need_avatar.append(uu) need_avatar.append(uu)
need_display = list(dict.fromkeys(need_display)) need_display = list(dict.fromkeys(need_display))
@@ -4994,13 +4996,11 @@ async def _search_chat_messages_via_fts(
if wd and wd != cu: if wd and wd != cu:
conv_name = wd conv_name = wd
h["conversationName"] = conv_name or cu h["conversationName"] = conv_name or cu
conv_avatar = _pick_avatar_url(crow) conv_avatar = base_url + _avatar_url_unified(
if (not conv_avatar) and cu and (cu in local_avatar_usernames): account_dir=account_dir,
conv_avatar = base_url + _build_avatar_url(account_dir.name, cu) username=cu,
if not conv_avatar and cu: local_avatar_usernames=local_avatar_usernames,
wa = str(wcdb_avatar_urls.get(cu) or "").strip() )
if wa.lower().startswith(("http://", "https://")):
conv_avatar = wa
h["conversationAvatar"] = conv_avatar h["conversationAvatar"] = conv_avatar
if su: if su:
row = contact_rows.get(su) row = contact_rows.get(su)
@@ -5010,13 +5010,11 @@ async def _search_chat_messages_via_fts(
if wd and wd != su: if wd and wd != su:
display_name = wd display_name = wd
h["senderDisplayName"] = display_name h["senderDisplayName"] = display_name
avatar_url = _pick_avatar_url(row) avatar_url = base_url + _avatar_url_unified(
if (not avatar_url) and (su in local_avatar_usernames): account_dir=account_dir,
avatar_url = base_url + _build_avatar_url(account_dir.name, su) username=su,
if not avatar_url: local_avatar_usernames=local_avatar_usernames,
wa = str(wcdb_avatar_urls.get(su) or "").strip() )
if wa.lower().startswith(("http://", "https://")):
avatar_url = wa
h["senderAvatar"] = avatar_url h["senderAvatar"] = avatar_url
return { return {

View File

@@ -8,7 +8,7 @@ import os
import sqlite3 import sqlite3
import subprocess import subprocess
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Any, Optional
from urllib.parse import urlparse from urllib.parse import urlparse
import requests import requests
@@ -16,6 +16,21 @@ from fastapi import APIRouter, HTTPException
from fastapi.responses import FileResponse, Response from fastapi.responses import FileResponse, Response
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ..avatar_cache import (
AVATAR_CACHE_TTL_SECONDS,
avatar_cache_entry_file_exists,
avatar_cache_entry_is_fresh,
build_avatar_cache_response_headers,
cache_key_for_avatar_user,
cache_key_for_avatar_url,
get_avatar_cache_url_entry,
get_avatar_cache_user_entry,
is_avatar_cache_enabled,
normalize_avatar_source_url,
touch_avatar_cache_entry,
upsert_avatar_cache_entry,
write_avatar_cache_payload,
)
from ..logging_config import get_logger from ..logging_config import get_logger
from ..media_helpers import ( from ..media_helpers import (
_convert_silk_to_wav, _convert_silk_to_wav,
@@ -43,14 +58,56 @@ from ..media_helpers import (
_try_find_decrypted_resource, _try_find_decrypted_resource,
_try_strip_media_prefix, _try_strip_media_prefix,
) )
from ..chat_helpers import _extract_md5_from_packed_info from ..chat_helpers import _extract_md5_from_packed_info, _load_contact_rows, _pick_avatar_url
from ..path_fix import PathFixRoute from ..path_fix import PathFixRoute
from ..wcdb_realtime import WCDB_REALTIME, get_avatar_urls as _wcdb_get_avatar_urls
logger = get_logger(__name__) logger = get_logger(__name__)
router = APIRouter(route_class=PathFixRoute) router = APIRouter(route_class=PathFixRoute)
def _resolve_avatar_remote_url(*, account_dir: Path, username: str) -> str:
u = str(username or "").strip()
if not u:
return ""
# 1) contact.db first (cheap local lookup)
try:
rows = _load_contact_rows(account_dir / "contact.db", [u])
row = rows.get(u)
raw = str(_pick_avatar_url(row) or "").strip()
if raw.lower().startswith(("http://", "https://")):
return normalize_avatar_source_url(raw)
except Exception:
pass
# 2) WCDB fallback (more complete on enterprise/openim IDs)
try:
wcdb_conn = WCDB_REALTIME.ensure_connected(account_dir)
with wcdb_conn.lock:
mp = _wcdb_get_avatar_urls(wcdb_conn.handle, [u])
wa = str(mp.get(u) or "").strip()
if wa.lower().startswith(("http://", "https://")):
return normalize_avatar_source_url(wa)
except Exception:
pass
return ""
def _parse_304_headers(headers: Any) -> tuple[str, str]:
try:
etag = str((headers or {}).get("ETag") or "").strip()
except Exception:
etag = ""
try:
last_modified = str((headers or {}).get("Last-Modified") or "").strip()
except Exception:
last_modified = ""
return etag, last_modified
@lru_cache(maxsize=4096) @lru_cache(maxsize=4096)
def _fast_probe_image_path_in_chat_attach( def _fast_probe_image_path_in_chat_attach(
*, *,
@@ -267,27 +324,309 @@ async def get_chat_avatar(username: str, account: Optional[str] = None):
if not username: if not username:
raise HTTPException(status_code=400, detail="Missing username.") raise HTTPException(status_code=400, detail="Missing username.")
account_dir = _resolve_account_dir(account) account_dir = _resolve_account_dir(account)
account_name = str(account_dir.name or "").strip()
user_key = str(username or "").strip()
# 1) Try on-disk cache first (fast path)
user_entry = None
cached_file = None
if is_avatar_cache_enabled() and account_name and user_key:
try:
user_entry = get_avatar_cache_user_entry(account_name, user_key)
cached_file = avatar_cache_entry_file_exists(account_name, user_entry)
if cached_file is not None:
logger.info(f"[avatar_cache_hit] kind=user account={account_name} username={user_key}")
except Exception as e:
logger.warning(f"[avatar_cache_error] read user cache failed account={account_name} username={user_key} err={e}")
head_image_db_path = account_dir / "head_image.db" head_image_db_path = account_dir / "head_image.db"
if not head_image_db_path.exists(): if not head_image_db_path.exists():
# No local head_image.db: allow fallback from cached/remote URL path.
if cached_file is not None and user_entry:
headers = build_avatar_cache_response_headers(user_entry)
return FileResponse(
str(cached_file),
media_type=str(user_entry.get("media_type") or "application/octet-stream"),
headers=headers,
)
raise HTTPException(status_code=404, detail="head_image.db not found.") raise HTTPException(status_code=404, detail="head_image.db not found.")
conn = sqlite3.connect(str(head_image_db_path)) conn = sqlite3.connect(str(head_image_db_path))
try: try:
row = conn.execute( meta = conn.execute(
"SELECT image_buffer FROM head_image WHERE username = ? ORDER BY update_time DESC LIMIT 1", "SELECT md5, update_time FROM head_image WHERE username = ? ORDER BY update_time DESC LIMIT 1",
(username,), (username,),
).fetchone() ).fetchone()
if meta and meta[0] is not None:
db_md5 = str(meta[0] or "").strip().lower()
try:
db_update_time = int(meta[1] or 0)
except Exception:
db_update_time = 0
# Cache still valid against head_image metadata.
if cached_file is not None and user_entry:
cached_md5 = str(user_entry.get("source_md5") or "").strip().lower()
try:
cached_update = int(user_entry.get("source_update_time") or 0)
except Exception:
cached_update = 0
if cached_md5 == db_md5 and cached_update == db_update_time:
touch_avatar_cache_entry(account_name, str(user_entry.get("cache_key") or ""))
headers = build_avatar_cache_response_headers(user_entry)
return FileResponse(
str(cached_file),
media_type=str(user_entry.get("media_type") or "application/octet-stream"),
headers=headers,
)
# Refresh from blob (changed or first-load)
row = conn.execute(
"SELECT image_buffer FROM head_image WHERE username = ? ORDER BY update_time DESC LIMIT 1",
(username,),
).fetchone()
if row and row[0] is not None:
data = bytes(row[0]) if isinstance(row[0], (memoryview, bytearray)) else row[0]
if not isinstance(data, (bytes, bytearray)):
data = bytes(data)
if data:
media_type = _detect_image_media_type(data)
media_type = media_type if media_type.startswith("image/") else "application/octet-stream"
entry, out_path = write_avatar_cache_payload(
account_name,
source_kind="user",
username=user_key,
payload=bytes(data),
media_type=media_type,
source_md5=db_md5,
source_update_time=db_update_time,
ttl_seconds=AVATAR_CACHE_TTL_SECONDS,
)
if entry and out_path:
logger.info(
f"[avatar_cache_download] kind=user account={account_name} username={user_key} src=head_image"
)
headers = build_avatar_cache_response_headers(entry)
return FileResponse(str(out_path), media_type=media_type, headers=headers)
# cache write failed: fallback to response bytes
logger.warning(
f"[avatar_cache_error] kind=user account={account_name} username={user_key} action=write_fallback"
)
return Response(content=bytes(data), media_type=media_type)
# meta not found (no local avatar blob)
row = None
finally: finally:
conn.close() conn.close()
if not row or row[0] is None: # 2) Fallback: remote avatar URL (contact/WCDB), cache by URL.
raise HTTPException(status_code=404, detail="Avatar not found.") remote_url = _resolve_avatar_remote_url(account_dir=account_dir, username=user_key)
if remote_url and is_avatar_cache_enabled():
url_entry = get_avatar_cache_url_entry(account_name, remote_url)
url_file = avatar_cache_entry_file_exists(account_name, url_entry)
if url_entry and url_file and avatar_cache_entry_is_fresh(url_entry):
logger.info(f"[avatar_cache_hit] kind=url account={account_name} username={user_key}")
touch_avatar_cache_entry(account_name, str(url_entry.get("cache_key") or ""))
# Keep user-key mapping aligned, so next user lookup is direct.
try:
upsert_avatar_cache_entry(
account_name,
cache_key=cache_key_for_avatar_user(user_key),
source_kind="user",
username=user_key,
source_url=remote_url,
source_md5=str(url_entry.get("source_md5") or ""),
source_update_time=int(url_entry.get("source_update_time") or 0),
rel_path=str(url_entry.get("rel_path") or ""),
media_type=str(url_entry.get("media_type") or "application/octet-stream"),
size_bytes=int(url_entry.get("size_bytes") or 0),
etag=str(url_entry.get("etag") or ""),
last_modified=str(url_entry.get("last_modified") or ""),
fetched_at=int(url_entry.get("fetched_at") or 0),
checked_at=int(url_entry.get("checked_at") or 0),
expires_at=int(url_entry.get("expires_at") or 0),
)
except Exception:
pass
headers = build_avatar_cache_response_headers(url_entry)
return FileResponse(
str(url_file),
media_type=str(url_entry.get("media_type") or "application/octet-stream"),
headers=headers,
)
data = bytes(row[0]) if isinstance(row[0], (memoryview, bytearray)) else row[0] # Revalidate / download remote avatar
if not isinstance(data, (bytes, bytearray)): def _download_remote_avatar(
data = bytes(data) source_url: str,
media_type = _detect_image_media_type(data) *,
return Response(content=data, media_type=media_type) etag: str,
last_modified: str,
) -> tuple[bytes, str, str, str, bool]:
base_headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120 Safari/537.36",
"Accept": "image/avif,image/webp,image/apng,image/*,*/*;q=0.8",
}
header_variants = [
{
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/107.0.0.0 Safari/537.36 MicroMessenger/7.0.20.1781(0x6700143B) WindowsWechat(0x63090719) XWEB/8351",
"Accept": "image/avif,image/webp,image/apng,image/svg+xml,image/*,*/*;q=0.8",
"Accept-Language": "zh-CN,zh;q=0.9",
"Referer": "https://servicewechat.com/",
"Origin": "https://servicewechat.com",
"Range": "bytes=0-",
},
{"Referer": "https://wx.qq.com/", "Origin": "https://wx.qq.com"},
{"Referer": "https://mp.weixin.qq.com/", "Origin": "https://mp.weixin.qq.com"},
{"Referer": "https://www.baidu.com/", "Origin": "https://www.baidu.com"},
{},
]
last_err: Exception | None = None
for extra in header_variants:
headers = dict(base_headers)
headers.update(extra)
if etag:
headers["If-None-Match"] = etag
if last_modified:
headers["If-Modified-Since"] = last_modified
r = requests.get(source_url, headers=headers, timeout=20, stream=True)
try:
if r.status_code == 304:
e2, lm2 = _parse_304_headers(r.headers)
return b"", "", (e2 or etag), (lm2 or last_modified), True
r.raise_for_status()
content_type = str(r.headers.get("Content-Type") or "").strip()
e2, lm2 = _parse_304_headers(r.headers)
max_bytes = 10 * 1024 * 1024
chunks: list[bytes] = []
total = 0
for ch in r.iter_content(chunk_size=64 * 1024):
if not ch:
continue
chunks.append(ch)
total += len(ch)
if total > max_bytes:
raise HTTPException(status_code=400, detail="Avatar too large (>10MB).")
return b"".join(chunks), content_type, e2, lm2, False
except HTTPException:
raise
except Exception as e:
last_err = e
finally:
try:
r.close()
except Exception:
pass
raise last_err or RuntimeError("avatar remote download failed")
etag0 = str((url_entry or {}).get("etag") or "").strip()
lm0 = str((url_entry or {}).get("last_modified") or "").strip()
try:
payload, ct, etag_new, lm_new, not_modified = await asyncio.to_thread(
_download_remote_avatar,
remote_url,
etag=etag0,
last_modified=lm0,
)
except Exception as e:
logger.warning(f"[avatar_cache_error] kind=url account={account_name} username={user_key} err={e}")
if url_entry and url_file:
headers = build_avatar_cache_response_headers(url_entry)
return FileResponse(
str(url_file),
media_type=str(url_entry.get("media_type") or "application/octet-stream"),
headers=headers,
)
raise HTTPException(status_code=404, detail="Avatar not found.")
if not_modified and url_entry and url_file:
touch_avatar_cache_entry(account_name, cache_key_for_avatar_url(remote_url))
if etag_new or lm_new:
try:
upsert_avatar_cache_entry(
account_name,
cache_key=cache_key_for_avatar_url(remote_url),
source_kind="url",
username=user_key,
source_url=remote_url,
source_md5=str(url_entry.get("source_md5") or ""),
source_update_time=int(url_entry.get("source_update_time") or 0),
rel_path=str(url_entry.get("rel_path") or ""),
media_type=str(url_entry.get("media_type") or "application/octet-stream"),
size_bytes=int(url_entry.get("size_bytes") or 0),
etag=etag_new or etag0,
last_modified=lm_new or lm0,
)
except Exception:
pass
logger.info(f"[avatar_cache_revalidate] kind=url account={account_name} username={user_key} status=304")
headers = build_avatar_cache_response_headers(url_entry)
return FileResponse(
str(url_file),
media_type=str(url_entry.get("media_type") or "application/octet-stream"),
headers=headers,
)
if payload:
payload2, media_type, _ext = _detect_media_type_and_ext(payload)
if media_type == "application/octet-stream" and ct:
try:
mt = ct.split(";")[0].strip()
if mt.startswith("image/"):
media_type = mt
except Exception:
pass
if str(media_type or "").startswith("image/"):
entry, out_path = write_avatar_cache_payload(
account_name,
source_kind="url",
username=user_key,
source_url=remote_url,
payload=payload2,
media_type=media_type,
etag=etag_new,
last_modified=lm_new,
ttl_seconds=AVATAR_CACHE_TTL_SECONDS,
)
if entry and out_path:
# bind user-key record to same file for quicker next access
try:
upsert_avatar_cache_entry(
account_name,
cache_key=cache_key_for_avatar_user(user_key),
source_kind="user",
username=user_key,
source_url=remote_url,
source_md5=str(entry.get("source_md5") or ""),
source_update_time=int(entry.get("source_update_time") or 0),
rel_path=str(entry.get("rel_path") or ""),
media_type=str(entry.get("media_type") or "application/octet-stream"),
size_bytes=int(entry.get("size_bytes") or 0),
etag=str(entry.get("etag") or ""),
last_modified=str(entry.get("last_modified") or ""),
fetched_at=int(entry.get("fetched_at") or 0),
checked_at=int(entry.get("checked_at") or 0),
expires_at=int(entry.get("expires_at") or 0),
)
except Exception:
pass
logger.info(f"[avatar_cache_download] kind=url account={account_name} username={user_key}")
headers = build_avatar_cache_response_headers(entry)
return FileResponse(str(out_path), media_type=media_type, headers=headers)
if cached_file is not None and user_entry:
headers = build_avatar_cache_response_headers(user_entry)
return FileResponse(
str(cached_file),
media_type=str(user_entry.get("media_type") or "application/octet-stream"),
headers=headers,
)
raise HTTPException(status_code=404, detail="Avatar not found.")
class EmojiDownloadRequest(BaseModel): class EmojiDownloadRequest(BaseModel):
@@ -434,7 +773,25 @@ async def proxy_image(url: str):
if not _is_allowed_proxy_image_host(host): if not _is_allowed_proxy_image_host(host):
raise HTTPException(status_code=400, detail="Unsupported url host for proxy_image.") raise HTTPException(status_code=400, detail="Unsupported url host for proxy_image.")
def _download_bytes() -> tuple[bytes, str]: source_url = normalize_avatar_source_url(u)
proxy_account = "_proxy"
cache_entry = get_avatar_cache_url_entry(proxy_account, source_url) if is_avatar_cache_enabled() else None
cache_file = avatar_cache_entry_file_exists(proxy_account, cache_entry)
if cache_entry and cache_file and avatar_cache_entry_is_fresh(cache_entry):
logger.info(f"[avatar_cache_hit] kind=proxy_url account={proxy_account}")
touch_avatar_cache_entry(proxy_account, cache_key_for_avatar_url(source_url))
headers = build_avatar_cache_response_headers(cache_entry)
return FileResponse(
str(cache_file),
media_type=str(cache_entry.get("media_type") or "application/octet-stream"),
headers=headers,
)
def _download_bytes(
*,
if_none_match: str = "",
if_modified_since: str = "",
) -> tuple[bytes, str, str, str, bool]:
base_headers = { base_headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120 Safari/537.36", "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120 Safari/537.36",
"Accept": "image/avif,image/webp,image/apng,image/*,*/*;q=0.8", "Accept": "image/avif,image/webp,image/apng,image/*,*/*;q=0.8",
@@ -464,10 +821,20 @@ async def proxy_image(url: str):
for extra in header_variants: for extra in header_variants:
headers = dict(base_headers) headers = dict(base_headers)
headers.update(extra) headers.update(extra)
if if_none_match:
headers["If-None-Match"] = if_none_match
if if_modified_since:
headers["If-Modified-Since"] = if_modified_since
r = requests.get(u, headers=headers, timeout=20, stream=True) r = requests.get(u, headers=headers, timeout=20, stream=True)
try: try:
if r.status_code == 304:
etag0 = str(r.headers.get("ETag") or "").strip()
lm0 = str(r.headers.get("Last-Modified") or "").strip()
return b"", "", etag0, lm0, True
r.raise_for_status() r.raise_for_status()
content_type = str(r.headers.get("Content-Type") or "").strip() content_type = str(r.headers.get("Content-Type") or "").strip()
etag0 = str(r.headers.get("ETag") or "").strip()
lm0 = str(r.headers.get("Last-Modified") or "").strip()
max_bytes = 10 * 1024 * 1024 max_bytes = 10 * 1024 * 1024
chunks: list[bytes] = [] chunks: list[bytes] = []
total = 0 total = 0
@@ -478,7 +845,7 @@ async def proxy_image(url: str):
total += len(ch) total += len(ch)
if total > max_bytes: if total > max_bytes:
raise HTTPException(status_code=400, detail="Proxy image too large (>10MB).") raise HTTPException(status_code=400, detail="Proxy image too large (>10MB).")
return b"".join(chunks), content_type return b"".join(chunks), content_type, etag0, lm0, False
except HTTPException: except HTTPException:
# Hard failure, don't retry with another referer. # Hard failure, don't retry with another referer.
raise raise
@@ -493,14 +860,50 @@ async def proxy_image(url: str):
# All variants failed. # All variants failed.
raise last_err or RuntimeError("proxy_image download failed") raise last_err or RuntimeError("proxy_image download failed")
etag0 = str((cache_entry or {}).get("etag") or "").strip()
lm0 = str((cache_entry or {}).get("last_modified") or "").strip()
try: try:
data, ct = await asyncio.to_thread(_download_bytes) data, ct, etag_new, lm_new, not_modified = await asyncio.to_thread(
_download_bytes,
if_none_match=etag0,
if_modified_since=lm0,
)
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
logger.warning(f"proxy_image failed: url={u} err={e}") logger.warning(f"proxy_image failed: url={u} err={e}")
if cache_entry and cache_file:
headers = build_avatar_cache_response_headers(cache_entry)
return FileResponse(
str(cache_file),
media_type=str(cache_entry.get("media_type") or "application/octet-stream"),
headers=headers,
)
raise HTTPException(status_code=502, detail=f"Proxy image failed: {e}") raise HTTPException(status_code=502, detail=f"Proxy image failed: {e}")
if not_modified and cache_entry and cache_file:
logger.info(f"[avatar_cache_revalidate] kind=proxy_url account={proxy_account} status=304")
upsert_avatar_cache_entry(
proxy_account,
cache_key=cache_key_for_avatar_url(source_url),
source_kind="url",
source_url=source_url,
username="",
source_md5=str(cache_entry.get("source_md5") or ""),
source_update_time=int(cache_entry.get("source_update_time") or 0),
rel_path=str(cache_entry.get("rel_path") or ""),
media_type=str(cache_entry.get("media_type") or "application/octet-stream"),
size_bytes=int(cache_entry.get("size_bytes") or 0),
etag=etag_new or etag0,
last_modified=lm_new or lm0,
)
headers = build_avatar_cache_response_headers(cache_entry)
return FileResponse(
str(cache_file),
media_type=str(cache_entry.get("media_type") or "application/octet-stream"),
headers=headers,
)
if not data: if not data:
raise HTTPException(status_code=502, detail="Proxy returned empty body.") raise HTTPException(status_code=502, detail="Proxy returned empty body.")
@@ -518,8 +921,24 @@ async def proxy_image(url: str):
if not str(media_type or "").startswith("image/"): if not str(media_type or "").startswith("image/"):
raise HTTPException(status_code=502, detail="Proxy did not return an image.") raise HTTPException(status_code=502, detail="Proxy did not return an image.")
if is_avatar_cache_enabled():
entry, out_path = write_avatar_cache_payload(
proxy_account,
source_kind="url",
source_url=source_url,
payload=payload,
media_type=media_type,
etag=etag_new,
last_modified=lm_new,
ttl_seconds=AVATAR_CACHE_TTL_SECONDS,
)
if entry and out_path:
logger.info(f"[avatar_cache_download] kind=proxy_url account={proxy_account}")
headers = build_avatar_cache_response_headers(entry)
return FileResponse(str(out_path), media_type=media_type, headers=headers)
resp = Response(content=payload, media_type=media_type) resp = Response(content=payload, media_type=media_type)
resp.headers["Cache-Control"] = "public, max-age=86400" resp.headers["Cache-Control"] = f"public, max-age={AVATAR_CACHE_TTL_SECONDS}"
return resp return resp

View File

@@ -17,7 +17,6 @@ from ...chat_helpers import (
_decode_sqlite_text, _decode_sqlite_text,
_iter_message_db_paths, _iter_message_db_paths,
_load_contact_rows, _load_contact_rows,
_pick_avatar_url,
_pick_display_name, _pick_display_name,
_quote_ident, _quote_ident,
_should_keep_session, _should_keep_session,
@@ -701,7 +700,7 @@ def build_card_00_global_overview(
u, cnt = stats.top_contact u, cnt = stats.top_contact
row = contact_rows.get(u) row = contact_rows.get(u)
display = _pick_display_name(row, u) display = _pick_display_name(row, u)
avatar = _pick_avatar_url(row) or (_build_avatar_url(str(account_dir.name or ""), u) if u else "") avatar = _build_avatar_url(str(account_dir.name or ""), u) if u else ""
top_contact_obj = { top_contact_obj = {
"username": u, "username": u,
"displayName": display, "displayName": display,
@@ -716,7 +715,7 @@ def build_card_00_global_overview(
u, cnt = stats.top_group u, cnt = stats.top_group
row = contact_rows.get(u) row = contact_rows.get(u)
display = _pick_display_name(row, u) display = _pick_display_name(row, u)
avatar = _pick_avatar_url(row) or (_build_avatar_url(str(account_dir.name or ""), u) if u else "") avatar = _build_avatar_url(str(account_dir.name or ""), u) if u else ""
top_group_obj = { top_group_obj = {
"username": u, "username": u,
"displayName": display, "displayName": display,

View File

@@ -14,7 +14,6 @@ from ...chat_helpers import (
_build_avatar_url, _build_avatar_url,
_iter_message_db_paths, _iter_message_db_paths,
_load_contact_rows, _load_contact_rows,
_pick_avatar_url,
_pick_display_name, _pick_display_name,
_quote_ident, _quote_ident,
_row_to_search_hit, _row_to_search_hit,
@@ -713,7 +712,7 @@ def _fetch_message_moment_payload(
contact_row = contact_rows.get(username) contact_row = contact_rows.get(username)
display = _pick_display_name(contact_row, username) display = _pick_display_name(contact_row, username)
avatar = _pick_avatar_url(contact_row) or (_build_avatar_url(str(account_dir.name or ""), username) if username else "") avatar = _build_avatar_url(str(account_dir.name or ""), username) if username else ""
return { return {
"timestamp": int(ref.ts), "timestamp": int(ref.ts),

View File

@@ -12,7 +12,6 @@ from typing import Any, Optional
from ...chat_helpers import ( from ...chat_helpers import (
_build_avatar_url, _build_avatar_url,
_load_contact_rows, _load_contact_rows,
_pick_avatar_url,
_pick_display_name, _pick_display_name,
_should_keep_session, _should_keep_session,
) )
@@ -385,7 +384,7 @@ def compute_reply_speed_stats(*, account_dir: Path, year: int) -> dict[str, Any]
def conv_to_obj(score: float | None, agg: _ConvAgg) -> dict[str, Any]: def conv_to_obj(score: float | None, agg: _ConvAgg) -> dict[str, Any]:
row = contact_rows.get(agg.username) row = contact_rows.get(agg.username)
display = _pick_display_name(row, agg.username) display = _pick_display_name(row, agg.username)
avatar = _pick_avatar_url(row) or (_build_avatar_url(str(account_dir.name or ""), agg.username) if agg.username else "") avatar = _build_avatar_url(str(account_dir.name or ""), agg.username) if agg.username else ""
avg_s = agg.avg_gap() avg_s = agg.avg_gap()
out: dict[str, Any] = { out: dict[str, Any] = {
"username": agg.username, "username": agg.username,
@@ -420,7 +419,7 @@ def compute_reply_speed_stats(*, account_dir: Path, year: int) -> dict[str, Any]
else: else:
row = contact_rows.get(global_fastest_u) row = contact_rows.get(global_fastest_u)
display = _pick_display_name(row, global_fastest_u) display = _pick_display_name(row, global_fastest_u)
avatar = _pick_avatar_url(row) or (_build_avatar_url(str(account_dir.name or ""), global_fastest_u) if global_fastest_u else "") avatar = _build_avatar_url(str(account_dir.name or ""), global_fastest_u) if global_fastest_u else ""
fastest_obj = { fastest_obj = {
"username": global_fastest_u, "username": global_fastest_u,
"displayName": display, "displayName": display,
@@ -440,7 +439,7 @@ def compute_reply_speed_stats(*, account_dir: Path, year: int) -> dict[str, Any]
else: else:
row = contact_rows.get(global_slowest_u) row = contact_rows.get(global_slowest_u)
display = _pick_display_name(row, global_slowest_u) display = _pick_display_name(row, global_slowest_u)
avatar = _pick_avatar_url(row) or (_build_avatar_url(str(account_dir.name or ""), global_slowest_u) if global_slowest_u else "") avatar = _build_avatar_url(str(account_dir.name or ""), global_slowest_u) if global_slowest_u else ""
slowest_obj = { slowest_obj = {
"username": global_slowest_u, "username": global_slowest_u,
"displayName": display, "displayName": display,
@@ -547,7 +546,7 @@ def compute_reply_speed_stats(*, account_dir: Path, year: int) -> dict[str, Any]
row = contact_rows.get(u) row = contact_rows.get(u)
display = _pick_display_name(row, u) display = _pick_display_name(row, u)
avatar = _pick_avatar_url(row) or (_build_avatar_url(str(account_dir.name or ""), u) if u else "") avatar = _build_avatar_url(str(account_dir.name or ""), u) if u else ""
series.append( series.append(
{ {
"username": u, "username": u,
@@ -595,7 +594,7 @@ def compute_reply_speed_stats(*, account_dir: Path, year: int) -> dict[str, Any]
if not u: if not u:
continue continue
display = _pick_display_name(r, u) display = _pick_display_name(r, u)
avatar = _pick_avatar_url(r) or (_build_avatar_url(str(account_dir.name or ""), u) if u else "") avatar = _build_avatar_url(str(account_dir.name or ""), u) if u else ""
all_contacts_list.append({ all_contacts_list.append({
"username": u, "username": u,
"displayName": display, "displayName": display,

View File

@@ -0,0 +1,173 @@
import os
import sqlite3
import sys
import unittest
import importlib
from pathlib import Path
from tempfile import TemporaryDirectory
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT / "src"))
class TestAvatarCacheChatMedia(unittest.TestCase):
def _seed_contact_db(self, path: Path, *, username: str = "wxid_friend") -> None:
conn = sqlite3.connect(str(path))
try:
conn.execute(
"""
CREATE TABLE contact (
username TEXT,
remark TEXT,
nick_name TEXT,
alias TEXT,
local_type INTEGER,
verify_flag INTEGER,
big_head_url TEXT,
small_head_url TEXT
)
"""
)
conn.execute(
"""
CREATE TABLE stranger (
username TEXT,
remark TEXT,
nick_name TEXT,
alias TEXT,
local_type INTEGER,
verify_flag INTEGER,
big_head_url TEXT,
small_head_url TEXT
)
"""
)
conn.execute(
"INSERT INTO contact VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
(
username,
"",
"测试好友",
"",
1,
0,
"https://wx.qlogo.cn/mmhead/ver_1/test_remote_avatar/132",
"",
),
)
conn.commit()
finally:
conn.close()
def _seed_session_db(self, path: Path, *, username: str = "wxid_friend") -> None:
conn = sqlite3.connect(str(path))
try:
conn.execute(
"""
CREATE TABLE SessionTable (
username TEXT,
sort_timestamp INTEGER,
last_timestamp INTEGER
)
"""
)
conn.execute("INSERT INTO SessionTable VALUES (?, ?, ?)", (username, 200, 200))
conn.commit()
finally:
conn.close()
def _seed_head_image_db(self, path: Path, *, username: str = "wxid_friend") -> None:
# 1x1 PNG
png = bytes.fromhex(
"89504E470D0A1A0A"
"0000000D49484452000000010000000108060000001F15C489"
"0000000D49444154789C6360606060000000050001A5F64540"
"0000000049454E44AE426082"
)
conn = sqlite3.connect(str(path))
try:
conn.execute("CREATE TABLE head_image(username TEXT PRIMARY KEY, md5 TEXT, image_buffer BLOB, update_time INTEGER)")
conn.execute(
"INSERT INTO head_image VALUES (?, ?, ?, ?)",
(username, "0123456789abcdef0123456789abcdef", sqlite3.Binary(png), 1735689600),
)
conn.commit()
finally:
conn.close()
def test_chat_avatar_caches_to_output_avatar_cache(self):
from fastapi import FastAPI
from fastapi.testclient import TestClient
with TemporaryDirectory() as td:
root = Path(td)
account = "wxid_test"
username = "wxid_friend"
account_dir = root / "output" / "databases" / account
account_dir.mkdir(parents=True, exist_ok=True)
self._seed_contact_db(account_dir / "contact.db", username=username)
self._seed_session_db(account_dir / "session.db", username=username)
self._seed_head_image_db(account_dir / "head_image.db", username=username)
prev_data = None
prev_cache = None
try:
prev_data = os.environ.get("WECHAT_TOOL_DATA_DIR")
prev_cache = os.environ.get("WECHAT_TOOL_AVATAR_CACHE_ENABLED")
os.environ["WECHAT_TOOL_DATA_DIR"] = str(root)
os.environ["WECHAT_TOOL_AVATAR_CACHE_ENABLED"] = "1"
import wechat_decrypt_tool.app_paths as app_paths
import wechat_decrypt_tool.chat_helpers as chat_helpers
import wechat_decrypt_tool.avatar_cache as avatar_cache
import wechat_decrypt_tool.routers.chat_media as chat_media
importlib.reload(app_paths)
importlib.reload(chat_helpers)
importlib.reload(avatar_cache)
importlib.reload(chat_media)
app = FastAPI()
app.include_router(chat_media.router)
client = TestClient(app)
resp = client.get("/api/chat/avatar", params={"account": account, "username": username})
self.assertEqual(resp.status_code, 200)
self.assertTrue(resp.headers.get("content-type", "").startswith("image/"))
cache_db = root / "output" / "avatar_cache" / account / "avatar_cache.db"
self.assertTrue(cache_db.exists())
conn = sqlite3.connect(str(cache_db))
try:
row = conn.execute(
"SELECT cache_key, source_kind, username, rel_path, media_type FROM avatar_cache_entries WHERE source_kind = 'user' LIMIT 1"
).fetchone()
self.assertIsNotNone(row)
rel_path = str(row[3] or "")
finally:
conn.close()
self.assertTrue(rel_path)
cache_file = (root / "output" / "avatar_cache" / account / rel_path).resolve()
self.assertTrue(cache_file.exists())
resp2 = client.get("/api/chat/avatar", params={"account": account, "username": username})
self.assertEqual(resp2.status_code, 200)
self.assertEqual(resp2.content, resp.content)
finally:
if prev_data is None:
os.environ.pop("WECHAT_TOOL_DATA_DIR", None)
else:
os.environ["WECHAT_TOOL_DATA_DIR"] = prev_data
if prev_cache is None:
os.environ.pop("WECHAT_TOOL_AVATAR_CACHE_ENABLED", None)
else:
os.environ["WECHAT_TOOL_AVATAR_CACHE_ENABLED"] = prev_cache
if __name__ == "__main__":
unittest.main()