diff --git a/gateway/session.py b/gateway/session.py index 02d4eb3e..36e187fe 100644 --- a/gateway/session.py +++ b/gateway/session.py @@ -1257,25 +1257,11 @@ class SessionStore: Used by /retry, /undo, and /compress to persist modified conversation history. Rewrites both SQLite and legacy JSONL storage. """ - # SQLite: clear old messages and re-insert + # SQLite: replace atomically so a mid-rewrite failure doesn't leave + # the session half-empty in the DB while JSONL still has history. if self._db: try: - self._db.clear_messages(session_id) - for msg in messages: - role = msg.get("role", "unknown") - self._db.append_message( - session_id=session_id, - role=role, - content=msg.get("content"), - tool_name=msg.get("tool_name"), - tool_calls=msg.get("tool_calls"), - tool_call_id=msg.get("tool_call_id"), - reasoning=msg.get("reasoning") if role == "assistant" else None, - reasoning_content=msg.get("reasoning_content") if role == "assistant" else None, - reasoning_details=msg.get("reasoning_details") if role == "assistant" else None, - codex_reasoning_items=msg.get("codex_reasoning_items") if role == "assistant" else None, - codex_message_items=msg.get("codex_message_items") if role == "assistant" else None, - ) + self._db.replace_messages(session_id, messages) except Exception as e: logger.debug("Failed to rewrite transcript in DB: %s", e) diff --git a/hermes_state.py b/hermes_state.py index 55895d48..e2ca5964 100644 --- a/hermes_state.py +++ b/hermes_state.py @@ -1172,6 +1172,85 @@ class SessionDB: return self._execute_write(_do) + def replace_messages(self, session_id: str, messages: List[Dict[str, Any]]) -> None: + """Atomically replace every message for a session. + + Used by transcript-rewrite flows such as /retry, /undo, and /compress. + The delete + reinsert sequence must commit as one transaction so a + mid-rewrite failure does not leave SQLite with a partial transcript. + """ + + def _do(conn): + conn.execute( + "DELETE FROM messages WHERE session_id = ?", (session_id,) + ) + conn.execute( + "UPDATE sessions SET message_count = 0, tool_call_count = 0 WHERE id = ?", + (session_id,), + ) + + now_ts = time.time() + total_messages = 0 + total_tool_calls = 0 + for msg in messages: + role = msg.get("role", "unknown") + tool_calls = msg.get("tool_calls") + reasoning_details = msg.get("reasoning_details") if role == "assistant" else None + codex_reasoning_items = ( + msg.get("codex_reasoning_items") if role == "assistant" else None + ) + codex_message_items = ( + msg.get("codex_message_items") if role == "assistant" else None + ) + + reasoning_details_json = ( + json.dumps(reasoning_details) if reasoning_details else None + ) + codex_items_json = ( + json.dumps(codex_reasoning_items) if codex_reasoning_items else None + ) + codex_message_items_json = ( + json.dumps(codex_message_items) if codex_message_items else None + ) + tool_calls_json = json.dumps(tool_calls) if tool_calls else None + + conn.execute( + """INSERT INTO messages (session_id, role, content, tool_call_id, + tool_calls, tool_name, timestamp, token_count, finish_reason, + reasoning, reasoning_content, reasoning_details, codex_reasoning_items, + codex_message_items) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + ( + session_id, + role, + msg.get("content"), + msg.get("tool_call_id"), + tool_calls_json, + msg.get("tool_name"), + now_ts, + msg.get("token_count"), + msg.get("finish_reason"), + msg.get("reasoning") if role == "assistant" else None, + msg.get("reasoning_content") if role == "assistant" else None, + reasoning_details_json, + codex_items_json, + codex_message_items_json, + ), + ) + total_messages += 1 + if tool_calls is not None: + total_tool_calls += ( + len(tool_calls) if isinstance(tool_calls, list) else 1 + ) + now_ts += 1e-6 + + conn.execute( + "UPDATE sessions SET message_count = ?, tool_call_count = ? WHERE id = ?", + (total_messages, total_tool_calls, session_id), + ) + + self._execute_write(_do) + def get_messages(self, session_id: str) -> List[Dict[str, Any]]: """Load all messages for a session, ordered by timestamp.""" with self._lock: diff --git a/tests/gateway/test_session.py b/tests/gateway/test_session.py index 228f414a..45afc671 100644 --- a/tests/gateway/test_session.py +++ b/tests/gateway/test_session.py @@ -1233,3 +1233,34 @@ class TestRewriteTranscriptPreservesReasoning: assert after[0].get("reasoning_content") == "provider scratchpad" assert after[0].get("reasoning_details") == [{"type": "summary", "text": "step by step"}] assert after[0].get("codex_reasoning_items") == [{"id": "r1", "type": "reasoning"}] + + def test_db_rewrite_is_atomic_on_insert_failure(self, tmp_path): + from hermes_state import SessionDB + + db = SessionDB(db_path=tmp_path / "test.db") + session_id = "atomic-rewrite-test" + db.create_session(session_id=session_id, source="cli") + db.append_message(session_id=session_id, role="user", content="before user") + db.append_message(session_id=session_id, role="assistant", content="before assistant") + + config = GatewayConfig() + with patch("gateway.session.SessionStore._ensure_loaded"): + store = SessionStore(sessions_dir=tmp_path, config=config) + store._db = db + store._loaded = True + + replacement = [ + {"role": "user", "content": "after user"}, + { + "role": "assistant", + "content": {"not": "sqlite-bindable but JSONL-safe"}, + }, + ] + + store.rewrite_transcript(session_id, replacement) + + after = db.get_messages_as_conversation(session_id) + assert [msg["content"] for msg in after] == [ + "before user", + "before assistant", + ]