fix(hindsight): drain retain queue cleanly on shutdown

The plugin used to spawn one daemon thread per sync_turn() to do the
aretain_batch network write. On CLI exit, that pattern raced interpreter
shutdown — the last retain could reach aiohttp after asyncio's
"cannot schedule new futures" guard had fired, producing noisy logs and
silently losing the final unsaved turn:

    WARNING ... Hindsight sync failed: cannot schedule new futures after
            interpreter shutdown
    ERROR asyncio: Unclosed client session
            client_session: <aiohttp.client.ClientSession object at 0x...>

Switch to a single-writer model: each provider owns one long-lived
writer thread plus a queue. sync_turn() snapshots state and enqueues a
job; the writer drains sequentially. Once shutdown() is called:

  - new sync_turn() / queue_prefetch() calls are dropped, not enqueued
  - a sentinel wakes the writer so it finishes in-flight work
  - shutdown joins the writer (10s) before nulling the client

Also register an idempotent atexit hook from the first sync_turn(), so
exit paths that don't go through MemoryManager.shutdown_all() (Ctrl-C,
abrupt exit) still get a chance to drain.

Tests: keep _sync_thread as a legacy alias to the writer, swap join()
calls to _retain_queue.join() (canonical wait-for-drain), add a new
TestShutdownRace suite covering single-writer reuse, post-shutdown drop,
queue draining, and shutdown idempotency.
This commit is contained in:
Nicolò Boschi 2026-04-28 14:49:14 +02:00 committed by Teknium
parent 5662ac2afc
commit 0565497dcc
2 changed files with 228 additions and 57 deletions

View File

