feat(memory): notify providers on mid-process session_id rotation (#17409)
Fixes #6672 Memory providers now receive on_session_switch() whenever AIAgent.session_id rotates mid-process — /resume, /branch, /reset, /new, and context compression. Before this, providers that cached per-session state in initialize() (Hindsight's _session_id, _document_id, accumulated _session_turns, _turn_counter) kept writing into the old session's record after the agent had moved on. MemoryProvider ABC ------------------ - New optional hook on_session_switch(new_session_id, *, parent_session_id='', reset=False, **kwargs) with no-op default for backward compat. reset=True signals /reset or /new — providers should flush accumulated per-session buffers. reset=False for /resume, /branch, compression where the logical conversation continues. MemoryManager ------------- - on_session_switch() fans the hook out to every registered provider. Isolated try/except per provider — one bad provider can't block others. - Empty/None new_session_id is a no-op to avoid corrupting provider state during shutdown paths. run_agent.py ------------ - _sync_external_memory_for_turn now passes session_id=self.session_id into sync_all() and queue_prefetch_all(). Providers with defensive session_id updates in sync_turn (Hindsight already had this at plugins/memory/hindsight/__init__.py:1199) now actually receive the current id. - Compression block at ~L8884 already notified the context engine of the rollover; now also calls _memory_manager.on_session_switch(reason='compression'). cli.py ------ - new_session() fires reset=True, reason='new_session' so providers flush buffers. - _handle_resume_command fires reset=False, reason='resume' with the previous session as parent_session_id. - _handle_branch_command fires reset=False, reason='branch' with the parent session_id already captured for the DB parent link. gateway/run.py -------------- - _handle_resume_command now evicts the cached AIAgent, mirroring /branch and /reset. The next message rebuilds a fresh agent whose memory provider initialize() runs with the correct session_id — matches the pattern the gateway already uses for provider state cross-session transitions. Hindsight reference implementation ---------------------------------- - plugins/memory/hindsight/__init__.py adds on_session_switch that: updates _session_id, mints a fresh _document_id (prevents vectorize-io/hindsight#1303 overwrite), and clears _session_turns / _turn_counter / _turn_index so in-flight batches don't flush under the new document id. parent_session_id only overwritten when provided (avoids clobbering on a bare switch). Tests ----- - tests/agent/test_memory_session_switch.py: new dedicated file. ABC default no-op, manager fan-out, failure isolation, empty-id no-op, session_id propagation through sync_all/queue_prefetch_all, Hindsight state transitions for every reset/non-reset case, parent preservation. - tests/cli/test_branch_command.py: new test verifying /branch fires the hook with correct parent_session_id + reset=False + reason. - tests/gateway/test_resume_command.py: new test verifying /resume evicts the cached agent. - tests/run_agent/test_memory_sync_interrupted.py: updated existing assertions to account for the session_id kwarg on sync_all and queue_prefetch_all. E2E verified (real imports, tmp HERMES_HOME): - /resume: session_id updates, doc_id fresh, buffers cleared, parent set - /branch: session_id forks, parent links to original - /new: reset=True clears accumulated state - compression: reason='compression' propagated, lineage preserved - Empty id: no-op, state preserved - Legacy provider without on_session_switch: no crash Reported by @nicoloboschi (Hindsight maintainer); related scope-widening comment by @kidonng extending coverage to compression.
This commit is contained in:
parent
d244596dba
commit
13683c0842
@ -402,6 +402,41 @@ class MemoryManager:
|
||||
provider.name, e,
|
||||
)
|
||||
|
||||
def on_session_switch(
|
||||
self,
|
||||
new_session_id: str,
|
||||
*,
|
||||
parent_session_id: str = "",
|
||||
reset: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Notify all providers that the agent's session_id has rotated.
|
||||
|
||||
Fires on ``/resume``, ``/branch``, ``/reset``, ``/new``, and
|
||||
context compression — any path that reassigns
|
||||
``AIAgent.session_id`` without tearing the provider down.
|
||||
|
||||
Providers keep running; they only need to refresh cached
|
||||
per-session state so subsequent writes land in the correct
|
||||
session's record. See ``MemoryProvider.on_session_switch`` for
|
||||
the full contract.
|
||||
"""
|
||||
if not new_session_id:
|
||||
return
|
||||
for provider in self._providers:
|
||||
try:
|
||||
provider.on_session_switch(
|
||||
new_session_id,
|
||||
parent_session_id=parent_session_id,
|
||||
reset=reset,
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"Memory provider '%s' on_session_switch failed: %s",
|
||||
provider.name, e,
|
||||
)
|
||||
|
||||
def on_pre_compress(self, messages: List[Dict[str, Any]]) -> str:
|
||||
"""Notify all providers before context compression.
|
||||
|
||||
|
||||
@ -25,6 +25,7 @@ Lifecycle (called by MemoryManager, wired in run_agent.py):
|
||||
Optional hooks (override to opt in):
|
||||
on_turn_start(turn, message, **kwargs) — per-turn tick with runtime context
|
||||
on_session_end(messages) — end-of-session extraction
|
||||
on_session_switch(new_session_id, **kwargs) — mid-process session_id rotation
|
||||
on_pre_compress(messages) -> str — extract before context compression
|
||||
on_memory_write(action, target, content, metadata=None) — mirror built-in memory writes
|
||||
on_delegation(task, result, **kwargs) — parent-side observation of subagent work
|
||||
@ -160,6 +161,45 @@ class MemoryProvider(ABC):
|
||||
(CLI exit, /reset, gateway session expiry).
|
||||
"""
|
||||
|
||||
def on_session_switch(
|
||||
self,
|
||||
new_session_id: str,
|
||||
*,
|
||||
parent_session_id: str = "",
|
||||
reset: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Called when the agent switches session_id mid-process.
|
||||
|
||||
Fires on ``/resume``, ``/branch``, ``/reset``, ``/new`` (CLI), the
|
||||
gateway equivalents, and context compression — any path that
|
||||
reassigns ``AIAgent.session_id`` without tearing the provider down.
|
||||
|
||||
Providers that cache per-session state in ``initialize()``
|
||||
(``_session_id``, ``_document_id``, accumulated turn buffers,
|
||||
counters) should update or reset that state here so subsequent
|
||||
writes land in the correct session's record.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
new_session_id:
|
||||
The session_id the agent just switched to.
|
||||
parent_session_id:
|
||||
The previous session_id, if meaningful — set for ``/branch``
|
||||
(fork lineage), context compression (continuation lineage),
|
||||
and ``/resume`` (the session we're leaving). Empty string
|
||||
when no lineage applies.
|
||||
reset:
|
||||
``True`` when this is a genuinely new conversation, not a
|
||||
resumption of an existing one. Fired by ``/reset`` / ``/new``.
|
||||
Providers should flush accumulated per-session buffers
|
||||
(``_session_turns``, ``_turn_counter``, etc.) when this is
|
||||
set. ``False`` for ``/resume`` / ``/branch`` / compression
|
||||
where the logical conversation continues under the new id.
|
||||
|
||||
Default is no-op for backward compatibility.
|
||||
"""
|
||||
|
||||
def on_pre_compress(self, messages: List[Dict[str, Any]]) -> str:
|
||||
"""Called before context compression discards old messages.
|
||||
|
||||
|
||||
49
cli.py
49
cli.py
@ -4809,6 +4809,22 @@ class HermesCLI:
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
# Notify memory providers that session_id rotated to a fresh
|
||||
# conversation. reset=True signals providers to flush accumulated
|
||||
# per-session state (_session_turns, _turn_counter, _document_id).
|
||||
# Fires BEFORE the plugin on_session_reset hook (shell hooks only
|
||||
# see the new id; Python providers see the transition). See #6672.
|
||||
try:
|
||||
_mm = getattr(self.agent, "_memory_manager", None)
|
||||
if _mm is not None:
|
||||
_mm.on_session_switch(
|
||||
self.session_id,
|
||||
parent_session_id=old_session_id or "",
|
||||
reset=True,
|
||||
reason="new_session",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
self._notify_session_boundary("on_session_reset")
|
||||
|
||||
if not silent:
|
||||
@ -4861,6 +4877,7 @@ class HermesCLI:
|
||||
_cprint(" Already on that session.")
|
||||
return
|
||||
|
||||
old_session_id = self.session_id
|
||||
# End current session
|
||||
try:
|
||||
self._session_db.end_session(self.session_id, "resumed_other")
|
||||
@ -4898,6 +4915,22 @@ class HermesCLI:
|
||||
if hasattr(self.agent, "_invalidate_system_prompt"):
|
||||
self.agent._invalidate_system_prompt()
|
||||
|
||||
# Notify memory providers that session_id rotated to a resumed
|
||||
# session. reset=False — the provider's accumulated state is
|
||||
# still valid; it just needs to target the new session_id for
|
||||
# subsequent writes. See #6672.
|
||||
try:
|
||||
_mm = getattr(self.agent, "_memory_manager", None)
|
||||
if _mm is not None:
|
||||
_mm.on_session_switch(
|
||||
target_id,
|
||||
parent_session_id=old_session_id or "",
|
||||
reset=False,
|
||||
reason="resume",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
title_part = f" \"{session_meta['title']}\"" if session_meta.get("title") else ""
|
||||
msg_count = len([m for m in self.conversation_history if m.get("role") == "user"])
|
||||
if self.conversation_history:
|
||||
@ -5018,6 +5051,22 @@ class HermesCLI:
|
||||
if hasattr(self.agent, "_invalidate_system_prompt"):
|
||||
self.agent._invalidate_system_prompt()
|
||||
|
||||
# Notify memory providers that session_id forked to a new branch.
|
||||
# reset=False — the branched session carries the transcript
|
||||
# forward, so provider state tracks the lineage. parent_session_id
|
||||
# links the branch back to the original. See #6672.
|
||||
try:
|
||||
_mm = getattr(self.agent, "_memory_manager", None)
|
||||
if _mm is not None:
|
||||
_mm.on_session_switch(
|
||||
new_session_id,
|
||||
parent_session_id=parent_session_id or "",
|
||||
reset=False,
|
||||
reason="branch",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
msg_count = len([m for m in self.conversation_history if m.get("role") == "user"])
|
||||
_cprint(
|
||||
f" ⑂ Branched session \"{branch_title}\""
|
||||
|
||||
@ -7817,6 +7817,13 @@ class GatewayRunner:
|
||||
return "Failed to switch session."
|
||||
self._clear_session_boundary_security_state(session_key)
|
||||
|
||||
# Evict any cached agent for this session so the next message
|
||||
# rebuilds with the correct session_id end-to-end — mirrors
|
||||
# /branch and /reset. Without this, the cached AIAgent (and its
|
||||
# memory provider, which cached `_session_id` during initialize())
|
||||
# keeps writing into the wrong session's record. See #6672.
|
||||
self._evict_cached_agent(session_key)
|
||||
|
||||
# Get the title for confirmation
|
||||
title = self._session_db.get_session_title(target_id) or name
|
||||
|
||||
|
||||
@ -1325,6 +1325,51 @@ class HindsightMemoryProvider(MemoryProvider):
|
||||
|
||||
return tool_error(f"Unknown tool: {tool_name}")
|
||||
|
||||
def on_session_switch(
|
||||
self,
|
||||
new_session_id: str,
|
||||
*,
|
||||
parent_session_id: str = "",
|
||||
reset: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Refresh cached per-session state when the agent rotates session_id.
|
||||
|
||||
Fires on /resume, /branch, /reset, /new, and context compression.
|
||||
Without this hook, initialize()-cached state (``_session_id``,
|
||||
``_document_id``, ``_session_turns``, ``_turn_counter``) would keep
|
||||
pointing at the previous session and writes would land in the wrong
|
||||
document. See hermes-agent#6672.
|
||||
|
||||
Always update ``_session_id`` so metadata and tags on subsequent
|
||||
retains reflect the active session. Always mint a fresh
|
||||
``_document_id`` so the new session's retain doesn't overwrite the
|
||||
old session's document on vectorize-io/hindsight#1303. Always clear
|
||||
the accumulated batch buffers (``_session_turns``, ``_turn_counter``,
|
||||
``_turn_index``) — even for /resume and /branch, the new session's
|
||||
batching must start from zero so an in-flight retain doesn't flush
|
||||
under the wrong ``_document_id``.
|
||||
|
||||
``parent_session_id`` is recorded for lineage tags on future retains.
|
||||
``reset`` is accepted but not needed for Hindsight's state model —
|
||||
buffer clearing is correct for every session switch, not only /reset.
|
||||
"""
|
||||
new_id = str(new_session_id or "").strip()
|
||||
if not new_id:
|
||||
return
|
||||
if parent_session_id:
|
||||
self._parent_session_id = str(parent_session_id).strip()
|
||||
self._session_id = new_id
|
||||
start_ts = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
self._document_id = f"{self._session_id}-{start_ts}"
|
||||
self._session_turns = []
|
||||
self._turn_counter = 0
|
||||
self._turn_index = 0
|
||||
logger.debug(
|
||||
"Hindsight on_session_switch: new_session=%s parent=%s reset=%s doc=%s",
|
||||
self._session_id, self._parent_session_id, reset, self._document_id,
|
||||
)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
logger.debug("Hindsight shutdown: waiting for background threads")
|
||||
for t in (self._prefetch_thread, self._sync_thread):
|
||||
|
||||
27
run_agent.py
27
run_agent.py
@ -4565,8 +4565,14 @@ class AIAgent:
|
||||
if not (self._memory_manager and final_response and original_user_message):
|
||||
return
|
||||
try:
|
||||
self._memory_manager.sync_all(original_user_message, final_response)
|
||||
self._memory_manager.queue_prefetch_all(original_user_message)
|
||||
self._memory_manager.sync_all(
|
||||
original_user_message, final_response,
|
||||
session_id=self.session_id or "",
|
||||
)
|
||||
self._memory_manager.queue_prefetch_all(
|
||||
original_user_message,
|
||||
session_id=self.session_id or "",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@ -8938,6 +8944,23 @@ class AIAgent:
|
||||
except Exception as _ce_err:
|
||||
logger.debug("context engine on_session_start (compression): %s", _ce_err)
|
||||
|
||||
# Notify memory providers of the compression-driven session_id rotation
|
||||
# so provider-cached per-session state (Hindsight's _document_id,
|
||||
# accumulated turn buffers, counters) refreshes. reset=False because
|
||||
# the logical conversation continues; only the id and DB row rolled
|
||||
# over. See #6672.
|
||||
try:
|
||||
_old_sid = locals().get("old_session_id")
|
||||
if _old_sid and self._memory_manager:
|
||||
self._memory_manager.on_session_switch(
|
||||
self.session_id or "",
|
||||
parent_session_id=_old_sid,
|
||||
reset=False,
|
||||
reason="compression",
|
||||
)
|
||||
except Exception as _me_err:
|
||||
logger.debug("memory manager on_session_switch (compression): %s", _me_err)
|
||||
|
||||
# Warn on repeated compressions (quality degrades with each pass)
|
||||
_cc = self.context_compressor.compression_count
|
||||
if _cc >= 2:
|
||||
|
||||
282
tests/agent/test_memory_session_switch.py
Normal file
282
tests/agent/test_memory_session_switch.py
Normal file
@ -0,0 +1,282 @@
|
||||
"""Tests for the on_session_switch hook and session_id propagation.
|
||||
|
||||
Covers #6672: memory providers must be notified when AIAgent.session_id
|
||||
rotates mid-process (via /resume, /branch, /reset, /new, or context
|
||||
compression). Without the notification, providers that cache per-session
|
||||
state in initialize() (Hindsight, and any plugin that stores session_id
|
||||
for scoped writes) keep writing into the old session's record.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from agent.memory_manager import MemoryManager
|
||||
from agent.memory_provider import MemoryProvider
|
||||
|
||||
|
||||
class _RecordingProvider(MemoryProvider):
|
||||
"""Provider that records every lifecycle call for assertion."""
|
||||
|
||||
def __init__(self, name="rec"):
|
||||
self._name = name
|
||||
self.switch_calls: list[dict] = []
|
||||
self.sync_calls: list[dict] = []
|
||||
self.queue_calls: list[dict] = []
|
||||
self.initialize_calls: list[dict] = []
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def is_available(self) -> bool: # pragma: no cover - unused
|
||||
return True
|
||||
|
||||
def initialize(self, session_id, **kwargs):
|
||||
self.initialize_calls.append({"session_id": session_id, **kwargs})
|
||||
|
||||
def get_tool_schemas(self):
|
||||
return []
|
||||
|
||||
def sync_turn(self, user_content, assistant_content, *, session_id=""):
|
||||
self.sync_calls.append(
|
||||
{"user": user_content, "asst": assistant_content, "session_id": session_id}
|
||||
)
|
||||
|
||||
def queue_prefetch(self, query, *, session_id=""):
|
||||
self.queue_calls.append({"query": query, "session_id": session_id})
|
||||
|
||||
def on_session_switch(
|
||||
self,
|
||||
new_session_id,
|
||||
*,
|
||||
parent_session_id="",
|
||||
reset=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.switch_calls.append(
|
||||
{
|
||||
"new": new_session_id,
|
||||
"parent": parent_session_id,
|
||||
"reset": reset,
|
||||
"extra": kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MemoryProvider ABC — default on_session_switch is a no-op
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _MinimalProvider(MemoryProvider):
|
||||
"""Provider that does NOT override on_session_switch — ABC default must no-op."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "minimal"
|
||||
|
||||
def is_available(self) -> bool:
|
||||
return True
|
||||
|
||||
def initialize(self, session_id, **kwargs): # pragma: no cover - unused
|
||||
pass
|
||||
|
||||
def get_tool_schemas(self):
|
||||
return []
|
||||
|
||||
|
||||
def test_abc_default_on_session_switch_is_noop():
|
||||
"""Providers that don't override the hook must not raise."""
|
||||
p = _MinimalProvider()
|
||||
# All three call styles must be accepted without raising
|
||||
p.on_session_switch("new-id")
|
||||
p.on_session_switch("new-id", parent_session_id="old-id")
|
||||
p.on_session_switch("new-id", parent_session_id="old-id", reset=True)
|
||||
p.on_session_switch("new-id", parent_session_id="old-id", reset=True, reason="new_session")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MemoryManager.on_session_switch — fan-out
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_manager_fans_out_to_all_providers():
|
||||
mm = MemoryManager()
|
||||
# Only one external provider is allowed; use the builtin slot for p1.
|
||||
p1 = _RecordingProvider(name="builtin")
|
||||
p2 = _RecordingProvider(name="hindsight")
|
||||
mm.add_provider(p1)
|
||||
mm.add_provider(p2)
|
||||
|
||||
mm.on_session_switch("new-sid", parent_session_id="old-sid", reset=False, reason="resume")
|
||||
|
||||
assert len(p1.switch_calls) == 1
|
||||
assert len(p2.switch_calls) == 1
|
||||
for call in (p1.switch_calls[0], p2.switch_calls[0]):
|
||||
assert call["new"] == "new-sid"
|
||||
assert call["parent"] == "old-sid"
|
||||
assert call["reset"] is False
|
||||
assert call["extra"] == {"reason": "resume"}
|
||||
|
||||
|
||||
def test_manager_ignores_empty_session_id():
|
||||
"""Empty string session_id must not trigger provider hooks.
|
||||
|
||||
Prevents accidental fires during shutdown when self.session_id may be
|
||||
cleared. Providers expect a meaningful id to switch TO.
|
||||
"""
|
||||
mm = MemoryManager()
|
||||
p = _RecordingProvider()
|
||||
mm.add_provider(p)
|
||||
mm.on_session_switch("")
|
||||
mm.on_session_switch(None) # type: ignore[arg-type]
|
||||
assert p.switch_calls == []
|
||||
|
||||
|
||||
def test_manager_isolates_provider_failures():
|
||||
"""A provider that raises must not block other providers."""
|
||||
|
||||
class _Broken(_RecordingProvider):
|
||||
def on_session_switch(self, *args, **kwargs): # type: ignore[override]
|
||||
raise RuntimeError("boom")
|
||||
|
||||
mm = MemoryManager()
|
||||
# MemoryManager rejects a second external provider, so pair broken
|
||||
# (builtin slot) with a good external one.
|
||||
broken = _Broken(name="builtin")
|
||||
good = _RecordingProvider(name="good")
|
||||
mm.add_provider(broken)
|
||||
mm.add_provider(good)
|
||||
|
||||
# Must not raise — exceptions in one provider are swallowed + logged
|
||||
mm.on_session_switch("new-sid", parent_session_id="old-sid")
|
||||
assert len(good.switch_calls) == 1
|
||||
assert good.switch_calls[0]["new"] == "new-sid"
|
||||
|
||||
|
||||
def test_manager_reset_flag_preserved():
|
||||
mm = MemoryManager()
|
||||
p = _RecordingProvider()
|
||||
mm.add_provider(p)
|
||||
mm.on_session_switch("new-sid", reset=True, reason="new_session")
|
||||
assert p.switch_calls[0]["reset"] is True
|
||||
assert p.switch_calls[0]["extra"] == {"reason": "new_session"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MemoryManager.sync_all / queue_prefetch_all — session_id propagation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_sync_all_propagates_session_id_to_providers():
|
||||
"""run_agent.py's sync_all call must pass session_id through to providers.
|
||||
|
||||
Without this, a provider that updates _session_id defensively in
|
||||
sync_turn (as Hindsight does at hindsight/__init__.py:1199) never
|
||||
sees the new id and keeps writing under the old one.
|
||||
"""
|
||||
mm = MemoryManager()
|
||||
p = _RecordingProvider()
|
||||
mm.add_provider(p)
|
||||
mm.sync_all("hello", "world", session_id="sess-42")
|
||||
assert p.sync_calls == [
|
||||
{"user": "hello", "asst": "world", "session_id": "sess-42"}
|
||||
]
|
||||
|
||||
|
||||
def test_queue_prefetch_all_propagates_session_id_to_providers():
|
||||
mm = MemoryManager()
|
||||
p = _RecordingProvider()
|
||||
mm.add_provider(p)
|
||||
mm.queue_prefetch_all("next query", session_id="sess-42")
|
||||
assert p.queue_calls == [{"query": "next query", "session_id": "sess-42"}]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Hindsight reference implementation — state-flush semantics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_hindsight_provider():
|
||||
"""Build a bare HindsightMemoryProvider that skips network setup.
|
||||
|
||||
We instantiate without importing optional deps at class-level by
|
||||
bypassing __init__ and seeding the attributes on_session_switch
|
||||
reads/writes. This keeps the test hermetic.
|
||||
"""
|
||||
hindsight_mod = pytest.importorskip("plugins.memory.hindsight")
|
||||
provider = object.__new__(hindsight_mod.HindsightMemoryProvider)
|
||||
provider._session_id = "old-sid"
|
||||
provider._parent_session_id = ""
|
||||
provider._document_id = "old-sid-20260101_000000_000000"
|
||||
provider._session_turns = ["turn-1", "turn-2"]
|
||||
provider._turn_counter = 2
|
||||
provider._turn_index = 2
|
||||
return provider
|
||||
|
||||
|
||||
def test_hindsight_on_session_switch_updates_session_id_and_mints_fresh_doc():
|
||||
provider = _make_hindsight_provider()
|
||||
old_doc = provider._document_id
|
||||
|
||||
provider.on_session_switch(
|
||||
"new-sid", parent_session_id="old-sid", reset=False, reason="resume"
|
||||
)
|
||||
|
||||
assert provider._session_id == "new-sid"
|
||||
assert provider._parent_session_id == "old-sid"
|
||||
# Document id MUST be fresh — else next retain overwrites old session doc
|
||||
assert provider._document_id != old_doc
|
||||
assert provider._document_id.startswith("new-sid-")
|
||||
|
||||
|
||||
def test_hindsight_on_session_switch_clears_turn_buffers():
|
||||
"""Accumulated _session_turns must not leak into the next session.
|
||||
|
||||
Hindsight batches turns under a single _document_id. If the buffer
|
||||
isn't cleared on switch, the next retain under the new _document_id
|
||||
flushes turns that belong to the previous session.
|
||||
"""
|
||||
provider = _make_hindsight_provider()
|
||||
provider.on_session_switch("new-sid", parent_session_id="old-sid")
|
||||
assert provider._session_turns == []
|
||||
assert provider._turn_counter == 0
|
||||
assert provider._turn_index == 0
|
||||
|
||||
|
||||
def test_hindsight_on_session_switch_clears_on_reset_true():
|
||||
"""reset=True (from /new, /reset) must also flush buffers."""
|
||||
provider = _make_hindsight_provider()
|
||||
provider.on_session_switch("new-sid", reset=True, reason="new_session")
|
||||
assert provider._session_id == "new-sid"
|
||||
assert provider._session_turns == []
|
||||
assert provider._turn_counter == 0
|
||||
|
||||
|
||||
def test_hindsight_on_session_switch_ignores_empty_id():
|
||||
"""Empty new_session_id must be a no-op to avoid corrupting state."""
|
||||
provider = _make_hindsight_provider()
|
||||
before = (
|
||||
provider._session_id,
|
||||
provider._document_id,
|
||||
list(provider._session_turns),
|
||||
provider._turn_counter,
|
||||
)
|
||||
provider.on_session_switch("")
|
||||
provider.on_session_switch(None) # type: ignore[arg-type]
|
||||
after = (
|
||||
provider._session_id,
|
||||
provider._document_id,
|
||||
list(provider._session_turns),
|
||||
provider._turn_counter,
|
||||
)
|
||||
assert before == after
|
||||
|
||||
|
||||
def test_hindsight_preserves_parent_across_empty_parent_arg():
|
||||
"""Omitting parent_session_id must NOT overwrite an existing one."""
|
||||
provider = _make_hindsight_provider()
|
||||
provider._parent_session_id = "original-parent"
|
||||
provider.on_session_switch("new-sid") # no parent passed
|
||||
assert provider._parent_session_id == "original-parent"
|
||||
@ -192,6 +192,33 @@ class TestBranchCommandCLI:
|
||||
|
||||
assert cli_instance._resumed is True
|
||||
|
||||
def test_branch_fires_on_session_switch_hook(self, cli_instance, session_db):
|
||||
"""The /branch command must notify memory providers of the rotation.
|
||||
|
||||
Without this, providers that cache per-session state in
|
||||
initialize() keep writing under the old session_id. See #6672.
|
||||
"""
|
||||
from cli import HermesCLI
|
||||
|
||||
# Wire a real-ish agent object with a MagicMock memory_manager
|
||||
agent = MagicMock()
|
||||
mm = MagicMock()
|
||||
agent._memory_manager = mm
|
||||
cli_instance.agent = agent
|
||||
original_id = cli_instance.session_id
|
||||
|
||||
HermesCLI._handle_branch_command(cli_instance, "/branch")
|
||||
|
||||
# Hook must have been called exactly once with the new session_id,
|
||||
# parent pointing at the branched-from session, reset=False, and
|
||||
# reason="branch" for diagnostics.
|
||||
assert mm.on_session_switch.call_count == 1
|
||||
_, kwargs = mm.on_session_switch.call_args
|
||||
assert mm.on_session_switch.call_args.args[0] == cli_instance.session_id
|
||||
assert kwargs["parent_session_id"] == original_id
|
||||
assert kwargs["reset"] is False
|
||||
assert kwargs["reason"] == "branch"
|
||||
|
||||
def test_fork_alias(self):
|
||||
"""The /fork alias should resolve to 'branch'."""
|
||||
from hermes_cli.commands import resolve_command
|
||||
|
||||
@ -230,3 +230,30 @@ class TestHandleResumeCommand:
|
||||
|
||||
assert real_key not in runner._running_agents
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_evicts_cached_agent(self, tmp_path):
|
||||
"""Gateway /resume evicts the cached AIAgent so the next message
|
||||
rebuilds with the correct session_id end-to-end — mirrors /branch
|
||||
and /reset. Without this, the cached agent's memory provider keeps
|
||||
writing into the wrong session. See #6672.
|
||||
"""
|
||||
import threading
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("old_session", "telegram")
|
||||
db.set_session_title("old_session", "Old Work")
|
||||
db.create_session("current_session_001", "telegram")
|
||||
|
||||
event = _make_event(text="/resume Old Work")
|
||||
runner = _make_runner(session_db=db, current_session_id="current_session_001",
|
||||
event=event)
|
||||
# Seed the cache with a fake agent
|
||||
real_key = _session_key_for_event(event)
|
||||
runner._agent_cache = {real_key: (MagicMock(), object())}
|
||||
runner._agent_cache_lock = threading.RLock()
|
||||
|
||||
await runner._handle_resume_command(event)
|
||||
|
||||
assert real_key not in runner._agent_cache
|
||||
db.close()
|
||||
|
||||
@ -31,6 +31,10 @@ def _bare_agent():
|
||||
|
||||
agent = AIAgent.__new__(AIAgent)
|
||||
agent._memory_manager = MagicMock()
|
||||
# session_id is now propagated into sync_all / queue_prefetch_all so
|
||||
# providers that cache per-session state can update it mid-process
|
||||
# (see #6672).
|
||||
agent.session_id = "test_session_001"
|
||||
return agent
|
||||
|
||||
|
||||
@ -80,9 +84,11 @@ class TestSyncExternalMemoryForTurn:
|
||||
)
|
||||
agent._memory_manager.sync_all.assert_called_once_with(
|
||||
"What's the weather in Paris?", "It's sunny and 22°C.",
|
||||
session_id="test_session_001",
|
||||
)
|
||||
agent._memory_manager.queue_prefetch_all.assert_called_once_with(
|
||||
"What's the weather in Paris?",
|
||||
session_id="test_session_001",
|
||||
)
|
||||
|
||||
# --- Edge cases (pre-existing behaviour preserved) ------------------
|
||||
|
||||
Loading…
Reference in New Issue
Block a user