feat(chat-edit-backend): 新增消息编辑快照、回滚与修复接口

- 新增 message_edits 存储,记录首次快照、编辑次数与已修改字段

- 新增消息编辑链路:raw 查询、编辑、修改状态、按消息/会话恢复

- 新增发送者修复与 packed_info_data 方向反转能力

- 编辑流程支持 output/db_storage 双写与失败回滚

- path_fix 改为按需校验 db_storage_path,并缓存 body 避免重复读取问题

- 补充 chat_edit_store 单元测试覆盖核心行为
This commit is contained in:
2977094657
2026-02-22 11:55:47 +08:00
Unverified
parent abc8e3e828
commit d5927156f7
6 changed files with 2704 additions and 13 deletions
+487
View File
@@ -0,0 +1,487 @@
from __future__ import annotations
import json
import re
import sqlite3
import time
from pathlib import Path
from typing import Any, Optional
from .app_paths import get_output_dir
_HEX_RE = re.compile(r"^[0-9a-fA-F]+$")
def _db_path() -> Path:
return get_output_dir() / "message_edits.db"
def _connect() -> sqlite3.Connection:
db_path = _db_path()
db_path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(db_path), timeout=5)
conn.row_factory = sqlite3.Row
_ensure_schema(conn)
return conn
def ensure_schema() -> None:
conn: Optional[sqlite3.Connection] = None
try:
conn = _connect()
finally:
try:
if conn is not None:
conn.close()
except Exception:
pass
def _ensure_schema(conn: sqlite3.Connection) -> None:
conn.execute(
"""
CREATE TABLE IF NOT EXISTS message_edits (
account TEXT NOT NULL,
session_id TEXT NOT NULL,
db TEXT NOT NULL,
table_name TEXT NOT NULL,
local_id INTEGER NOT NULL,
first_edited_at INTEGER NOT NULL,
last_edited_at INTEGER NOT NULL,
edit_count INTEGER NOT NULL,
original_msg_json TEXT NOT NULL,
original_resource_json TEXT,
edited_cols_json TEXT,
PRIMARY KEY (account, session_id, db, table_name, local_id)
)
"""
)
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_message_edits_account_session ON message_edits(account, session_id)"
)
conn.execute("CREATE INDEX IF NOT EXISTS idx_message_edits_account_last ON message_edits(account, last_edited_at)")
# Backwards-compatible migrations for existing DBs.
try:
cols = {
str(r[1] or "").strip().lower()
for r in conn.execute("PRAGMA table_info(message_edits)").fetchall()
if r and len(r) > 1 and r[1]
}
if "edited_cols_json" not in cols:
conn.execute("ALTER TABLE message_edits ADD COLUMN edited_cols_json TEXT")
except Exception:
pass
conn.commit()
def _now_ms() -> int:
return int(time.time() * 1000)
def format_message_id(db: str, table_name: str, local_id: int) -> str:
return f"{str(db or '').strip()}:{str(table_name or '').strip()}:{int(local_id or 0)}"
def parse_message_id(message_id: str) -> tuple[str, str, int]:
parts = str(message_id or "").split(":", 2)
if len(parts) != 3:
raise ValueError("Invalid message_id format.")
db = str(parts[0] or "").strip()
table_name = str(parts[1] or "").strip()
try:
local_id = int(parts[2] or 0)
except Exception:
raise ValueError("Invalid message_id format.")
if not db or not table_name or local_id <= 0:
raise ValueError("Invalid message_id format.")
return db, table_name, local_id
def _bytes_to_hex(value: bytes) -> str:
return "0x" + value.hex()
def _hex_to_bytes(value: str) -> Optional[bytes]:
s = str(value or "").strip()
if not s.startswith("0x"):
return None
hex_part = s[2:]
if (not hex_part) or (len(hex_part) % 2 != 0) or (_HEX_RE.match(hex_part) is None):
return None
try:
return bytes.fromhex(hex_part)
except Exception:
return None
def _jsonify_blobs(obj: Any) -> Any:
if obj is None:
return None
if isinstance(obj, (bytes, bytearray, memoryview)):
return _bytes_to_hex(bytes(obj))
if isinstance(obj, dict):
return {str(k): _jsonify_blobs(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
return [_jsonify_blobs(v) for v in obj]
return obj
def _dejsonify_blobs(obj: Any) -> Any:
if obj is None:
return None
if isinstance(obj, str):
b = _hex_to_bytes(obj)
return b if b is not None else obj
if isinstance(obj, dict):
return {str(k): _dejsonify_blobs(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_dejsonify_blobs(v) for v in obj]
return obj
def dumps_json_with_blobs(obj: Any) -> str:
return json.dumps(_jsonify_blobs(obj), ensure_ascii=False, separators=(",", ":"))
def loads_json_with_blobs(payload: str) -> Any:
return _dejsonify_blobs(json.loads(str(payload or "") or "null"))
def upsert_original_once(
*,
account: str,
session_id: str,
db: str,
table_name: str,
local_id: int,
original_msg: dict[str, Any],
original_resource: Optional[dict[str, Any]],
now_ms: Optional[int] = None,
) -> None:
"""Insert the original snapshot for a message only once, then bump counters on subsequent edits."""
a = str(account or "").strip()
sid = str(session_id or "").strip()
db_norm = str(db or "").strip()
t = str(table_name or "").strip()
lid = int(local_id or 0)
if not a or not sid or not db_norm or not t or lid <= 0:
raise ValueError("Missing required keys for message edit store.")
ts = int(now_ms if now_ms is not None else _now_ms())
msg_json = dumps_json_with_blobs(original_msg or {})
res_json = dumps_json_with_blobs(original_resource) if original_resource is not None else None
conn: Optional[sqlite3.Connection] = None
try:
conn = _connect()
existing = conn.execute(
"""
SELECT 1
FROM message_edits
WHERE account = ? AND session_id = ? AND db = ? AND table_name = ? AND local_id = ?
LIMIT 1
""",
(a, sid, db_norm, t, lid),
).fetchone()
if existing is None:
conn.execute(
"""
INSERT INTO message_edits(
account, session_id, db, table_name, local_id,
first_edited_at, last_edited_at, edit_count,
original_msg_json, original_resource_json, edited_cols_json
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(a, sid, db_norm, t, lid, ts, ts, 1, msg_json, res_json, None),
)
else:
conn.execute(
"""
UPDATE message_edits
SET last_edited_at = ?, edit_count = edit_count + 1
WHERE account = ? AND session_id = ? AND db = ? AND table_name = ? AND local_id = ?
""",
(ts, a, sid, db_norm, t, lid),
)
conn.commit()
finally:
try:
if conn is not None:
conn.close()
except Exception:
pass
def _parse_json_str_list(payload: Any) -> list[str]:
if payload is None:
return []
if isinstance(payload, (list, tuple)):
return [str(x or "").strip() for x in payload if str(x or "").strip()]
s = str(payload or "").strip()
if not s:
return []
try:
v = json.loads(s)
except Exception:
return []
if not isinstance(v, list):
return []
return [str(x or "").strip() for x in v if str(x or "").strip()]
def merge_edited_columns(
*,
account: str,
session_id: str,
db: str,
table_name: str,
local_id: int,
columns: list[str],
) -> bool:
"""Merge edited message column names into the per-message edit record.
This allows reset to restore only the fields actually modified by the tool.
"""
a = str(account or "").strip()
sid = str(session_id or "").strip()
db_norm = str(db or "").strip()
t = str(table_name or "").strip()
lid = int(local_id or 0)
if not a or not sid or not db_norm or not t or lid <= 0:
return False
cols_in = [str(x or "").strip() for x in (columns or []) if str(x or "").strip()]
if not cols_in:
return True
conn: Optional[sqlite3.Connection] = None
try:
conn = _connect()
row = conn.execute(
"""
SELECT edited_cols_json
FROM message_edits
WHERE account = ? AND session_id = ? AND db = ? AND table_name = ? AND local_id = ?
LIMIT 1
""",
(a, sid, db_norm, t, lid),
).fetchone()
if row is None:
return False
existing = _parse_json_str_list(row[0] if row and len(row) else None)
merged = {c.lower() for c in existing if c} | {c.lower() for c in cols_in if c}
merged_list = sorted(merged)
payload = json.dumps(merged_list, ensure_ascii=False, separators=(",", ":"))
conn.execute(
"""
UPDATE message_edits
SET edited_cols_json = ?
WHERE account = ? AND session_id = ? AND db = ? AND table_name = ? AND local_id = ?
""",
(payload, a, sid, db_norm, t, lid),
)
conn.commit()
return True
finally:
try:
if conn is not None:
conn.close()
except Exception:
pass
def _row_to_dict(row: Optional[sqlite3.Row]) -> Optional[dict[str, Any]]:
if row is None:
return None
out: dict[str, Any] = {}
for k in row.keys():
out[str(k)] = row[k]
return out
def list_sessions(account: str) -> list[dict[str, Any]]:
a = str(account or "").strip()
if not a:
return []
conn: Optional[sqlite3.Connection] = None
try:
conn = _connect()
rows = conn.execute(
"""
SELECT session_id, COUNT(*) AS msg_count, MAX(last_edited_at) AS last_edited_at
FROM message_edits
WHERE account = ?
GROUP BY session_id
ORDER BY last_edited_at DESC
""",
(a,),
).fetchall()
out: list[dict[str, Any]] = []
for r in rows:
try:
sid = str(r["session_id"] or "").strip()
except Exception:
sid = ""
if not sid:
continue
out.append(
{
"session_id": sid,
"msg_count": int(r["msg_count"] or 0),
"last_edited_at": int(r["last_edited_at"] or 0),
}
)
return out
finally:
try:
if conn is not None:
conn.close()
except Exception:
pass
def list_messages(account: str, session_id: str) -> list[dict[str, Any]]:
a = str(account or "").strip()
sid = str(session_id or "").strip()
if not a or not sid:
return []
conn: Optional[sqlite3.Connection] = None
try:
conn = _connect()
rows = conn.execute(
"""
SELECT *
FROM message_edits
WHERE account = ? AND session_id = ?
ORDER BY last_edited_at ASC, local_id ASC
""",
(a, sid),
).fetchall()
out: list[dict[str, Any]] = []
for r in rows:
item = _row_to_dict(r) or {}
try:
item["message_id"] = format_message_id(item.get("db") or "", item.get("table_name") or "", item.get("local_id") or 0)
except Exception:
item["message_id"] = ""
out.append(item)
return out
finally:
try:
if conn is not None:
conn.close()
except Exception:
pass
def get_message_edit(account: str, session_id: str, message_id: str) -> Optional[dict[str, Any]]:
a = str(account or "").strip()
sid = str(session_id or "").strip()
if not a or not sid or not message_id:
return None
try:
db, table_name, local_id = parse_message_id(message_id)
except Exception:
return None
conn: Optional[sqlite3.Connection] = None
try:
conn = _connect()
row = conn.execute(
"""
SELECT *
FROM message_edits
WHERE account = ? AND session_id = ? AND db = ? AND table_name = ? AND local_id = ?
LIMIT 1
""",
(a, sid, db, table_name, int(local_id)),
).fetchone()
item = _row_to_dict(row)
if not item:
return None
item["message_id"] = format_message_id(db, table_name, local_id)
return item
finally:
try:
if conn is not None:
conn.close()
except Exception:
pass
def delete_message_edit(account: str, session_id: str, message_id: str) -> bool:
a = str(account or "").strip()
sid = str(session_id or "").strip()
if not a or not sid or not message_id:
return False
try:
db, table_name, local_id = parse_message_id(message_id)
except Exception:
return False
conn: Optional[sqlite3.Connection] = None
try:
conn = _connect()
cur = conn.execute(
"""
DELETE FROM message_edits
WHERE account = ? AND session_id = ? AND db = ? AND table_name = ? AND local_id = ?
""",
(a, sid, db, table_name, int(local_id)),
)
conn.commit()
return int(getattr(cur, "rowcount", 0) or 0) > 0
finally:
try:
if conn is not None:
conn.close()
except Exception:
pass
def update_message_edit_local_id(
*,
account: str,
session_id: str,
db: str,
table_name: str,
old_local_id: int,
new_local_id: int,
) -> bool:
"""Update the primary key local_id for an existing edit record (unsafe operations may change Msg.local_id)."""
a = str(account or "").strip()
sid = str(session_id or "").strip()
db_norm = str(db or "").strip()
t = str(table_name or "").strip()
old_lid = int(old_local_id or 0)
new_lid = int(new_local_id or 0)
if not a or not sid or not db_norm or not t or old_lid <= 0 or new_lid <= 0:
return False
if old_lid == new_lid:
return True
conn: Optional[sqlite3.Connection] = None
try:
conn = _connect()
cur = conn.execute(
"""
UPDATE message_edits
SET local_id = ?
WHERE account = ? AND session_id = ? AND db = ? AND table_name = ? AND local_id = ?
""",
(new_lid, a, sid, db_norm, t, old_lid),
)
conn.commit()
return int(getattr(cur, "rowcount", 0) or 0) > 0
except Exception:
return False
finally:
try:
if conn is not None:
conn.close()
except Exception:
pass
Binary file not shown.
+34 -12
View File
@@ -32,10 +32,8 @@ class PathFixRequest(Request):
def _validate_paths_in_json(self, json_data: dict) -> Optional[str]:
"""验证JSON中的路径,返回错误信息(如果有)"""
logger.info(f"开始验证路径,JSON数据: {json_data}")
# 检查db_storage_path字段(现在是必需的)
if 'db_storage_path' not in json_data:
return "缺少必需的db_storage_path参数,请提供具体的数据库存储路径。"
# 仅在提供 db_storage_path 时进行校验(例如 /api/decrypt)。
# 其它 API 的 JSON payload 不一定包含路径字段,不应强制要求。
if 'db_storage_path' in json_data:
path = json_data['db_storage_path']
@@ -115,11 +113,16 @@ class PathFixRequest(Request):
async def body(self) -> bytes:
"""重写body方法,预处理JSON中的路径问题"""
cached = getattr(self.state, "_pathfix_body_bytes", None)
if isinstance(cached, (bytes, bytearray)):
return bytes(cached)
body = await super().body()
# 只处理JSON请求
content_type = self.headers.get("content-type", "")
if "application/json" not in content_type:
self.state._pathfix_body_bytes = body
return body
try:
@@ -134,6 +137,7 @@ class PathFixRequest(Request):
logger.info(f"检测到路径错误: {path_error}")
# 我们将错误信息存储在请求中,稍后在路由处理器中检查
self.state.path_validation_error = path_error
self.state._pathfix_body_bytes = body
return body
except json.JSONDecodeError as e:
# JSON格式错误,继续尝试修复
@@ -169,17 +173,30 @@ class PathFixRequest(Request):
if path_error:
logger.info(f"修复后检测到路径错误: {path_error}")
self.state.path_validation_error = path_error
return fixed_body_str.encode('utf-8')
fixed_bytes = fixed_body_str.encode('utf-8')
self.state._pathfix_body_bytes = fixed_bytes
try:
self._body = fixed_bytes # type: ignore[attr-defined]
except Exception:
pass
return fixed_bytes
else:
logger.info(f"修复后路径验证通过")
except json.JSONDecodeError as e:
logger.warning(f"修复后JSON仍然解析失败: {e}")
return fixed_body_str.encode('utf-8')
fixed_bytes = fixed_body_str.encode('utf-8')
self.state._pathfix_body_bytes = fixed_bytes
try:
self._body = fixed_bytes # type: ignore[attr-defined]
except Exception:
pass
return fixed_bytes
except Exception as e:
# 如果处理失败,返回原始body
logger.warning(f"JSON路径修复失败,使用原始请求体: {e}")
self.state._pathfix_body_bytes = body
return body
@@ -193,12 +210,17 @@ class PathFixRoute(APIRoute):
# 将Request替换为我们的自定义Request
custom_request = PathFixRequest(request.scope, request.receive)
# 检查是否有路径验证错误
if hasattr(custom_request.state, 'path_validation_error'):
raise HTTPException(
status_code=400,
detail=custom_request.state.path_validation_error,
)
# 仅对 JSON 请求预读 body,以触发路径修复/校验逻辑,并在发现错误时提前返回 400。
try:
content_type = (custom_request.headers.get("content-type", "") or "").lower()
if "application/json" in content_type:
await custom_request.body()
except Exception:
pass
path_err = getattr(custom_request.state, "path_validation_error", None)
if path_err:
raise HTTPException(status_code=400, detail=path_err)
return await original_route_handler(custom_request)
File diff suppressed because it is too large Load Diff
+114
View File
@@ -175,6 +175,36 @@ def _load_wcdb_lib() -> ctypes.CDLL:
except Exception:
pass
# Optional (newer DLLs): update a single message content in message db.
# Signature: wcdb_update_message(handle, sessionId, localId, createTime, newContent, outError)
try:
lib.wcdb_update_message.argtypes = [
ctypes.c_int64,
ctypes.c_char_p,
ctypes.c_int64,
ctypes.c_int32,
ctypes.c_char_p,
ctypes.POINTER(ctypes.c_char_p),
]
lib.wcdb_update_message.restype = ctypes.c_int
except Exception:
pass
# Optional (newer DLLs): delete a single message in message db.
# Signature: wcdb_delete_message(handle, sessionId, localId, createTime, dbPathHint, outError)
try:
lib.wcdb_delete_message.argtypes = [
ctypes.c_int64,
ctypes.c_char_p,
ctypes.c_int64,
ctypes.c_int32,
ctypes.c_char_p,
ctypes.POINTER(ctypes.c_char_p),
]
lib.wcdb_delete_message.restype = ctypes.c_int
except Exception:
pass
# Optional (newer DLLs): wcdb_get_sns_timeline(handle, limit, offset, usernames_json, keyword, start_time, end_time, out_json)
try:
lib.wcdb_get_sns_timeline.argtypes = [
@@ -259,6 +289,32 @@ def _call_out_json(fn, *args) -> str:
pass
def _call_out_error(fn, *args) -> None:
lib = _load_wcdb_lib()
out = ctypes.c_char_p()
rc = int(fn(*args, ctypes.byref(out)))
try:
if rc != 0:
err = ""
try:
if out.value:
err = (out.value or b"").decode("utf-8", errors="replace")
except Exception:
err = ""
logs = get_native_logs()
hint = f" logs={logs[:6]}" if logs else ""
if err:
raise WCDBRealtimeError(f"wcdb api call failed: {rc}. error={err}.{hint}")
raise WCDBRealtimeError(f"wcdb api call failed: {rc}.{hint}")
finally:
try:
if out.value:
lib.wcdb_free_string(out)
except Exception:
pass
def get_native_logs() -> list[str]:
try:
_ensure_initialized()
@@ -500,6 +556,64 @@ def exec_query(handle: int, *, kind: str, path: Optional[str], sql: str) -> list
return []
def update_message(handle: int, *, session_id: str, local_id: int, create_time: int, new_content: str) -> None:
"""Update a single message content in the live encrypted db_storage via WCDB.
Requires wcdb_update_message export in wcdb_api.dll.
"""
_ensure_initialized()
lib = _load_wcdb_lib()
fn = getattr(lib, "wcdb_update_message", None)
if not fn:
raise WCDBRealtimeError("Current wcdb_api.dll does not support update_message.")
sid = str(session_id or "").strip()
if not sid:
raise WCDBRealtimeError("Missing session_id for update_message.")
_call_out_error(
fn,
ctypes.c_int64(int(handle)),
sid.encode("utf-8"),
ctypes.c_int64(int(local_id or 0)),
ctypes.c_int32(int(create_time or 0)),
str(new_content or "").encode("utf-8"),
)
def delete_message(
handle: int,
*,
session_id: str,
local_id: int,
create_time: int,
db_path_hint: str | None = None,
) -> None:
"""Delete a single message in the live encrypted db_storage via WCDB.
Requires wcdb_delete_message export in wcdb_api.dll.
"""
_ensure_initialized()
lib = _load_wcdb_lib()
fn = getattr(lib, "wcdb_delete_message", None)
if not fn:
raise WCDBRealtimeError("Current wcdb_api.dll does not support delete_message.")
sid = str(session_id or "").strip()
if not sid:
raise WCDBRealtimeError("Missing session_id for delete_message.")
hint = str(db_path_hint or "").strip()
_call_out_error(
fn,
ctypes.c_int64(int(handle)),
sid.encode("utf-8"),
ctypes.c_int64(int(local_id or 0)),
ctypes.c_int32(int(create_time or 0)),
hint.encode("utf-8"),
)
def get_sns_timeline(
handle: int,
*,
+182
View File
@@ -0,0 +1,182 @@
import os
import sys
import json
import sqlite3
import unittest
import importlib
from pathlib import Path
from tempfile import TemporaryDirectory
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT / "src"))
class TestChatEditStore(unittest.TestCase):
def setUp(self):
self._prev_data_dir = os.environ.get("WECHAT_TOOL_DATA_DIR")
self._td = TemporaryDirectory()
os.environ["WECHAT_TOOL_DATA_DIR"] = self._td.name
import wechat_decrypt_tool.app_paths as app_paths
import wechat_decrypt_tool.chat_edit_store as chat_edit_store
importlib.reload(app_paths)
importlib.reload(chat_edit_store)
self.app_paths = app_paths
self.store = chat_edit_store
def tearDown(self):
if self._prev_data_dir is None:
os.environ.pop("WECHAT_TOOL_DATA_DIR", None)
else:
os.environ["WECHAT_TOOL_DATA_DIR"] = self._prev_data_dir
self._td.cleanup()
def test_ensure_schema_creates_db(self):
self.store.ensure_schema()
db_path = self.app_paths.get_output_dir() / "message_edits.db"
self.assertTrue(db_path.exists())
conn = sqlite3.connect(str(db_path))
try:
row = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='message_edits' LIMIT 1"
).fetchone()
self.assertIsNotNone(row)
finally:
conn.close()
def test_blob_hex_roundtrip(self):
payload = {"a": b"\x00\xff", "nested": {"b": memoryview(b"\x01\x02")}}
dumped = self.store.dumps_json_with_blobs(payload)
self.assertIn("0x00ff", dumped.lower())
self.assertIn("0x0102", dumped.lower())
loaded = self.store.loads_json_with_blobs(dumped)
self.assertEqual(loaded["a"], b"\x00\xff")
self.assertEqual(loaded["nested"]["b"], b"\x01\x02")
def test_message_id_format_parse(self):
mid = self.store.format_message_id("message_0", "Msg_foo", 123)
self.assertEqual(mid, "message_0:Msg_foo:123")
db, table, local_id = self.store.parse_message_id(mid)
self.assertEqual(db, "message_0")
self.assertEqual(table, "Msg_foo")
self.assertEqual(local_id, 123)
with self.assertRaises(ValueError):
self.store.parse_message_id("bad")
def test_upsert_original_once_does_not_overwrite_snapshot(self):
now1 = 1000
now2 = 2000
self.store.upsert_original_once(
account="wxid_me",
session_id="wxid_you",
db="message_0",
table_name="Msg_foo",
local_id=1,
original_msg={"local_id": 1, "message_content": "hello", "compress_content": b"\x01"},
original_resource={"message_id": 9, "packed_info": b"\x02"},
now_ms=now1,
)
self.store.upsert_original_once(
account="wxid_me",
session_id="wxid_you",
db="message_0",
table_name="Msg_foo",
local_id=1,
original_msg={"local_id": 1, "message_content": "SHOULD_NOT_OVERWRITE", "compress_content": b"\x03"},
original_resource={"message_id": 9, "packed_info": b"\x04"},
now_ms=now2,
)
mid = self.store.format_message_id("message_0", "Msg_foo", 1)
item = self.store.get_message_edit("wxid_me", "wxid_you", mid)
self.assertIsNotNone(item)
self.assertEqual(int(item["first_edited_at"]), now1)
self.assertEqual(int(item["last_edited_at"]), now2)
self.assertEqual(int(item["edit_count"]), 2)
original_msg = self.store.loads_json_with_blobs(item["original_msg_json"])
self.assertEqual(original_msg["message_content"], "hello")
self.assertEqual(original_msg["compress_content"], b"\x01")
original_res = self.store.loads_json_with_blobs(item["original_resource_json"])
self.assertEqual(int(original_res["message_id"]), 9)
self.assertEqual(original_res["packed_info"], b"\x02")
def test_update_message_edit_local_id_moves_primary_key(self):
self.store.upsert_original_once(
account="wxid_me",
session_id="wxid_you",
db="message_0",
table_name="Msg_foo",
local_id=10,
original_msg={"local_id": 10, "message_content": "hello"},
original_resource=None,
now_ms=1234,
)
ok = self.store.update_message_edit_local_id(
account="wxid_me",
session_id="wxid_you",
db="message_0",
table_name="Msg_foo",
old_local_id=10,
new_local_id=11,
)
self.assertTrue(ok)
old_mid = self.store.format_message_id("message_0", "Msg_foo", 10)
new_mid = self.store.format_message_id("message_0", "Msg_foo", 11)
self.assertIsNone(self.store.get_message_edit("wxid_me", "wxid_you", old_mid))
self.assertIsNotNone(self.store.get_message_edit("wxid_me", "wxid_you", new_mid))
def test_list_sessions_counts(self):
self.store.upsert_original_once(
account="wxid_me",
session_id="u1",
db="message_0",
table_name="Msg_foo",
local_id=1,
original_msg={"local_id": 1, "message_content": "a"},
original_resource=None,
now_ms=100,
)
self.store.upsert_original_once(
account="wxid_me",
session_id="u1",
db="message_0",
table_name="Msg_foo",
local_id=2,
original_msg={"local_id": 2, "message_content": "b"},
original_resource=None,
now_ms=200,
)
self.store.upsert_original_once(
account="wxid_me",
session_id="u2",
db="message_0",
table_name="Msg_foo",
local_id=3,
original_msg={"local_id": 3, "message_content": "c"},
original_resource=None,
now_ms=300,
)
stats = self.store.list_sessions("wxid_me")
by_sid = {s["session_id"]: s for s in stats}
self.assertEqual(int(by_sid["u1"]["msg_count"]), 2)
self.assertEqual(int(by_sid["u1"]["last_edited_at"]), 200)
self.assertEqual(int(by_sid["u2"]["msg_count"]), 1)
self.assertEqual(int(by_sid["u2"]["last_edited_at"]), 300)
if __name__ == "__main__":
unittest.main()