@ -29,10 +29,12 @@ Or via $HERMES_HOME/hindsight/config.json (profile-scoped), falling back to
from __future__ import annotations
import asyncio
import atexit
import importlib
import json
import logging
import os
import queue
import threading
from datetime import datetime, timezone
@ -100,6 +102,10 @@ _loop: asyncio.AbstractEventLoop | None = None
_loop_thread: threading.Thread | None = None
_loop_lock = threading.Lock()
# Sentinel pushed to the per-provider retain queue to wake the writer for a
# clean exit. A unique object so it can never collide with a real job.
_WRITER_SENTINEL = object()
def _get_loop() -> asyncio.AbstractEventLoop:
"""Return a long-lived event loop running on a background thread."""
@ -444,6 +450,16 @@ class HindsightMemoryProvider(MemoryProvider):
self._prefetch_result = ""
self._prefetch_lock = threading.Lock()
self._prefetch_thread = None
# Single-writer model for retain. sync_turn() enqueues; the writer
# thread drains sequentially. Avoids spawning ad-hoc threads that
# can race the interpreter shutdown and emit "cannot schedule new
# futures after interpreter shutdown" / "Unclosed client session".
self._retain_queue: queue.Queue = queue.Queue()
self._writer_thread: threading.Thread | None = None
self._shutting_down = threading.Event()
self._atexit_registered = False
# Legacy alias — older tests/callers reference _sync_thread directly.
# Points at _writer_thread once the writer is running.
self._sync_thread = None
self._session_id = ""
self._parent_session_id = ""
@ -818,6 +834,73 @@ class HindsightMemoryProvider(MemoryProvider):
)
)
def _ensure_writer(self) -> None:
"""Lazy-start the single retain-writer thread.
We don't start the writer in initialize() so providers that never
retain (e.g. tools-only mode) don't pay for an idle thread.
"""
thread = self._writer_thread
if thread is not None and thread.is_alive():
return
# If the previous writer exited (e.g. after a prior shutdown), reset
# the flag so this fresh writer is allowed to drain new jobs.
self._shutting_down.clear()
thread = threading.Thread(
target=self._writer_loop,
daemon=True,
name="hindsight-writer",
)
self._writer_thread = thread
# Keep the legacy _sync_thread alias pointing at the writer so any
# external code that joins _sync_thread keeps working.
self._sync_thread = thread
thread.start()
def _writer_loop(self) -> None:
"""Drain the retain queue serially. Exits on sentinel.
Each job() is wrapped so a single failure can't kill the writer.
task_done() always fires so queue.join() works in tests.
"""
while True:
try:
job = self._retain_queue.get(timeout=1.0)
except queue.Empty:
if self._shutting_down.is_set():
return
continue
try:
if job is _WRITER_SENTINEL:
return
try:
job()
except Exception as exc:
logger.warning("Hindsight retain failed: %s", exc, exc_info=True)
finally:
self._retain_queue.task_done()
def _register_atexit(self) -> None:
"""Register an idempotent atexit hook to drain the writer.
Without this, a CLI exit that doesn't go through MemoryManager.
shutdown_all() would leave in-flight retain jobs racing interpreter
teardown, producing "cannot schedule new futures" warnings and
unclosed aiohttp sessions.
"""
if self._atexit_registered:
return
self._atexit_registered = True
atexit.register(self._atexit_shutdown)
def _atexit_shutdown(self) -> None:
if self._shutting_down.is_set():
return
try:
self.shutdown()
except Exception as exc:
logger.debug("Hindsight atexit shutdown failed: %s", exc)
def _run_hindsight_operation(self, operation):
"""Run an async Hindsight client operation, retrying once after idle shutdown."""
client = self._get_client()
@ -1081,6 +1164,9 @@ class HindsightMemoryProvider(MemoryProvider):
if not self._auto_recall:
logger.debug("Prefetch: skipped (auto_recall disabled)")
return
if self._shutting_down.is_set():
logger.debug("Prefetch: skipped (shutting down)")
return
# Truncate query to max chars
if self._recall_max_input_chars and len(query) > self._recall_max_input_chars:
query = query[:self._recall_max_input_chars]
@ -1189,13 +1275,19 @@ class HindsightMemoryProvider(MemoryProvider):
return kwargs
def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None:
"""Retain conversation turn in background (non-blocking).
"""Enqueue a retain for the current turn. Non-blocking.
Respects retain_every_n_turns for batching.
The actual aretain_batch runs on a single long-lived writer thread
that drains an in-memory queue. Once shutdown() has been called,
further sync_turn() calls are dropped this prevents post-exit
retains from reaching aiohttp after interpreter shutdown begins.
"""
if not self._auto_retain:
logger.debug("sync_turn: skipped (auto_retain disabled)")
return
if self._shutting_down.is_set():
logger.debug("sync_turn: skipped (shutting down)")
return
if session_id:
self._session_id = str(session_id).strip()
@ -1220,37 +1312,42 @@ class HindsightMemoryProvider(MemoryProvider):
if self._parent_session_id:
lineage_tags.append(f"parent:{self._parent_session_id}")
def _sync():
try:
item = self._build_retain_kwargs(
content,
context=self._retain_context,
metadata=self._build_metadata(
message_count=len(self._session_turns) * 2,
turn_index=self._turn_index,
),
tags=lineage_tags or None,
)
item.pop("bank_id", None)
item.pop("retain_async", None)
logger.debug("Hindsight retain: bank=%s, doc=%s, async=%s, content_len=%d, num_turns=%d",
self._bank_id, self._document_id, self._retain_async, len(content), len(self._session_turns))
self._run_hindsight_operation(
lambda client: client.aretain_batch(
bank_id=self._bank_id,
items=[item],
document_id=self._document_id,
retain_async=self._retain_async,
)
)
logger.debug("Hindsight retain succeeded")
except Exception as e:
logger.warning("Hindsight sync failed: %s", e, exc_info=True)
# Snapshot the state needed for the retain. The writer may run after
# _session_turns / _turn_index are mutated by a later sync_turn().
metadata_snapshot = self._build_metadata(
message_count=len(self._session_turns) * 2,
turn_index=self._turn_index,
)
num_turns = len(self._session_turns)
document_id = self._document_id
bank_id = self._bank_id
retain_async_flag = self._retain_async
retain_context = self._retain_context
if self._sync_thread and self._sync_thread.is_alive():
self._sync_thread.join(timeout=5.0)
self._sync_thread = threading.Thread(target=_sync, daemon=True, name="hindsight-sync")
self._sync_thread.start()
def _do_retain() -> None:
item = self._build_retain_kwargs(
content,
context=retain_context,
metadata=metadata_snapshot,
tags=lineage_tags or None,
)
item.pop("bank_id", None)
item.pop("retain_async", None)
logger.debug("Hindsight retain: bank=%s, doc=%s, async=%s, content_len=%d, num_turns=%d",
bank_id, document_id, retain_async_flag, len(content), num_turns)
self._run_hindsight_operation(
lambda client: client.aretain_batch(
bank_id=bank_id,
items=[item],
document_id=document_id,
retain_async=retain_async_flag,
)
)
logger.debug("Hindsight retain succeeded")
self._ensure_writer()
self._register_atexit()
self._retain_queue.put(_do_retain)
def get_tool_schemas(self) -> List[Dict[str, Any]]:
if self._memory_mode == "context":
@ -1371,10 +1468,28 @@ class HindsightMemoryProvider(MemoryProvider):
)
def shutdown(self) -> None:
logger.debug("Hindsight shutdown: waiting for background threads")
for t in (self._prefetch_thread, self._sync_thread):
if t and t.is_alive():
t.join(timeout=5.0)
logger.debug("Hindsight shutdown: stopping writer + waiting for background threads")
# Stop accepting new retain jobs first so anyone still calling
# sync_turn() during teardown is dropped, not enqueued.
self._shutting_down.set()
# Drain the writer: it will finish in-flight work, then exit on
# the sentinel. Bounded join keeps shutdown predictable even if
# the daemon is wedged.
writer = self._writer_thread
if writer is not None and writer.is_alive():
try:
self._retain_queue.put(_WRITER_SENTINEL)
except Exception:
pass
writer.join(timeout=10.0)
if writer.is_alive():
logger.warning(
"Hindsight writer did not stop within 10s; "
"abandoning %d pending retain(s)",
self._retain_queue.qsize(),
)
if self._prefetch_thread and self._prefetch_thread.is_alive():
self._prefetch_thread.join(timeout=5.0)
if self._client is not None:
try:
if self._mode == "local_embedded":

View File

@ -669,7 +669,7 @@ class TestSyncTurn:
p._client = _make_mock_client()
p.sync_turn("hello", "hi there")
p._sync_thread.join(timeout=5.0)
p._retain_queue.join()
p._client.aretain_batch.assert_called_once()
call_kwargs = p._client.aretain_batch.call_args.kwargs
@ -710,8 +710,7 @@ class TestSyncTurn:
def test_sync_turn_with_tags(self, provider_with_config):
p = provider_with_config(retain_tags=["conv", "session1"])
p.sync_turn("hello", "hi")
if p._sync_thread:
p._sync_thread.join(timeout=5.0)
p._retain_queue.join()
item = p._client.aretain_batch.call_args.kwargs["items"][0]
assert "conv" in item["tags"]
assert "session1" in item["tags"]
@ -720,8 +719,7 @@ class TestSyncTurn:
def test_sync_turn_uses_aretain_batch(self, provider):
"""sync_turn should use aretain_batch with retain_async."""
provider.sync_turn("hello", "hi")
if provider._sync_thread:
provider._sync_thread.join(timeout=5.0)
provider._retain_queue.join()
provider._client.aretain_batch.assert_called_once()
call_kwargs = provider._client.aretain_batch.call_args.kwargs
assert call_kwargs["document_id"].startswith("test-session-")
@ -732,8 +730,7 @@ class TestSyncTurn:
def test_sync_turn_custom_context(self, provider_with_config):
p = provider_with_config(retain_context="my-agent")
p.sync_turn("hello", "hi")
if p._sync_thread:
p._sync_thread.join(timeout=5.0)
p._retain_queue.join()
item = p._client.aretain_batch.call_args.kwargs["items"][0]
assert item["context"] == "my-agent"
@ -744,7 +741,7 @@ class TestSyncTurn:
p.sync_turn("turn2-user", "turn2-asst")
assert p._sync_thread is None
p.sync_turn("turn3-user", "turn3-asst")
p._sync_thread.join(timeout=5.0)
p._retain_queue.join()
p._client.aretain_batch.assert_called_once()
call_kwargs = p._client.aretain_batch.call_args.kwargs
assert call_kwargs["document_id"].startswith("test-session-")
@ -765,15 +762,13 @@ class TestSyncTurn:
p.sync_turn("turn1-user", "turn1-asst")
p.sync_turn("turn2-user", "turn2-asst")
if p._sync_thread:
p._sync_thread.join(timeout=5.0)
p._retain_queue.join()
p._client.aretain_batch.reset_mock()
p.sync_turn("turn3-user", "turn3-asst")
p.sync_turn("turn4-user", "turn4-asst")
if p._sync_thread:
p._sync_thread.join(timeout=5.0)
p._retain_queue.join()
content = p._client.aretain_batch.call_args.kwargs["items"][0]["content"]
# Should contain ALL turns from the session
@ -785,8 +780,7 @@ class TestSyncTurn:
def test_sync_turn_passes_document_id(self, provider):
"""sync_turn should pass document_id (session_id + per-startup ts)."""
provider.sync_turn("hello", "hi")
if provider._sync_thread:
provider._sync_thread.join(timeout=5.0)
provider._retain_queue.join()
call_kwargs = provider._client.aretain_batch.call_args.kwargs
# Format: {session_id}-{YYYYMMDD_HHMMSS_microseconds}
assert call_kwargs["document_id"].startswith("test-session-")
@ -819,8 +813,7 @@ class TestSyncTurn:
def test_sync_turn_session_tag(self, provider):
"""Each retain should be tagged with session:<id> for filtering."""
provider.sync_turn("hello", "hi")
if provider._sync_thread:
provider._sync_thread.join(timeout=5.0)
provider._retain_queue.join()
item = provider._client.aretain_batch.call_args.kwargs["items"][0]
assert "session:test-session" in item["tags"]
@ -841,8 +834,7 @@ class TestSyncTurn:
)
p._client = _make_mock_client()
p.sync_turn("hello", "hi")
if p._sync_thread:
p._sync_thread.join(timeout=5.0)
p._retain_queue.join()
item = p._client.aretain_batch.call_args.kwargs["items"][0]
assert "session:child-session" in item["tags"]
@ -851,15 +843,14 @@ class TestSyncTurn:
def test_sync_turn_error_does_not_raise(self, provider):
provider._client.aretain_batch.side_effect = RuntimeError("network error")
provider.sync_turn("hello", "hi")
if provider._sync_thread:
provider._sync_thread.join(timeout=5.0)
provider._retain_queue.join()
def test_sync_turn_preserves_unicode(self, provider_with_config):
"""Non-ASCII text (CJK, ZWJ emoji) must survive JSON round-trip intact."""
p = provider_with_config()
p._client = _make_mock_client()
p.sync_turn("안녕 こんにちは 你好", "👨‍👩‍👧‍👦 family")
p._sync_thread.join(timeout=5.0)
p._retain_queue.join()
p._client.aretain_batch.assert_called_once()
item = p._client.aretain_batch.call_args.kwargs["items"][0]
# ensure_ascii=False means non-ASCII chars appear as-is in the raw JSON,
@ -871,6 +862,71 @@ class TestSyncTurn:
assert "👨‍👩‍👧‍👦" in raw_json
# ---------------------------------------------------------------------------
# Shutdown / writer tests
# ---------------------------------------------------------------------------
class TestShutdownRace:
def test_sync_turn_uses_single_writer_thread(self, provider):
"""All retains run through one long-lived writer thread."""
provider.sync_turn("a", "b")
provider._retain_queue.join()
first_writer = provider._writer_thread
assert first_writer is not None
assert first_writer.is_alive()
provider.sync_turn("c", "d")
provider._retain_queue.join()
# Same thread reused — no ad-hoc thread per call.
assert provider._writer_thread is first_writer
assert provider._client.aretain_batch.call_count == 2
def test_sync_turn_after_shutdown_is_dropped(self, provider):
"""Once shutdown has fired, new sync_turn() calls are no-ops.
This is the core of the fix: the plugin must not enqueue a retain
during interpreter teardown that's what causes the
'cannot schedule new futures' RuntimeError + unclosed aiohttp
sessions on CLI exit.
"""
client = provider._client
provider.shutdown()
before_calls = client.aretain_batch.call_count
provider.sync_turn("late", "turn")
# No new enqueue — the retain queue stays empty.
assert provider._retain_queue.empty()
# And no new client call (would be impossible anyway since shutdown
# nulled self._client; we assert via the captured handle).
assert client.aretain_batch.call_count == before_calls
def test_queue_prefetch_after_shutdown_is_dropped(self, provider):
provider.shutdown()
provider.queue_prefetch("late query")
assert provider._prefetch_thread is None
def test_shutdown_drains_pending_retains(self, provider):
"""Shutdown must wait for queued retains to complete, not abandon them.
Otherwise the LAST in-flight turn typically the most important
is silently lost.
"""
client = provider._client
provider.sync_turn("a", "b")
provider.sync_turn("c", "d")
provider.shutdown()
# Both retains drained before shutdown returned.
assert client.aretain_batch.call_count == 2
assert provider._retain_queue.empty()
def test_shutdown_is_idempotent(self, provider):
provider.sync_turn("a", "b")
provider.shutdown()
# Second shutdown shouldn't blow up or re-close the client.
provider.shutdown()
assert provider._shutting_down.is_set()
# ---------------------------------------------------------------------------
# System prompt tests
# ---------------------------------------------------------------------------