diff --git a/src/wechat_decrypt_tool/chat_edit_store.py b/src/wechat_decrypt_tool/chat_edit_store.py new file mode 100644 index 0000000..e66328b --- /dev/null +++ b/src/wechat_decrypt_tool/chat_edit_store.py @@ -0,0 +1,487 @@ +from __future__ import annotations + +import json +import re +import sqlite3 +import time +from pathlib import Path +from typing import Any, Optional + +from .app_paths import get_output_dir + +_HEX_RE = re.compile(r"^[0-9a-fA-F]+$") + + +def _db_path() -> Path: + return get_output_dir() / "message_edits.db" + + +def _connect() -> sqlite3.Connection: + db_path = _db_path() + db_path.parent.mkdir(parents=True, exist_ok=True) + conn = sqlite3.connect(str(db_path), timeout=5) + conn.row_factory = sqlite3.Row + _ensure_schema(conn) + return conn + + +def ensure_schema() -> None: + conn: Optional[sqlite3.Connection] = None + try: + conn = _connect() + finally: + try: + if conn is not None: + conn.close() + except Exception: + pass + + +def _ensure_schema(conn: sqlite3.Connection) -> None: + conn.execute( + """ + CREATE TABLE IF NOT EXISTS message_edits ( + account TEXT NOT NULL, + session_id TEXT NOT NULL, + db TEXT NOT NULL, + table_name TEXT NOT NULL, + local_id INTEGER NOT NULL, + first_edited_at INTEGER NOT NULL, + last_edited_at INTEGER NOT NULL, + edit_count INTEGER NOT NULL, + original_msg_json TEXT NOT NULL, + original_resource_json TEXT, + edited_cols_json TEXT, + PRIMARY KEY (account, session_id, db, table_name, local_id) + ) + """ + ) + conn.execute( + "CREATE INDEX IF NOT EXISTS idx_message_edits_account_session ON message_edits(account, session_id)" + ) + conn.execute("CREATE INDEX IF NOT EXISTS idx_message_edits_account_last ON message_edits(account, last_edited_at)") + + # Backwards-compatible migrations for existing DBs. + try: + cols = { + str(r[1] or "").strip().lower() + for r in conn.execute("PRAGMA table_info(message_edits)").fetchall() + if r and len(r) > 1 and r[1] + } + if "edited_cols_json" not in cols: + conn.execute("ALTER TABLE message_edits ADD COLUMN edited_cols_json TEXT") + except Exception: + pass + conn.commit() + + +def _now_ms() -> int: + return int(time.time() * 1000) + + +def format_message_id(db: str, table_name: str, local_id: int) -> str: + return f"{str(db or '').strip()}:{str(table_name or '').strip()}:{int(local_id or 0)}" + + +def parse_message_id(message_id: str) -> tuple[str, str, int]: + parts = str(message_id or "").split(":", 2) + if len(parts) != 3: + raise ValueError("Invalid message_id format.") + db = str(parts[0] or "").strip() + table_name = str(parts[1] or "").strip() + try: + local_id = int(parts[2] or 0) + except Exception: + raise ValueError("Invalid message_id format.") + if not db or not table_name or local_id <= 0: + raise ValueError("Invalid message_id format.") + return db, table_name, local_id + + +def _bytes_to_hex(value: bytes) -> str: + return "0x" + value.hex() + + +def _hex_to_bytes(value: str) -> Optional[bytes]: + s = str(value or "").strip() + if not s.startswith("0x"): + return None + hex_part = s[2:] + if (not hex_part) or (len(hex_part) % 2 != 0) or (_HEX_RE.match(hex_part) is None): + return None + try: + return bytes.fromhex(hex_part) + except Exception: + return None + + +def _jsonify_blobs(obj: Any) -> Any: + if obj is None: + return None + if isinstance(obj, (bytes, bytearray, memoryview)): + return _bytes_to_hex(bytes(obj)) + if isinstance(obj, dict): + return {str(k): _jsonify_blobs(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return [_jsonify_blobs(v) for v in obj] + return obj + + +def _dejsonify_blobs(obj: Any) -> Any: + if obj is None: + return None + if isinstance(obj, str): + b = _hex_to_bytes(obj) + return b if b is not None else obj + if isinstance(obj, dict): + return {str(k): _dejsonify_blobs(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_dejsonify_blobs(v) for v in obj] + return obj + + +def dumps_json_with_blobs(obj: Any) -> str: + return json.dumps(_jsonify_blobs(obj), ensure_ascii=False, separators=(",", ":")) + + +def loads_json_with_blobs(payload: str) -> Any: + return _dejsonify_blobs(json.loads(str(payload or "") or "null")) + + +def upsert_original_once( + *, + account: str, + session_id: str, + db: str, + table_name: str, + local_id: int, + original_msg: dict[str, Any], + original_resource: Optional[dict[str, Any]], + now_ms: Optional[int] = None, +) -> None: + """Insert the original snapshot for a message only once, then bump counters on subsequent edits.""" + a = str(account or "").strip() + sid = str(session_id or "").strip() + db_norm = str(db or "").strip() + t = str(table_name or "").strip() + lid = int(local_id or 0) + if not a or not sid or not db_norm or not t or lid <= 0: + raise ValueError("Missing required keys for message edit store.") + + ts = int(now_ms if now_ms is not None else _now_ms()) + msg_json = dumps_json_with_blobs(original_msg or {}) + res_json = dumps_json_with_blobs(original_resource) if original_resource is not None else None + + conn: Optional[sqlite3.Connection] = None + try: + conn = _connect() + existing = conn.execute( + """ + SELECT 1 + FROM message_edits + WHERE account = ? AND session_id = ? AND db = ? AND table_name = ? AND local_id = ? + LIMIT 1 + """, + (a, sid, db_norm, t, lid), + ).fetchone() + if existing is None: + conn.execute( + """ + INSERT INTO message_edits( + account, session_id, db, table_name, local_id, + first_edited_at, last_edited_at, edit_count, + original_msg_json, original_resource_json, edited_cols_json + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + (a, sid, db_norm, t, lid, ts, ts, 1, msg_json, res_json, None), + ) + else: + conn.execute( + """ + UPDATE message_edits + SET last_edited_at = ?, edit_count = edit_count + 1 + WHERE account = ? AND session_id = ? AND db = ? AND table_name = ? AND local_id = ? + """, + (ts, a, sid, db_norm, t, lid), + ) + conn.commit() + finally: + try: + if conn is not None: + conn.close() + except Exception: + pass + + +def _parse_json_str_list(payload: Any) -> list[str]: + if payload is None: + return [] + if isinstance(payload, (list, tuple)): + return [str(x or "").strip() for x in payload if str(x or "").strip()] + s = str(payload or "").strip() + if not s: + return [] + try: + v = json.loads(s) + except Exception: + return [] + if not isinstance(v, list): + return [] + return [str(x or "").strip() for x in v if str(x or "").strip()] + + +def merge_edited_columns( + *, + account: str, + session_id: str, + db: str, + table_name: str, + local_id: int, + columns: list[str], +) -> bool: + """Merge edited message column names into the per-message edit record. + + This allows reset to restore only the fields actually modified by the tool. + """ + a = str(account or "").strip() + sid = str(session_id or "").strip() + db_norm = str(db or "").strip() + t = str(table_name or "").strip() + lid = int(local_id or 0) + if not a or not sid or not db_norm or not t or lid <= 0: + return False + + cols_in = [str(x or "").strip() for x in (columns or []) if str(x or "").strip()] + if not cols_in: + return True + + conn: Optional[sqlite3.Connection] = None + try: + conn = _connect() + row = conn.execute( + """ + SELECT edited_cols_json + FROM message_edits + WHERE account = ? AND session_id = ? AND db = ? AND table_name = ? AND local_id = ? + LIMIT 1 + """, + (a, sid, db_norm, t, lid), + ).fetchone() + if row is None: + return False + + existing = _parse_json_str_list(row[0] if row and len(row) else None) + merged = {c.lower() for c in existing if c} | {c.lower() for c in cols_in if c} + merged_list = sorted(merged) + payload = json.dumps(merged_list, ensure_ascii=False, separators=(",", ":")) + conn.execute( + """ + UPDATE message_edits + SET edited_cols_json = ? + WHERE account = ? AND session_id = ? AND db = ? AND table_name = ? AND local_id = ? + """, + (payload, a, sid, db_norm, t, lid), + ) + conn.commit() + return True + finally: + try: + if conn is not None: + conn.close() + except Exception: + pass + + +def _row_to_dict(row: Optional[sqlite3.Row]) -> Optional[dict[str, Any]]: + if row is None: + return None + out: dict[str, Any] = {} + for k in row.keys(): + out[str(k)] = row[k] + return out + + +def list_sessions(account: str) -> list[dict[str, Any]]: + a = str(account or "").strip() + if not a: + return [] + + conn: Optional[sqlite3.Connection] = None + try: + conn = _connect() + rows = conn.execute( + """ + SELECT session_id, COUNT(*) AS msg_count, MAX(last_edited_at) AS last_edited_at + FROM message_edits + WHERE account = ? + GROUP BY session_id + ORDER BY last_edited_at DESC + """, + (a,), + ).fetchall() + out: list[dict[str, Any]] = [] + for r in rows: + try: + sid = str(r["session_id"] or "").strip() + except Exception: + sid = "" + if not sid: + continue + out.append( + { + "session_id": sid, + "msg_count": int(r["msg_count"] or 0), + "last_edited_at": int(r["last_edited_at"] or 0), + } + ) + return out + finally: + try: + if conn is not None: + conn.close() + except Exception: + pass + + +def list_messages(account: str, session_id: str) -> list[dict[str, Any]]: + a = str(account or "").strip() + sid = str(session_id or "").strip() + if not a or not sid: + return [] + + conn: Optional[sqlite3.Connection] = None + try: + conn = _connect() + rows = conn.execute( + """ + SELECT * + FROM message_edits + WHERE account = ? AND session_id = ? + ORDER BY last_edited_at ASC, local_id ASC + """, + (a, sid), + ).fetchall() + out: list[dict[str, Any]] = [] + for r in rows: + item = _row_to_dict(r) or {} + try: + item["message_id"] = format_message_id(item.get("db") or "", item.get("table_name") or "", item.get("local_id") or 0) + except Exception: + item["message_id"] = "" + out.append(item) + return out + finally: + try: + if conn is not None: + conn.close() + except Exception: + pass + + +def get_message_edit(account: str, session_id: str, message_id: str) -> Optional[dict[str, Any]]: + a = str(account or "").strip() + sid = str(session_id or "").strip() + if not a or not sid or not message_id: + return None + try: + db, table_name, local_id = parse_message_id(message_id) + except Exception: + return None + + conn: Optional[sqlite3.Connection] = None + try: + conn = _connect() + row = conn.execute( + """ + SELECT * + FROM message_edits + WHERE account = ? AND session_id = ? AND db = ? AND table_name = ? AND local_id = ? + LIMIT 1 + """, + (a, sid, db, table_name, int(local_id)), + ).fetchone() + item = _row_to_dict(row) + if not item: + return None + item["message_id"] = format_message_id(db, table_name, local_id) + return item + finally: + try: + if conn is not None: + conn.close() + except Exception: + pass + + +def delete_message_edit(account: str, session_id: str, message_id: str) -> bool: + a = str(account or "").strip() + sid = str(session_id or "").strip() + if not a or not sid or not message_id: + return False + try: + db, table_name, local_id = parse_message_id(message_id) + except Exception: + return False + + conn: Optional[sqlite3.Connection] = None + try: + conn = _connect() + cur = conn.execute( + """ + DELETE FROM message_edits + WHERE account = ? AND session_id = ? AND db = ? AND table_name = ? AND local_id = ? + """, + (a, sid, db, table_name, int(local_id)), + ) + conn.commit() + return int(getattr(cur, "rowcount", 0) or 0) > 0 + finally: + try: + if conn is not None: + conn.close() + except Exception: + pass + + +def update_message_edit_local_id( + *, + account: str, + session_id: str, + db: str, + table_name: str, + old_local_id: int, + new_local_id: int, +) -> bool: + """Update the primary key local_id for an existing edit record (unsafe operations may change Msg.local_id).""" + a = str(account or "").strip() + sid = str(session_id or "").strip() + db_norm = str(db or "").strip() + t = str(table_name or "").strip() + old_lid = int(old_local_id or 0) + new_lid = int(new_local_id or 0) + if not a or not sid or not db_norm or not t or old_lid <= 0 or new_lid <= 0: + return False + if old_lid == new_lid: + return True + + conn: Optional[sqlite3.Connection] = None + try: + conn = _connect() + cur = conn.execute( + """ + UPDATE message_edits + SET local_id = ? + WHERE account = ? AND session_id = ? AND db = ? AND table_name = ? AND local_id = ? + """, + (new_lid, a, sid, db_norm, t, old_lid), + ) + conn.commit() + return int(getattr(cur, "rowcount", 0) or 0) > 0 + except Exception: + return False + finally: + try: + if conn is not None: + conn.close() + except Exception: + pass diff --git a/src/wechat_decrypt_tool/native/wcdb_api.dll b/src/wechat_decrypt_tool/native/wcdb_api.dll index 0ec075d..ecce05d 100644 Binary files a/src/wechat_decrypt_tool/native/wcdb_api.dll and b/src/wechat_decrypt_tool/native/wcdb_api.dll differ diff --git a/src/wechat_decrypt_tool/path_fix.py b/src/wechat_decrypt_tool/path_fix.py index c0ddf65..e05ff08 100644 --- a/src/wechat_decrypt_tool/path_fix.py +++ b/src/wechat_decrypt_tool/path_fix.py @@ -32,10 +32,8 @@ class PathFixRequest(Request): def _validate_paths_in_json(self, json_data: dict) -> Optional[str]: """验证JSON中的路径,返回错误信息(如果有)""" logger.info(f"开始验证路径,JSON数据: {json_data}") - # 检查db_storage_path字段(现在是必需的) - if 'db_storage_path' not in json_data: - return "缺少必需的db_storage_path参数,请提供具体的数据库存储路径。" - + # 仅在提供 db_storage_path 时进行校验(例如 /api/decrypt)。 + # 其它 API 的 JSON payload 不一定包含路径字段,不应强制要求。 if 'db_storage_path' in json_data: path = json_data['db_storage_path'] @@ -115,11 +113,16 @@ class PathFixRequest(Request): async def body(self) -> bytes: """重写body方法,预处理JSON中的路径问题""" + cached = getattr(self.state, "_pathfix_body_bytes", None) + if isinstance(cached, (bytes, bytearray)): + return bytes(cached) + body = await super().body() # 只处理JSON请求 content_type = self.headers.get("content-type", "") if "application/json" not in content_type: + self.state._pathfix_body_bytes = body return body try: @@ -134,6 +137,7 @@ class PathFixRequest(Request): logger.info(f"检测到路径错误: {path_error}") # 我们将错误信息存储在请求中,稍后在路由处理器中检查 self.state.path_validation_error = path_error + self.state._pathfix_body_bytes = body return body except json.JSONDecodeError as e: # JSON格式错误,继续尝试修复 @@ -169,17 +173,30 @@ class PathFixRequest(Request): if path_error: logger.info(f"修复后检测到路径错误: {path_error}") self.state.path_validation_error = path_error - return fixed_body_str.encode('utf-8') + fixed_bytes = fixed_body_str.encode('utf-8') + self.state._pathfix_body_bytes = fixed_bytes + try: + self._body = fixed_bytes # type: ignore[attr-defined] + except Exception: + pass + return fixed_bytes else: logger.info(f"修复后路径验证通过") except json.JSONDecodeError as e: logger.warning(f"修复后JSON仍然解析失败: {e}") - return fixed_body_str.encode('utf-8') + fixed_bytes = fixed_body_str.encode('utf-8') + self.state._pathfix_body_bytes = fixed_bytes + try: + self._body = fixed_bytes # type: ignore[attr-defined] + except Exception: + pass + return fixed_bytes except Exception as e: # 如果处理失败,返回原始body logger.warning(f"JSON路径修复失败,使用原始请求体: {e}") + self.state._pathfix_body_bytes = body return body @@ -193,12 +210,17 @@ class PathFixRoute(APIRoute): # 将Request替换为我们的自定义Request custom_request = PathFixRequest(request.scope, request.receive) - # 检查是否有路径验证错误 - if hasattr(custom_request.state, 'path_validation_error'): - raise HTTPException( - status_code=400, - detail=custom_request.state.path_validation_error, - ) + # 仅对 JSON 请求预读 body,以触发路径修复/校验逻辑,并在发现错误时提前返回 400。 + try: + content_type = (custom_request.headers.get("content-type", "") or "").lower() + if "application/json" in content_type: + await custom_request.body() + except Exception: + pass + + path_err = getattr(custom_request.state, "path_validation_error", None) + if path_err: + raise HTTPException(status_code=400, detail=path_err) return await original_route_handler(custom_request) diff --git a/src/wechat_decrypt_tool/routers/chat.py b/src/wechat_decrypt_tool/routers/chat.py index b34f746..494e040 100644 --- a/src/wechat_decrypt_tool/routers/chat.py +++ b/src/wechat_decrypt_tool/routers/chat.py @@ -64,7 +64,8 @@ from ..chat_helpers import ( _split_group_sender_prefix, _to_char_token_text, ) -from ..media_helpers import _try_find_decrypted_resource +from ..media_helpers import _resolve_account_db_storage_dir, _try_find_decrypted_resource +from .. import chat_edit_store from ..path_fix import PathFixRoute from ..session_last_message import ( build_session_last_message_table, @@ -74,12 +75,14 @@ from ..session_last_message import ( from ..wcdb_realtime import ( WCDBRealtimeError, WCDB_REALTIME, + exec_query as _wcdb_exec_query, get_avatar_urls as _wcdb_get_avatar_urls, get_display_names as _wcdb_get_display_names, get_group_members as _wcdb_get_group_members, get_group_nicknames as _wcdb_get_group_nicknames, get_messages as _wcdb_get_messages, get_sessions as _wcdb_get_sessions, + update_message as _wcdb_update_message, ) logger = get_logger(__name__) @@ -98,6 +101,266 @@ def _is_hex_md5(value: Any) -> bool: return len(s) == 32 and all(c in "0123456789abcdef" for c in s) +_HEX_RE = re.compile(r"^[0-9a-fA-F]+$") + + +def _hex_to_bytes(value: str) -> Optional[bytes]: + s = str(value or "").strip() + if not s.startswith("0x"): + return None + hex_part = s[2:] + if (not hex_part) or (len(hex_part) % 2 != 0) or (_HEX_RE.match(hex_part) is None): + return None + try: + return bytes.fromhex(hex_part) + except Exception: + return None + + +def _bytes_to_hex(value: bytes) -> str: + return "0x" + value.hex() + + +def _is_mostly_printable_text(s: str) -> bool: + if not s: + return False + sample = s[:600] + if not sample: + return False + printable = sum(1 for ch in sample if ch.isprintable() or ch in {"\n", "\r", "\t"}) + return (printable / len(sample)) >= 0.85 + + +def _jsonify_db_value(key: str, value: Any) -> Any: + """Convert sqlite row values into JSON-friendly values (best-effort).""" + if value is None: + return None + if isinstance(value, memoryview): + value = value.tobytes() + if isinstance(value, (bytes, bytearray)): + b = bytes(value) + k = str(key or "").strip().lower() + if k in {"compress_content", "packed_info_data", "packed_info", "packedinfo", "packedinfodata"} or k.endswith( + "_data" + ): + return _bytes_to_hex(b) + if not b: + return "" + try: + s = b.decode("utf-8") + if _is_mostly_printable_text(s): + return s + except Exception: + pass + return _bytes_to_hex(b) + if isinstance(value, (int, float, bool, str)): + return value + try: + return str(value) + except Exception: + return None + + +def _sql_literal(value: Any) -> str: + """Build a SQLite literal for WCDB exec_query (no parameters supported).""" + if value is None: + return "NULL" + if isinstance(value, bool): + return "1" if value else "0" + if isinstance(value, (int, float)): + try: + return str(int(value)) + except Exception: + return "0" + if isinstance(value, memoryview): + value = value.tobytes() + if isinstance(value, (bytes, bytearray)): + b = bytes(value) + return "X'" + b.hex() + "'" + s = str(value) + return "'" + s.replace("'", "''") + "'" + + +def _normalize_edit_value(col: str, value: Any, *, from_snapshot: bool = False) -> Any: + c = str(col or "").strip().lower() + if value is None: + return None + if isinstance(value, str): + # Allow editing BLOBs via 0x... hex strings (unsafe only, enforced elsewhere). + b = _hex_to_bytes(value) + if b is not None: + return b + + # Some WCDB exec_query snapshots return raw BLOBs as bare hex strings (without 0x prefix). + # When restoring from snapshots (reset), convert them back to bytes so SQLite stores them as BLOB again. + want_blob_hex = ( + c in {"packed_info_data", "packed_info", "packedinfo", "packedinfodata"} + or c.endswith("_data") + or c in {"source"} + or (from_snapshot and c in {"message_content", "compress_content"}) + ) + if want_blob_hex: + s = value.strip() + # Heuristic for message_content: avoid converting legitimate short "hex-like" text messages. + min_len = 0 + if c == "message_content": + s_lower = s.lower() + # zstd frame magic: 28 b5 2f fd + if s_lower.startswith("28b52ffd"): + min_len = 16 + else: + min_len = 64 + if s and (len(s) >= min_len) and (len(s) % 2 == 0) and (_HEX_RE.fullmatch(s) is not None): + try: + return bytes.fromhex(s) + except Exception: + return value + if c in { + "local_id", + "create_time", + "server_id", + "local_type", + "sort_seq", + } or c.startswith("wcdb_ct_"): + s = value.strip() + if s and re.fullmatch(r"-?\d+", s): + try: + return int(s) + except Exception: + return value + return value + + +def _is_safe_edit_column(col: str, *, unsafe: bool) -> bool: + if unsafe: + return True + c = str(col or "").strip().lower() + if not c: + return False + if c == "local_id": + return False + if c.startswith("wcdb_ct_"): + return False + if c in {"compress_content", "packed_info_data", "packed_info"}: + return False + return True + + +def _pb_read_varint(buf: bytes, i: int) -> tuple[int, int]: + """Read a protobuf varint from buf starting at i, returning (value, next_index).""" + x = 0 + shift = 0 + while i < len(buf) and shift < 64: + b = buf[i] + i += 1 + x |= (b & 0x7F) << shift + if (b & 0x80) == 0: + return x, i + shift += 7 + raise ValueError("Invalid varint.") + + +def _pb_write_varint(x: int) -> bytes: + """Write a protobuf varint for a non-negative integer.""" + n = int(x or 0) + if n < 0: + raise ValueError("Negative varint.") + out = bytearray() + while True: + b = n & 0x7F + n >>= 7 + if n: + out.append(b | 0x80) + else: + out.append(b) + break + return bytes(out) + + +def _swap_packed_info_from_to(packed: bytes | bytearray | memoryview) -> tuple[bytes, int, int]: + """Swap protobuf field #1 and #2 varint values in packed_info_data. + + Empirically, WeChat uses packed_info_data as a tiny protobuf containing at least: + - field 1: fromId (Name2Id rowid) + - field 2: toId (Name2Id rowid) + + Swapping these flips message direction in the WeChat client. + Returns (new_bytes, old_field1, old_field2). + """ + if isinstance(packed, memoryview): + data = packed.tobytes() + else: + data = bytes(packed) + if not data: + raise ValueError("Empty packed_info_data.") + + # Pass 1: find the first occurrences of field 1/2 varints. + i = 0 + v1: Optional[int] = None + v2: Optional[int] = None + while i < len(data): + key, i = _pb_read_varint(data, i) + field_num = key >> 3 + wire = key & 7 + if wire == 0: + val, i = _pb_read_varint(data, i) + if field_num == 1 and v1 is None: + v1 = int(val) + elif field_num == 2 and v2 is None: + v2 = int(val) + continue + if wire == 1: + i += 8 + continue + if wire == 2: + ln, i = _pb_read_varint(data, i) + i += int(ln) + continue + if wire == 5: + i += 4 + continue + raise ValueError(f"Unsupported wire type: {wire}") + + if v1 is None or v2 is None: + raise ValueError("packed_info_data does not contain field #1 and #2 varints.") + + # Pass 2: rebuild and swap values for all field 1/2 varints. + i = 0 + out = bytearray() + while i < len(data): + key, i2 = _pb_read_varint(data, i) + field_num = key >> 3 + wire = key & 7 + out += _pb_write_varint(key) + i = i2 + + if wire == 0: + val, i = _pb_read_varint(data, i) + if field_num == 1: + val = int(v2) + elif field_num == 2: + val = int(v1) + out += _pb_write_varint(int(val)) + continue + if wire == 1: + out += data[i : i + 8] + i += 8 + continue + if wire == 2: + ln, i = _pb_read_varint(data, i) + out += _pb_write_varint(int(ln)) + out += data[i : i + int(ln)] + i += int(ln) + continue + if wire == 5: + out += data[i : i + 4] + i += 4 + continue + raise ValueError(f"Unsupported wire type: {wire}") + + return bytes(out), int(v1), int(v2) + + def _avatar_url_unified( *, account_dir: Path, @@ -6859,3 +7122,1626 @@ async def resolve_app_message( if found_appmsg: raise HTTPException(status_code=404, detail="App message decode failed.") raise HTTPException(status_code=404, detail="Message not found for server_id.") + + +def _normalize_table_name_case(conn: sqlite3.Connection, table_name: str) -> str: + t = str(table_name or "").strip() + if not t: + return "" + try: + r = conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND lower(name)=lower(?) LIMIT 1", + (t,), + ).fetchone() + if r is not None and r[0]: + # With `conn.text_factory = bytes`, sqlite_master.name can be returned as bytes. + # Decode it to avoid querying a non-existent table like "b'Msg_...'". + return _decode_sqlite_text(r[0]).strip() + except Exception: + pass + return t + + +def _table_info_columns(conn: sqlite3.Connection, table_name: str) -> set[str]: + t = str(table_name or "").strip() + if not t: + return set() + quoted = _quote_ident(t) + try: + cols = conn.execute(f"PRAGMA table_info({quoted})").fetchall() + except Exception: + return set() + out: set[str] = set() + for c in cols: + try: + name = _decode_sqlite_text(c[1]).strip() + except Exception: + continue + if name: + out.add(name) + return out + + +def _has_column(conn: sqlite3.Connection, table_name: str, column_name: str) -> bool: + want = str(column_name or "").strip().lower() + if not want: + return False + for c in _table_info_columns(conn, table_name): + if str(c or "").strip().lower() == want: + return True + return False + + +def _lookup_output_my_rowid(conn: sqlite3.Connection, my_wxid: str) -> Optional[int]: + try: + r = conn.execute( + "SELECT rowid FROM Name2Id WHERE user_name = ? LIMIT 1", + (str(my_wxid or "").strip(),), + ).fetchone() + if r is None: + return None + return int(r[0]) + except Exception: + return None + + +def _lookup_output_username_by_rowid(conn: sqlite3.Connection, rowid: int) -> str: + try: + r = conn.execute( + "SELECT user_name FROM Name2Id WHERE rowid = ? LIMIT 1", + (int(rowid or 0),), + ).fetchone() + if r is None: + return "" + return _decode_sqlite_text(r[0]).strip() + except Exception: + return "" + + +def _select_output_message_row(conn: sqlite3.Connection, *, table_name: str, local_id: int) -> Optional[sqlite3.Row]: + t = _normalize_table_name_case(conn, table_name) + if not t: + return None + quoted_table = _quote_ident(t) + has_packed_info_data = _has_column(conn, t, "packed_info_data") + packed_select = "m.packed_info_data AS packed_info_data, " if has_packed_info_data else "NULL AS packed_info_data, " + sql_with_join = ( + "SELECT " + "m.local_id, m.server_id, m.local_type, m.sort_seq, m.real_sender_id, m.create_time, " + "m.message_content, m.compress_content, " + + packed_select + + "n.user_name AS sender_username " + f"FROM {quoted_table} m " + "LEFT JOIN Name2Id n ON m.real_sender_id = n.rowid " + "WHERE m.local_id = ? " + "LIMIT 1" + ) + sql_no_join = ( + "SELECT " + "m.local_id, m.server_id, m.local_type, m.sort_seq, m.real_sender_id, m.create_time, " + "m.message_content, m.compress_content, " + + packed_select + + "'' AS sender_username " + f"FROM {quoted_table} m " + "WHERE m.local_id = ? " + "LIMIT 1" + ) + try: + return conn.execute(sql_with_join, (int(local_id),)).fetchone() + except Exception: + try: + return conn.execute(sql_no_join, (int(local_id),)).fetchone() + except Exception: + return None + + +def _resolve_db_storage_message_paths(account_dir: Path, db_stem: str) -> tuple[Path, Path]: + db_storage_dir = _resolve_account_db_storage_dir(account_dir) + if db_storage_dir is None: + raise HTTPException(status_code=400, detail="Cannot resolve db_storage directory for this account.") + db_name = str(db_stem or "").strip() + if not db_name: + raise HTTPException(status_code=400, detail="Invalid message_id.") + msg_db_path = db_storage_dir / "message" / f"{db_name}.db" + res_db_path = db_storage_dir / "message" / "message_resource.db" + return msg_db_path, res_db_path + + +def _build_wcdb_update_sql(*, table_name: str, updates: dict[str, Any], where_local_id: int) -> str: + t = str(table_name or "").strip() + if not t: + raise HTTPException(status_code=400, detail="Missing table_name.") + if not updates: + raise HTTPException(status_code=400, detail="Missing edits.") + parts: list[str] = [] + for k, v in updates.items(): + col = str(k or "").strip() + if not col: + continue + parts.append(f"{_quote_ident(col)} = {_sql_literal(v)}") + if not parts: + raise HTTPException(status_code=400, detail="Missing edits.") + return f"UPDATE {_quote_ident(t)} SET " + ", ".join(parts) + f" WHERE local_id = {int(where_local_id)}" + + +def _build_sqlite_update_sql(*, table_name: str, updates: dict[str, Any], where_local_id: int) -> tuple[str, list[Any]]: + t = str(table_name or "").strip() + if not t: + raise HTTPException(status_code=400, detail="Missing table_name.") + if not updates: + raise HTTPException(status_code=400, detail="Missing edits.") + cols: list[str] = [] + params: list[Any] = [] + for k, v in updates.items(): + col = str(k or "").strip() + if not col: + continue + cols.append(f"{_quote_ident(col)} = ?") + params.append(v) + if not cols: + raise HTTPException(status_code=400, detail="Missing edits.") + sql = f"UPDATE {_quote_ident(t)} SET " + ", ".join(cols) + " WHERE local_id = ?" + params.append(int(where_local_id)) + return sql, params + + +@router.get("/api/chat/messages/raw", summary="获取单条消息原始字段(output 解密库)") +def get_chat_message_raw( + *, + account: Optional[str] = None, + username: str, + message_id: str, +) -> dict[str, Any]: + if not username: + raise HTTPException(status_code=400, detail="Missing username.") + if not message_id: + raise HTTPException(status_code=400, detail="Missing message_id.") + + account_dir = _resolve_account_dir(account) + try: + db_stem, table_name_in, local_id = chat_edit_store.parse_message_id(message_id) + except Exception: + raise HTTPException(status_code=400, detail="Invalid message_id.") + + db_path = account_dir / f"{db_stem}.db" + if not db_path.exists(): + raise HTTPException(status_code=404, detail="Message database not found.") + + conn: Optional[sqlite3.Connection] = None + try: + conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row + conn.text_factory = bytes + table_name = _normalize_table_name_case(conn, table_name_in) + if not table_name: + raise HTTPException(status_code=404, detail="Message table not found.") + + quoted_table = _quote_ident(table_name) + row = conn.execute(f"SELECT * FROM {quoted_table} WHERE local_id = ? LIMIT 1", (int(local_id),)).fetchone() + if row is None: + raise HTTPException(status_code=404, detail="Message not found.") + + out_row: dict[str, Any] = {} + for k in row.keys(): + out_row[str(k)] = _jsonify_db_value(str(k), row[k]) + + return { + "status": "success", + "account": account_dir.name, + "username": username, + "messageId": f"{db_stem}:{table_name}:{int(local_id)}", + "row": out_row, + } + finally: + if conn is not None: + try: + conn.close() + except Exception: + pass + + +@router.post("/api/chat/messages/edit", summary="编辑/修改消息(写入真实库 db_storage 并同步 output)") +async def edit_chat_message(request: Request) -> dict[str, Any]: + payload = await request.json() + if not isinstance(payload, dict): + raise HTTPException(status_code=400, detail="Invalid payload.") + + account = str(payload.get("account") or "").strip() or None + session_id = str(payload.get("session_id") or payload.get("username") or payload.get("sessionId") or "").strip() + message_id_in = str(payload.get("message_id") or payload.get("messageId") or "").strip() + edits_in = payload.get("edits") + unsafe = bool(payload.get("unsafe") or False) + + if not session_id: + raise HTTPException(status_code=400, detail="Missing session_id.") + if not message_id_in: + raise HTTPException(status_code=400, detail="Missing message_id.") + if not isinstance(edits_in, dict) or not edits_in: + raise HTTPException(status_code=400, detail="Missing edits.") + + account_dir = _resolve_account_dir(account) + base_url = str(request.base_url).rstrip("/") + + try: + db_stem, table_name_in, local_id_old = chat_edit_store.parse_message_id(message_id_in) + except Exception: + raise HTTPException(status_code=400, detail="Invalid message_id.") + + msg_db_path_out = account_dir / f"{db_stem}.db" + if not msg_db_path_out.exists(): + raise HTTPException(status_code=404, detail="Message database not found.") + + msg_db_path_real, res_db_path_real = _resolve_db_storage_message_paths(account_dir, db_stem) + if not msg_db_path_real.exists(): + raise HTTPException(status_code=404, detail="Real message database not found in db_storage.") + + # Validate edits against output schema and normalize table name casing. + table_name = table_name_in + edits: dict[str, Any] = {} + explicit_keys: set[str] = set() + conn_schema: Optional[sqlite3.Connection] = None + try: + conn_schema = sqlite3.connect(str(msg_db_path_out)) + conn_schema.row_factory = sqlite3.Row + table_name = _normalize_table_name_case(conn_schema, table_name_in) + if not table_name: + raise HTTPException(status_code=404, detail="Message table not found.") + cols = _table_info_columns(conn_schema, table_name) + if not cols: + raise HTTPException(status_code=404, detail="Message table not found.") + + for k, v in edits_in.items(): + col = str(k or "").strip() + if not col: + continue + if col not in cols: + raise HTTPException(status_code=400, detail=f"Unknown column: {col}") + if not _is_safe_edit_column(col, unsafe=unsafe): + raise HTTPException(status_code=400, detail=f"Unsafe column requires unsafe=true: {col}") + explicit_keys.add(col) + edits[col] = _normalize_edit_value(col, v) + if not edits: + raise HTTPException(status_code=400, detail="Missing edits.") + finally: + if conn_schema is not None: + try: + conn_schema.close() + except Exception: + pass + + message_id = f"{db_stem}:{table_name}:{int(local_id_old)}" + + # Decide update strategy for real db_storage. + only_message_content = (set(edits.keys()) == {"message_content"}) and ("compress_content" not in explicit_keys) + + # Default behavior: clear compress_content when message_content changes, unless explicitly provided. + output_edits = dict(edits) + if "message_content" in edits and ("compress_content" not in explicit_keys): + output_edits.setdefault("compress_content", None) + + new_local_id = int(edits.get("local_id") or 0) if "local_id" in edits else int(local_id_old) + if new_local_id <= 0: + new_local_id = int(local_id_old) + + # Resource sync mapping when Msg fields change. + resource_sync_map: dict[str, str] = { + "local_type": "message_local_type", + "create_time": "message_create_time", + "server_id": "message_svr_id", + "origin_source": "message_origin_source", + } + if unsafe: + resource_sync_map["local_id"] = "message_local_id" + + warnings: list[str] = [] + + with _realtime_sync_lock(account_dir.name, session_id): + # Ensure WCDB realtime connection. + try: + wcdb_conn = WCDB_REALTIME.ensure_connected(account_dir) + except WCDBRealtimeError as e: + raise HTTPException(status_code=400, detail=str(e)) + + # Read original row from real db_storage (snapshot). + original_row: Optional[dict[str, Any]] = None + original_create_time = 0 + try: + select_sql = f"SELECT * FROM {_quote_ident(table_name)} WHERE local_id = {int(local_id_old)} LIMIT 1" + with wcdb_conn.lock: + rows = _wcdb_exec_query( + wcdb_conn.handle, + kind="message", + path=str(msg_db_path_real), + sql=select_sql, + ) + if rows and isinstance(rows[0], dict): + original_row = rows[0] + try: + original_create_time = int(original_row.get("create_time") or 0) + except Exception: + original_create_time = 0 + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to read original message row: {e}") + if not original_row: + raise HTTPException(status_code=404, detail="Message not found in real db_storage.") + + # Read original resource row from real db_storage (optional). + original_resource_row: Optional[dict[str, Any]] = None + try: + if res_db_path_real.exists() and original_create_time > 0: + res_sql = ( + "SELECT * FROM MessageResourceInfo " + f"WHERE message_local_id = {int(local_id_old)} AND message_create_time = {int(original_create_time)} " + "ORDER BY message_id DESC " + "LIMIT 1" + ) + with wcdb_conn.lock: + res_rows = _wcdb_exec_query( + wcdb_conn.handle, + kind="message", + path=str(res_db_path_real), + sql=res_sql, + ) + if res_rows and isinstance(res_rows[0], dict): + original_resource_row = res_rows[0] + except Exception: + original_resource_row = None + + # Create snapshot record only if this message hasn't been edited via this tool. + created_record = False + existing_record = chat_edit_store.get_message_edit(account_dir.name, session_id, message_id) + if existing_record is None: + try: + chat_edit_store.upsert_original_once( + account=account_dir.name, + session_id=session_id, + db=db_stem, + table_name=table_name, + local_id=int(local_id_old), + original_msg=original_row, + original_resource=original_resource_row, + now_ms=int(time.time() * 1000), + ) + created_record = True + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to write edit snapshot: {e}") + + # Read current output row (for rollback if real update fails). + out_before: dict[str, Any] = {} + conn_out: Optional[sqlite3.Connection] = None + try: + conn_out = sqlite3.connect(str(msg_db_path_out), timeout=5) + conn_out.row_factory = sqlite3.Row + table_name_out = _normalize_table_name_case(conn_out, table_name) + if not table_name_out: + raise HTTPException(status_code=404, detail="Message table not found.") + quoted = _quote_ident(table_name_out) + row_before = conn_out.execute( + f"SELECT * FROM {quoted} WHERE local_id = ? LIMIT 1", + (int(local_id_old),), + ).fetchone() + if row_before is None: + if created_record: + try: + chat_edit_store.delete_message_edit(account_dir.name, session_id, message_id) + except Exception: + pass + raise HTTPException(status_code=404, detail="Message not found in output database.") + for k in row_before.keys(): + out_before[str(k)] = row_before[k] + + # Apply edits to output decrypted db first; if this fails, do not touch the real db_storage. + sql_out, params_out = _build_sqlite_update_sql( + table_name=table_name_out, + updates=output_edits, + where_local_id=int(local_id_old), + ) + cur_out = conn_out.execute(sql_out, params_out) + conn_out.commit() + if int(getattr(cur_out, "rowcount", 0) or 0) <= 0: + if created_record: + try: + chat_edit_store.delete_message_edit(account_dir.name, session_id, message_id) + except Exception: + pass + raise HTTPException(status_code=404, detail="Message not found in output database.") + except HTTPException: + raise + except Exception as e: + if created_record: + try: + chat_edit_store.delete_message_edit(account_dir.name, session_id, message_id) + except Exception: + pass + raise HTTPException(status_code=500, detail=f"Failed to update output database: {e}") + + # Apply edits to real db_storage. If it fails, rollback output changes. + try: + if only_message_content: + new_content = edits.get("message_content") + if isinstance(new_content, (bytes, bytearray, memoryview)): + try: + new_content = bytes(new_content).decode("utf-8", errors="replace") + except Exception: + new_content = "" + _wcdb_update_message( + wcdb_conn.handle, + session_id=session_id, + local_id=int(local_id_old), + create_time=int(original_create_time), + new_content=str(new_content or ""), + ) + else: + real_edits = dict(edits) + if "message_content" in edits and ("compress_content" not in explicit_keys): + real_edits.setdefault("compress_content", None) + sql_real = _build_wcdb_update_sql( + table_name=table_name, + updates=real_edits, + where_local_id=int(local_id_old), + ) + with wcdb_conn.lock: + _wcdb_exec_query( + wcdb_conn.handle, + kind="message", + path=str(msg_db_path_real), + sql=sql_real, + ) + except Exception as e: + # Roll back output changes. + try: + where_lid = int(new_local_id) if ("local_id" in edits) else int(local_id_old) + cols_now = _table_info_columns(conn_out, table_name_out) + rollback_updates = {k: v for k, v in out_before.items() if str(k or "") in cols_now} + sql_rb, params_rb = _build_sqlite_update_sql( + table_name=table_name_out, + updates=rollback_updates, + where_local_id=where_lid, + ) + conn_out.execute(sql_rb, params_rb) + conn_out.commit() + except Exception: + pass + # Remove newly-created snapshot record (real db was not touched successfully). + if created_record: + try: + chat_edit_store.delete_message_edit(account_dir.name, session_id, message_id) + except Exception: + pass + raise HTTPException(status_code=500, detail=f"Failed to update real db_storage: {e}") + finally: + if conn_out is not None: + try: + conn_out.close() + except Exception: + pass + + # Sync message_resource key fields (best-effort). + try: + msg_to_res_updates: dict[str, Any] = {} + for msg_col, res_col in resource_sync_map.items(): + if msg_col in edits: + msg_to_res_updates[res_col] = _normalize_edit_value(res_col, edits[msg_col]) + if msg_to_res_updates: + res_message_id = 0 + if original_resource_row is not None: + try: + res_message_id = int(original_resource_row.get("message_id") or 0) + except Exception: + res_message_id = 0 + if res_message_id > 0: + # real db_storage + if res_db_path_real.exists(): + parts = [f"{_quote_ident(k)} = {_sql_literal(v)}" for k, v in msg_to_res_updates.items()] + sql_res_real = ( + "UPDATE MessageResourceInfo SET " + + ", ".join(parts) + + f" WHERE message_id = {int(res_message_id)}" + ) + with wcdb_conn.lock: + _wcdb_exec_query( + wcdb_conn.handle, + kind="message", + path=str(res_db_path_real), + sql=sql_res_real, + ) + + # output decrypted + out_res_db_path = account_dir / "message_resource.db" + if out_res_db_path.exists(): + conn_res = sqlite3.connect(str(out_res_db_path), timeout=5) + try: + set_cols = ", ".join([f"{_quote_ident(k)} = ?" for k in msg_to_res_updates.keys()]) + params = list(msg_to_res_updates.values()) + [int(res_message_id)] + conn_res.execute( + f"UPDATE MessageResourceInfo SET {set_cols} WHERE message_id = ?", + params, + ) + conn_res.commit() + finally: + conn_res.close() + else: + warnings.append("message_resource row not found; skipped resource sync.") + except Exception as e: + warnings.append(f"Failed to sync message_resource: {e}") + + # If local_id changed (unsafe), move the edit record key so future reset works. + edit_record_local_id = int(local_id_old) + if "local_id" in edits and int(new_local_id) != int(local_id_old): + ok = chat_edit_store.update_message_edit_local_id( + account=account_dir.name, + session_id=session_id, + db=db_stem, + table_name=table_name, + old_local_id=int(local_id_old), + new_local_id=int(new_local_id), + ) + if not ok: + warnings.append("Failed to update edit record key after local_id change.") + else: + edit_record_local_id = int(new_local_id) + + # If this was an already-tracked message, bump edit metadata. + if existing_record is not None: + try: + chat_edit_store.upsert_original_once( + account=account_dir.name, + session_id=session_id, + db=db_stem, + table_name=table_name, + local_id=int(edit_record_local_id), + original_msg={}, + original_resource=None, + now_ms=int(time.time() * 1000), + ) + except Exception: + pass + + # Track which columns were actually modified so reset can restore only those fields. + try: + chat_edit_store.merge_edited_columns( + account=account_dir.name, + session_id=session_id, + db=db_stem, + table_name=table_name, + local_id=int(edit_record_local_id), + columns=list(output_edits.keys()), + ) + except Exception: + pass + + # Build updated message object (best-effort, from output). + updated_message: Optional[dict[str, Any]] = None + try: + conn_msg = sqlite3.connect(str(msg_db_path_out)) + conn_msg.row_factory = sqlite3.Row + conn_msg.text_factory = bytes + row = _select_output_message_row(conn_msg, table_name=table_name, local_id=int(new_local_id)) + if row is not None: + my_rowid = _lookup_output_my_rowid(conn_msg, account_dir.name) + out_res_db_path2 = account_dir / "message_resource.db" + resource_conn: Optional[sqlite3.Connection] = None + resource_chat_id: Optional[int] = None + try: + if out_res_db_path2.exists(): + resource_conn = sqlite3.connect(str(out_res_db_path2)) + resource_conn.row_factory = sqlite3.Row + resource_chat_id = _resource_lookup_chat_id(resource_conn, session_id) + except Exception: + if resource_conn is not None: + try: + resource_conn.close() + except Exception: + pass + resource_conn = None + resource_chat_id = None + + merged: list[dict[str, Any]] = [] + sender_usernames: list[str] = [] + quote_usernames: list[str] = [] + pat_usernames: set[str] = set() + _append_full_messages_from_rows( + merged=merged, + sender_usernames=sender_usernames, + quote_usernames=quote_usernames, + pat_usernames=pat_usernames, + rows=[row], + db_path=msg_db_path_out, + table_name=table_name, + username=session_id, + account_dir=account_dir, + is_group=bool(session_id.endswith("@chatroom")), + my_rowid=my_rowid, + resource_conn=resource_conn, + resource_chat_id=resource_chat_id, + ) + _postprocess_full_messages( + merged=merged, + sender_usernames=sender_usernames, + quote_usernames=quote_usernames, + pat_usernames=pat_usernames, + account_dir=account_dir, + username=session_id, + base_url=base_url, + contact_db_path=account_dir / "contact.db", + head_image_db_path=account_dir / "head_image.db", + ) + if merged: + updated_message = merged[0] + if resource_conn is not None: + try: + resource_conn.close() + except Exception: + pass + conn_msg.close() + except Exception: + updated_message = None + + resp: dict[str, Any] = { + "status": "success", + "account": account_dir.name, + "session_id": session_id, + "messageId": f"{db_stem}:{table_name}:{int(new_local_id)}", + } + if warnings: + resp["warnings"] = warnings + if updated_message is not None: + resp["updated_message"] = updated_message + return resp + + +@router.get("/api/chat/edits/sessions", summary="获取有修改记录的会话列表") +def list_chat_edited_sessions(request: Request, account: Optional[str] = None) -> dict[str, Any]: + account_dir = _resolve_account_dir(account) + base_url = str(request.base_url).rstrip("/") + + stats = chat_edit_store.list_sessions(account_dir.name) + session_ids = [str(s.get("session_id") or "").strip() for s in stats if str(s.get("session_id") or "").strip()] + contact_db_path = account_dir / "contact.db" + contact_rows = _load_contact_rows(contact_db_path, session_ids) + + sessions: list[dict[str, Any]] = [] + for s in stats: + uname = str(s.get("session_id") or "").strip() + if not uname: + continue + row = contact_rows.get(uname) + name = _pick_display_name(row, uname) if row is not None else uname + avatar = base_url + _avatar_url_unified(account_dir=account_dir, username=uname) + sessions.append( + { + "username": uname, + "name": name, + "avatar": avatar, + "isGroup": bool(uname.endswith("@chatroom")), + "editedCount": int(s.get("msg_count") or 0), + "lastEditedAt": int(s.get("last_edited_at") or 0), + } + ) + + return { + "status": "success", + "account": account_dir.name, + "sessions": sessions, + } + + +@router.get("/api/chat/edits/messages", summary="获取某会话下所有被修改过的消息(原/现对比)") +def list_chat_edited_messages( + request: Request, + username: str, + account: Optional[str] = None, +) -> dict[str, Any]: + if not username: + raise HTTPException(status_code=400, detail="Missing username.") + account_dir = _resolve_account_dir(account) + base_url = str(request.base_url).rstrip("/") + + edits = chat_edit_store.list_messages(account_dir.name, username) + if not edits: + return {"status": "success", "account": account_dir.name, "username": username, "items": []} + + # Open resource DB once (optional). + resource_conn: Optional[sqlite3.Connection] = None + resource_chat_id: Optional[int] = None + out_res_db_path = account_dir / "message_resource.db" + try: + if out_res_db_path.exists(): + resource_conn = sqlite3.connect(str(out_res_db_path)) + resource_conn.row_factory = sqlite3.Row + resource_chat_id = _resource_lookup_chat_id(resource_conn, username) + except Exception: + if resource_conn is not None: + try: + resource_conn.close() + except Exception: + pass + + resource_conn = None + resource_chat_id = None + + is_group = bool(username.endswith("@chatroom")) + + msg_conns: dict[str, sqlite3.Connection] = {} + my_rowids: dict[str, Optional[int]] = {} + + merged_current: list[dict[str, Any]] = [] + sender_usernames_current: list[str] = [] + quote_usernames_current: list[str] = [] + pat_usernames_current: set[str] = set() + + merged_original: list[dict[str, Any]] = [] + sender_usernames_original: list[str] = [] + quote_usernames_original: list[str] = [] + pat_usernames_original: set[str] = set() + + current_raw_by_id: dict[str, dict[str, Any]] = {} + original_raw_by_id: dict[str, Any] = {} + + try: + for rec in edits: + db_stem = str(rec.get("db") or "").strip() + table_name = str(rec.get("table_name") or "").strip() + try: + local_id = int(rec.get("local_id") or 0) + except Exception: + local_id = 0 + if not db_stem or not table_name or local_id <= 0: + continue + + message_id = str(rec.get("message_id") or "").strip() or f"{db_stem}:{table_name}:{int(local_id)}" + + conn_msg = msg_conns.get(db_stem) + if conn_msg is None: + db_path_out = account_dir / f"{db_stem}.db" + if not db_path_out.exists(): + continue + conn_msg = sqlite3.connect(str(db_path_out)) + conn_msg.row_factory = sqlite3.Row + conn_msg.text_factory = bytes + msg_conns[db_stem] = conn_msg + my_rowids[db_stem] = _lookup_output_my_rowid(conn_msg, account_dir.name) + + row_cur = _select_output_message_row(conn_msg, table_name=table_name, local_id=local_id) + if row_cur is not None: + _append_full_messages_from_rows( + merged=merged_current, + sender_usernames=sender_usernames_current, + quote_usernames=quote_usernames_current, + pat_usernames=pat_usernames_current, + rows=[row_cur], + db_path=account_dir / f"{db_stem}.db", + table_name=table_name, + username=username, + account_dir=account_dir, + is_group=is_group, + my_rowid=my_rowids.get(db_stem), + resource_conn=resource_conn, + resource_chat_id=resource_chat_id, + ) + cur_raw: dict[str, Any] = {} + for k in row_cur.keys(): + cur_raw[str(k)] = _jsonify_db_value(str(k), row_cur[k]) + current_raw_by_id[message_id] = cur_raw + + # Original raw snapshot (for UI raw display) + try: + original_raw_by_id[message_id] = json.loads(str(rec.get("original_msg_json") or "") or "null") + except Exception: + original_raw_by_id[message_id] = None + + # Original row for rendering + try: + orig_row = chat_edit_store.loads_json_with_blobs(str(rec.get("original_msg_json") or "") or "") + except Exception: + orig_row = None + if isinstance(orig_row, dict): + try: + rsid = int(orig_row.get("real_sender_id") or 0) + except Exception: + rsid = 0 + sender_username = _lookup_output_username_by_rowid(conn_msg, rsid) if rsid > 0 else "" + orig_row["sender_username"] = sender_username + orig_row.setdefault("packed_info_data", None) + _append_full_messages_from_rows( + merged=merged_original, + sender_usernames=sender_usernames_original, + quote_usernames=quote_usernames_original, + pat_usernames=pat_usernames_original, + rows=[orig_row], + db_path=account_dir / f"{db_stem}.db", + table_name=table_name, + username=username, + account_dir=account_dir, + is_group=is_group, + my_rowid=my_rowids.get(db_stem), + resource_conn=resource_conn, + resource_chat_id=resource_chat_id, + ) + + if merged_current: + _postprocess_full_messages( + merged=merged_current, + sender_usernames=sender_usernames_current, + quote_usernames=quote_usernames_current, + pat_usernames=pat_usernames_current, + account_dir=account_dir, + username=username, + base_url=base_url, + contact_db_path=account_dir / "contact.db", + head_image_db_path=account_dir / "head_image.db", + ) + if merged_original: + _postprocess_full_messages( + merged=merged_original, + sender_usernames=sender_usernames_original, + quote_usernames=quote_usernames_original, + pat_usernames=pat_usernames_original, + account_dir=account_dir, + username=username, + base_url=base_url, + contact_db_path=account_dir / "contact.db", + head_image_db_path=account_dir / "head_image.db", + ) + + current_by_id = {str(m.get("id") or ""): m for m in merged_current if str(m.get("id") or "").strip()} + original_by_id = {str(m.get("id") or ""): m for m in merged_original if str(m.get("id") or "").strip()} + + items: list[dict[str, Any]] = [] + for rec in edits: + mid = str(rec.get("message_id") or "").strip() + if not mid: + try: + mid = chat_edit_store.format_message_id( + rec.get("db") or "", + rec.get("table_name") or "", + int(rec.get("local_id") or 0), + ) + except Exception: + mid = "" + if not mid: + continue + items.append( + { + "messageId": mid, + "firstEditedAt": int(rec.get("first_edited_at") or 0), + "lastEditedAt": int(rec.get("last_edited_at") or 0), + "editCount": int(rec.get("edit_count") or 0), + "original": original_by_id.get(mid), + "current": current_by_id.get(mid), + "originalRaw": original_raw_by_id.get(mid), + "currentRaw": current_raw_by_id.get(mid), + } + ) + + items.sort(key=lambda x: int(((x.get("current") or x.get("original") or {}) or {}).get("createTime") or 0)) + return { + "status": "success", + "account": account_dir.name, + "username": username, + "items": items, + } + finally: + for c in msg_conns.values(): + try: + c.close() + except Exception: + pass + if resource_conn is not None: + try: + resource_conn.close() + except Exception: + pass + + +@router.get("/api/chat/edits/message_status", summary="某条消息是否被本工具修改过") +def get_chat_edit_status(*, account: Optional[str] = None, username: str, message_id: str) -> dict[str, Any]: + if not username: + raise HTTPException(status_code=400, detail="Missing username.") + if not message_id: + raise HTTPException(status_code=400, detail="Missing message_id.") + account_dir = _resolve_account_dir(account) + item = chat_edit_store.get_message_edit(account_dir.name, username, message_id) + if not item: + return {"modified": False} + return { + "modified": True, + "firstEditedAt": int(item.get("first_edited_at") or 0), + "lastEditedAt": int(item.get("last_edited_at") or 0), + "editCount": int(item.get("edit_count") or 0), + } + + +@router.post("/api/chat/messages/repair_sender", summary="修复某条消息的发送者(real_sender_id)") +async def repair_chat_message_sender(request: Request) -> dict[str, Any]: + """Repair message sender for cases where an incorrect reset wrote wrong metadata. + + Currently this supports forcing the message to be treated as "sent by me" by setting + `real_sender_id` to the account's Name2Id rowid, in both db_storage and output DB. + """ + payload = await request.json() + if not isinstance(payload, dict): + raise HTTPException(status_code=400, detail="Invalid payload.") + + account = str(payload.get("account") or "").strip() or None + session_id = str(payload.get("session_id") or payload.get("username") or payload.get("sessionId") or "").strip() + message_id = str(payload.get("message_id") or payload.get("messageId") or "").strip() + mode = str(payload.get("mode") or "me").strip().lower() + + if not session_id: + raise HTTPException(status_code=400, detail="Missing session_id.") + if not message_id: + raise HTTPException(status_code=400, detail="Missing message_id.") + if mode not in {"me"}: + raise HTTPException(status_code=400, detail="Unsupported mode.") + + account_dir = _resolve_account_dir(account) + try: + db_stem, table_name_in, local_id = chat_edit_store.parse_message_id(message_id) + except Exception: + raise HTTPException(status_code=400, detail="Invalid message_id.") + + msg_db_path_out = account_dir / f"{db_stem}.db" + if not msg_db_path_out.exists(): + raise HTTPException(status_code=404, detail="Message database not found.") + + msg_db_path_real, _res_db_path_real = _resolve_db_storage_message_paths(account_dir, db_stem) + if not msg_db_path_real.exists(): + raise HTTPException(status_code=404, detail="Real message database not found in db_storage.") + + # Resolve output table name casing and the "my" rowid for this message DB. + table_name_out = "" + my_rowid_out: Optional[int] = None + conn_probe: Optional[sqlite3.Connection] = None + try: + conn_probe = sqlite3.connect(str(msg_db_path_out), timeout=5) + conn_probe.row_factory = sqlite3.Row + table_name_out = _normalize_table_name_case(conn_probe, table_name_in) or "" + if not table_name_out: + raise HTTPException(status_code=404, detail="Message table not found.") + + r = conn_probe.execute( + "SELECT rowid FROM Name2Id WHERE user_name = ? ORDER BY rowid ASC LIMIT 1", + (account_dir.name,), + ).fetchone() + if r is not None: + try: + my_rowid_out = int(r[0]) + except Exception: + my_rowid_out = None + finally: + if conn_probe is not None: + try: + conn_probe.close() + except Exception: + pass + + if my_rowid_out is None or my_rowid_out <= 0: + raise HTTPException(status_code=404, detail="Name2Id row not found for account in output db.") + + with _realtime_sync_lock(account_dir.name, session_id): + try: + wcdb_conn = WCDB_REALTIME.ensure_connected(account_dir) + except WCDBRealtimeError as e: + raise HTTPException(status_code=400, detail=str(e)) + + # Resolve "my" rowid from the live db_storage message DB. + sql_my = ( + "SELECT rowid FROM Name2Id WHERE user_name = " + + _sql_literal(account_dir.name) + + " ORDER BY rowid ASC LIMIT 1" + ) + with wcdb_conn.lock: + rows = _wcdb_exec_query(wcdb_conn.handle, kind="message", path=str(msg_db_path_real), sql=sql_my) + + my_rowid_real = 0 + if rows and isinstance(rows[0], dict): + for k, v in rows[0].items(): + if str(k or "").strip().lower() == "rowid": + try: + my_rowid_real = int(v or 0) + except Exception: + my_rowid_real = 0 + break + + if my_rowid_real <= 0: + raise HTTPException(status_code=404, detail="Name2Id row not found for account in real db_storage.") + + # 1) Update real db_storage (source of truth). + try: + sql_real = _build_wcdb_update_sql( + table_name=table_name_in, + updates={"real_sender_id": int(my_rowid_real)}, + where_local_id=int(local_id), + ) + with wcdb_conn.lock: + _wcdb_exec_query(wcdb_conn.handle, kind="message", path=str(msg_db_path_real), sql=sql_real) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to update real db_storage: {e}") + + # 2) Sync output decrypted DB so UI reflects the change immediately. + try: + conn_out = sqlite3.connect(str(msg_db_path_out), timeout=5) + try: + sql_out, params_out = _build_sqlite_update_sql( + table_name=table_name_out, + updates={"real_sender_id": int(my_rowid_out)}, + where_local_id=int(local_id), + ) + conn_out.execute(sql_out, params_out) + conn_out.commit() + finally: + conn_out.close() + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to update output db: {e}") + + return { + "status": "success", + "account": account_dir.name, + "sessionId": session_id, + "messageId": f"{db_stem}:{table_name_out or table_name_in}:{int(local_id)}", + "mode": mode, + } + + +@router.post("/api/chat/messages/flip_direction", summary="反转某条消息在微信客户端的左右位置(packed_info_data)") +async def flip_chat_message_direction(request: Request) -> dict[str, Any]: + """Flip a message's bubble side in the *WeChat client* by swapping from/to in packed_info_data. + + Note: this intentionally edits `packed_info_data` (a protobuf-like BLOB). It is risky. + A snapshot is recorded so users can undo via `/api/chat/edits/reset_message`. + """ + + payload = await request.json() + if not isinstance(payload, dict): + raise HTTPException(status_code=400, detail="Invalid payload.") + + account = str(payload.get("account") or "").strip() or None + session_id = str(payload.get("session_id") or payload.get("username") or payload.get("sessionId") or "").strip() + message_id_in = str(payload.get("message_id") or payload.get("messageId") or "").strip() + + if not session_id: + raise HTTPException(status_code=400, detail="Missing session_id.") + if not message_id_in: + raise HTTPException(status_code=400, detail="Missing message_id.") + + account_dir = _resolve_account_dir(account) + try: + db_stem, table_name_in, local_id = chat_edit_store.parse_message_id(message_id_in) + except Exception: + raise HTTPException(status_code=400, detail="Invalid message_id.") + + msg_db_path_out = account_dir / f"{db_stem}.db" + if not msg_db_path_out.exists(): + raise HTTPException(status_code=404, detail="Message database not found.") + + msg_db_path_real, _res_db_path_real = _resolve_db_storage_message_paths(account_dir, db_stem) + if not msg_db_path_real.exists(): + raise HTTPException(status_code=404, detail="Real message database not found in db_storage.") + + def _coerce_packed_bytes(value: Any) -> Optional[bytes]: + if value is None: + return None + if isinstance(value, memoryview): + value = value.tobytes() + if isinstance(value, bytearray): + value = bytes(value) + if isinstance(value, bytes): + # If a past bug stored the blob as TEXT hex, sqlite may return ASCII bytes here. + try: + s = value.decode("ascii").strip() + except Exception: + return value + if not s: + return b"" + b = _hex_to_bytes(s) + if b is not None: + return b + if (len(s) % 2 == 0) and (_HEX_RE.fullmatch(s) is not None): + try: + return bytes.fromhex(s) + except Exception: + return value + return value + if isinstance(value, str): + s = value.strip() + if not s: + return b"" + b = _hex_to_bytes(s) + if b is not None: + return b + if (len(s) % 2 == 0) and (_HEX_RE.fullmatch(s) is not None): + try: + return bytes.fromhex(s) + except Exception: + return None + return s.encode("utf-8", errors="replace") + return None + + # Resolve output table name casing and read packed_info_data bytes from output DB. + table_name_out = "" + packed_before: Optional[bytes] = None + conn_out_probe: Optional[sqlite3.Connection] = None + try: + conn_out_probe = sqlite3.connect(str(msg_db_path_out), timeout=5) + conn_out_probe.row_factory = sqlite3.Row + conn_out_probe.text_factory = bytes + table_name_out = _normalize_table_name_case(conn_out_probe, table_name_in) or "" + if not table_name_out: + raise HTTPException(status_code=404, detail="Message table not found.") + cols = _table_info_columns(conn_out_probe, table_name_out) + if not cols or ("packed_info_data" not in cols): + raise HTTPException(status_code=400, detail="packed_info_data column not found.") + quoted = _quote_ident(table_name_out) + row = conn_out_probe.execute( + f"SELECT packed_info_data FROM {quoted} WHERE local_id = ? LIMIT 1", + (int(local_id),), + ).fetchone() + if row is None: + raise HTTPException(status_code=404, detail="Message not found in output database.") + packed_before = _coerce_packed_bytes(row["packed_info_data"]) + finally: + if conn_out_probe is not None: + try: + conn_out_probe.close() + except Exception: + pass + + if not packed_before: + raise HTTPException(status_code=400, detail="packed_info_data is empty; cannot flip direction.") + + try: + packed_after, old_from_id, old_to_id = _swap_packed_info_from_to(packed_before) + except Exception as e: + raise HTTPException(status_code=400, detail=f"Cannot flip packed_info_data: {e}") + + # Apply to output DB first, then real db_storage. Record snapshot so users can undo. + message_id = f"{db_stem}:{table_name_out or table_name_in}:{int(local_id)}" + created_record = False + + with _realtime_sync_lock(account_dir.name, session_id): + try: + wcdb_conn = WCDB_REALTIME.ensure_connected(account_dir) + except WCDBRealtimeError as e: + raise HTTPException(status_code=400, detail=str(e)) + + # Snapshot original row from real db_storage (only once). + existing_record = chat_edit_store.get_message_edit(account_dir.name, session_id, message_id) + if existing_record is None: + try: + select_sql = f"SELECT * FROM {_quote_ident(table_name_in)} WHERE local_id = {int(local_id)} LIMIT 1" + with wcdb_conn.lock: + rows = _wcdb_exec_query( + wcdb_conn.handle, + kind="message", + path=str(msg_db_path_real), + sql=select_sql, + ) + if not rows or not isinstance(rows[0], dict): + raise HTTPException(status_code=404, detail="Message not found in real db_storage.") + original_row = rows[0] + + chat_edit_store.upsert_original_once( + account=account_dir.name, + session_id=session_id, + db=db_stem, + table_name=table_name_out or table_name_in, + local_id=int(local_id), + original_msg=original_row, + original_resource=None, + now_ms=int(time.time() * 1000), + ) + created_record = True + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to write edit snapshot: {e}") + + # 1) Update output decrypted DB (so UI can show it in raw view). + try: + conn_out = sqlite3.connect(str(msg_db_path_out), timeout=5) + try: + sql_out, params_out = _build_sqlite_update_sql( + table_name=table_name_out, + updates={"packed_info_data": packed_after}, + where_local_id=int(local_id), + ) + cur = conn_out.execute(sql_out, params_out) + conn_out.commit() + if int(getattr(cur, "rowcount", 0) or 0) <= 0: + raise HTTPException(status_code=404, detail="Message not found in output database.") + finally: + conn_out.close() + except HTTPException: + if created_record: + try: + chat_edit_store.delete_message_edit(account_dir.name, session_id, message_id) + except Exception: + pass + raise + except Exception as e: + if created_record: + try: + chat_edit_store.delete_message_edit(account_dir.name, session_id, message_id) + except Exception: + pass + raise HTTPException(status_code=500, detail=f"Failed to update output database: {e}") + + # 2) Update real db_storage (source of truth). Rollback output on failure. + try: + sql_real = _build_wcdb_update_sql( + table_name=table_name_in, + updates={"packed_info_data": packed_after}, + where_local_id=int(local_id), + ) + with wcdb_conn.lock: + _wcdb_exec_query(wcdb_conn.handle, kind="message", path=str(msg_db_path_real), sql=sql_real) + except Exception as e: + # Roll back output changes. + try: + conn_rb = sqlite3.connect(str(msg_db_path_out), timeout=5) + try: + sql_rb, params_rb = _build_sqlite_update_sql( + table_name=table_name_out, + updates={"packed_info_data": packed_before}, + where_local_id=int(local_id), + ) + conn_rb.execute(sql_rb, params_rb) + conn_rb.commit() + finally: + conn_rb.close() + except Exception: + pass + + if created_record: + try: + chat_edit_store.delete_message_edit(account_dir.name, session_id, message_id) + except Exception: + pass + raise HTTPException(status_code=500, detail=f"Failed to update real db_storage: {e}") + + # Track which columns were modified so reset restores only those. + try: + chat_edit_store.merge_edited_columns( + account=account_dir.name, + session_id=session_id, + db=db_stem, + table_name=table_name_out or table_name_in, + local_id=int(local_id), + columns=["packed_info_data"], + ) + except Exception: + pass + + # Bump edit metadata for already-tracked messages. + if existing_record is not None: + try: + chat_edit_store.upsert_original_once( + account=account_dir.name, + session_id=session_id, + db=db_stem, + table_name=table_name_out or table_name_in, + local_id=int(local_id), + original_msg={}, + original_resource=None, + now_ms=int(time.time() * 1000), + ) + except Exception: + pass + + return { + "status": "success", + "account": account_dir.name, + "sessionId": session_id, + "messageId": message_id, + "before": { + "packed_info_data": _bytes_to_hex(packed_before), + "fromId": int(old_from_id), + "toId": int(old_to_id), + }, + "after": { + "packed_info_data": _bytes_to_hex(packed_after), + "fromId": int(old_to_id), + "toId": int(old_from_id), + }, + } + + +def _restore_message_from_snapshot( + *, + account_dir: Path, + session_id: str, + message_id: str, + record: dict[str, Any], + wcdb_conn, +) -> None: + db_stem, table_name, local_id_current = chat_edit_store.parse_message_id(message_id) + msg_db_path_out = account_dir / f"{db_stem}.db" + if not msg_db_path_out.exists(): + raise HTTPException(status_code=404, detail="Message database not found.") + + msg_db_path_real, res_db_path_real = _resolve_db_storage_message_paths(account_dir, db_stem) + if not msg_db_path_real.exists(): + raise HTTPException(status_code=404, detail="Real message database not found in db_storage.") + + original_msg = chat_edit_store.loads_json_with_blobs(str(record.get("original_msg_json") or "") or "") + if not isinstance(original_msg, dict): + raise HTTPException(status_code=500, detail="Invalid original snapshot.") + + original_resource = None + if str(record.get("original_resource_json") or ""): + try: + original_resource = chat_edit_store.loads_json_with_blobs(str(record.get("original_resource_json") or "") or "") + except Exception: + original_resource = None + + edited_cols: set[str] = set() + try: + raw = str(record.get("edited_cols_json") or "").strip() + if raw: + v = json.loads(raw) + if isinstance(v, list): + edited_cols = {str(x or "").strip().lower() for x in v if str(x or "").strip()} + except Exception: + edited_cols = set() + + # Backward compatible default: older records didn't track edited columns. + if not edited_cols: + edited_cols = {"message_content", "compress_content"} + + # Editing content implicitly clears compress_content unless explicitly provided. + if "message_content" in edited_cols: + edited_cols.add("compress_content") + + orig_key_map = {str(k or "").strip().lower(): str(k) for k in original_msg.keys()} + + # Read current create_time from real db to call wcdb_update_message reliably. + cur_create_time = 0 + try: + sql_ct = f"SELECT create_time FROM {_quote_ident(table_name)} WHERE local_id = {int(local_id_current)} LIMIT 1" + with wcdb_conn.lock: + rows = _wcdb_exec_query(wcdb_conn.handle, kind="message", path=str(msg_db_path_real), sql=sql_ct) + if rows and isinstance(rows[0], dict): + cur_create_time = int(rows[0].get("create_time") or 0) + except Exception: + cur_create_time = 0 + if cur_create_time <= 0: + raise HTTPException(status_code=404, detail="Message not found in real db_storage.") + + # Restore message_content via wcdb_update_message (best-effort). + # Some builds store message_content as an encrypted/compressed BLOB; WCDB exec_query may return it as bare hex. + # In that case, don't call update_message with the hex string; restoring the raw column bytes below is safer. + if "message_content" in edited_cols and "message_content" in orig_key_map: + try: + content = original_msg.get(orig_key_map["message_content"]) + if isinstance(content, str): + s = content.strip() + if s and (len(s) % 2 == 0) and (_HEX_RE.fullmatch(s) is not None): + s_lower = s.lower() + if (len(s) >= 64) or (s_lower.startswith("28b52ffd") and len(s) >= 16): + content = None + if isinstance(content, (bytes, bytearray, memoryview)): + try: + content = bytes(content).decode("utf-8", errors="replace") + except Exception: + content = "" + if content is not None: + _wcdb_update_message( + wcdb_conn.handle, + session_id=session_id, + local_id=int(local_id_current), + create_time=int(cur_create_time), + new_content=str(content or ""), + ) + except Exception: + pass + + # Restore only columns that were actually edited by the tool. + try: + restore_updates: dict[str, Any] = {} + for col_lc in sorted(edited_cols): + k = orig_key_map.get(col_lc) + if not k: + continue + restore_updates[k] = _normalize_edit_value(k, original_msg.get(k), from_snapshot=True) + + if restore_updates: + sql_real = _build_wcdb_update_sql( + table_name=table_name, + updates=restore_updates, + where_local_id=int(local_id_current), + ) + with wcdb_conn.lock: + _wcdb_exec_query(wcdb_conn.handle, kind="message", path=str(msg_db_path_real), sql=sql_real) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to restore real db_storage: {e}") + + # Restore output decrypted Msg_*. + try: + conn_out = sqlite3.connect(str(msg_db_path_out), timeout=5) + try: + tnorm = _normalize_table_name_case(conn_out, table_name) + if not tnorm: + raise HTTPException(status_code=404, detail="Message table not found.") + cols = _table_info_columns(conn_out, tnorm) + col_map = {str(c or "").strip().lower(): str(c) for c in cols if str(c or "").strip()} + restore_out: dict[str, Any] = {} + for col_lc in sorted(edited_cols): + col = col_map.get(col_lc) + k = orig_key_map.get(col_lc) + if not col or not k: + continue + restore_out[col] = _normalize_edit_value(col, original_msg.get(k), from_snapshot=True) + + if restore_out: + sql_out, params = _build_sqlite_update_sql( + table_name=tnorm, + updates=restore_out, + where_local_id=int(local_id_current), + ) + conn_out.execute(sql_out, params) + conn_out.commit() + finally: + conn_out.close() + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"Failed to restore output database: {e}") + + # Restore message_resource key fields (best-effort, by message_id). + need_restore_resource = any( + k in edited_cols for k in {"local_type", "create_time", "server_id", "origin_source", "local_id"} + ) + if need_restore_resource and isinstance(original_resource, dict): + try: + res_message_id = int(original_resource.get("message_id") or 0) + except Exception: + res_message_id = 0 + if res_message_id > 0: + restore_res: dict[str, Any] = {} + msg_to_res = { + "local_type": "message_local_type", + "create_time": "message_create_time", + "server_id": "message_svr_id", + "origin_source": "message_origin_source", + "local_id": "message_local_id", + } + for msg_col, res_col in msg_to_res.items(): + if msg_col not in edited_cols: + continue + if res_col in original_resource: + restore_res[res_col] = _normalize_edit_value(res_col, original_resource.get(res_col), from_snapshot=True) + if restore_res: + try: + parts = [f"{_quote_ident(k)} = {_sql_literal(v)}" for k, v in restore_res.items()] + sql_res_real = ( + "UPDATE MessageResourceInfo SET " + ", ".join(parts) + f" WHERE message_id = {int(res_message_id)}" + ) + if res_db_path_real.exists(): + with wcdb_conn.lock: + _wcdb_exec_query( + wcdb_conn.handle, + kind="message", + path=str(res_db_path_real), + sql=sql_res_real, + ) + except Exception: + pass + + try: + out_res_db_path = account_dir / "message_resource.db" + if out_res_db_path.exists(): + conn_res = sqlite3.connect(str(out_res_db_path), timeout=5) + try: + set_cols = ", ".join([f"{_quote_ident(k)} = ?" for k in restore_res.keys()]) + params = list(restore_res.values()) + [int(res_message_id)] + conn_res.execute(f"UPDATE MessageResourceInfo SET {set_cols} WHERE message_id = ?", params) + conn_res.commit() + finally: + conn_res.close() + except Exception: + pass + + +@router.post("/api/chat/edits/reset_message", summary="恢复某条消息到首次快照,并删除修改记录") +async def reset_chat_edited_message(request: Request) -> dict[str, Any]: + payload = await request.json() + if not isinstance(payload, dict): + raise HTTPException(status_code=400, detail="Invalid payload.") + + account = str(payload.get("account") or "").strip() or None + session_id = str(payload.get("session_id") or payload.get("username") or payload.get("sessionId") or "").strip() + message_id = str(payload.get("message_id") or payload.get("messageId") or "").strip() + if not session_id: + raise HTTPException(status_code=400, detail="Missing session_id.") + if not message_id: + raise HTTPException(status_code=400, detail="Missing message_id.") + + account_dir = _resolve_account_dir(account) + record = chat_edit_store.get_message_edit(account_dir.name, session_id, message_id) + if not record: + raise HTTPException(status_code=404, detail="Edit record not found.") + + with _realtime_sync_lock(account_dir.name, session_id): + try: + wcdb_conn = WCDB_REALTIME.ensure_connected(account_dir) + except WCDBRealtimeError as e: + raise HTTPException(status_code=400, detail=str(e)) + + _restore_message_from_snapshot( + account_dir=account_dir, + session_id=session_id, + message_id=message_id, + record=record, + wcdb_conn=wcdb_conn, + ) + + try: + chat_edit_store.delete_message_edit(account_dir.name, session_id, message_id) + except Exception: + pass + + return {"status": "success"} + + +@router.post("/api/chat/edits/reset_session", summary="一键恢复某会话下全部修改记录") +async def reset_chat_edited_session(request: Request) -> dict[str, Any]: + payload = await request.json() + if not isinstance(payload, dict): + raise HTTPException(status_code=400, detail="Invalid payload.") + + account = str(payload.get("account") or "").strip() or None + session_id = str(payload.get("session_id") or payload.get("username") or payload.get("sessionId") or "").strip() + if not session_id: + raise HTTPException(status_code=400, detail="Missing session_id.") + + account_dir = _resolve_account_dir(account) + records = chat_edit_store.list_messages(account_dir.name, session_id) + if not records: + return {"status": "success", "restored": 0, "failed": 0, "failures": []} + + restored = 0 + failures: list[dict[str, Any]] = [] + + with _realtime_sync_lock(account_dir.name, session_id): + try: + wcdb_conn = WCDB_REALTIME.ensure_connected(account_dir) + except WCDBRealtimeError as e: + raise HTTPException(status_code=400, detail=str(e)) + + for rec in records: + mid = str(rec.get("message_id") or "").strip() + if not mid: + try: + mid = chat_edit_store.format_message_id( + rec.get("db") or "", + rec.get("table_name") or "", + int(rec.get("local_id") or 0), + ) + except Exception: + mid = "" + if not mid: + continue + try: + _restore_message_from_snapshot( + account_dir=account_dir, + session_id=session_id, + message_id=mid, + record=rec, + wcdb_conn=wcdb_conn, + ) + try: + chat_edit_store.delete_message_edit(account_dir.name, session_id, mid) + except Exception: + pass + restored += 1 + except Exception as e: + failures.append({"messageId": mid, "error": str(e)}) + + return {"status": "success", "restored": int(restored), "failed": int(len(failures)), "failures": failures} diff --git a/src/wechat_decrypt_tool/wcdb_realtime.py b/src/wechat_decrypt_tool/wcdb_realtime.py index aeee2b3..a875033 100644 --- a/src/wechat_decrypt_tool/wcdb_realtime.py +++ b/src/wechat_decrypt_tool/wcdb_realtime.py @@ -175,6 +175,36 @@ def _load_wcdb_lib() -> ctypes.CDLL: except Exception: pass + # Optional (newer DLLs): update a single message content in message db. + # Signature: wcdb_update_message(handle, sessionId, localId, createTime, newContent, outError) + try: + lib.wcdb_update_message.argtypes = [ + ctypes.c_int64, + ctypes.c_char_p, + ctypes.c_int64, + ctypes.c_int32, + ctypes.c_char_p, + ctypes.POINTER(ctypes.c_char_p), + ] + lib.wcdb_update_message.restype = ctypes.c_int + except Exception: + pass + + # Optional (newer DLLs): delete a single message in message db. + # Signature: wcdb_delete_message(handle, sessionId, localId, createTime, dbPathHint, outError) + try: + lib.wcdb_delete_message.argtypes = [ + ctypes.c_int64, + ctypes.c_char_p, + ctypes.c_int64, + ctypes.c_int32, + ctypes.c_char_p, + ctypes.POINTER(ctypes.c_char_p), + ] + lib.wcdb_delete_message.restype = ctypes.c_int + except Exception: + pass + # Optional (newer DLLs): wcdb_get_sns_timeline(handle, limit, offset, usernames_json, keyword, start_time, end_time, out_json) try: lib.wcdb_get_sns_timeline.argtypes = [ @@ -259,6 +289,32 @@ def _call_out_json(fn, *args) -> str: pass +def _call_out_error(fn, *args) -> None: + lib = _load_wcdb_lib() + out = ctypes.c_char_p() + rc = int(fn(*args, ctypes.byref(out))) + try: + if rc != 0: + err = "" + try: + if out.value: + err = (out.value or b"").decode("utf-8", errors="replace") + except Exception: + err = "" + + logs = get_native_logs() + hint = f" logs={logs[:6]}" if logs else "" + if err: + raise WCDBRealtimeError(f"wcdb api call failed: {rc}. error={err}.{hint}") + raise WCDBRealtimeError(f"wcdb api call failed: {rc}.{hint}") + finally: + try: + if out.value: + lib.wcdb_free_string(out) + except Exception: + pass + + def get_native_logs() -> list[str]: try: _ensure_initialized() @@ -500,6 +556,64 @@ def exec_query(handle: int, *, kind: str, path: Optional[str], sql: str) -> list return [] +def update_message(handle: int, *, session_id: str, local_id: int, create_time: int, new_content: str) -> None: + """Update a single message content in the live encrypted db_storage via WCDB. + + Requires wcdb_update_message export in wcdb_api.dll. + """ + _ensure_initialized() + lib = _load_wcdb_lib() + fn = getattr(lib, "wcdb_update_message", None) + if not fn: + raise WCDBRealtimeError("Current wcdb_api.dll does not support update_message.") + + sid = str(session_id or "").strip() + if not sid: + raise WCDBRealtimeError("Missing session_id for update_message.") + + _call_out_error( + fn, + ctypes.c_int64(int(handle)), + sid.encode("utf-8"), + ctypes.c_int64(int(local_id or 0)), + ctypes.c_int32(int(create_time or 0)), + str(new_content or "").encode("utf-8"), + ) + + +def delete_message( + handle: int, + *, + session_id: str, + local_id: int, + create_time: int, + db_path_hint: str | None = None, +) -> None: + """Delete a single message in the live encrypted db_storage via WCDB. + + Requires wcdb_delete_message export in wcdb_api.dll. + """ + _ensure_initialized() + lib = _load_wcdb_lib() + fn = getattr(lib, "wcdb_delete_message", None) + if not fn: + raise WCDBRealtimeError("Current wcdb_api.dll does not support delete_message.") + + sid = str(session_id or "").strip() + if not sid: + raise WCDBRealtimeError("Missing session_id for delete_message.") + + hint = str(db_path_hint or "").strip() + _call_out_error( + fn, + ctypes.c_int64(int(handle)), + sid.encode("utf-8"), + ctypes.c_int64(int(local_id or 0)), + ctypes.c_int32(int(create_time or 0)), + hint.encode("utf-8"), + ) + + def get_sns_timeline( handle: int, *, diff --git a/tests/test_chat_edit_store.py b/tests/test_chat_edit_store.py new file mode 100644 index 0000000..c8ae040 --- /dev/null +++ b/tests/test_chat_edit_store.py @@ -0,0 +1,182 @@ +import os +import sys +import json +import sqlite3 +import unittest +import importlib +from pathlib import Path +from tempfile import TemporaryDirectory + + +ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(ROOT / "src")) + + +class TestChatEditStore(unittest.TestCase): + def setUp(self): + self._prev_data_dir = os.environ.get("WECHAT_TOOL_DATA_DIR") + self._td = TemporaryDirectory() + os.environ["WECHAT_TOOL_DATA_DIR"] = self._td.name + + import wechat_decrypt_tool.app_paths as app_paths + import wechat_decrypt_tool.chat_edit_store as chat_edit_store + + importlib.reload(app_paths) + importlib.reload(chat_edit_store) + + self.app_paths = app_paths + self.store = chat_edit_store + + def tearDown(self): + if self._prev_data_dir is None: + os.environ.pop("WECHAT_TOOL_DATA_DIR", None) + else: + os.environ["WECHAT_TOOL_DATA_DIR"] = self._prev_data_dir + self._td.cleanup() + + def test_ensure_schema_creates_db(self): + self.store.ensure_schema() + db_path = self.app_paths.get_output_dir() / "message_edits.db" + self.assertTrue(db_path.exists()) + + conn = sqlite3.connect(str(db_path)) + try: + row = conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='message_edits' LIMIT 1" + ).fetchone() + self.assertIsNotNone(row) + finally: + conn.close() + + def test_blob_hex_roundtrip(self): + payload = {"a": b"\x00\xff", "nested": {"b": memoryview(b"\x01\x02")}} + dumped = self.store.dumps_json_with_blobs(payload) + self.assertIn("0x00ff", dumped.lower()) + self.assertIn("0x0102", dumped.lower()) + + loaded = self.store.loads_json_with_blobs(dumped) + self.assertEqual(loaded["a"], b"\x00\xff") + self.assertEqual(loaded["nested"]["b"], b"\x01\x02") + + def test_message_id_format_parse(self): + mid = self.store.format_message_id("message_0", "Msg_foo", 123) + self.assertEqual(mid, "message_0:Msg_foo:123") + + db, table, local_id = self.store.parse_message_id(mid) + self.assertEqual(db, "message_0") + self.assertEqual(table, "Msg_foo") + self.assertEqual(local_id, 123) + + with self.assertRaises(ValueError): + self.store.parse_message_id("bad") + + def test_upsert_original_once_does_not_overwrite_snapshot(self): + now1 = 1000 + now2 = 2000 + self.store.upsert_original_once( + account="wxid_me", + session_id="wxid_you", + db="message_0", + table_name="Msg_foo", + local_id=1, + original_msg={"local_id": 1, "message_content": "hello", "compress_content": b"\x01"}, + original_resource={"message_id": 9, "packed_info": b"\x02"}, + now_ms=now1, + ) + + self.store.upsert_original_once( + account="wxid_me", + session_id="wxid_you", + db="message_0", + table_name="Msg_foo", + local_id=1, + original_msg={"local_id": 1, "message_content": "SHOULD_NOT_OVERWRITE", "compress_content": b"\x03"}, + original_resource={"message_id": 9, "packed_info": b"\x04"}, + now_ms=now2, + ) + + mid = self.store.format_message_id("message_0", "Msg_foo", 1) + item = self.store.get_message_edit("wxid_me", "wxid_you", mid) + self.assertIsNotNone(item) + self.assertEqual(int(item["first_edited_at"]), now1) + self.assertEqual(int(item["last_edited_at"]), now2) + self.assertEqual(int(item["edit_count"]), 2) + + original_msg = self.store.loads_json_with_blobs(item["original_msg_json"]) + self.assertEqual(original_msg["message_content"], "hello") + self.assertEqual(original_msg["compress_content"], b"\x01") + + original_res = self.store.loads_json_with_blobs(item["original_resource_json"]) + self.assertEqual(int(original_res["message_id"]), 9) + self.assertEqual(original_res["packed_info"], b"\x02") + + def test_update_message_edit_local_id_moves_primary_key(self): + self.store.upsert_original_once( + account="wxid_me", + session_id="wxid_you", + db="message_0", + table_name="Msg_foo", + local_id=10, + original_msg={"local_id": 10, "message_content": "hello"}, + original_resource=None, + now_ms=1234, + ) + + ok = self.store.update_message_edit_local_id( + account="wxid_me", + session_id="wxid_you", + db="message_0", + table_name="Msg_foo", + old_local_id=10, + new_local_id=11, + ) + self.assertTrue(ok) + + old_mid = self.store.format_message_id("message_0", "Msg_foo", 10) + new_mid = self.store.format_message_id("message_0", "Msg_foo", 11) + self.assertIsNone(self.store.get_message_edit("wxid_me", "wxid_you", old_mid)) + self.assertIsNotNone(self.store.get_message_edit("wxid_me", "wxid_you", new_mid)) + + def test_list_sessions_counts(self): + self.store.upsert_original_once( + account="wxid_me", + session_id="u1", + db="message_0", + table_name="Msg_foo", + local_id=1, + original_msg={"local_id": 1, "message_content": "a"}, + original_resource=None, + now_ms=100, + ) + self.store.upsert_original_once( + account="wxid_me", + session_id="u1", + db="message_0", + table_name="Msg_foo", + local_id=2, + original_msg={"local_id": 2, "message_content": "b"}, + original_resource=None, + now_ms=200, + ) + self.store.upsert_original_once( + account="wxid_me", + session_id="u2", + db="message_0", + table_name="Msg_foo", + local_id=3, + original_msg={"local_id": 3, "message_content": "c"}, + original_resource=None, + now_ms=300, + ) + + stats = self.store.list_sessions("wxid_me") + by_sid = {s["session_id"]: s for s in stats} + self.assertEqual(int(by_sid["u1"]["msg_count"]), 2) + self.assertEqual(int(by_sid["u1"]["last_edited_at"]), 200) + self.assertEqual(int(by_sid["u2"]["msg_count"]), 1) + self.assertEqual(int(by_sid["u2"]["last_edited_at"]), 300) + + +if __name__ == "__main__": + unittest.main() +