fix(session): make SQLite transcript rewrites transactional

This commit is contained in:
ThomassJonax 2026-04-28 07:02:36 +03:00 committed by Teknium
parent 22ddac4b14
commit 2f9243c333
3 changed files with 113 additions and 17 deletions

View File

@ -1257,25 +1257,11 @@ class SessionStore:
Used by /retry, /undo, and /compress to persist modified conversation history.
Rewrites both SQLite and legacy JSONL storage.
"""
# SQLite: clear old messages and re-insert
# SQLite: replace atomically so a mid-rewrite failure doesn't leave
# the session half-empty in the DB while JSONL still has history.
if self._db:
try:
self._db.clear_messages(session_id)
for msg in messages:
role = msg.get("role", "unknown")
self._db.append_message(
session_id=session_id,
role=role,
content=msg.get("content"),
tool_name=msg.get("tool_name"),
tool_calls=msg.get("tool_calls"),
tool_call_id=msg.get("tool_call_id"),
reasoning=msg.get("reasoning") if role == "assistant" else None,
reasoning_content=msg.get("reasoning_content") if role == "assistant" else None,
reasoning_details=msg.get("reasoning_details") if role == "assistant" else None,
codex_reasoning_items=msg.get("codex_reasoning_items") if role == "assistant" else None,
codex_message_items=msg.get("codex_message_items") if role == "assistant" else None,
)
self._db.replace_messages(session_id, messages)
except Exception as e:
logger.debug("Failed to rewrite transcript in DB: %s", e)

View File

@ -1172,6 +1172,85 @@ class SessionDB:
return self._execute_write(_do)
def replace_messages(self, session_id: str, messages: List[Dict[str, Any]]) -> None:
"""Atomically replace every message for a session.
Used by transcript-rewrite flows such as /retry, /undo, and /compress.
The delete + reinsert sequence must commit as one transaction so a
mid-rewrite failure does not leave SQLite with a partial transcript.
"""
def _do(conn):
conn.execute(
"DELETE FROM messages WHERE session_id = ?", (session_id,)
)
conn.execute(
"UPDATE sessions SET message_count = 0, tool_call_count = 0 WHERE id = ?",
(session_id,),
)
now_ts = time.time()
total_messages = 0
total_tool_calls = 0
for msg in messages:
role = msg.get("role", "unknown")
tool_calls = msg.get("tool_calls")
reasoning_details = msg.get("reasoning_details") if role == "assistant" else None
codex_reasoning_items = (
msg.get("codex_reasoning_items") if role == "assistant" else None
)
codex_message_items = (
msg.get("codex_message_items") if role == "assistant" else None
)
reasoning_details_json = (
json.dumps(reasoning_details) if reasoning_details else None
)
codex_items_json = (
json.dumps(codex_reasoning_items) if codex_reasoning_items else None
)
codex_message_items_json = (
json.dumps(codex_message_items) if codex_message_items else None
)
tool_calls_json = json.dumps(tool_calls) if tool_calls else None
conn.execute(
"""INSERT INTO messages (session_id, role, content, tool_call_id,
tool_calls, tool_name, timestamp, token_count, finish_reason,
reasoning, reasoning_content, reasoning_details, codex_reasoning_items,
codex_message_items)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(
session_id,
role,
msg.get("content"),
msg.get("tool_call_id"),
tool_calls_json,
msg.get("tool_name"),
now_ts,
msg.get("token_count"),
msg.get("finish_reason"),
msg.get("reasoning") if role == "assistant" else None,
msg.get("reasoning_content") if role == "assistant" else None,
reasoning_details_json,
codex_items_json,
codex_message_items_json,
),
)
total_messages += 1
if tool_calls is not None:
total_tool_calls += (
len(tool_calls) if isinstance(tool_calls, list) else 1
)
now_ts += 1e-6
conn.execute(
"UPDATE sessions SET message_count = ?, tool_call_count = ? WHERE id = ?",
(total_messages, total_tool_calls, session_id),
)
self._execute_write(_do)
def get_messages(self, session_id: str) -> List[Dict[str, Any]]:
"""Load all messages for a session, ordered by timestamp."""
with self._lock:

View File

@ -1233,3 +1233,34 @@ class TestRewriteTranscriptPreservesReasoning:
assert after[0].get("reasoning_content") == "provider scratchpad"
assert after[0].get("reasoning_details") == [{"type": "summary", "text": "step by step"}]
assert after[0].get("codex_reasoning_items") == [{"id": "r1", "type": "reasoning"}]
def test_db_rewrite_is_atomic_on_insert_failure(self, tmp_path):
from hermes_state import SessionDB
db = SessionDB(db_path=tmp_path / "test.db")
session_id = "atomic-rewrite-test"
db.create_session(session_id=session_id, source="cli")
db.append_message(session_id=session_id, role="user", content="before user")
db.append_message(session_id=session_id, role="assistant", content="before assistant")
config = GatewayConfig()
with patch("gateway.session.SessionStore._ensure_loaded"):
store = SessionStore(sessions_dir=tmp_path, config=config)
store._db = db
store._loaded = True
replacement = [
{"role": "user", "content": "after user"},
{
"role": "assistant",
"content": {"not": "sqlite-bindable but JSONL-safe"},
},
]
store.rewrite_transcript(session_id, replacement)
after = db.get_messages_as_conversation(session_id)
assert [msg["content"] for msg in after] == [
"before user",
"before assistant",
]