mirror of
https://github.com/LifeArchiveProject/WeChatDataAnalysis.git
synced 2026-06-18 15:54:08 +08:00
feat(chat-edit-backend): 新增消息编辑快照、回滚与修复接口
- 新增 message_edits 存储,记录首次快照、编辑次数与已修改字段 - 新增消息编辑链路:raw 查询、编辑、修改状态、按消息/会话恢复 - 新增发送者修复与 packed_info_data 方向反转能力 - 编辑流程支持 output/db_storage 双写与失败回滚 - path_fix 改为按需校验 db_storage_path,并缓存 body 避免重复读取问题 - 补充 chat_edit_store 单元测试覆盖核心行为
This commit is contained in:
@@ -0,0 +1,487 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
import sqlite3
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
from .app_paths import get_output_dir
|
||||
|
||||
_HEX_RE = re.compile(r"^[0-9a-fA-F]+$")
|
||||
|
||||
|
||||
def _db_path() -> Path:
|
||||
return get_output_dir() / "message_edits.db"
|
||||
|
||||
|
||||
def _connect() -> sqlite3.Connection:
|
||||
db_path = _db_path()
|
||||
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(str(db_path), timeout=5)
|
||||
conn.row_factory = sqlite3.Row
|
||||
_ensure_schema(conn)
|
||||
return conn
|
||||
|
||||
|
||||
def ensure_schema() -> None:
|
||||
conn: Optional[sqlite3.Connection] = None
|
||||
try:
|
||||
conn = _connect()
|
||||
finally:
|
||||
try:
|
||||
if conn is not None:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _ensure_schema(conn: sqlite3.Connection) -> None:
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS message_edits (
|
||||
account TEXT NOT NULL,
|
||||
session_id TEXT NOT NULL,
|
||||
db TEXT NOT NULL,
|
||||
table_name TEXT NOT NULL,
|
||||
local_id INTEGER NOT NULL,
|
||||
first_edited_at INTEGER NOT NULL,
|
||||
last_edited_at INTEGER NOT NULL,
|
||||
edit_count INTEGER NOT NULL,
|
||||
original_msg_json TEXT NOT NULL,
|
||||
original_resource_json TEXT,
|
||||
edited_cols_json TEXT,
|
||||
PRIMARY KEY (account, session_id, db, table_name, local_id)
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_message_edits_account_session ON message_edits(account, session_id)"
|
||||
)
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_message_edits_account_last ON message_edits(account, last_edited_at)")
|
||||
|
||||
# Backwards-compatible migrations for existing DBs.
|
||||
try:
|
||||
cols = {
|
||||
str(r[1] or "").strip().lower()
|
||||
for r in conn.execute("PRAGMA table_info(message_edits)").fetchall()
|
||||
if r and len(r) > 1 and r[1]
|
||||
}
|
||||
if "edited_cols_json" not in cols:
|
||||
conn.execute("ALTER TABLE message_edits ADD COLUMN edited_cols_json TEXT")
|
||||
except Exception:
|
||||
pass
|
||||
conn.commit()
|
||||
|
||||
|
||||
def _now_ms() -> int:
|
||||
return int(time.time() * 1000)
|
||||
|
||||
|
||||
def format_message_id(db: str, table_name: str, local_id: int) -> str:
|
||||
return f"{str(db or '').strip()}:{str(table_name or '').strip()}:{int(local_id or 0)}"
|
||||
|
||||
|
||||
def parse_message_id(message_id: str) -> tuple[str, str, int]:
|
||||
parts = str(message_id or "").split(":", 2)
|
||||
if len(parts) != 3:
|
||||
raise ValueError("Invalid message_id format.")
|
||||
db = str(parts[0] or "").strip()
|
||||
table_name = str(parts[1] or "").strip()
|
||||
try:
|
||||
local_id = int(parts[2] or 0)
|
||||
except Exception:
|
||||
raise ValueError("Invalid message_id format.")
|
||||
if not db or not table_name or local_id <= 0:
|
||||
raise ValueError("Invalid message_id format.")
|
||||
return db, table_name, local_id
|
||||
|
||||
|
||||
def _bytes_to_hex(value: bytes) -> str:
|
||||
return "0x" + value.hex()
|
||||
|
||||
|
||||
def _hex_to_bytes(value: str) -> Optional[bytes]:
|
||||
s = str(value or "").strip()
|
||||
if not s.startswith("0x"):
|
||||
return None
|
||||
hex_part = s[2:]
|
||||
if (not hex_part) or (len(hex_part) % 2 != 0) or (_HEX_RE.match(hex_part) is None):
|
||||
return None
|
||||
try:
|
||||
return bytes.fromhex(hex_part)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _jsonify_blobs(obj: Any) -> Any:
|
||||
if obj is None:
|
||||
return None
|
||||
if isinstance(obj, (bytes, bytearray, memoryview)):
|
||||
return _bytes_to_hex(bytes(obj))
|
||||
if isinstance(obj, dict):
|
||||
return {str(k): _jsonify_blobs(v) for k, v in obj.items()}
|
||||
if isinstance(obj, (list, tuple)):
|
||||
return [_jsonify_blobs(v) for v in obj]
|
||||
return obj
|
||||
|
||||
|
||||
def _dejsonify_blobs(obj: Any) -> Any:
|
||||
if obj is None:
|
||||
return None
|
||||
if isinstance(obj, str):
|
||||
b = _hex_to_bytes(obj)
|
||||
return b if b is not None else obj
|
||||
if isinstance(obj, dict):
|
||||
return {str(k): _dejsonify_blobs(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_dejsonify_blobs(v) for v in obj]
|
||||
return obj
|
||||
|
||||
|
||||
def dumps_json_with_blobs(obj: Any) -> str:
|
||||
return json.dumps(_jsonify_blobs(obj), ensure_ascii=False, separators=(",", ":"))
|
||||
|
||||
|
||||
def loads_json_with_blobs(payload: str) -> Any:
|
||||
return _dejsonify_blobs(json.loads(str(payload or "") or "null"))
|
||||
|
||||
|
||||
def upsert_original_once(
|
||||
*,
|
||||
account: str,
|
||||
session_id: str,
|
||||
db: str,
|
||||
table_name: str,
|
||||
local_id: int,
|
||||
original_msg: dict[str, Any],
|
||||
original_resource: Optional[dict[str, Any]],
|
||||
now_ms: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Insert the original snapshot for a message only once, then bump counters on subsequent edits."""
|
||||
a = str(account or "").strip()
|
||||
sid = str(session_id or "").strip()
|
||||
db_norm = str(db or "").strip()
|
||||
t = str(table_name or "").strip()
|
||||
lid = int(local_id or 0)
|
||||
if not a or not sid or not db_norm or not t or lid <= 0:
|
||||
raise ValueError("Missing required keys for message edit store.")
|
||||
|
||||
ts = int(now_ms if now_ms is not None else _now_ms())
|
||||
msg_json = dumps_json_with_blobs(original_msg or {})
|
||||
res_json = dumps_json_with_blobs(original_resource) if original_resource is not None else None
|
||||
|
||||
conn: Optional[sqlite3.Connection] = None
|
||||
try:
|
||||
conn = _connect()
|
||||
existing = conn.execute(
|
||||
"""
|
||||
SELECT 1
|
||||
FROM message_edits
|
||||
WHERE account = ? AND session_id = ? AND db = ? AND table_name = ? AND local_id = ?
|
||||
LIMIT 1
|
||||
""",
|
||||
(a, sid, db_norm, t, lid),
|
||||
).fetchone()
|
||||
if existing is None:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO message_edits(
|
||||
account, session_id, db, table_name, local_id,
|
||||
first_edited_at, last_edited_at, edit_count,
|
||||
original_msg_json, original_resource_json, edited_cols_json
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(a, sid, db_norm, t, lid, ts, ts, 1, msg_json, res_json, None),
|
||||
)
|
||||
else:
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE message_edits
|
||||
SET last_edited_at = ?, edit_count = edit_count + 1
|
||||
WHERE account = ? AND session_id = ? AND db = ? AND table_name = ? AND local_id = ?
|
||||
""",
|
||||
(ts, a, sid, db_norm, t, lid),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
try:
|
||||
if conn is not None:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _parse_json_str_list(payload: Any) -> list[str]:
|
||||
if payload is None:
|
||||
return []
|
||||
if isinstance(payload, (list, tuple)):
|
||||
return [str(x or "").strip() for x in payload if str(x or "").strip()]
|
||||
s = str(payload or "").strip()
|
||||
if not s:
|
||||
return []
|
||||
try:
|
||||
v = json.loads(s)
|
||||
except Exception:
|
||||
return []
|
||||
if not isinstance(v, list):
|
||||
return []
|
||||
return [str(x or "").strip() for x in v if str(x or "").strip()]
|
||||
|
||||
|
||||
def merge_edited_columns(
|
||||
*,
|
||||
account: str,
|
||||
session_id: str,
|
||||
db: str,
|
||||
table_name: str,
|
||||
local_id: int,
|
||||
columns: list[str],
|
||||
) -> bool:
|
||||
"""Merge edited message column names into the per-message edit record.
|
||||
|
||||
This allows reset to restore only the fields actually modified by the tool.
|
||||
"""
|
||||
a = str(account or "").strip()
|
||||
sid = str(session_id or "").strip()
|
||||
db_norm = str(db or "").strip()
|
||||
t = str(table_name or "").strip()
|
||||
lid = int(local_id or 0)
|
||||
if not a or not sid or not db_norm or not t or lid <= 0:
|
||||
return False
|
||||
|
||||
cols_in = [str(x or "").strip() for x in (columns or []) if str(x or "").strip()]
|
||||
if not cols_in:
|
||||
return True
|
||||
|
||||
conn: Optional[sqlite3.Connection] = None
|
||||
try:
|
||||
conn = _connect()
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT edited_cols_json
|
||||
FROM message_edits
|
||||
WHERE account = ? AND session_id = ? AND db = ? AND table_name = ? AND local_id = ?
|
||||
LIMIT 1
|
||||
""",
|
||||
(a, sid, db_norm, t, lid),
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return False
|
||||
|
||||
existing = _parse_json_str_list(row[0] if row and len(row) else None)
|
||||
merged = {c.lower() for c in existing if c} | {c.lower() for c in cols_in if c}
|
||||
merged_list = sorted(merged)
|
||||
payload = json.dumps(merged_list, ensure_ascii=False, separators=(",", ":"))
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE message_edits
|
||||
SET edited_cols_json = ?
|
||||
WHERE account = ? AND session_id = ? AND db = ? AND table_name = ? AND local_id = ?
|
||||
""",
|
||||
(payload, a, sid, db_norm, t, lid),
|
||||
)
|
||||
conn.commit()
|
||||
return True
|
||||
finally:
|
||||
try:
|
||||
if conn is not None:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _row_to_dict(row: Optional[sqlite3.Row]) -> Optional[dict[str, Any]]:
|
||||
if row is None:
|
||||
return None
|
||||
out: dict[str, Any] = {}
|
||||
for k in row.keys():
|
||||
out[str(k)] = row[k]
|
||||
return out
|
||||
|
||||
|
||||
def list_sessions(account: str) -> list[dict[str, Any]]:
|
||||
a = str(account or "").strip()
|
||||
if not a:
|
||||
return []
|
||||
|
||||
conn: Optional[sqlite3.Connection] = None
|
||||
try:
|
||||
conn = _connect()
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT session_id, COUNT(*) AS msg_count, MAX(last_edited_at) AS last_edited_at
|
||||
FROM message_edits
|
||||
WHERE account = ?
|
||||
GROUP BY session_id
|
||||
ORDER BY last_edited_at DESC
|
||||
""",
|
||||
(a,),
|
||||
).fetchall()
|
||||
out: list[dict[str, Any]] = []
|
||||
for r in rows:
|
||||
try:
|
||||
sid = str(r["session_id"] or "").strip()
|
||||
except Exception:
|
||||
sid = ""
|
||||
if not sid:
|
||||
continue
|
||||
out.append(
|
||||
{
|
||||
"session_id": sid,
|
||||
"msg_count": int(r["msg_count"] or 0),
|
||||
"last_edited_at": int(r["last_edited_at"] or 0),
|
||||
}
|
||||
)
|
||||
return out
|
||||
finally:
|
||||
try:
|
||||
if conn is not None:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def list_messages(account: str, session_id: str) -> list[dict[str, Any]]:
|
||||
a = str(account or "").strip()
|
||||
sid = str(session_id or "").strip()
|
||||
if not a or not sid:
|
||||
return []
|
||||
|
||||
conn: Optional[sqlite3.Connection] = None
|
||||
try:
|
||||
conn = _connect()
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT *
|
||||
FROM message_edits
|
||||
WHERE account = ? AND session_id = ?
|
||||
ORDER BY last_edited_at ASC, local_id ASC
|
||||
""",
|
||||
(a, sid),
|
||||
).fetchall()
|
||||
out: list[dict[str, Any]] = []
|
||||
for r in rows:
|
||||
item = _row_to_dict(r) or {}
|
||||
try:
|
||||
item["message_id"] = format_message_id(item.get("db") or "", item.get("table_name") or "", item.get("local_id") or 0)
|
||||
except Exception:
|
||||
item["message_id"] = ""
|
||||
out.append(item)
|
||||
return out
|
||||
finally:
|
||||
try:
|
||||
if conn is not None:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def get_message_edit(account: str, session_id: str, message_id: str) -> Optional[dict[str, Any]]:
|
||||
a = str(account or "").strip()
|
||||
sid = str(session_id or "").strip()
|
||||
if not a or not sid or not message_id:
|
||||
return None
|
||||
try:
|
||||
db, table_name, local_id = parse_message_id(message_id)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
conn: Optional[sqlite3.Connection] = None
|
||||
try:
|
||||
conn = _connect()
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT *
|
||||
FROM message_edits
|
||||
WHERE account = ? AND session_id = ? AND db = ? AND table_name = ? AND local_id = ?
|
||||
LIMIT 1
|
||||
""",
|
||||
(a, sid, db, table_name, int(local_id)),
|
||||
).fetchone()
|
||||
item = _row_to_dict(row)
|
||||
if not item:
|
||||
return None
|
||||
item["message_id"] = format_message_id(db, table_name, local_id)
|
||||
return item
|
||||
finally:
|
||||
try:
|
||||
if conn is not None:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def delete_message_edit(account: str, session_id: str, message_id: str) -> bool:
|
||||
a = str(account or "").strip()
|
||||
sid = str(session_id or "").strip()
|
||||
if not a or not sid or not message_id:
|
||||
return False
|
||||
try:
|
||||
db, table_name, local_id = parse_message_id(message_id)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
conn: Optional[sqlite3.Connection] = None
|
||||
try:
|
||||
conn = _connect()
|
||||
cur = conn.execute(
|
||||
"""
|
||||
DELETE FROM message_edits
|
||||
WHERE account = ? AND session_id = ? AND db = ? AND table_name = ? AND local_id = ?
|
||||
""",
|
||||
(a, sid, db, table_name, int(local_id)),
|
||||
)
|
||||
conn.commit()
|
||||
return int(getattr(cur, "rowcount", 0) or 0) > 0
|
||||
finally:
|
||||
try:
|
||||
if conn is not None:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def update_message_edit_local_id(
|
||||
*,
|
||||
account: str,
|
||||
session_id: str,
|
||||
db: str,
|
||||
table_name: str,
|
||||
old_local_id: int,
|
||||
new_local_id: int,
|
||||
) -> bool:
|
||||
"""Update the primary key local_id for an existing edit record (unsafe operations may change Msg.local_id)."""
|
||||
a = str(account or "").strip()
|
||||
sid = str(session_id or "").strip()
|
||||
db_norm = str(db or "").strip()
|
||||
t = str(table_name or "").strip()
|
||||
old_lid = int(old_local_id or 0)
|
||||
new_lid = int(new_local_id or 0)
|
||||
if not a or not sid or not db_norm or not t or old_lid <= 0 or new_lid <= 0:
|
||||
return False
|
||||
if old_lid == new_lid:
|
||||
return True
|
||||
|
||||
conn: Optional[sqlite3.Connection] = None
|
||||
try:
|
||||
conn = _connect()
|
||||
cur = conn.execute(
|
||||
"""
|
||||
UPDATE message_edits
|
||||
SET local_id = ?
|
||||
WHERE account = ? AND session_id = ? AND db = ? AND table_name = ? AND local_id = ?
|
||||
""",
|
||||
(new_lid, a, sid, db_norm, t, old_lid),
|
||||
)
|
||||
conn.commit()
|
||||
return int(getattr(cur, "rowcount", 0) or 0) > 0
|
||||
except Exception:
|
||||
return False
|
||||
finally:
|
||||
try:
|
||||
if conn is not None:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
Binary file not shown.
@@ -32,10 +32,8 @@ class PathFixRequest(Request):
|
||||
def _validate_paths_in_json(self, json_data: dict) -> Optional[str]:
|
||||
"""验证JSON中的路径,返回错误信息(如果有)"""
|
||||
logger.info(f"开始验证路径,JSON数据: {json_data}")
|
||||
# 检查db_storage_path字段(现在是必需的)
|
||||
if 'db_storage_path' not in json_data:
|
||||
return "缺少必需的db_storage_path参数,请提供具体的数据库存储路径。"
|
||||
|
||||
# 仅在提供 db_storage_path 时进行校验(例如 /api/decrypt)。
|
||||
# 其它 API 的 JSON payload 不一定包含路径字段,不应强制要求。
|
||||
if 'db_storage_path' in json_data:
|
||||
path = json_data['db_storage_path']
|
||||
|
||||
@@ -115,11 +113,16 @@ class PathFixRequest(Request):
|
||||
|
||||
async def body(self) -> bytes:
|
||||
"""重写body方法,预处理JSON中的路径问题"""
|
||||
cached = getattr(self.state, "_pathfix_body_bytes", None)
|
||||
if isinstance(cached, (bytes, bytearray)):
|
||||
return bytes(cached)
|
||||
|
||||
body = await super().body()
|
||||
|
||||
# 只处理JSON请求
|
||||
content_type = self.headers.get("content-type", "")
|
||||
if "application/json" not in content_type:
|
||||
self.state._pathfix_body_bytes = body
|
||||
return body
|
||||
|
||||
try:
|
||||
@@ -134,6 +137,7 @@ class PathFixRequest(Request):
|
||||
logger.info(f"检测到路径错误: {path_error}")
|
||||
# 我们将错误信息存储在请求中,稍后在路由处理器中检查
|
||||
self.state.path_validation_error = path_error
|
||||
self.state._pathfix_body_bytes = body
|
||||
return body
|
||||
except json.JSONDecodeError as e:
|
||||
# JSON格式错误,继续尝试修复
|
||||
@@ -169,17 +173,30 @@ class PathFixRequest(Request):
|
||||
if path_error:
|
||||
logger.info(f"修复后检测到路径错误: {path_error}")
|
||||
self.state.path_validation_error = path_error
|
||||
return fixed_body_str.encode('utf-8')
|
||||
fixed_bytes = fixed_body_str.encode('utf-8')
|
||||
self.state._pathfix_body_bytes = fixed_bytes
|
||||
try:
|
||||
self._body = fixed_bytes # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
pass
|
||||
return fixed_bytes
|
||||
else:
|
||||
logger.info(f"修复后路径验证通过")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"修复后JSON仍然解析失败: {e}")
|
||||
|
||||
return fixed_body_str.encode('utf-8')
|
||||
fixed_bytes = fixed_body_str.encode('utf-8')
|
||||
self.state._pathfix_body_bytes = fixed_bytes
|
||||
try:
|
||||
self._body = fixed_bytes # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
pass
|
||||
return fixed_bytes
|
||||
|
||||
except Exception as e:
|
||||
# 如果处理失败,返回原始body
|
||||
logger.warning(f"JSON路径修复失败,使用原始请求体: {e}")
|
||||
self.state._pathfix_body_bytes = body
|
||||
return body
|
||||
|
||||
|
||||
@@ -193,12 +210,17 @@ class PathFixRoute(APIRoute):
|
||||
# 将Request替换为我们的自定义Request
|
||||
custom_request = PathFixRequest(request.scope, request.receive)
|
||||
|
||||
# 检查是否有路径验证错误
|
||||
if hasattr(custom_request.state, 'path_validation_error'):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=custom_request.state.path_validation_error,
|
||||
)
|
||||
# 仅对 JSON 请求预读 body,以触发路径修复/校验逻辑,并在发现错误时提前返回 400。
|
||||
try:
|
||||
content_type = (custom_request.headers.get("content-type", "") or "").lower()
|
||||
if "application/json" in content_type:
|
||||
await custom_request.body()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
path_err = getattr(custom_request.state, "path_validation_error", None)
|
||||
if path_err:
|
||||
raise HTTPException(status_code=400, detail=path_err)
|
||||
|
||||
return await original_route_handler(custom_request)
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -175,6 +175,36 @@ def _load_wcdb_lib() -> ctypes.CDLL:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Optional (newer DLLs): update a single message content in message db.
|
||||
# Signature: wcdb_update_message(handle, sessionId, localId, createTime, newContent, outError)
|
||||
try:
|
||||
lib.wcdb_update_message.argtypes = [
|
||||
ctypes.c_int64,
|
||||
ctypes.c_char_p,
|
||||
ctypes.c_int64,
|
||||
ctypes.c_int32,
|
||||
ctypes.c_char_p,
|
||||
ctypes.POINTER(ctypes.c_char_p),
|
||||
]
|
||||
lib.wcdb_update_message.restype = ctypes.c_int
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Optional (newer DLLs): delete a single message in message db.
|
||||
# Signature: wcdb_delete_message(handle, sessionId, localId, createTime, dbPathHint, outError)
|
||||
try:
|
||||
lib.wcdb_delete_message.argtypes = [
|
||||
ctypes.c_int64,
|
||||
ctypes.c_char_p,
|
||||
ctypes.c_int64,
|
||||
ctypes.c_int32,
|
||||
ctypes.c_char_p,
|
||||
ctypes.POINTER(ctypes.c_char_p),
|
||||
]
|
||||
lib.wcdb_delete_message.restype = ctypes.c_int
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Optional (newer DLLs): wcdb_get_sns_timeline(handle, limit, offset, usernames_json, keyword, start_time, end_time, out_json)
|
||||
try:
|
||||
lib.wcdb_get_sns_timeline.argtypes = [
|
||||
@@ -259,6 +289,32 @@ def _call_out_json(fn, *args) -> str:
|
||||
pass
|
||||
|
||||
|
||||
def _call_out_error(fn, *args) -> None:
|
||||
lib = _load_wcdb_lib()
|
||||
out = ctypes.c_char_p()
|
||||
rc = int(fn(*args, ctypes.byref(out)))
|
||||
try:
|
||||
if rc != 0:
|
||||
err = ""
|
||||
try:
|
||||
if out.value:
|
||||
err = (out.value or b"").decode("utf-8", errors="replace")
|
||||
except Exception:
|
||||
err = ""
|
||||
|
||||
logs = get_native_logs()
|
||||
hint = f" logs={logs[:6]}" if logs else ""
|
||||
if err:
|
||||
raise WCDBRealtimeError(f"wcdb api call failed: {rc}. error={err}.{hint}")
|
||||
raise WCDBRealtimeError(f"wcdb api call failed: {rc}.{hint}")
|
||||
finally:
|
||||
try:
|
||||
if out.value:
|
||||
lib.wcdb_free_string(out)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def get_native_logs() -> list[str]:
|
||||
try:
|
||||
_ensure_initialized()
|
||||
@@ -500,6 +556,64 @@ def exec_query(handle: int, *, kind: str, path: Optional[str], sql: str) -> list
|
||||
return []
|
||||
|
||||
|
||||
def update_message(handle: int, *, session_id: str, local_id: int, create_time: int, new_content: str) -> None:
|
||||
"""Update a single message content in the live encrypted db_storage via WCDB.
|
||||
|
||||
Requires wcdb_update_message export in wcdb_api.dll.
|
||||
"""
|
||||
_ensure_initialized()
|
||||
lib = _load_wcdb_lib()
|
||||
fn = getattr(lib, "wcdb_update_message", None)
|
||||
if not fn:
|
||||
raise WCDBRealtimeError("Current wcdb_api.dll does not support update_message.")
|
||||
|
||||
sid = str(session_id or "").strip()
|
||||
if not sid:
|
||||
raise WCDBRealtimeError("Missing session_id for update_message.")
|
||||
|
||||
_call_out_error(
|
||||
fn,
|
||||
ctypes.c_int64(int(handle)),
|
||||
sid.encode("utf-8"),
|
||||
ctypes.c_int64(int(local_id or 0)),
|
||||
ctypes.c_int32(int(create_time or 0)),
|
||||
str(new_content or "").encode("utf-8"),
|
||||
)
|
||||
|
||||
|
||||
def delete_message(
|
||||
handle: int,
|
||||
*,
|
||||
session_id: str,
|
||||
local_id: int,
|
||||
create_time: int,
|
||||
db_path_hint: str | None = None,
|
||||
) -> None:
|
||||
"""Delete a single message in the live encrypted db_storage via WCDB.
|
||||
|
||||
Requires wcdb_delete_message export in wcdb_api.dll.
|
||||
"""
|
||||
_ensure_initialized()
|
||||
lib = _load_wcdb_lib()
|
||||
fn = getattr(lib, "wcdb_delete_message", None)
|
||||
if not fn:
|
||||
raise WCDBRealtimeError("Current wcdb_api.dll does not support delete_message.")
|
||||
|
||||
sid = str(session_id or "").strip()
|
||||
if not sid:
|
||||
raise WCDBRealtimeError("Missing session_id for delete_message.")
|
||||
|
||||
hint = str(db_path_hint or "").strip()
|
||||
_call_out_error(
|
||||
fn,
|
||||
ctypes.c_int64(int(handle)),
|
||||
sid.encode("utf-8"),
|
||||
ctypes.c_int64(int(local_id or 0)),
|
||||
ctypes.c_int32(int(create_time or 0)),
|
||||
hint.encode("utf-8"),
|
||||
)
|
||||
|
||||
|
||||
def get_sns_timeline(
|
||||
handle: int,
|
||||
*,
|
||||
|
||||
@@ -0,0 +1,182 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import sqlite3
|
||||
import unittest
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[1]
|
||||
sys.path.insert(0, str(ROOT / "src"))
|
||||
|
||||
|
||||
class TestChatEditStore(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self._prev_data_dir = os.environ.get("WECHAT_TOOL_DATA_DIR")
|
||||
self._td = TemporaryDirectory()
|
||||
os.environ["WECHAT_TOOL_DATA_DIR"] = self._td.name
|
||||
|
||||
import wechat_decrypt_tool.app_paths as app_paths
|
||||
import wechat_decrypt_tool.chat_edit_store as chat_edit_store
|
||||
|
||||
importlib.reload(app_paths)
|
||||
importlib.reload(chat_edit_store)
|
||||
|
||||
self.app_paths = app_paths
|
||||
self.store = chat_edit_store
|
||||
|
||||
def tearDown(self):
|
||||
if self._prev_data_dir is None:
|
||||
os.environ.pop("WECHAT_TOOL_DATA_DIR", None)
|
||||
else:
|
||||
os.environ["WECHAT_TOOL_DATA_DIR"] = self._prev_data_dir
|
||||
self._td.cleanup()
|
||||
|
||||
def test_ensure_schema_creates_db(self):
|
||||
self.store.ensure_schema()
|
||||
db_path = self.app_paths.get_output_dir() / "message_edits.db"
|
||||
self.assertTrue(db_path.exists())
|
||||
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
try:
|
||||
row = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='message_edits' LIMIT 1"
|
||||
).fetchone()
|
||||
self.assertIsNotNone(row)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def test_blob_hex_roundtrip(self):
|
||||
payload = {"a": b"\x00\xff", "nested": {"b": memoryview(b"\x01\x02")}}
|
||||
dumped = self.store.dumps_json_with_blobs(payload)
|
||||
self.assertIn("0x00ff", dumped.lower())
|
||||
self.assertIn("0x0102", dumped.lower())
|
||||
|
||||
loaded = self.store.loads_json_with_blobs(dumped)
|
||||
self.assertEqual(loaded["a"], b"\x00\xff")
|
||||
self.assertEqual(loaded["nested"]["b"], b"\x01\x02")
|
||||
|
||||
def test_message_id_format_parse(self):
|
||||
mid = self.store.format_message_id("message_0", "Msg_foo", 123)
|
||||
self.assertEqual(mid, "message_0:Msg_foo:123")
|
||||
|
||||
db, table, local_id = self.store.parse_message_id(mid)
|
||||
self.assertEqual(db, "message_0")
|
||||
self.assertEqual(table, "Msg_foo")
|
||||
self.assertEqual(local_id, 123)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
self.store.parse_message_id("bad")
|
||||
|
||||
def test_upsert_original_once_does_not_overwrite_snapshot(self):
|
||||
now1 = 1000
|
||||
now2 = 2000
|
||||
self.store.upsert_original_once(
|
||||
account="wxid_me",
|
||||
session_id="wxid_you",
|
||||
db="message_0",
|
||||
table_name="Msg_foo",
|
||||
local_id=1,
|
||||
original_msg={"local_id": 1, "message_content": "hello", "compress_content": b"\x01"},
|
||||
original_resource={"message_id": 9, "packed_info": b"\x02"},
|
||||
now_ms=now1,
|
||||
)
|
||||
|
||||
self.store.upsert_original_once(
|
||||
account="wxid_me",
|
||||
session_id="wxid_you",
|
||||
db="message_0",
|
||||
table_name="Msg_foo",
|
||||
local_id=1,
|
||||
original_msg={"local_id": 1, "message_content": "SHOULD_NOT_OVERWRITE", "compress_content": b"\x03"},
|
||||
original_resource={"message_id": 9, "packed_info": b"\x04"},
|
||||
now_ms=now2,
|
||||
)
|
||||
|
||||
mid = self.store.format_message_id("message_0", "Msg_foo", 1)
|
||||
item = self.store.get_message_edit("wxid_me", "wxid_you", mid)
|
||||
self.assertIsNotNone(item)
|
||||
self.assertEqual(int(item["first_edited_at"]), now1)
|
||||
self.assertEqual(int(item["last_edited_at"]), now2)
|
||||
self.assertEqual(int(item["edit_count"]), 2)
|
||||
|
||||
original_msg = self.store.loads_json_with_blobs(item["original_msg_json"])
|
||||
self.assertEqual(original_msg["message_content"], "hello")
|
||||
self.assertEqual(original_msg["compress_content"], b"\x01")
|
||||
|
||||
original_res = self.store.loads_json_with_blobs(item["original_resource_json"])
|
||||
self.assertEqual(int(original_res["message_id"]), 9)
|
||||
self.assertEqual(original_res["packed_info"], b"\x02")
|
||||
|
||||
def test_update_message_edit_local_id_moves_primary_key(self):
|
||||
self.store.upsert_original_once(
|
||||
account="wxid_me",
|
||||
session_id="wxid_you",
|
||||
db="message_0",
|
||||
table_name="Msg_foo",
|
||||
local_id=10,
|
||||
original_msg={"local_id": 10, "message_content": "hello"},
|
||||
original_resource=None,
|
||||
now_ms=1234,
|
||||
)
|
||||
|
||||
ok = self.store.update_message_edit_local_id(
|
||||
account="wxid_me",
|
||||
session_id="wxid_you",
|
||||
db="message_0",
|
||||
table_name="Msg_foo",
|
||||
old_local_id=10,
|
||||
new_local_id=11,
|
||||
)
|
||||
self.assertTrue(ok)
|
||||
|
||||
old_mid = self.store.format_message_id("message_0", "Msg_foo", 10)
|
||||
new_mid = self.store.format_message_id("message_0", "Msg_foo", 11)
|
||||
self.assertIsNone(self.store.get_message_edit("wxid_me", "wxid_you", old_mid))
|
||||
self.assertIsNotNone(self.store.get_message_edit("wxid_me", "wxid_you", new_mid))
|
||||
|
||||
def test_list_sessions_counts(self):
|
||||
self.store.upsert_original_once(
|
||||
account="wxid_me",
|
||||
session_id="u1",
|
||||
db="message_0",
|
||||
table_name="Msg_foo",
|
||||
local_id=1,
|
||||
original_msg={"local_id": 1, "message_content": "a"},
|
||||
original_resource=None,
|
||||
now_ms=100,
|
||||
)
|
||||
self.store.upsert_original_once(
|
||||
account="wxid_me",
|
||||
session_id="u1",
|
||||
db="message_0",
|
||||
table_name="Msg_foo",
|
||||
local_id=2,
|
||||
original_msg={"local_id": 2, "message_content": "b"},
|
||||
original_resource=None,
|
||||
now_ms=200,
|
||||
)
|
||||
self.store.upsert_original_once(
|
||||
account="wxid_me",
|
||||
session_id="u2",
|
||||
db="message_0",
|
||||
table_name="Msg_foo",
|
||||
local_id=3,
|
||||
original_msg={"local_id": 3, "message_content": "c"},
|
||||
original_resource=None,
|
||||
now_ms=300,
|
||||
)
|
||||
|
||||
stats = self.store.list_sessions("wxid_me")
|
||||
by_sid = {s["session_id"]: s for s in stats}
|
||||
self.assertEqual(int(by_sid["u1"]["msg_count"]), 2)
|
||||
self.assertEqual(int(by_sid["u1"]["last_edited_at"]), 200)
|
||||
self.assertEqual(int(by_sid["u2"]["msg_count"]), 1)
|
||||
self.assertEqual(int(by_sid["u2"]["last_edited_at"]), 300)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user