mirror of
https://github.com/LifeArchiveProject/WeChatDataAnalysis.git
synced 2026-06-18 15:54:08 +08:00
feat(chat-edit-backend): 新增消息编辑快照、回滚与修复接口
- 新增 message_edits 存储,记录首次快照、编辑次数与已修改字段 - 新增消息编辑链路:raw 查询、编辑、修改状态、按消息/会话恢复 - 新增发送者修复与 packed_info_data 方向反转能力 - 编辑流程支持 output/db_storage 双写与失败回滚 - path_fix 改为按需校验 db_storage_path,并缓存 body 避免重复读取问题 - 补充 chat_edit_store 单元测试覆盖核心行为
This commit is contained in:
@@ -0,0 +1,487 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import sqlite3
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from .app_paths import get_output_dir
|
||||||
|
|
||||||
|
_HEX_RE = re.compile(r"^[0-9a-fA-F]+$")
|
||||||
|
|
||||||
|
|
||||||
|
def _db_path() -> Path:
|
||||||
|
return get_output_dir() / "message_edits.db"
|
||||||
|
|
||||||
|
|
||||||
|
def _connect() -> sqlite3.Connection:
|
||||||
|
db_path = _db_path()
|
||||||
|
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
conn = sqlite3.connect(str(db_path), timeout=5)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
_ensure_schema(conn)
|
||||||
|
return conn
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_schema() -> None:
|
||||||
|
conn: Optional[sqlite3.Connection] = None
|
||||||
|
try:
|
||||||
|
conn = _connect()
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
if conn is not None:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_schema(conn: sqlite3.Connection) -> None:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS message_edits (
|
||||||
|
account TEXT NOT NULL,
|
||||||
|
session_id TEXT NOT NULL,
|
||||||
|
db TEXT NOT NULL,
|
||||||
|
table_name TEXT NOT NULL,
|
||||||
|
local_id INTEGER NOT NULL,
|
||||||
|
first_edited_at INTEGER NOT NULL,
|
||||||
|
last_edited_at INTEGER NOT NULL,
|
||||||
|
edit_count INTEGER NOT NULL,
|
||||||
|
original_msg_json TEXT NOT NULL,
|
||||||
|
original_resource_json TEXT,
|
||||||
|
edited_cols_json TEXT,
|
||||||
|
PRIMARY KEY (account, session_id, db, table_name, local_id)
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
"CREATE INDEX IF NOT EXISTS idx_message_edits_account_session ON message_edits(account, session_id)"
|
||||||
|
)
|
||||||
|
conn.execute("CREATE INDEX IF NOT EXISTS idx_message_edits_account_last ON message_edits(account, last_edited_at)")
|
||||||
|
|
||||||
|
# Backwards-compatible migrations for existing DBs.
|
||||||
|
try:
|
||||||
|
cols = {
|
||||||
|
str(r[1] or "").strip().lower()
|
||||||
|
for r in conn.execute("PRAGMA table_info(message_edits)").fetchall()
|
||||||
|
if r and len(r) > 1 and r[1]
|
||||||
|
}
|
||||||
|
if "edited_cols_json" not in cols:
|
||||||
|
conn.execute("ALTER TABLE message_edits ADD COLUMN edited_cols_json TEXT")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def _now_ms() -> int:
|
||||||
|
return int(time.time() * 1000)
|
||||||
|
|
||||||
|
|
||||||
|
def format_message_id(db: str, table_name: str, local_id: int) -> str:
|
||||||
|
return f"{str(db or '').strip()}:{str(table_name or '').strip()}:{int(local_id or 0)}"
|
||||||
|
|
||||||
|
|
||||||
|
def parse_message_id(message_id: str) -> tuple[str, str, int]:
|
||||||
|
parts = str(message_id or "").split(":", 2)
|
||||||
|
if len(parts) != 3:
|
||||||
|
raise ValueError("Invalid message_id format.")
|
||||||
|
db = str(parts[0] or "").strip()
|
||||||
|
table_name = str(parts[1] or "").strip()
|
||||||
|
try:
|
||||||
|
local_id = int(parts[2] or 0)
|
||||||
|
except Exception:
|
||||||
|
raise ValueError("Invalid message_id format.")
|
||||||
|
if not db or not table_name or local_id <= 0:
|
||||||
|
raise ValueError("Invalid message_id format.")
|
||||||
|
return db, table_name, local_id
|
||||||
|
|
||||||
|
|
||||||
|
def _bytes_to_hex(value: bytes) -> str:
|
||||||
|
return "0x" + value.hex()
|
||||||
|
|
||||||
|
|
||||||
|
def _hex_to_bytes(value: str) -> Optional[bytes]:
|
||||||
|
s = str(value or "").strip()
|
||||||
|
if not s.startswith("0x"):
|
||||||
|
return None
|
||||||
|
hex_part = s[2:]
|
||||||
|
if (not hex_part) or (len(hex_part) % 2 != 0) or (_HEX_RE.match(hex_part) is None):
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return bytes.fromhex(hex_part)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _jsonify_blobs(obj: Any) -> Any:
|
||||||
|
if obj is None:
|
||||||
|
return None
|
||||||
|
if isinstance(obj, (bytes, bytearray, memoryview)):
|
||||||
|
return _bytes_to_hex(bytes(obj))
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return {str(k): _jsonify_blobs(v) for k, v in obj.items()}
|
||||||
|
if isinstance(obj, (list, tuple)):
|
||||||
|
return [_jsonify_blobs(v) for v in obj]
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
def _dejsonify_blobs(obj: Any) -> Any:
|
||||||
|
if obj is None:
|
||||||
|
return None
|
||||||
|
if isinstance(obj, str):
|
||||||
|
b = _hex_to_bytes(obj)
|
||||||
|
return b if b is not None else obj
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return {str(k): _dejsonify_blobs(v) for k, v in obj.items()}
|
||||||
|
if isinstance(obj, list):
|
||||||
|
return [_dejsonify_blobs(v) for v in obj]
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
def dumps_json_with_blobs(obj: Any) -> str:
|
||||||
|
return json.dumps(_jsonify_blobs(obj), ensure_ascii=False, separators=(",", ":"))
|
||||||
|
|
||||||
|
|
||||||
|
def loads_json_with_blobs(payload: str) -> Any:
|
||||||
|
return _dejsonify_blobs(json.loads(str(payload or "") or "null"))
|
||||||
|
|
||||||
|
|
||||||
|
def upsert_original_once(
|
||||||
|
*,
|
||||||
|
account: str,
|
||||||
|
session_id: str,
|
||||||
|
db: str,
|
||||||
|
table_name: str,
|
||||||
|
local_id: int,
|
||||||
|
original_msg: dict[str, Any],
|
||||||
|
original_resource: Optional[dict[str, Any]],
|
||||||
|
now_ms: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Insert the original snapshot for a message only once, then bump counters on subsequent edits."""
|
||||||
|
a = str(account or "").strip()
|
||||||
|
sid = str(session_id or "").strip()
|
||||||
|
db_norm = str(db or "").strip()
|
||||||
|
t = str(table_name or "").strip()
|
||||||
|
lid = int(local_id or 0)
|
||||||
|
if not a or not sid or not db_norm or not t or lid <= 0:
|
||||||
|
raise ValueError("Missing required keys for message edit store.")
|
||||||
|
|
||||||
|
ts = int(now_ms if now_ms is not None else _now_ms())
|
||||||
|
msg_json = dumps_json_with_blobs(original_msg or {})
|
||||||
|
res_json = dumps_json_with_blobs(original_resource) if original_resource is not None else None
|
||||||
|
|
||||||
|
conn: Optional[sqlite3.Connection] = None
|
||||||
|
try:
|
||||||
|
conn = _connect()
|
||||||
|
existing = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT 1
|
||||||
|
FROM message_edits
|
||||||
|
WHERE account = ? AND session_id = ? AND db = ? AND table_name = ? AND local_id = ?
|
||||||
|
LIMIT 1
|
||||||
|
""",
|
||||||
|
(a, sid, db_norm, t, lid),
|
||||||
|
).fetchone()
|
||||||
|
if existing is None:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO message_edits(
|
||||||
|
account, session_id, db, table_name, local_id,
|
||||||
|
first_edited_at, last_edited_at, edit_count,
|
||||||
|
original_msg_json, original_resource_json, edited_cols_json
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(a, sid, db_norm, t, lid, ts, ts, 1, msg_json, res_json, None),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
UPDATE message_edits
|
||||||
|
SET last_edited_at = ?, edit_count = edit_count + 1
|
||||||
|
WHERE account = ? AND session_id = ? AND db = ? AND table_name = ? AND local_id = ?
|
||||||
|
""",
|
||||||
|
(ts, a, sid, db_norm, t, lid),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
if conn is not None:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_json_str_list(payload: Any) -> list[str]:
|
||||||
|
if payload is None:
|
||||||
|
return []
|
||||||
|
if isinstance(payload, (list, tuple)):
|
||||||
|
return [str(x or "").strip() for x in payload if str(x or "").strip()]
|
||||||
|
s = str(payload or "").strip()
|
||||||
|
if not s:
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
v = json.loads(s)
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
if not isinstance(v, list):
|
||||||
|
return []
|
||||||
|
return [str(x or "").strip() for x in v if str(x or "").strip()]
|
||||||
|
|
||||||
|
|
||||||
|
def merge_edited_columns(
|
||||||
|
*,
|
||||||
|
account: str,
|
||||||
|
session_id: str,
|
||||||
|
db: str,
|
||||||
|
table_name: str,
|
||||||
|
local_id: int,
|
||||||
|
columns: list[str],
|
||||||
|
) -> bool:
|
||||||
|
"""Merge edited message column names into the per-message edit record.
|
||||||
|
|
||||||
|
This allows reset to restore only the fields actually modified by the tool.
|
||||||
|
"""
|
||||||
|
a = str(account or "").strip()
|
||||||
|
sid = str(session_id or "").strip()
|
||||||
|
db_norm = str(db or "").strip()
|
||||||
|
t = str(table_name or "").strip()
|
||||||
|
lid = int(local_id or 0)
|
||||||
|
if not a or not sid or not db_norm or not t or lid <= 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
cols_in = [str(x or "").strip() for x in (columns or []) if str(x or "").strip()]
|
||||||
|
if not cols_in:
|
||||||
|
return True
|
||||||
|
|
||||||
|
conn: Optional[sqlite3.Connection] = None
|
||||||
|
try:
|
||||||
|
conn = _connect()
|
||||||
|
row = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT edited_cols_json
|
||||||
|
FROM message_edits
|
||||||
|
WHERE account = ? AND session_id = ? AND db = ? AND table_name = ? AND local_id = ?
|
||||||
|
LIMIT 1
|
||||||
|
""",
|
||||||
|
(a, sid, db_norm, t, lid),
|
||||||
|
).fetchone()
|
||||||
|
if row is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
existing = _parse_json_str_list(row[0] if row and len(row) else None)
|
||||||
|
merged = {c.lower() for c in existing if c} | {c.lower() for c in cols_in if c}
|
||||||
|
merged_list = sorted(merged)
|
||||||
|
payload = json.dumps(merged_list, ensure_ascii=False, separators=(",", ":"))
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
UPDATE message_edits
|
||||||
|
SET edited_cols_json = ?
|
||||||
|
WHERE account = ? AND session_id = ? AND db = ? AND table_name = ? AND local_id = ?
|
||||||
|
""",
|
||||||
|
(payload, a, sid, db_norm, t, lid),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
return True
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
if conn is not None:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _row_to_dict(row: Optional[sqlite3.Row]) -> Optional[dict[str, Any]]:
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
out: dict[str, Any] = {}
|
||||||
|
for k in row.keys():
|
||||||
|
out[str(k)] = row[k]
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def list_sessions(account: str) -> list[dict[str, Any]]:
|
||||||
|
a = str(account or "").strip()
|
||||||
|
if not a:
|
||||||
|
return []
|
||||||
|
|
||||||
|
conn: Optional[sqlite3.Connection] = None
|
||||||
|
try:
|
||||||
|
conn = _connect()
|
||||||
|
rows = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT session_id, COUNT(*) AS msg_count, MAX(last_edited_at) AS last_edited_at
|
||||||
|
FROM message_edits
|
||||||
|
WHERE account = ?
|
||||||
|
GROUP BY session_id
|
||||||
|
ORDER BY last_edited_at DESC
|
||||||
|
""",
|
||||||
|
(a,),
|
||||||
|
).fetchall()
|
||||||
|
out: list[dict[str, Any]] = []
|
||||||
|
for r in rows:
|
||||||
|
try:
|
||||||
|
sid = str(r["session_id"] or "").strip()
|
||||||
|
except Exception:
|
||||||
|
sid = ""
|
||||||
|
if not sid:
|
||||||
|
continue
|
||||||
|
out.append(
|
||||||
|
{
|
||||||
|
"session_id": sid,
|
||||||
|
"msg_count": int(r["msg_count"] or 0),
|
||||||
|
"last_edited_at": int(r["last_edited_at"] or 0),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
if conn is not None:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def list_messages(account: str, session_id: str) -> list[dict[str, Any]]:
|
||||||
|
a = str(account or "").strip()
|
||||||
|
sid = str(session_id or "").strip()
|
||||||
|
if not a or not sid:
|
||||||
|
return []
|
||||||
|
|
||||||
|
conn: Optional[sqlite3.Connection] = None
|
||||||
|
try:
|
||||||
|
conn = _connect()
|
||||||
|
rows = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT *
|
||||||
|
FROM message_edits
|
||||||
|
WHERE account = ? AND session_id = ?
|
||||||
|
ORDER BY last_edited_at ASC, local_id ASC
|
||||||
|
""",
|
||||||
|
(a, sid),
|
||||||
|
).fetchall()
|
||||||
|
out: list[dict[str, Any]] = []
|
||||||
|
for r in rows:
|
||||||
|
item = _row_to_dict(r) or {}
|
||||||
|
try:
|
||||||
|
item["message_id"] = format_message_id(item.get("db") or "", item.get("table_name") or "", item.get("local_id") or 0)
|
||||||
|
except Exception:
|
||||||
|
item["message_id"] = ""
|
||||||
|
out.append(item)
|
||||||
|
return out
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
if conn is not None:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def get_message_edit(account: str, session_id: str, message_id: str) -> Optional[dict[str, Any]]:
|
||||||
|
a = str(account or "").strip()
|
||||||
|
sid = str(session_id or "").strip()
|
||||||
|
if not a or not sid or not message_id:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
db, table_name, local_id = parse_message_id(message_id)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
conn: Optional[sqlite3.Connection] = None
|
||||||
|
try:
|
||||||
|
conn = _connect()
|
||||||
|
row = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT *
|
||||||
|
FROM message_edits
|
||||||
|
WHERE account = ? AND session_id = ? AND db = ? AND table_name = ? AND local_id = ?
|
||||||
|
LIMIT 1
|
||||||
|
""",
|
||||||
|
(a, sid, db, table_name, int(local_id)),
|
||||||
|
).fetchone()
|
||||||
|
item = _row_to_dict(row)
|
||||||
|
if not item:
|
||||||
|
return None
|
||||||
|
item["message_id"] = format_message_id(db, table_name, local_id)
|
||||||
|
return item
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
if conn is not None:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def delete_message_edit(account: str, session_id: str, message_id: str) -> bool:
|
||||||
|
a = str(account or "").strip()
|
||||||
|
sid = str(session_id or "").strip()
|
||||||
|
if not a or not sid or not message_id:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
db, table_name, local_id = parse_message_id(message_id)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
conn: Optional[sqlite3.Connection] = None
|
||||||
|
try:
|
||||||
|
conn = _connect()
|
||||||
|
cur = conn.execute(
|
||||||
|
"""
|
||||||
|
DELETE FROM message_edits
|
||||||
|
WHERE account = ? AND session_id = ? AND db = ? AND table_name = ? AND local_id = ?
|
||||||
|
""",
|
||||||
|
(a, sid, db, table_name, int(local_id)),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
return int(getattr(cur, "rowcount", 0) or 0) > 0
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
if conn is not None:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def update_message_edit_local_id(
|
||||||
|
*,
|
||||||
|
account: str,
|
||||||
|
session_id: str,
|
||||||
|
db: str,
|
||||||
|
table_name: str,
|
||||||
|
old_local_id: int,
|
||||||
|
new_local_id: int,
|
||||||
|
) -> bool:
|
||||||
|
"""Update the primary key local_id for an existing edit record (unsafe operations may change Msg.local_id)."""
|
||||||
|
a = str(account or "").strip()
|
||||||
|
sid = str(session_id or "").strip()
|
||||||
|
db_norm = str(db or "").strip()
|
||||||
|
t = str(table_name or "").strip()
|
||||||
|
old_lid = int(old_local_id or 0)
|
||||||
|
new_lid = int(new_local_id or 0)
|
||||||
|
if not a or not sid or not db_norm or not t or old_lid <= 0 or new_lid <= 0:
|
||||||
|
return False
|
||||||
|
if old_lid == new_lid:
|
||||||
|
return True
|
||||||
|
|
||||||
|
conn: Optional[sqlite3.Connection] = None
|
||||||
|
try:
|
||||||
|
conn = _connect()
|
||||||
|
cur = conn.execute(
|
||||||
|
"""
|
||||||
|
UPDATE message_edits
|
||||||
|
SET local_id = ?
|
||||||
|
WHERE account = ? AND session_id = ? AND db = ? AND table_name = ? AND local_id = ?
|
||||||
|
""",
|
||||||
|
(new_lid, a, sid, db_norm, t, old_lid),
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
return int(getattr(cur, "rowcount", 0) or 0) > 0
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
if conn is not None:
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
Binary file not shown.
@@ -32,10 +32,8 @@ class PathFixRequest(Request):
|
|||||||
def _validate_paths_in_json(self, json_data: dict) -> Optional[str]:
|
def _validate_paths_in_json(self, json_data: dict) -> Optional[str]:
|
||||||
"""验证JSON中的路径,返回错误信息(如果有)"""
|
"""验证JSON中的路径,返回错误信息(如果有)"""
|
||||||
logger.info(f"开始验证路径,JSON数据: {json_data}")
|
logger.info(f"开始验证路径,JSON数据: {json_data}")
|
||||||
# 检查db_storage_path字段(现在是必需的)
|
# 仅在提供 db_storage_path 时进行校验(例如 /api/decrypt)。
|
||||||
if 'db_storage_path' not in json_data:
|
# 其它 API 的 JSON payload 不一定包含路径字段,不应强制要求。
|
||||||
return "缺少必需的db_storage_path参数,请提供具体的数据库存储路径。"
|
|
||||||
|
|
||||||
if 'db_storage_path' in json_data:
|
if 'db_storage_path' in json_data:
|
||||||
path = json_data['db_storage_path']
|
path = json_data['db_storage_path']
|
||||||
|
|
||||||
@@ -115,11 +113,16 @@ class PathFixRequest(Request):
|
|||||||
|
|
||||||
async def body(self) -> bytes:
|
async def body(self) -> bytes:
|
||||||
"""重写body方法,预处理JSON中的路径问题"""
|
"""重写body方法,预处理JSON中的路径问题"""
|
||||||
|
cached = getattr(self.state, "_pathfix_body_bytes", None)
|
||||||
|
if isinstance(cached, (bytes, bytearray)):
|
||||||
|
return bytes(cached)
|
||||||
|
|
||||||
body = await super().body()
|
body = await super().body()
|
||||||
|
|
||||||
# 只处理JSON请求
|
# 只处理JSON请求
|
||||||
content_type = self.headers.get("content-type", "")
|
content_type = self.headers.get("content-type", "")
|
||||||
if "application/json" not in content_type:
|
if "application/json" not in content_type:
|
||||||
|
self.state._pathfix_body_bytes = body
|
||||||
return body
|
return body
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -134,6 +137,7 @@ class PathFixRequest(Request):
|
|||||||
logger.info(f"检测到路径错误: {path_error}")
|
logger.info(f"检测到路径错误: {path_error}")
|
||||||
# 我们将错误信息存储在请求中,稍后在路由处理器中检查
|
# 我们将错误信息存储在请求中,稍后在路由处理器中检查
|
||||||
self.state.path_validation_error = path_error
|
self.state.path_validation_error = path_error
|
||||||
|
self.state._pathfix_body_bytes = body
|
||||||
return body
|
return body
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
# JSON格式错误,继续尝试修复
|
# JSON格式错误,继续尝试修复
|
||||||
@@ -169,17 +173,30 @@ class PathFixRequest(Request):
|
|||||||
if path_error:
|
if path_error:
|
||||||
logger.info(f"修复后检测到路径错误: {path_error}")
|
logger.info(f"修复后检测到路径错误: {path_error}")
|
||||||
self.state.path_validation_error = path_error
|
self.state.path_validation_error = path_error
|
||||||
return fixed_body_str.encode('utf-8')
|
fixed_bytes = fixed_body_str.encode('utf-8')
|
||||||
|
self.state._pathfix_body_bytes = fixed_bytes
|
||||||
|
try:
|
||||||
|
self._body = fixed_bytes # type: ignore[attr-defined]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return fixed_bytes
|
||||||
else:
|
else:
|
||||||
logger.info(f"修复后路径验证通过")
|
logger.info(f"修复后路径验证通过")
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
logger.warning(f"修复后JSON仍然解析失败: {e}")
|
logger.warning(f"修复后JSON仍然解析失败: {e}")
|
||||||
|
|
||||||
return fixed_body_str.encode('utf-8')
|
fixed_bytes = fixed_body_str.encode('utf-8')
|
||||||
|
self.state._pathfix_body_bytes = fixed_bytes
|
||||||
|
try:
|
||||||
|
self._body = fixed_bytes # type: ignore[attr-defined]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return fixed_bytes
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 如果处理失败,返回原始body
|
# 如果处理失败,返回原始body
|
||||||
logger.warning(f"JSON路径修复失败,使用原始请求体: {e}")
|
logger.warning(f"JSON路径修复失败,使用原始请求体: {e}")
|
||||||
|
self.state._pathfix_body_bytes = body
|
||||||
return body
|
return body
|
||||||
|
|
||||||
|
|
||||||
@@ -193,12 +210,17 @@ class PathFixRoute(APIRoute):
|
|||||||
# 将Request替换为我们的自定义Request
|
# 将Request替换为我们的自定义Request
|
||||||
custom_request = PathFixRequest(request.scope, request.receive)
|
custom_request = PathFixRequest(request.scope, request.receive)
|
||||||
|
|
||||||
# 检查是否有路径验证错误
|
# 仅对 JSON 请求预读 body,以触发路径修复/校验逻辑,并在发现错误时提前返回 400。
|
||||||
if hasattr(custom_request.state, 'path_validation_error'):
|
try:
|
||||||
raise HTTPException(
|
content_type = (custom_request.headers.get("content-type", "") or "").lower()
|
||||||
status_code=400,
|
if "application/json" in content_type:
|
||||||
detail=custom_request.state.path_validation_error,
|
await custom_request.body()
|
||||||
)
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
path_err = getattr(custom_request.state, "path_validation_error", None)
|
||||||
|
if path_err:
|
||||||
|
raise HTTPException(status_code=400, detail=path_err)
|
||||||
|
|
||||||
return await original_route_handler(custom_request)
|
return await original_route_handler(custom_request)
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -175,6 +175,36 @@ def _load_wcdb_lib() -> ctypes.CDLL:
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# Optional (newer DLLs): update a single message content in message db.
|
||||||
|
# Signature: wcdb_update_message(handle, sessionId, localId, createTime, newContent, outError)
|
||||||
|
try:
|
||||||
|
lib.wcdb_update_message.argtypes = [
|
||||||
|
ctypes.c_int64,
|
||||||
|
ctypes.c_char_p,
|
||||||
|
ctypes.c_int64,
|
||||||
|
ctypes.c_int32,
|
||||||
|
ctypes.c_char_p,
|
||||||
|
ctypes.POINTER(ctypes.c_char_p),
|
||||||
|
]
|
||||||
|
lib.wcdb_update_message.restype = ctypes.c_int
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Optional (newer DLLs): delete a single message in message db.
|
||||||
|
# Signature: wcdb_delete_message(handle, sessionId, localId, createTime, dbPathHint, outError)
|
||||||
|
try:
|
||||||
|
lib.wcdb_delete_message.argtypes = [
|
||||||
|
ctypes.c_int64,
|
||||||
|
ctypes.c_char_p,
|
||||||
|
ctypes.c_int64,
|
||||||
|
ctypes.c_int32,
|
||||||
|
ctypes.c_char_p,
|
||||||
|
ctypes.POINTER(ctypes.c_char_p),
|
||||||
|
]
|
||||||
|
lib.wcdb_delete_message.restype = ctypes.c_int
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
# Optional (newer DLLs): wcdb_get_sns_timeline(handle, limit, offset, usernames_json, keyword, start_time, end_time, out_json)
|
# Optional (newer DLLs): wcdb_get_sns_timeline(handle, limit, offset, usernames_json, keyword, start_time, end_time, out_json)
|
||||||
try:
|
try:
|
||||||
lib.wcdb_get_sns_timeline.argtypes = [
|
lib.wcdb_get_sns_timeline.argtypes = [
|
||||||
@@ -259,6 +289,32 @@ def _call_out_json(fn, *args) -> str:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _call_out_error(fn, *args) -> None:
|
||||||
|
lib = _load_wcdb_lib()
|
||||||
|
out = ctypes.c_char_p()
|
||||||
|
rc = int(fn(*args, ctypes.byref(out)))
|
||||||
|
try:
|
||||||
|
if rc != 0:
|
||||||
|
err = ""
|
||||||
|
try:
|
||||||
|
if out.value:
|
||||||
|
err = (out.value or b"").decode("utf-8", errors="replace")
|
||||||
|
except Exception:
|
||||||
|
err = ""
|
||||||
|
|
||||||
|
logs = get_native_logs()
|
||||||
|
hint = f" logs={logs[:6]}" if logs else ""
|
||||||
|
if err:
|
||||||
|
raise WCDBRealtimeError(f"wcdb api call failed: {rc}. error={err}.{hint}")
|
||||||
|
raise WCDBRealtimeError(f"wcdb api call failed: {rc}.{hint}")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
if out.value:
|
||||||
|
lib.wcdb_free_string(out)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def get_native_logs() -> list[str]:
|
def get_native_logs() -> list[str]:
|
||||||
try:
|
try:
|
||||||
_ensure_initialized()
|
_ensure_initialized()
|
||||||
@@ -500,6 +556,64 @@ def exec_query(handle: int, *, kind: str, path: Optional[str], sql: str) -> list
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def update_message(handle: int, *, session_id: str, local_id: int, create_time: int, new_content: str) -> None:
|
||||||
|
"""Update a single message content in the live encrypted db_storage via WCDB.
|
||||||
|
|
||||||
|
Requires wcdb_update_message export in wcdb_api.dll.
|
||||||
|
"""
|
||||||
|
_ensure_initialized()
|
||||||
|
lib = _load_wcdb_lib()
|
||||||
|
fn = getattr(lib, "wcdb_update_message", None)
|
||||||
|
if not fn:
|
||||||
|
raise WCDBRealtimeError("Current wcdb_api.dll does not support update_message.")
|
||||||
|
|
||||||
|
sid = str(session_id or "").strip()
|
||||||
|
if not sid:
|
||||||
|
raise WCDBRealtimeError("Missing session_id for update_message.")
|
||||||
|
|
||||||
|
_call_out_error(
|
||||||
|
fn,
|
||||||
|
ctypes.c_int64(int(handle)),
|
||||||
|
sid.encode("utf-8"),
|
||||||
|
ctypes.c_int64(int(local_id or 0)),
|
||||||
|
ctypes.c_int32(int(create_time or 0)),
|
||||||
|
str(new_content or "").encode("utf-8"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def delete_message(
|
||||||
|
handle: int,
|
||||||
|
*,
|
||||||
|
session_id: str,
|
||||||
|
local_id: int,
|
||||||
|
create_time: int,
|
||||||
|
db_path_hint: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Delete a single message in the live encrypted db_storage via WCDB.
|
||||||
|
|
||||||
|
Requires wcdb_delete_message export in wcdb_api.dll.
|
||||||
|
"""
|
||||||
|
_ensure_initialized()
|
||||||
|
lib = _load_wcdb_lib()
|
||||||
|
fn = getattr(lib, "wcdb_delete_message", None)
|
||||||
|
if not fn:
|
||||||
|
raise WCDBRealtimeError("Current wcdb_api.dll does not support delete_message.")
|
||||||
|
|
||||||
|
sid = str(session_id or "").strip()
|
||||||
|
if not sid:
|
||||||
|
raise WCDBRealtimeError("Missing session_id for delete_message.")
|
||||||
|
|
||||||
|
hint = str(db_path_hint or "").strip()
|
||||||
|
_call_out_error(
|
||||||
|
fn,
|
||||||
|
ctypes.c_int64(int(handle)),
|
||||||
|
sid.encode("utf-8"),
|
||||||
|
ctypes.c_int64(int(local_id or 0)),
|
||||||
|
ctypes.c_int32(int(create_time or 0)),
|
||||||
|
hint.encode("utf-8"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_sns_timeline(
|
def get_sns_timeline(
|
||||||
handle: int,
|
handle: int,
|
||||||
*,
|
*,
|
||||||
|
|||||||
@@ -0,0 +1,182 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import sqlite3
|
||||||
|
import unittest
|
||||||
|
import importlib
|
||||||
|
from pathlib import Path
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
|
|
||||||
|
|
||||||
|
ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
sys.path.insert(0, str(ROOT / "src"))
|
||||||
|
|
||||||
|
|
||||||
|
class TestChatEditStore(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self._prev_data_dir = os.environ.get("WECHAT_TOOL_DATA_DIR")
|
||||||
|
self._td = TemporaryDirectory()
|
||||||
|
os.environ["WECHAT_TOOL_DATA_DIR"] = self._td.name
|
||||||
|
|
||||||
|
import wechat_decrypt_tool.app_paths as app_paths
|
||||||
|
import wechat_decrypt_tool.chat_edit_store as chat_edit_store
|
||||||
|
|
||||||
|
importlib.reload(app_paths)
|
||||||
|
importlib.reload(chat_edit_store)
|
||||||
|
|
||||||
|
self.app_paths = app_paths
|
||||||
|
self.store = chat_edit_store
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
if self._prev_data_dir is None:
|
||||||
|
os.environ.pop("WECHAT_TOOL_DATA_DIR", None)
|
||||||
|
else:
|
||||||
|
os.environ["WECHAT_TOOL_DATA_DIR"] = self._prev_data_dir
|
||||||
|
self._td.cleanup()
|
||||||
|
|
||||||
|
def test_ensure_schema_creates_db(self):
|
||||||
|
self.store.ensure_schema()
|
||||||
|
db_path = self.app_paths.get_output_dir() / "message_edits.db"
|
||||||
|
self.assertTrue(db_path.exists())
|
||||||
|
|
||||||
|
conn = sqlite3.connect(str(db_path))
|
||||||
|
try:
|
||||||
|
row = conn.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' AND name='message_edits' LIMIT 1"
|
||||||
|
).fetchone()
|
||||||
|
self.assertIsNotNone(row)
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def test_blob_hex_roundtrip(self):
|
||||||
|
payload = {"a": b"\x00\xff", "nested": {"b": memoryview(b"\x01\x02")}}
|
||||||
|
dumped = self.store.dumps_json_with_blobs(payload)
|
||||||
|
self.assertIn("0x00ff", dumped.lower())
|
||||||
|
self.assertIn("0x0102", dumped.lower())
|
||||||
|
|
||||||
|
loaded = self.store.loads_json_with_blobs(dumped)
|
||||||
|
self.assertEqual(loaded["a"], b"\x00\xff")
|
||||||
|
self.assertEqual(loaded["nested"]["b"], b"\x01\x02")
|
||||||
|
|
||||||
|
def test_message_id_format_parse(self):
|
||||||
|
mid = self.store.format_message_id("message_0", "Msg_foo", 123)
|
||||||
|
self.assertEqual(mid, "message_0:Msg_foo:123")
|
||||||
|
|
||||||
|
db, table, local_id = self.store.parse_message_id(mid)
|
||||||
|
self.assertEqual(db, "message_0")
|
||||||
|
self.assertEqual(table, "Msg_foo")
|
||||||
|
self.assertEqual(local_id, 123)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
self.store.parse_message_id("bad")
|
||||||
|
|
||||||
|
def test_upsert_original_once_does_not_overwrite_snapshot(self):
|
||||||
|
now1 = 1000
|
||||||
|
now2 = 2000
|
||||||
|
self.store.upsert_original_once(
|
||||||
|
account="wxid_me",
|
||||||
|
session_id="wxid_you",
|
||||||
|
db="message_0",
|
||||||
|
table_name="Msg_foo",
|
||||||
|
local_id=1,
|
||||||
|
original_msg={"local_id": 1, "message_content": "hello", "compress_content": b"\x01"},
|
||||||
|
original_resource={"message_id": 9, "packed_info": b"\x02"},
|
||||||
|
now_ms=now1,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.store.upsert_original_once(
|
||||||
|
account="wxid_me",
|
||||||
|
session_id="wxid_you",
|
||||||
|
db="message_0",
|
||||||
|
table_name="Msg_foo",
|
||||||
|
local_id=1,
|
||||||
|
original_msg={"local_id": 1, "message_content": "SHOULD_NOT_OVERWRITE", "compress_content": b"\x03"},
|
||||||
|
original_resource={"message_id": 9, "packed_info": b"\x04"},
|
||||||
|
now_ms=now2,
|
||||||
|
)
|
||||||
|
|
||||||
|
mid = self.store.format_message_id("message_0", "Msg_foo", 1)
|
||||||
|
item = self.store.get_message_edit("wxid_me", "wxid_you", mid)
|
||||||
|
self.assertIsNotNone(item)
|
||||||
|
self.assertEqual(int(item["first_edited_at"]), now1)
|
||||||
|
self.assertEqual(int(item["last_edited_at"]), now2)
|
||||||
|
self.assertEqual(int(item["edit_count"]), 2)
|
||||||
|
|
||||||
|
original_msg = self.store.loads_json_with_blobs(item["original_msg_json"])
|
||||||
|
self.assertEqual(original_msg["message_content"], "hello")
|
||||||
|
self.assertEqual(original_msg["compress_content"], b"\x01")
|
||||||
|
|
||||||
|
original_res = self.store.loads_json_with_blobs(item["original_resource_json"])
|
||||||
|
self.assertEqual(int(original_res["message_id"]), 9)
|
||||||
|
self.assertEqual(original_res["packed_info"], b"\x02")
|
||||||
|
|
||||||
|
def test_update_message_edit_local_id_moves_primary_key(self):
|
||||||
|
self.store.upsert_original_once(
|
||||||
|
account="wxid_me",
|
||||||
|
session_id="wxid_you",
|
||||||
|
db="message_0",
|
||||||
|
table_name="Msg_foo",
|
||||||
|
local_id=10,
|
||||||
|
original_msg={"local_id": 10, "message_content": "hello"},
|
||||||
|
original_resource=None,
|
||||||
|
now_ms=1234,
|
||||||
|
)
|
||||||
|
|
||||||
|
ok = self.store.update_message_edit_local_id(
|
||||||
|
account="wxid_me",
|
||||||
|
session_id="wxid_you",
|
||||||
|
db="message_0",
|
||||||
|
table_name="Msg_foo",
|
||||||
|
old_local_id=10,
|
||||||
|
new_local_id=11,
|
||||||
|
)
|
||||||
|
self.assertTrue(ok)
|
||||||
|
|
||||||
|
old_mid = self.store.format_message_id("message_0", "Msg_foo", 10)
|
||||||
|
new_mid = self.store.format_message_id("message_0", "Msg_foo", 11)
|
||||||
|
self.assertIsNone(self.store.get_message_edit("wxid_me", "wxid_you", old_mid))
|
||||||
|
self.assertIsNotNone(self.store.get_message_edit("wxid_me", "wxid_you", new_mid))
|
||||||
|
|
||||||
|
def test_list_sessions_counts(self):
|
||||||
|
self.store.upsert_original_once(
|
||||||
|
account="wxid_me",
|
||||||
|
session_id="u1",
|
||||||
|
db="message_0",
|
||||||
|
table_name="Msg_foo",
|
||||||
|
local_id=1,
|
||||||
|
original_msg={"local_id": 1, "message_content": "a"},
|
||||||
|
original_resource=None,
|
||||||
|
now_ms=100,
|
||||||
|
)
|
||||||
|
self.store.upsert_original_once(
|
||||||
|
account="wxid_me",
|
||||||
|
session_id="u1",
|
||||||
|
db="message_0",
|
||||||
|
table_name="Msg_foo",
|
||||||
|
local_id=2,
|
||||||
|
original_msg={"local_id": 2, "message_content": "b"},
|
||||||
|
original_resource=None,
|
||||||
|
now_ms=200,
|
||||||
|
)
|
||||||
|
self.store.upsert_original_once(
|
||||||
|
account="wxid_me",
|
||||||
|
session_id="u2",
|
||||||
|
db="message_0",
|
||||||
|
table_name="Msg_foo",
|
||||||
|
local_id=3,
|
||||||
|
original_msg={"local_id": 3, "message_content": "c"},
|
||||||
|
original_resource=None,
|
||||||
|
now_ms=300,
|
||||||
|
)
|
||||||
|
|
||||||
|
stats = self.store.list_sessions("wxid_me")
|
||||||
|
by_sid = {s["session_id"]: s for s in stats}
|
||||||
|
self.assertEqual(int(by_sid["u1"]["msg_count"]), 2)
|
||||||
|
self.assertEqual(int(by_sid["u1"]["last_edited_at"]), 200)
|
||||||
|
self.assertEqual(int(by_sid["u2"]["msg_count"]), 1)
|
||||||
|
self.assertEqual(int(by_sid["u2"]["last_edited_at"]), 300)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
|
|
||||||
Reference in New Issue
Block a user