Merge remote-tracking branch 'origin/main' into bb/tui-long-session-perf
This commit is contained in:
commit
7da2f07641
@ -43,10 +43,18 @@ def busy_input_hint_gateway(mode: str) -> str:
|
||||
"Send `/busy interrupt` to make new messages stop the current task "
|
||||
"immediately, or `/busy status` to check. This notice won't appear again."
|
||||
)
|
||||
if mode == "steer":
|
||||
return (
|
||||
"💡 First-time tip — I steered your message into the current run; "
|
||||
"it will arrive after the next tool call instead of interrupting. "
|
||||
"Send `/busy interrupt` or `/busy queue` to change this, or "
|
||||
"`/busy status` to check. This notice won't appear again."
|
||||
)
|
||||
return (
|
||||
"💡 First-time tip — I just interrupted my current task to answer you. "
|
||||
"Send `/busy queue` to queue follow-ups for after the current task instead, "
|
||||
"or `/busy status` to check. This notice won't appear again."
|
||||
"`/busy steer` to inject them mid-run without interrupting, or "
|
||||
"`/busy status` to check. This notice won't appear again."
|
||||
)
|
||||
|
||||
|
||||
@ -55,13 +63,19 @@ def busy_input_hint_cli(mode: str) -> str:
|
||||
if mode == "queue":
|
||||
return (
|
||||
"(tip) Your message was queued for the next turn. "
|
||||
"Use /busy interrupt to make Enter stop the current run instead. "
|
||||
"This tip only shows once."
|
||||
"Use /busy interrupt to make Enter stop the current run instead, "
|
||||
"or /busy steer to inject mid-run. This tip only shows once."
|
||||
)
|
||||
if mode == "steer":
|
||||
return (
|
||||
"(tip) Your message was steered into the current run; it arrives "
|
||||
"after the next tool call. Use /busy interrupt or /busy queue to "
|
||||
"change this. This tip only shows once."
|
||||
)
|
||||
return (
|
||||
"(tip) Your message interrupted the current run. "
|
||||
"Use /busy queue to queue messages for the next turn instead. "
|
||||
"This tip only shows once."
|
||||
"Use /busy queue to queue messages for the next turn instead, "
|
||||
"or /busy steer to inject mid-run. This tip only shows once."
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -422,6 +422,29 @@ PLATFORM_HINTS = {
|
||||
"your response. Images are sent as native photos, and other files arrive as downloadable "
|
||||
"documents."
|
||||
),
|
||||
"yuanbao": (
|
||||
"You are on Yuanbao (腾讯元宝), a Chinese AI assistant platform. "
|
||||
"Markdown formatting is supported (code blocks, tables, bold/italic). "
|
||||
"You CAN send media files natively — to deliver a file to the user, include "
|
||||
"MEDIA:/absolute/path/to/file in your response. The file will be sent as a native "
|
||||
"Yuanbao attachment: images (.jpg, .png, .webp, .gif) are sent as photos, "
|
||||
"and other files (.pdf, .docx, .txt, .zip, etc.) arrive as downloadable documents "
|
||||
"(max 50 MB). You can also include image URLs in markdown format  and "
|
||||
"they will be downloaded and sent as native photos. "
|
||||
"Do NOT tell the user you lack file-sending capability — use MEDIA: syntax "
|
||||
"whenever a file delivery is appropriate.\n\n"
|
||||
"Stickers (贴纸 / 表情包 / TIM face): Yuanbao has a built-in sticker catalogue. "
|
||||
"When the user sends a sticker (you see '[emoji: 名称]' in their message) or asks "
|
||||
"you to send/reply-with a 贴纸/表情/表情包, you MUST use the sticker tools:\n"
|
||||
" 1. Call yb_search_sticker with a Chinese keyword (e.g. '666', '比心', '吃瓜', "
|
||||
" '捂脸', '合十') to discover matching sticker_ids.\n"
|
||||
" 2. Call yb_send_sticker with the chosen sticker_id or name — this sends a real "
|
||||
" TIMFaceElem that renders as a native sticker in the chat.\n"
|
||||
"DO NOT draw sticker-like PNGs with execute_code/Pillow/matplotlib and then send "
|
||||
"them via MEDIA: or send_image_file. That produces a fake low-quality 'sticker' "
|
||||
"image and is the WRONG path. Bare Unicode emoji in text is also not a substitute "
|
||||
"— when a sticker is the right response, use yb_send_sticker."
|
||||
),
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@ -606,6 +606,7 @@ platform_toolsets:
|
||||
signal: [hermes-signal]
|
||||
homeassistant: [hermes-homeassistant]
|
||||
qqbot: [hermes-qqbot]
|
||||
yuanbao: [hermes-yuanbao]
|
||||
|
||||
# =============================================================================
|
||||
# Gateway Platform Settings
|
||||
@ -847,8 +848,12 @@ display:
|
||||
# What Enter does when Hermes is already busy (CLI and gateway platforms).
|
||||
# interrupt: Interrupt the current run and redirect Hermes (default)
|
||||
# queue: Queue your message for the next turn
|
||||
# steer: Inject your message mid-run via /steer, arriving at the agent
|
||||
# after the next tool call — no interrupt, no role violation.
|
||||
# Falls back to 'queue' if the agent isn't running yet or if
|
||||
# images are attached (steer only carries text).
|
||||
# Ctrl+C (or /stop in gateway) always interrupts regardless of this setting.
|
||||
# Toggle at runtime with /busy_input_mode <interrupt|queue>.
|
||||
# Toggle at runtime with /busy <interrupt|queue|steer>.
|
||||
busy_input_mode: interrupt
|
||||
|
||||
# Background process notifications (gateway/messaging only).
|
||||
|
||||
130
cli.py
130
cli.py
@ -974,6 +974,7 @@ def _run_state_db_auto_maintenance(session_db) -> None:
|
||||
return
|
||||
try:
|
||||
from hermes_cli.config import load_config as _load_full_config
|
||||
from hermes_constants import get_hermes_home as _get_hermes_home
|
||||
cfg = (_load_full_config().get("sessions") or {})
|
||||
if not cfg.get("auto_prune", False):
|
||||
return
|
||||
@ -981,11 +982,35 @@ def _run_state_db_auto_maintenance(session_db) -> None:
|
||||
retention_days=int(cfg.get("retention_days", 90)),
|
||||
min_interval_hours=int(cfg.get("min_interval_hours", 24)),
|
||||
vacuum=bool(cfg.get("vacuum_after_prune", True)),
|
||||
sessions_dir=_get_hermes_home() / "sessions",
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug("state.db auto-maintenance skipped: %s", exc)
|
||||
|
||||
|
||||
def _run_checkpoint_auto_maintenance() -> None:
|
||||
"""Call ``checkpoint_manager.maybe_auto_prune_checkpoints`` using current config.
|
||||
|
||||
Reads the ``checkpoints:`` section from config.yaml via
|
||||
:func:`hermes_cli.config.load_config`. Honours ``auto_prune`` /
|
||||
``retention_days`` / ``delete_orphans`` / ``min_interval_hours``.
|
||||
Never raises — maintenance must never block interactive startup.
|
||||
"""
|
||||
try:
|
||||
from hermes_cli.config import load_config as _load_full_config
|
||||
cfg = (_load_full_config().get("checkpoints") or {})
|
||||
if not cfg.get("auto_prune", False):
|
||||
return
|
||||
from tools.checkpoint_manager import maybe_auto_prune_checkpoints
|
||||
maybe_auto_prune_checkpoints(
|
||||
retention_days=int(cfg.get("retention_days", 7)),
|
||||
min_interval_hours=int(cfg.get("min_interval_hours", 24)),
|
||||
delete_orphans=bool(cfg.get("delete_orphans", True)),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug("checkpoint auto-maintenance skipped: %s", exc)
|
||||
|
||||
|
||||
def _prune_stale_worktrees(repo_root: str, max_age_hours: int = 24) -> None:
|
||||
"""Remove stale worktrees and orphaned branches on startup.
|
||||
|
||||
@ -1848,9 +1873,16 @@ class HermesCLI:
|
||||
self.bell_on_complete = CLI_CONFIG["display"].get("bell_on_complete", False)
|
||||
# show_reasoning: display model thinking/reasoning before the response
|
||||
self.show_reasoning = CLI_CONFIG["display"].get("show_reasoning", False)
|
||||
# busy_input_mode: "interrupt" (Enter interrupts current run) or "queue" (Enter queues for next turn)
|
||||
_bim = CLI_CONFIG["display"].get("busy_input_mode", "interrupt")
|
||||
self.busy_input_mode = "queue" if str(_bim).strip().lower() == "queue" else "interrupt"
|
||||
# busy_input_mode: "interrupt" (Enter interrupts current run),
|
||||
# "queue" (Enter queues for next turn), or "steer" (Enter injects
|
||||
# mid-run via /steer, arriving after the next tool call).
|
||||
_bim = str(CLI_CONFIG["display"].get("busy_input_mode", "interrupt")).strip().lower()
|
||||
if _bim == "queue":
|
||||
self.busy_input_mode = "queue"
|
||||
elif _bim == "steer":
|
||||
self.busy_input_mode = "steer"
|
||||
else:
|
||||
self.busy_input_mode = "interrupt"
|
||||
|
||||
self.verbose = verbose if verbose is not None else (self.tool_progress_mode == "verbose")
|
||||
|
||||
@ -2045,6 +2077,11 @@ class HermesCLI:
|
||||
# Never blocks startup on failure.
|
||||
_run_state_db_auto_maintenance(self._session_db)
|
||||
|
||||
# Opportunistic shadow-repo cleanup — deletes orphan/stale
|
||||
# checkpoint repos under ~/.hermes/checkpoints/. Opt-in via
|
||||
# checkpoints.auto_prune, idempotent via .last_prune marker.
|
||||
_run_checkpoint_auto_maintenance()
|
||||
|
||||
# Deferred title: stored in memory until the session is created in the DB
|
||||
self._pending_title: Optional[str] = None
|
||||
|
||||
@ -4942,22 +4979,37 @@ class HermesCLI:
|
||||
_cprint(f" Branch session: {new_session_id}")
|
||||
|
||||
def save_conversation(self):
|
||||
"""Save the current conversation to a file."""
|
||||
"""Save the current conversation to a JSON snapshot under ~/.hermes/sessions/saved/.
|
||||
|
||||
The snapshot is a convenience export for sharing or off-line inspection;
|
||||
every message is already persisted incrementally to the SQLite session
|
||||
DB, so the live session remains resumable via ``hermes --resume <id>``
|
||||
regardless of whether the user ever runs ``/save``.
|
||||
"""
|
||||
if not self.conversation_history:
|
||||
print("(;_;) No conversation to save.")
|
||||
return
|
||||
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"hermes_conversation_{timestamp}.json"
|
||||
|
||||
saved_dir = get_hermes_home() / "sessions" / "saved"
|
||||
try:
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
saved_dir.mkdir(parents=True, exist_ok=True)
|
||||
except Exception as e:
|
||||
print(f"(x_x) Failed to create save directory {saved_dir}: {e}")
|
||||
return
|
||||
path = saved_dir / f"hermes_conversation_{timestamp}.json"
|
||||
|
||||
try:
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump({
|
||||
"model": self.model,
|
||||
"session_id": self.session_id,
|
||||
"session_start": self.session_start.isoformat(),
|
||||
"messages": self.conversation_history,
|
||||
}, f, indent=2, ensure_ascii=False)
|
||||
print(f"(^_^)v Conversation saved to: {filename}")
|
||||
print(f"(^_^)v Conversation snapshot saved to: {path}")
|
||||
if self.session_id:
|
||||
print(f" Resume the live session with: hermes --resume {self.session_id}")
|
||||
except Exception as e:
|
||||
print(f"(x_x) Failed to save: {e}")
|
||||
|
||||
@ -6313,6 +6365,12 @@ class HermesCLI:
|
||||
turn_route = self._resolve_turn_agent_config(prompt)
|
||||
|
||||
def run_background():
|
||||
set_sudo_password_callback(self._sudo_password_callback)
|
||||
set_approval_callback(self._approval_callback)
|
||||
try:
|
||||
set_secret_capture_callback(self._secret_capture_callback)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
bg_agent = AIAgent(
|
||||
model=turn_route["model"],
|
||||
@ -6410,6 +6468,12 @@ class HermesCLI:
|
||||
print()
|
||||
_cprint(f" ❌ Background task #{task_num} failed: {e}")
|
||||
finally:
|
||||
try:
|
||||
set_sudo_password_callback(None)
|
||||
set_approval_callback(None)
|
||||
set_secret_capture_callback(None)
|
||||
except Exception:
|
||||
pass
|
||||
self._background_tasks.pop(task_id, None)
|
||||
# Clear spinner only if no foreground agent owns it
|
||||
if not self._agent_running:
|
||||
@ -6804,24 +6868,36 @@ class HermesCLI:
|
||||
/busy Show current busy input mode
|
||||
/busy status Show current busy input mode
|
||||
/busy queue Queue input for the next turn instead of interrupting
|
||||
/busy steer Inject Enter mid-run via /steer (after next tool call)
|
||||
/busy interrupt Interrupt the current run on Enter (default)
|
||||
"""
|
||||
parts = cmd.strip().split(maxsplit=1)
|
||||
if len(parts) < 2 or parts[1].strip().lower() == "status":
|
||||
_cprint(f" {_ACCENT}Busy input mode: {self.busy_input_mode}{_RST}")
|
||||
_cprint(f" {_DIM}Enter while busy: {'queues for next turn' if self.busy_input_mode == 'queue' else 'interrupts current run'}{_RST}")
|
||||
_cprint(f" {_DIM}Usage: /busy [queue|interrupt|status]{_RST}")
|
||||
if self.busy_input_mode == "queue":
|
||||
_behavior = "queues for next turn"
|
||||
elif self.busy_input_mode == "steer":
|
||||
_behavior = "steers into current run (after next tool call)"
|
||||
else:
|
||||
_behavior = "interrupts current run"
|
||||
_cprint(f" {_DIM}Enter while busy: {_behavior}{_RST}")
|
||||
_cprint(f" {_DIM}Usage: /busy [queue|steer|interrupt|status]{_RST}")
|
||||
return
|
||||
|
||||
arg = parts[1].strip().lower()
|
||||
if arg not in {"queue", "interrupt"}:
|
||||
if arg not in {"queue", "interrupt", "steer"}:
|
||||
_cprint(f" {_DIM}(._.) Unknown argument: {arg}{_RST}")
|
||||
_cprint(f" {_DIM}Usage: /busy [queue|interrupt|status]{_RST}")
|
||||
_cprint(f" {_DIM}Usage: /busy [queue|steer|interrupt|status]{_RST}")
|
||||
return
|
||||
|
||||
self.busy_input_mode = arg
|
||||
if save_config_value("display.busy_input_mode", arg):
|
||||
behavior = "Enter will queue follow-up input while Hermes is busy." if arg == "queue" else "Enter will interrupt the current run while Hermes is busy."
|
||||
if arg == "queue":
|
||||
behavior = "Enter will queue follow-up input while Hermes is busy."
|
||||
elif arg == "steer":
|
||||
behavior = "Enter will steer your message into the current run (after the next tool call)."
|
||||
else:
|
||||
behavior = "Enter will interrupt the current run while Hermes is busy."
|
||||
_cprint(f" {_ACCENT}✓ Busy input mode set to '{arg}' (saved to config){_RST}")
|
||||
_cprint(f" {_DIM}{behavior}{_RST}")
|
||||
else:
|
||||
@ -9198,12 +9274,34 @@ class HermesCLI:
|
||||
# Bundle text + images as a tuple when images are present
|
||||
payload = (text, images) if images else text
|
||||
if self._agent_running and not (text and _looks_like_slash_command(text)):
|
||||
if self.busy_input_mode == "queue":
|
||||
_effective_mode = self.busy_input_mode
|
||||
if _effective_mode == "steer":
|
||||
# Route Enter through /steer — inject mid-run after the
|
||||
# next tool call. Images can't ride along (steer only
|
||||
# appends text), so fall back to queue when images are
|
||||
# attached. If the agent lacks steer() or rejects the
|
||||
# payload, also fall back to queue so nothing is lost.
|
||||
if images or not text:
|
||||
_effective_mode = "queue"
|
||||
else:
|
||||
accepted = False
|
||||
try:
|
||||
if self.agent is not None and hasattr(self.agent, "steer"):
|
||||
accepted = bool(self.agent.steer(text))
|
||||
except Exception as exc:
|
||||
_cprint(f" {_DIM}Steer failed ({exc}) — queued for next turn.{_RST}")
|
||||
accepted = False
|
||||
if accepted:
|
||||
preview = text[:80] + ("..." if len(text) > 80 else "")
|
||||
_cprint(f" {_ACCENT}⏩ Steered: '{preview}'{_RST}")
|
||||
else:
|
||||
_effective_mode = "queue"
|
||||
if _effective_mode == "queue":
|
||||
# Queue for the next turn instead of interrupting
|
||||
self._pending_input.put(payload)
|
||||
preview = text if text else f"[{len(images)} image{'s' if len(images) != 1 else ''} attached]"
|
||||
_cprint(f" Queued for the next turn: {preview[:80]}{'...' if len(preview) > 80 else ''}")
|
||||
else:
|
||||
elif _effective_mode == "interrupt":
|
||||
self._interrupt_queue.put(payload)
|
||||
# Debug: log to file when message enters interrupt queue
|
||||
try:
|
||||
|
||||
@ -77,7 +77,7 @@ _KNOWN_DELIVERY_PLATFORMS = frozenset({
|
||||
"telegram", "discord", "slack", "whatsapp", "signal",
|
||||
"matrix", "mattermost", "homeassistant", "dingtalk", "feishu",
|
||||
"wecom", "wecom_callback", "weixin", "sms", "email", "webhook", "bluebubbles",
|
||||
"qqbot",
|
||||
"qqbot", "yuanbao",
|
||||
})
|
||||
|
||||
# Platforms that support a configured cron/notification home target, mapped to
|
||||
@ -337,6 +337,7 @@ def _deliver_result(job: dict, content: str, adapters=None, loop=None) -> Option
|
||||
"sms": Platform.SMS,
|
||||
"bluebubbles": Platform.BLUEBUBBLES,
|
||||
"qqbot": Platform.QQBOT,
|
||||
"yuanbao": Platform.YUANBAO,
|
||||
}
|
||||
|
||||
# Optionally wrap the content with a header/footer so the user knows this
|
||||
@ -1308,6 +1309,17 @@ def tick(verbose: bool = True, adapters=None, loop=None) -> int:
|
||||
_futures.append(_tick_pool.submit(_ctx.run, _process_job, job))
|
||||
_results.extend(f.result() for f in _futures)
|
||||
|
||||
# Best-effort sweep of MCP stdio subprocesses that survived their
|
||||
# session teardown during this tick. Runs AFTER every job has
|
||||
# finished so active sessions (including live user chats) are
|
||||
# never touched — only PIDs explicitly detected as orphans in
|
||||
# tools.mcp_tool._run_stdio's finally block are reaped.
|
||||
try:
|
||||
from tools.mcp_tool import _kill_orphaned_mcp_children
|
||||
_kill_orphaned_mcp_children()
|
||||
except Exception as _e:
|
||||
logger.debug("Post-tick MCP orphan cleanup failed: %s", _e)
|
||||
|
||||
return sum(_results)
|
||||
finally:
|
||||
if fcntl:
|
||||
|
||||
@ -57,7 +57,7 @@ def _session_entry_name(origin: Dict[str, Any]) -> str:
|
||||
# Build / refresh
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def build_channel_directory(adapters: Dict[Any, Any]) -> Dict[str, Any]:
|
||||
async def build_channel_directory(adapters: Dict[Any, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Build a channel directory from connected platform adapters and session data.
|
||||
|
||||
@ -72,7 +72,7 @@ def build_channel_directory(adapters: Dict[Any, Any]) -> Dict[str, Any]:
|
||||
if platform == Platform.DISCORD:
|
||||
platforms["discord"] = _build_discord(adapter)
|
||||
elif platform == Platform.SLACK:
|
||||
platforms["slack"] = _build_slack(adapter)
|
||||
platforms["slack"] = await _build_slack(adapter)
|
||||
except Exception as e:
|
||||
logger.warning("Channel directory: failed to build %s: %s", platform.value, e)
|
||||
|
||||
@ -136,21 +136,66 @@ def _build_discord(adapter) -> List[Dict[str, str]]:
|
||||
return channels
|
||||
|
||||
|
||||
def _build_slack(adapter) -> List[Dict[str, str]]:
|
||||
"""List Slack channels the bot has joined."""
|
||||
# Slack adapter may expose a web client
|
||||
client = getattr(adapter, "_app", None) or getattr(adapter, "_client", None)
|
||||
if not client:
|
||||
async def _build_slack(adapter) -> List[Dict[str, Any]]:
|
||||
"""List Slack channels the bot has joined across all workspaces.
|
||||
|
||||
Uses ``users.conversations`` against each workspace's web client. Pulls
|
||||
public + private channels the bot is a member of, then merges in DMs
|
||||
discovered from session history (IMs aren't useful to enumerate
|
||||
proactively).
|
||||
"""
|
||||
team_clients = getattr(adapter, "_team_clients", None) or {}
|
||||
if not team_clients:
|
||||
return _build_from_sessions("slack")
|
||||
|
||||
try:
|
||||
from tools.send_message_tool import _send_slack # noqa: F401
|
||||
# Use the Slack Web API directly if available
|
||||
except Exception:
|
||||
pass
|
||||
channels: List[Dict[str, Any]] = []
|
||||
seen_ids: set = set()
|
||||
|
||||
# Fallback to session data
|
||||
return _build_from_sessions("slack")
|
||||
for team_id, client in team_clients.items():
|
||||
try:
|
||||
cursor: Optional[str] = None
|
||||
for _page in range(20): # safety cap on pagination
|
||||
response = await client.users_conversations(
|
||||
types="public_channel,private_channel",
|
||||
exclude_archived=True,
|
||||
limit=200,
|
||||
cursor=cursor,
|
||||
)
|
||||
if not response.get("ok"):
|
||||
logger.warning(
|
||||
"Channel directory: users.conversations not ok for team %s: %s",
|
||||
team_id,
|
||||
response.get("error", "unknown"),
|
||||
)
|
||||
break
|
||||
for ch in response.get("channels", []):
|
||||
cid = ch.get("id")
|
||||
name = ch.get("name")
|
||||
if not cid or not name or cid in seen_ids:
|
||||
continue
|
||||
seen_ids.add(cid)
|
||||
channels.append({
|
||||
"id": cid,
|
||||
"name": name,
|
||||
"type": "private" if ch.get("is_private") else "channel",
|
||||
})
|
||||
cursor = (response.get("response_metadata") or {}).get("next_cursor")
|
||||
if not cursor:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Channel directory: failed to list Slack channels for team %s: %s",
|
||||
team_id, e,
|
||||
)
|
||||
continue
|
||||
|
||||
# Merge in DM/group entries discovered from session history.
|
||||
for entry in _build_from_sessions("slack"):
|
||||
if entry.get("id") not in seen_ids:
|
||||
channels.append(entry)
|
||||
seen_ids.add(entry.get("id"))
|
||||
|
||||
return channels
|
||||
|
||||
|
||||
def _build_from_sessions(platform_name: str) -> List[Dict[str, str]]:
|
||||
@ -223,6 +268,14 @@ def resolve_channel_name(platform_name: str, name: str) -> Optional[str]:
|
||||
if not channels:
|
||||
return None
|
||||
|
||||
# 0. Exact ID match — case-sensitive, no normalization. Lets callers pass
|
||||
# raw platform IDs (e.g. Slack "C0B0QV5434G") even when the format guard
|
||||
# in _parse_target_ref hasn't recognized them as explicit.
|
||||
raw = name.strip()
|
||||
for ch in channels:
|
||||
if ch.get("id") == raw:
|
||||
return ch["id"]
|
||||
|
||||
query = _normalize_channel_query(name)
|
||||
|
||||
# 1. Exact name match, including the display labels shown by send_message(action="list")
|
||||
|
||||
@ -67,6 +67,7 @@ class Platform(Enum):
|
||||
WEIXIN = "weixin"
|
||||
BLUEBUBBLES = "bluebubbles"
|
||||
QQBOT = "qqbot"
|
||||
YUANBAO = "yuanbao"
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -195,6 +196,14 @@ class StreamingConfig:
|
||||
edit_interval: float = 1.0 # Seconds between message edits (Telegram rate-limits at ~1/s)
|
||||
buffer_threshold: int = 40 # Chars before forcing an edit
|
||||
cursor: str = " ▉" # Cursor shown during streaming
|
||||
# Ported from openclaw/openclaw#72038. When >0, the final edit for
|
||||
# a long-running streamed response is delivered as a fresh message
|
||||
# if the original preview has been visible for at least this many
|
||||
# seconds, so the platform's visible timestamp reflects completion
|
||||
# time instead of the preview creation time. Currently applied to
|
||||
# Telegram only (other platforms ignore the setting). Default 60s
|
||||
# matches the OpenClaw rollout. Set to 0 to disable.
|
||||
fresh_final_after_seconds: float = 60.0
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
@ -203,6 +212,7 @@ class StreamingConfig:
|
||||
"edit_interval": self.edit_interval,
|
||||
"buffer_threshold": self.buffer_threshold,
|
||||
"cursor": self.cursor,
|
||||
"fresh_final_after_seconds": self.fresh_final_after_seconds,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@ -215,6 +225,9 @@ class StreamingConfig:
|
||||
edit_interval=float(data.get("edit_interval", 1.0)),
|
||||
buffer_threshold=int(data.get("buffer_threshold", 40)),
|
||||
cursor=data.get("cursor", " ▉"),
|
||||
fresh_final_after_seconds=float(
|
||||
data.get("fresh_final_after_seconds", 60.0)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -314,6 +327,9 @@ class GatewayConfig:
|
||||
# QQBot uses extra dict for app credentials
|
||||
elif platform == Platform.QQBOT and config.extra.get("app_id") and config.extra.get("client_secret"):
|
||||
connected.append(platform)
|
||||
# Yuanbao uses extra dict for app credentials
|
||||
elif platform == Platform.YUANBAO and config.extra.get("app_id") and config.extra.get("app_secret"):
|
||||
connected.append(platform)
|
||||
# DingTalk uses client_id/client_secret from config.extra or env vars
|
||||
elif platform == Platform.DINGTALK and (
|
||||
config.extra.get("client_id") or os.getenv("DINGTALK_CLIENT_ID")
|
||||
@ -570,6 +586,8 @@ def load_gateway_config() -> GatewayConfig:
|
||||
)
|
||||
if "reply_prefix" in platform_cfg:
|
||||
bridged["reply_prefix"] = platform_cfg["reply_prefix"]
|
||||
if "reply_in_thread" in platform_cfg:
|
||||
bridged["reply_in_thread"] = platform_cfg["reply_in_thread"]
|
||||
if "require_mention" in platform_cfg:
|
||||
bridged["require_mention"] = platform_cfg["require_mention"]
|
||||
if "free_response_channels" in platform_cfg:
|
||||
@ -584,7 +602,7 @@ def load_gateway_config() -> GatewayConfig:
|
||||
bridged["group_policy"] = platform_cfg["group_policy"]
|
||||
if "group_allow_from" in platform_cfg:
|
||||
bridged["group_allow_from"] = platform_cfg["group_allow_from"]
|
||||
if plat == Platform.DISCORD and "channel_skill_bindings" in platform_cfg:
|
||||
if plat in (Platform.DISCORD, Platform.SLACK) and "channel_skill_bindings" in platform_cfg:
|
||||
bridged["channel_skill_bindings"] = platform_cfg["channel_skill_bindings"]
|
||||
if "channel_prompts" in platform_cfg:
|
||||
channel_prompts = platform_cfg["channel_prompts"]
|
||||
@ -609,6 +627,8 @@ def load_gateway_config() -> GatewayConfig:
|
||||
if isinstance(slack_cfg, dict):
|
||||
if "require_mention" in slack_cfg and not os.getenv("SLACK_REQUIRE_MENTION"):
|
||||
os.environ["SLACK_REQUIRE_MENTION"] = str(slack_cfg["require_mention"]).lower()
|
||||
if "strict_mention" in slack_cfg and not os.getenv("SLACK_STRICT_MENTION"):
|
||||
os.environ["SLACK_STRICT_MENTION"] = str(slack_cfg["strict_mention"]).lower()
|
||||
if "allow_bots" in slack_cfg and not os.getenv("SLACK_ALLOW_BOTS"):
|
||||
os.environ["SLACK_ALLOW_BOTS"] = str(slack_cfg["allow_bots"]).lower()
|
||||
frc = slack_cfg.get("free_response_channels")
|
||||
@ -918,8 +938,12 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
slack_token = os.getenv("SLACK_BOT_TOKEN")
|
||||
if slack_token:
|
||||
if Platform.SLACK not in config.platforms:
|
||||
# No yaml config for Slack — env-only setup, enable it
|
||||
config.platforms[Platform.SLACK] = PlatformConfig()
|
||||
config.platforms[Platform.SLACK].enabled = True
|
||||
config.platforms[Platform.SLACK].enabled = True
|
||||
# If yaml config exists, respect its enabled flag (don't override
|
||||
# explicit enabled: false). Token is still stored so skills that
|
||||
# send Slack messages can use it without activating the gateway adapter.
|
||||
config.platforms[Platform.SLACK].token = slack_token
|
||||
slack_home = os.getenv("SLACK_HOME_CHANNEL")
|
||||
if slack_home and Platform.SLACK in config.platforms:
|
||||
@ -1276,6 +1300,48 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
name=os.getenv("QQBOT_HOME_CHANNEL_NAME") or os.getenv(qq_home_name_env, "Home"),
|
||||
)
|
||||
|
||||
# Yuanbao — YUANBAO_APP_ID preferred
|
||||
yuanbao_app_id = os.getenv("YUANBAO_APP_ID") or os.getenv("YUANBAO_APP_KEY")
|
||||
yuanbao_app_secret = os.getenv("YUANBAO_APP_SECRET")
|
||||
if yuanbao_app_id and yuanbao_app_secret:
|
||||
if Platform.YUANBAO not in config.platforms:
|
||||
config.platforms[Platform.YUANBAO] = PlatformConfig()
|
||||
config.platforms[Platform.YUANBAO].enabled = True
|
||||
extra = config.platforms[Platform.YUANBAO].extra
|
||||
extra["app_id"] = yuanbao_app_id
|
||||
extra["app_secret"] = yuanbao_app_secret
|
||||
yuanbao_bot_id = os.getenv("YUANBAO_BOT_ID")
|
||||
if yuanbao_bot_id:
|
||||
extra["bot_id"] = yuanbao_bot_id
|
||||
yuanbao_ws_url = os.getenv("YUANBAO_WS_URL")
|
||||
if yuanbao_ws_url:
|
||||
extra["ws_url"] = yuanbao_ws_url
|
||||
yuanbao_api_domain = os.getenv("YUANBAO_API_DOMAIN")
|
||||
if yuanbao_api_domain:
|
||||
extra["api_domain"] = yuanbao_api_domain
|
||||
yuanbao_route_env = os.getenv("YUANBAO_ROUTE_ENV")
|
||||
if yuanbao_route_env:
|
||||
extra["route_env"] = yuanbao_route_env
|
||||
yuanbao_home = os.getenv("YUANBAO_HOME_CHANNEL")
|
||||
if yuanbao_home:
|
||||
config.platforms[Platform.YUANBAO].home_channel = HomeChannel(
|
||||
platform=Platform.YUANBAO,
|
||||
chat_id=yuanbao_home,
|
||||
name=os.getenv("YUANBAO_HOME_CHANNEL_NAME", "Home"),
|
||||
)
|
||||
yuanbao_dm_policy = os.getenv("YUANBAO_DM_POLICY")
|
||||
if yuanbao_dm_policy:
|
||||
extra["dm_policy"] = yuanbao_dm_policy.strip().lower()
|
||||
yuanbao_dm_allow_from = os.getenv("YUANBAO_DM_ALLOW_FROM")
|
||||
if yuanbao_dm_allow_from:
|
||||
extra["dm_allow_from"] = yuanbao_dm_allow_from
|
||||
yuanbao_group_policy = os.getenv("YUANBAO_GROUP_POLICY")
|
||||
if yuanbao_group_policy:
|
||||
extra["group_policy"] = yuanbao_group_policy.strip().lower()
|
||||
yuanbao_group_allow_from = os.getenv("YUANBAO_GROUP_ALLOW_FROM")
|
||||
if yuanbao_group_allow_from:
|
||||
extra["group_allow_from"] = yuanbao_group_allow_from
|
||||
|
||||
# Session settings
|
||||
idle_minutes = os.getenv("SESSION_IDLE_MINUTES")
|
||||
if idle_minutes:
|
||||
|
||||
@ -79,7 +79,9 @@ _PLATFORM_DEFAULTS: dict[str, dict[str, Any]] = {
|
||||
"discord": _TIER_HIGH,
|
||||
|
||||
# Tier 2 — edit support, often customer/workspace channels
|
||||
"slack": _TIER_MEDIUM,
|
||||
# Slack: tool_progress off by default — Bolt posts cannot be edited like CLI;
|
||||
# "new"/"all" spam permanent lines in channels (hermes-agent#14663).
|
||||
"slack": {**_TIER_MEDIUM, "tool_progress": "off"},
|
||||
"mattermost": _TIER_MEDIUM,
|
||||
"matrix": _TIER_MEDIUM,
|
||||
"feishu": _TIER_MEDIUM,
|
||||
|
||||
@ -28,6 +28,7 @@ def mirror_to_session(
|
||||
message_text: str,
|
||||
source_label: str = "cli",
|
||||
thread_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Append a delivery-mirror message to the target session's transcript.
|
||||
@ -39,9 +40,20 @@ def mirror_to_session(
|
||||
All errors are caught -- this is never fatal.
|
||||
"""
|
||||
try:
|
||||
session_id = _find_session_id(platform, str(chat_id), thread_id=thread_id)
|
||||
session_id = _find_session_id(
|
||||
platform,
|
||||
str(chat_id),
|
||||
thread_id=thread_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
if not session_id:
|
||||
logger.debug("Mirror: no session found for %s:%s:%s", platform, chat_id, thread_id)
|
||||
logger.debug(
|
||||
"Mirror: no session found for %s:%s:%s:%s",
|
||||
platform,
|
||||
chat_id,
|
||||
thread_id,
|
||||
user_id,
|
||||
)
|
||||
return False
|
||||
|
||||
mirror_msg = {
|
||||
@ -59,17 +71,33 @@ def mirror_to_session(
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.debug("Mirror failed for %s:%s:%s: %s", platform, chat_id, thread_id, e)
|
||||
logger.debug(
|
||||
"Mirror failed for %s:%s:%s:%s: %s",
|
||||
platform,
|
||||
chat_id,
|
||||
thread_id,
|
||||
user_id,
|
||||
e,
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def _find_session_id(platform: str, chat_id: str, thread_id: Optional[str] = None) -> Optional[str]:
|
||||
def _find_session_id(
|
||||
platform: str,
|
||||
chat_id: str,
|
||||
thread_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Find the active session_id for a platform + chat_id pair.
|
||||
|
||||
Scans sessions.json entries and matches where origin.chat_id == chat_id
|
||||
on the right platform. DM session keys don't embed the chat_id
|
||||
(e.g. "agent:main:telegram:dm"), so we check the origin dict.
|
||||
|
||||
When *user_id* is provided, prefer exact sender matches. If multiple
|
||||
same-chat candidates exist and none matches the user, return None instead
|
||||
of guessing and contaminating another participant's session.
|
||||
"""
|
||||
if not _SESSIONS_INDEX.exists():
|
||||
return None
|
||||
@ -81,8 +109,7 @@ def _find_session_id(platform: str, chat_id: str, thread_id: Optional[str] = Non
|
||||
return None
|
||||
|
||||
platform_lower = platform.lower()
|
||||
best_match = None
|
||||
best_updated = ""
|
||||
candidates = []
|
||||
|
||||
for _key, entry in data.items():
|
||||
origin = entry.get("origin") or {}
|
||||
@ -96,12 +123,31 @@ def _find_session_id(platform: str, chat_id: str, thread_id: Optional[str] = Non
|
||||
origin_thread_id = origin.get("thread_id")
|
||||
if thread_id is not None and str(origin_thread_id or "") != str(thread_id):
|
||||
continue
|
||||
updated = entry.get("updated_at", "")
|
||||
if updated > best_updated:
|
||||
best_updated = updated
|
||||
best_match = entry.get("session_id")
|
||||
candidates.append(entry)
|
||||
|
||||
return best_match
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
if user_id:
|
||||
exact_user_matches = [
|
||||
entry for entry in candidates
|
||||
if str((entry.get("origin") or {}).get("user_id") or "") == str(user_id)
|
||||
]
|
||||
if exact_user_matches:
|
||||
candidates = exact_user_matches
|
||||
elif len(candidates) > 1:
|
||||
return None
|
||||
elif len(candidates) > 1:
|
||||
distinct_user_ids = {
|
||||
str((entry.get("origin") or {}).get("user_id") or "").strip()
|
||||
for entry in candidates
|
||||
if str((entry.get("origin") or {}).get("user_id") or "").strip()
|
||||
}
|
||||
if len(distinct_user_ids) > 1:
|
||||
return None
|
||||
|
||||
best_entry = max(candidates, key=lambda entry: entry.get("updated_at", ""))
|
||||
return best_entry.get("session_id")
|
||||
|
||||
|
||||
def _append_to_jsonl(session_id: str, message: dict) -> None:
|
||||
|
||||
@ -10,10 +10,12 @@ Each adapter handles:
|
||||
|
||||
from .base import BasePlatformAdapter, MessageEvent, SendResult
|
||||
from .qqbot import QQAdapter
|
||||
from .yuanbao import YuanbaoAdapter
|
||||
|
||||
__all__ = [
|
||||
"BasePlatformAdapter",
|
||||
"MessageEvent",
|
||||
"SendResult",
|
||||
"QQAdapter",
|
||||
"YuanbaoAdapter",
|
||||
]
|
||||
|
||||
@ -336,6 +336,39 @@ def proxy_kwargs_for_aiohttp(proxy_url: str | None) -> tuple[dict, dict]:
|
||||
return {}, {"proxy": proxy_url}
|
||||
|
||||
|
||||
def is_host_excluded_by_no_proxy(hostname: str, no_proxy_value: str | None = None) -> bool:
|
||||
"""Return True when ``hostname`` matches a ``NO_PROXY`` entry.
|
||||
|
||||
Supports comma- or whitespace-separated entries with optional leading dots
|
||||
and ``*.`` wildcards, which match both the apex domain and subdomains.
|
||||
"""
|
||||
raw = no_proxy_value
|
||||
if raw is None:
|
||||
raw = os.environ.get("NO_PROXY") or os.environ.get("no_proxy") or ""
|
||||
|
||||
raw = raw.strip()
|
||||
if not raw:
|
||||
return False
|
||||
|
||||
lower_hostname = hostname.lower()
|
||||
for entry in re.split(r"[\s,]+", raw):
|
||||
normalized = entry.strip().lower()
|
||||
if not normalized:
|
||||
continue
|
||||
if normalized == "*":
|
||||
return True
|
||||
|
||||
if normalized.startswith("*."):
|
||||
normalized = normalized[2:]
|
||||
elif normalized.startswith("."):
|
||||
normalized = normalized[1:]
|
||||
|
||||
if lower_hostname == normalized or lower_hostname.endswith(f".{normalized}"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
@ -693,7 +726,15 @@ SUPPORTED_DOCUMENT_TYPES = {
|
||||
".pdf": "application/pdf",
|
||||
".md": "text/markdown",
|
||||
".txt": "text/plain",
|
||||
".csv": "text/csv",
|
||||
".log": "text/plain",
|
||||
".json": "application/json",
|
||||
".xml": "application/xml",
|
||||
".yaml": "application/yaml",
|
||||
".yml": "application/yaml",
|
||||
".toml": "application/toml",
|
||||
".ini": "text/plain",
|
||||
".cfg": "text/plain",
|
||||
".zip": "application/zip",
|
||||
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
@ -982,6 +1023,61 @@ def resolve_channel_prompt(
|
||||
return None
|
||||
|
||||
|
||||
def resolve_channel_skills(
|
||||
config_extra: dict,
|
||||
channel_id: str,
|
||||
parent_id: str | None = None,
|
||||
) -> list[str] | None:
|
||||
"""Resolve auto-loaded skill(s) for a channel/thread from platform config.
|
||||
|
||||
Looks up ``channel_skill_bindings`` in the adapter's ``config.extra`` dict.
|
||||
|
||||
Config format::
|
||||
|
||||
channel_skill_bindings:
|
||||
- id: "C0123" # Slack channel ID or Discord channel/forum ID
|
||||
skills: ["skill-a", "skill-b"]
|
||||
- id: "D0ABCDE"
|
||||
skill: "solo-skill" # single string also accepted
|
||||
|
||||
Prefers an exact match on *channel_id*; falls back to *parent_id*
|
||||
(useful for forum threads / Slack threads inheriting the parent channel's
|
||||
binding).
|
||||
|
||||
Returns a deduplicated list of skill names (order preserved), or None if
|
||||
no match is found.
|
||||
"""
|
||||
bindings = config_extra.get("channel_skill_bindings") or []
|
||||
if not isinstance(bindings, list) or not bindings:
|
||||
return None
|
||||
ids_to_check: set[str] = set()
|
||||
if channel_id:
|
||||
ids_to_check.add(str(channel_id))
|
||||
if parent_id:
|
||||
ids_to_check.add(str(parent_id))
|
||||
if not ids_to_check:
|
||||
return None
|
||||
for entry in bindings:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
entry_id = str(entry.get("id", ""))
|
||||
if entry_id in ids_to_check:
|
||||
skills = entry.get("skills") or entry.get("skill")
|
||||
if isinstance(skills, str):
|
||||
s = skills.strip()
|
||||
return [s] if s else None
|
||||
if isinstance(skills, list) and skills:
|
||||
seen: list[str] = []
|
||||
for name in skills:
|
||||
if not isinstance(name, str):
|
||||
continue
|
||||
nm = name.strip()
|
||||
if nm and nm not in seen:
|
||||
seen.append(nm)
|
||||
return seen or None
|
||||
return None
|
||||
|
||||
|
||||
class BasePlatformAdapter(ABC):
|
||||
"""
|
||||
Base class for platform adapters.
|
||||
@ -1258,6 +1354,27 @@ class BasePlatformAdapter(ABC):
|
||||
"""
|
||||
return SendResult(success=False, error="Not supported")
|
||||
|
||||
async def delete_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete a previously sent message. Optional — platforms that don't
|
||||
support deletion return ``False`` and callers fall back to leaving
|
||||
the message in place.
|
||||
|
||||
Used by the stream consumer's fresh-final cleanup path (see
|
||||
openclaw/openclaw#72038) to remove long-lived preview messages
|
||||
after sending the completed reply as a fresh message so the
|
||||
platform's visible timestamp reflects completion time.
|
||||
|
||||
Returns ``True`` on successful deletion, ``False`` otherwise.
|
||||
Subclasses should override for platforms with a deletion API
|
||||
(e.g. Telegram ``deleteMessage``).
|
||||
"""
|
||||
return False
|
||||
|
||||
async def send_typing(self, chat_id: str, metadata=None) -> None:
|
||||
"""
|
||||
Send a typing indicator.
|
||||
|
||||
@ -2679,21 +2679,8 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
skills: ["skill-a", "skill-b"]
|
||||
Also checks parent_id so forum threads inherit the forum's bindings.
|
||||
"""
|
||||
bindings = self.config.extra.get("channel_skill_bindings", [])
|
||||
if not bindings:
|
||||
return None
|
||||
ids_to_check = {channel_id}
|
||||
if parent_id:
|
||||
ids_to_check.add(parent_id)
|
||||
for entry in bindings:
|
||||
entry_id = str(entry.get("id", ""))
|
||||
if entry_id in ids_to_check:
|
||||
skills = entry.get("skills") or entry.get("skill")
|
||||
if isinstance(skills, str):
|
||||
return [skills]
|
||||
if isinstance(skills, list) and skills:
|
||||
return list(dict.fromkeys(skills)) # dedup, preserve order
|
||||
return None
|
||||
from gateway.platforms.base import resolve_channel_skills
|
||||
return resolve_channel_skills(self.config.extra, channel_id, parent_id)
|
||||
|
||||
def _resolve_channel_prompt(self, channel_id: str, parent_id: str | None = None) -> str | None:
|
||||
"""Resolve a Discord per-channel prompt, preferring the exact channel over its parent."""
|
||||
|
||||
@ -57,6 +57,15 @@ class MessageDeduplicator:
|
||||
if len(self._seen) > self._max_size:
|
||||
cutoff = now - self._ttl
|
||||
self._seen = {k: v for k, v in self._seen.items() if v > cutoff}
|
||||
if len(self._seen) > self._max_size:
|
||||
# TTL pruning alone does not cap the cache when every entry is
|
||||
# still fresh. Keep the newest entries so the helper's
|
||||
# max_size bound is enforced under sustained traffic.
|
||||
newest = sorted(
|
||||
self._seen.items(),
|
||||
key=lambda item: item[1],
|
||||
)[-self._max_size:]
|
||||
self._seen = dict(newest)
|
||||
return False
|
||||
|
||||
def clear(self):
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1209,6 +1209,31 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
)
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
async def delete_message(self, chat_id: str, message_id: str) -> bool:
|
||||
"""Delete a previously sent Telegram message.
|
||||
|
||||
Used by the stream consumer's fresh-final cleanup path (ported
|
||||
from openclaw/openclaw#72038) to remove long-lived preview
|
||||
messages after sending the completed reply as a fresh message.
|
||||
Telegram's Bot API ``deleteMessage`` works for bot-posted
|
||||
messages in the last 48 hours. Failures are non-fatal — the
|
||||
caller leaves the preview in place and logs at debug level.
|
||||
"""
|
||||
if not self._bot:
|
||||
return False
|
||||
try:
|
||||
await self._bot.delete_message(
|
||||
chat_id=int(chat_id),
|
||||
message_id=int(message_id),
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"[%s] Failed to delete Telegram message %s: %s",
|
||||
self.name, message_id, e,
|
||||
)
|
||||
return False
|
||||
|
||||
async def send_update_prompt(
|
||||
self, chat_id: str, prompt: str, default: str = "",
|
||||
session_key: str = "",
|
||||
|
||||
4754
gateway/platforms/yuanbao.py
Normal file
4754
gateway/platforms/yuanbao.py
Normal file
File diff suppressed because it is too large
Load Diff
647
gateway/platforms/yuanbao_media.py
Normal file
647
gateway/platforms/yuanbao_media.py
Normal file
@ -0,0 +1,647 @@
|
||||
"""
|
||||
yuanbao_media.py — 元宝平台媒体处理模块
|
||||
|
||||
提供 COS 上传、文件下载、TIM 媒体消息构建等功能。
|
||||
移植自 TypeScript 版 media.ts(yuanbao-openclaw-plugin),
|
||||
使用 httpx 替代 cos-nodejs-sdk-v5,避免引入额外 SDK 依赖。
|
||||
|
||||
COS 上传流程:
|
||||
1. 调用 genUploadInfo 获取临时凭证(tmpSecretId/tmpSecretKey/sessionToken)
|
||||
2. 用临时凭证通过 HMAC-SHA1 签名构建 Authorization 头
|
||||
3. HTTP PUT 上传到 COS
|
||||
|
||||
TIM 消息体构建:
|
||||
- buildImageMsgBody() → TIMImageElem
|
||||
- buildFileMsgBody() → TIMFileElem
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import secrets
|
||||
import struct
|
||||
import time
|
||||
import urllib.parse
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Optional, Any
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ============ 常量 ============
|
||||
|
||||
UPLOAD_INFO_PATH = "/api/resource/genUploadInfo"
|
||||
DEFAULT_API_DOMAIN = "yuanbao.tencent.com"
|
||||
DEFAULT_MAX_SIZE_MB = 50
|
||||
|
||||
# COS 加速域名后缀(优先使用全球加速)
|
||||
COS_USE_ACCELERATE = True
|
||||
|
||||
# ============ 类型映射 ============
|
||||
|
||||
# MIME → image_format 数字(TIM 协议字段)
|
||||
_MIME_TO_IMAGE_FORMAT: dict[str, int] = {
|
||||
"image/jpeg": 1,
|
||||
"image/jpg": 1,
|
||||
"image/gif": 2,
|
||||
"image/png": 3,
|
||||
"image/bmp": 4,
|
||||
"image/webp": 255,
|
||||
"image/heic": 255,
|
||||
"image/tiff": 255,
|
||||
}
|
||||
|
||||
# 文件扩展名 → MIME
|
||||
_EXT_TO_MIME: dict[str, str] = {
|
||||
".jpg": "image/jpeg",
|
||||
".jpeg": "image/jpeg",
|
||||
".png": "image/png",
|
||||
".gif": "image/gif",
|
||||
".webp": "image/webp",
|
||||
".bmp": "image/bmp",
|
||||
".heic": "image/heic",
|
||||
".tiff": "image/tiff",
|
||||
".ico": "image/x-icon",
|
||||
".pdf": "application/pdf",
|
||||
".doc": "application/msword",
|
||||
".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
".xls": "application/vnd.ms-excel",
|
||||
".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
".ppt": "application/vnd.ms-powerpoint",
|
||||
".pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
".txt": "text/plain",
|
||||
".zip": "application/zip",
|
||||
".tar": "application/x-tar",
|
||||
".gz": "application/gzip",
|
||||
".mp3": "audio/mpeg",
|
||||
".mp4": "video/mp4",
|
||||
".wav": "audio/wav",
|
||||
".ogg": "audio/ogg",
|
||||
".webm": "video/webm",
|
||||
}
|
||||
|
||||
|
||||
# ============ 工具函数 ============
|
||||
|
||||
def guess_mime_type(filename: str) -> str:
|
||||
"""根据文件扩展名猜测 MIME 类型。"""
|
||||
ext = os.path.splitext(filename)[-1].lower()
|
||||
return _EXT_TO_MIME.get(ext, "application/octet-stream")
|
||||
|
||||
|
||||
def is_image(filename: str, mime_type: str = "") -> bool:
|
||||
"""判断是否为图片类型。"""
|
||||
if mime_type.startswith("image/"):
|
||||
return True
|
||||
ext = os.path.splitext(filename)[-1].lower()
|
||||
return ext in {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".heic", ".tiff", ".ico"}
|
||||
|
||||
|
||||
def get_image_format(mime_type: str) -> int:
|
||||
"""获取 TIM 图片格式编号。"""
|
||||
return _MIME_TO_IMAGE_FORMAT.get(mime_type.lower(), 255)
|
||||
|
||||
|
||||
def md5_hex(data: bytes) -> str:
|
||||
"""计算 MD5 十六进制摘要。"""
|
||||
return hashlib.md5(data).hexdigest()
|
||||
|
||||
|
||||
def generate_file_id() -> str:
|
||||
"""生成随机文件 ID(32 位 hex)。"""
|
||||
return secrets.token_hex(16)
|
||||
|
||||
|
||||
|
||||
# ============ 图片尺寸解析(纯 Python,无需 Pillow) ============
|
||||
|
||||
def parse_image_size(data: bytes) -> Optional[dict[str, int]]:
|
||||
"""
|
||||
解析图片宽高(支持 JPEG/PNG/GIF/WebP),无需第三方依赖。
|
||||
返回 {"width": w, "height": h} 或 None(无法识别)。
|
||||
"""
|
||||
return (
|
||||
_parse_png_size(data)
|
||||
or _parse_jpeg_size(data)
|
||||
or _parse_gif_size(data)
|
||||
or _parse_webp_size(data)
|
||||
)
|
||||
|
||||
|
||||
def _parse_png_size(buf: bytes) -> Optional[dict[str, int]]:
|
||||
if len(buf) < 24:
|
||||
return None
|
||||
if buf[:4] != b"\x89PNG":
|
||||
return None
|
||||
w = struct.unpack(">I", buf[16:20])[0]
|
||||
h = struct.unpack(">I", buf[20:24])[0]
|
||||
return {"width": w, "height": h}
|
||||
|
||||
|
||||
def _parse_jpeg_size(buf: bytes) -> Optional[dict[str, int]]:
|
||||
if len(buf) < 4 or buf[0] != 0xFF or buf[1] != 0xD8:
|
||||
return None
|
||||
i = 2
|
||||
while i < len(buf) - 9:
|
||||
if buf[i] != 0xFF:
|
||||
i += 1
|
||||
continue
|
||||
marker = buf[i + 1]
|
||||
if marker in (0xC0, 0xC2):
|
||||
h = struct.unpack(">H", buf[i + 5: i + 7])[0]
|
||||
w = struct.unpack(">H", buf[i + 7: i + 9])[0]
|
||||
return {"width": w, "height": h}
|
||||
if i + 3 < len(buf):
|
||||
i += 2 + struct.unpack(">H", buf[i + 2: i + 4])[0]
|
||||
else:
|
||||
break
|
||||
return None
|
||||
|
||||
|
||||
def _parse_gif_size(buf: bytes) -> Optional[dict[str, int]]:
|
||||
if len(buf) < 10:
|
||||
return None
|
||||
sig = buf[:6].decode("ascii", errors="replace")
|
||||
if sig not in ("GIF87a", "GIF89a"):
|
||||
return None
|
||||
w = struct.unpack("<H", buf[6:8])[0]
|
||||
h = struct.unpack("<H", buf[8:10])[0]
|
||||
return {"width": w, "height": h}
|
||||
|
||||
|
||||
def _parse_webp_size(buf: bytes) -> Optional[dict[str, int]]:
|
||||
if len(buf) < 16:
|
||||
return None
|
||||
if buf[:4] != b"RIFF" or buf[8:12] != b"WEBP":
|
||||
return None
|
||||
chunk = buf[12:16].decode("ascii", errors="replace")
|
||||
if chunk == "VP8 ":
|
||||
if len(buf) >= 30 and buf[23] == 0x9D and buf[24] == 0x01 and buf[25] == 0x2A:
|
||||
w = struct.unpack("<H", buf[26:28])[0] & 0x3FFF
|
||||
h = struct.unpack("<H", buf[28:30])[0] & 0x3FFF
|
||||
return {"width": w, "height": h}
|
||||
elif chunk == "VP8L":
|
||||
if len(buf) >= 25 and buf[20] == 0x2F:
|
||||
bits = struct.unpack("<I", buf[21:25])[0]
|
||||
w = (bits & 0x3FFF) + 1
|
||||
h = ((bits >> 14) & 0x3FFF) + 1
|
||||
return {"width": w, "height": h}
|
||||
elif chunk == "VP8X":
|
||||
if len(buf) >= 30:
|
||||
w = (buf[24] | (buf[25] << 8) | (buf[26] << 16)) + 1
|
||||
h = (buf[27] | (buf[28] << 8) | (buf[29] << 16)) + 1
|
||||
return {"width": w, "height": h}
|
||||
return None
|
||||
|
||||
|
||||
# ============ URL 下载 ============
|
||||
|
||||
async def download_url(
|
||||
url: str,
|
||||
max_size_mb: int = DEFAULT_MAX_SIZE_MB,
|
||||
) -> tuple[bytes, str]:
|
||||
"""
|
||||
下载 URL 内容,返回 (bytes, content_type)。
|
||||
|
||||
Args:
|
||||
url: HTTP(S) URL
|
||||
max_size_mb: 最大允许大小(MB),超过则抛出异常
|
||||
|
||||
Returns:
|
||||
(data_bytes, content_type_string)
|
||||
|
||||
Raises:
|
||||
ValueError: 内容超过大小限制
|
||||
httpx.HTTPError: 网络/HTTP 错误
|
||||
"""
|
||||
max_bytes = max_size_mb * 1024 * 1024
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
# 先 HEAD 检查大小
|
||||
try:
|
||||
head = await client.head(url)
|
||||
content_length = int(head.headers.get("content-length", 0) or 0)
|
||||
if content_length > 0 and content_length > max_bytes:
|
||||
raise ValueError(
|
||||
f"文件过大: {content_length / 1024 / 1024:.1f} MB > {max_size_mb} MB"
|
||||
)
|
||||
except httpx.HTTPStatusError:
|
||||
pass # 部分服务器不支持 HEAD,忽略
|
||||
|
||||
# GET 下载(流式读取,防止超限)
|
||||
async with client.stream("GET", url) as resp:
|
||||
resp.raise_for_status()
|
||||
|
||||
content_type = resp.headers.get("content-type", "").split(";")[0].strip()
|
||||
|
||||
chunks: list[bytes] = []
|
||||
downloaded = 0
|
||||
async for chunk in resp.aiter_bytes(65536):
|
||||
downloaded += len(chunk)
|
||||
if downloaded > max_bytes:
|
||||
raise ValueError(
|
||||
f"文件过大: 已超过 {max_size_mb} MB 限制"
|
||||
)
|
||||
chunks.append(chunk)
|
||||
|
||||
data = b"".join(chunks)
|
||||
return data, content_type
|
||||
|
||||
|
||||
# ============ COS 鉴权(HMAC-SHA1) ============
|
||||
|
||||
def _cos_sign(
|
||||
method: str,
|
||||
path: str,
|
||||
params: dict[str, str],
|
||||
headers: dict[str, str],
|
||||
secret_id: str,
|
||||
secret_key: str,
|
||||
start_time: Optional[int] = None,
|
||||
expire_seconds: int = 3600,
|
||||
) -> str:
|
||||
"""
|
||||
构建 COS 请求签名(q-sign-algorithm=sha1 方案)。
|
||||
参考:https://cloud.tencent.com/document/product/436/7778
|
||||
|
||||
Args:
|
||||
method: HTTP 方法(小写,如 "put")
|
||||
path: URL 路径(URL encode 后的小写)
|
||||
params: URL 查询参数 dict(用于签名)
|
||||
headers: 参与签名的请求头 dict(key 需小写)
|
||||
secret_id: 临时 SecretId(tmpSecretId)
|
||||
secret_key: 临时 SecretKey(tmpSecretKey)
|
||||
start_time: 签名起始 Unix 时间戳(默认 now)
|
||||
expire_seconds: 签名有效期(秒,默认 3600)
|
||||
|
||||
Returns:
|
||||
Authorization header 值(完整字符串)
|
||||
"""
|
||||
now = int(time.time())
|
||||
q_sign_time = f"{start_time or now};{(start_time or now) + expire_seconds}"
|
||||
|
||||
# Step 1: SignKey = HMAC-SHA1(SecretKey, q-sign-time)
|
||||
sign_key = hmac.new(
|
||||
secret_key.encode("utf-8"),
|
||||
q_sign_time.encode("utf-8"),
|
||||
hashlib.sha1,
|
||||
).hexdigest()
|
||||
|
||||
# Step 2: HttpString
|
||||
# 参数和头部需按字典序排列,key 小写
|
||||
sorted_params = sorted((k.lower(), urllib.parse.quote(str(v), safe="") ) for k, v in params.items())
|
||||
sorted_headers = sorted((k.lower(), urllib.parse.quote(str(v), safe="") ) for k, v in headers.items())
|
||||
|
||||
url_param_list = ";".join(k for k, _ in sorted_params)
|
||||
url_params = "&".join(f"{k}={v}" for k, v in sorted_params)
|
||||
header_list = ";".join(k for k, _ in sorted_headers)
|
||||
header_str = "&".join(f"{k}={v}" for k, v in sorted_headers)
|
||||
|
||||
http_string = "\n".join([
|
||||
method.lower(),
|
||||
path,
|
||||
url_params,
|
||||
header_str,
|
||||
"",
|
||||
])
|
||||
|
||||
# Step 3: StringToSign = sha1 hash of HttpString
|
||||
sha1_of_http = hashlib.sha1(http_string.encode("utf-8")).hexdigest()
|
||||
string_to_sign = "\n".join([
|
||||
"sha1",
|
||||
q_sign_time,
|
||||
sha1_of_http,
|
||||
"",
|
||||
])
|
||||
|
||||
# Step 4: Signature = HMAC-SHA1(SignKey, StringToSign)
|
||||
signature = hmac.new(
|
||||
sign_key.encode("utf-8"),
|
||||
string_to_sign.encode("utf-8"),
|
||||
hashlib.sha1,
|
||||
).hexdigest()
|
||||
|
||||
return (
|
||||
f"q-sign-algorithm=sha1"
|
||||
f"&q-ak={secret_id}"
|
||||
f"&q-sign-time={q_sign_time}"
|
||||
f"&q-key-time={q_sign_time}"
|
||||
f"&q-header-list={header_list}"
|
||||
f"&q-url-param-list={url_param_list}"
|
||||
f"&q-signature={signature}"
|
||||
)
|
||||
|
||||
|
||||
# ============ 主要公开 API ============
|
||||
|
||||
async def get_cos_credentials(
|
||||
app_key: str,
|
||||
api_domain: str,
|
||||
token: str,
|
||||
filename: str = "file",
|
||||
file_id: Optional[str] = None,
|
||||
bot_id: str = "",
|
||||
route_env: str = "",
|
||||
) -> dict:
|
||||
"""
|
||||
调用 genUploadInfo 接口获取 COS 临时密钥及上传配置。
|
||||
|
||||
Args:
|
||||
app_key: 应用 Key(用于 X-ID 头)
|
||||
api_domain: API 域名(如 https://bot.yuanbao.tencent.com)
|
||||
token: 当前有效的签票 token(X-Token 头)
|
||||
filename: 待上传的文件名(含扩展名)
|
||||
file_id: 客户端生成的唯一文件 ID(不传则自动生成)
|
||||
bot_id: Bot 账号 ID(用于 X-ID 头)
|
||||
|
||||
Returns:
|
||||
COS 上传配置 dict,包含以下字段:
|
||||
bucketName (str) — COS Bucket 名称
|
||||
region (str) — COS 地域
|
||||
location (str) — 上传 Key(对象路径)
|
||||
encryptTmpSecretId (str) — 临时 SecretId
|
||||
encryptTmpSecretKey(str) — 临时 SecretKey
|
||||
encryptToken (str) — SessionToken
|
||||
startTime (int) — 凭证起始时间戳(Unix)
|
||||
expiredTime (int) — 凭证过期时间戳(Unix)
|
||||
resourceUrl (str) — 上传后的公网访问 URL
|
||||
resourceID (str) — 资源 ID(可选)
|
||||
|
||||
Raises:
|
||||
RuntimeError: 接口返回非 0 code 或字段缺失
|
||||
"""
|
||||
if file_id is None:
|
||||
file_id = generate_file_id()
|
||||
|
||||
upload_url = f"{api_domain.rstrip('/')}{UPLOAD_INFO_PATH}"
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"X-Token": token,
|
||||
"X-ID": bot_id or app_key,
|
||||
"X-Source": "web",
|
||||
}
|
||||
if route_env:
|
||||
headers["X-Route-Env"] = route_env
|
||||
body = {
|
||||
"fileName": filename,
|
||||
"fileId": file_id,
|
||||
"docFrom": "localDoc",
|
||||
"docOpenId": "",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=15.0) as client:
|
||||
resp = await client.post(upload_url, json=body, headers=headers)
|
||||
resp.raise_for_status()
|
||||
result: dict[str, Any] = resp.json()
|
||||
|
||||
code = result.get("code")
|
||||
if code != 0 and code is not None:
|
||||
raise RuntimeError(
|
||||
f"genUploadInfo 失败: code={code}, msg={result.get('msg', '')}"
|
||||
)
|
||||
|
||||
data = result.get("data") or result
|
||||
required_fields = ["bucketName", "location"]
|
||||
missing = [f for f in required_fields if not data.get(f)]
|
||||
if missing:
|
||||
raise RuntimeError(
|
||||
f"genUploadInfo 返回字段不完整: 缺少字段 {missing}"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
async def upload_to_cos(
|
||||
file_bytes: bytes,
|
||||
filename: str,
|
||||
content_type: str,
|
||||
credentials: dict,
|
||||
bucket: str,
|
||||
region: str,
|
||||
) -> dict:
|
||||
"""
|
||||
通过 httpx PUT 请求将文件上传到 COS。
|
||||
使用临时凭证(tmpSecretId/tmpSecretKey/sessionToken)构建 HMAC-SHA1 签名。
|
||||
|
||||
Args:
|
||||
file_bytes: 文件二进制内容
|
||||
filename: 文件名(用于辅助计算 MIME、UUID)
|
||||
content_type: MIME 类型(如 "image/jpeg")
|
||||
credentials: get_cos_credentials() 返回的 dict,包含:
|
||||
encryptTmpSecretId → tmpSecretId
|
||||
encryptTmpSecretKey → tmpSecretKey
|
||||
encryptToken → sessionToken
|
||||
location → COS key(对象路径)
|
||||
resourceUrl → 上传后公网 URL
|
||||
startTime → 凭证起始时间(Unix)
|
||||
expiredTime → 凭证过期时间(Unix)
|
||||
bucket: COS Bucket 名称(如 chatbot-1234567890)
|
||||
region: COS 地域(如 ap-guangzhou)
|
||||
|
||||
Returns:
|
||||
上传结果 dict,包含:
|
||||
url (str) — COS 公网访问 URL
|
||||
uuid (str) — 文件内容 MD5
|
||||
size (int) — 文件大小(字节)
|
||||
width (int, optional) — 图片宽度(仅图片)
|
||||
height (int, optional) — 图片高度(仅图片)
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: COS 返回非 2xx 状态
|
||||
RuntimeError: credentials 字段缺失
|
||||
"""
|
||||
secret_id: str = credentials.get("encryptTmpSecretId", "")
|
||||
secret_key: str = credentials.get("encryptTmpSecretKey", "")
|
||||
session_token: str = credentials.get("encryptToken", "")
|
||||
cos_key: str = credentials.get("location", "")
|
||||
resource_url: str = credentials.get("resourceUrl", "")
|
||||
start_time: Optional[int] = credentials.get("startTime")
|
||||
expired_time: Optional[int] = credentials.get("expiredTime")
|
||||
|
||||
if not secret_id or not secret_key or not cos_key:
|
||||
raise RuntimeError(
|
||||
f"COS credentials 不完整: secretId={bool(secret_id)}, "
|
||||
f"secretKey={bool(secret_key)}, location={bool(cos_key)}"
|
||||
)
|
||||
|
||||
# 构建 COS 上传 URL(优先使用全球加速域名)
|
||||
if COS_USE_ACCELERATE:
|
||||
cos_host = f"{bucket}.cos.accelerate.myqcloud.com"
|
||||
else:
|
||||
cos_host = f"{bucket}.cos.{region}.myqcloud.com"
|
||||
|
||||
# URL encode cos_key(保留 /)
|
||||
encoded_key = urllib.parse.quote(cos_key, safe="/")
|
||||
cos_url = f"https://{cos_host}/{encoded_key.lstrip('/')}"
|
||||
|
||||
# 确定 Content-Type
|
||||
if not content_type or content_type == "application/octet-stream":
|
||||
if is_image(filename):
|
||||
content_type = guess_mime_type(filename)
|
||||
else:
|
||||
content_type = "application/octet-stream"
|
||||
|
||||
# 计算文件 MD5 + size
|
||||
file_uuid = md5_hex(file_bytes)
|
||||
file_size = len(file_bytes)
|
||||
|
||||
# 参与签名的请求头
|
||||
sign_headers = {
|
||||
"host": cos_host,
|
||||
"content-type": content_type,
|
||||
"x-cos-security-token": session_token,
|
||||
}
|
||||
|
||||
# 计算签名有效期
|
||||
now = int(time.time())
|
||||
sign_start = start_time if start_time else now
|
||||
sign_expire = (expired_time - now) if expired_time and expired_time > now else 3600
|
||||
|
||||
authorization = _cos_sign(
|
||||
method="put",
|
||||
path=f"/{encoded_key.lstrip('/')}",
|
||||
params={},
|
||||
headers=sign_headers,
|
||||
secret_id=secret_id,
|
||||
secret_key=secret_key,
|
||||
start_time=sign_start,
|
||||
expire_seconds=sign_expire,
|
||||
)
|
||||
|
||||
put_headers = {
|
||||
"Authorization": authorization,
|
||||
"Content-Type": content_type,
|
||||
"x-cos-security-token": session_token,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"COS PUT: bucket=%s region=%s key=%s size=%d mime=%s",
|
||||
bucket, region, cos_key, file_size, content_type,
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
resp = await client.put(
|
||||
cos_url,
|
||||
content=file_bytes,
|
||||
headers=put_headers,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
|
||||
# 解析图片尺寸(仅图片类型)
|
||||
result: dict[str, Any] = {
|
||||
"url": resource_url or cos_url,
|
||||
"uuid": file_uuid,
|
||||
"size": file_size,
|
||||
}
|
||||
|
||||
if content_type.startswith("image/"):
|
||||
size_info = parse_image_size(file_bytes)
|
||||
if size_info:
|
||||
result["width"] = size_info["width"]
|
||||
result["height"] = size_info["height"]
|
||||
|
||||
logger.info(
|
||||
"COS 上传成功: url=%s size=%d",
|
||||
result["url"], file_size,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
# ============ TIM 媒体消息构建 ============
|
||||
|
||||
def build_image_msg_body(
|
||||
url: str,
|
||||
uuid: Optional[str] = None,
|
||||
filename: Optional[str] = None,
|
||||
size: int = 0,
|
||||
width: int = 0,
|
||||
height: int = 0,
|
||||
mime_type: str = "",
|
||||
) -> list[dict]:
|
||||
"""
|
||||
构建腾讯 IM TIMImageElem 消息体。
|
||||
参考:https://cloud.tencent.com/document/product/269/2720
|
||||
|
||||
Args:
|
||||
url: 图片公网访问 URL(COS resourceUrl)
|
||||
uuid: 文件 UUID(MD5 或其他唯一标识)
|
||||
filename: 文件名(uuid 为空时作为备用)
|
||||
size: 文件大小(字节)
|
||||
width: 图片宽度(像素)
|
||||
height: 图片高度(像素)
|
||||
mime_type: MIME 类型(用于确定 image_format)
|
||||
|
||||
Returns:
|
||||
TIMImageElem 消息体列表(适合直接放入 msg_body)
|
||||
"""
|
||||
_uuid = uuid or filename or _basename_from_url(url) or "image"
|
||||
image_format = get_image_format(mime_type) if mime_type else 255
|
||||
|
||||
return [
|
||||
{
|
||||
"msg_type": "TIMImageElem",
|
||||
"msg_content": {
|
||||
"uuid": _uuid,
|
||||
"image_format": image_format,
|
||||
"image_info_array": [
|
||||
{
|
||||
"type": 1, # 1 = 原图
|
||||
"size": size,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"url": url,
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def build_file_msg_body(
|
||||
url: str,
|
||||
filename: str,
|
||||
uuid: Optional[str] = None,
|
||||
size: int = 0,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
构建腾讯 IM TIMFileElem 消息体。
|
||||
参考:https://cloud.tencent.com/document/product/269/2720
|
||||
|
||||
Args:
|
||||
url: 文件公网访问 URL(COS resourceUrl)
|
||||
filename: 文件名(含扩展名)
|
||||
uuid: 文件 UUID(MD5 或其他唯一标识,不传则使用 filename)
|
||||
size: 文件大小(字节)
|
||||
|
||||
Returns:
|
||||
TIMFileElem 消息体列表(适合直接放入 msg_body)
|
||||
"""
|
||||
_uuid = uuid or filename
|
||||
|
||||
return [
|
||||
{
|
||||
"msg_type": "TIMFileElem",
|
||||
"msg_content": {
|
||||
"uuid": _uuid,
|
||||
"file_name": filename,
|
||||
"file_size": size,
|
||||
"url": url,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
# ============ 内部工具 ============
|
||||
|
||||
def _basename_from_url(url: str) -> str:
|
||||
"""从 URL 提取文件名。"""
|
||||
try:
|
||||
parsed = urllib.parse.urlparse(url)
|
||||
return os.path.basename(parsed.path)
|
||||
except Exception:
|
||||
return ""
|
||||
1210
gateway/platforms/yuanbao_proto.py
Normal file
1210
gateway/platforms/yuanbao_proto.py
Normal file
File diff suppressed because it is too large
Load Diff
558
gateway/platforms/yuanbao_sticker.py
Normal file
558
gateway/platforms/yuanbao_sticker.py
Normal file
@ -0,0 +1,558 @@
|
||||
"""
|
||||
Yuanbao sticker (TIMFaceElem) support.
|
||||
|
||||
Ported from yuanbao-openclaw-plugin/src/sticker/.
|
||||
|
||||
TIMFaceElem wire format:
|
||||
{
|
||||
"msg_type": "TIMFaceElem",
|
||||
"msg_content": {
|
||||
"index": 0, # always 0 per Yuanbao convention
|
||||
"data": "<json>", # serialised sticker metadata
|
||||
}
|
||||
}
|
||||
|
||||
The `data` field carries a JSON string with the sticker's metadata so the
|
||||
receiver can look up the correct asset in the emoji pack.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import unicodedata
|
||||
from typing import Optional
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sticker catalogue – ported from builtin-stickers.json
|
||||
# Key : canonical name (Chinese)
|
||||
# Value : {sticker_id, package_id, name, description, width, height, formats}
|
||||
# ---------------------------------------------------------------------------
|
||||
STICKER_MAP: dict[str, dict] = {
|
||||
"六六六": {
|
||||
"sticker_id": "278", "package_id": "1003", "name": "六六六",
|
||||
"description": "666 厉害 牛 棒 绝了 好强 awesome",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"我想开了": {
|
||||
"sticker_id": "262", "package_id": "1003", "name": "我想开了",
|
||||
"description": "想开 佛系 释怀 顿悟 看淡了 无所谓",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"害羞": {
|
||||
"sticker_id": "130", "package_id": "1003", "name": "害羞",
|
||||
"description": "腼腆 不好意思 脸红 娇羞 羞涩 捂脸",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"比心": {
|
||||
"sticker_id": "252", "package_id": "1003", "name": "比心",
|
||||
"description": "笔芯 爱你 爱心手势 love heart 喜欢你",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"委屈": {
|
||||
"sticker_id": "125", "package_id": "1003", "name": "委屈",
|
||||
"description": "难过 想哭 可怜巴巴 瘪嘴 受伤 被欺负",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"亲亲": {
|
||||
"sticker_id": "146", "package_id": "1003", "name": "亲亲",
|
||||
"description": "么么 mua 亲一下 kiss 飞吻 啵",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"酷": {
|
||||
"sticker_id": "131", "package_id": "1003", "name": "酷",
|
||||
"description": "帅 墨镜 cool 高冷 有型 swagger",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"睡": {
|
||||
"sticker_id": "145", "package_id": "1003", "name": "睡",
|
||||
"description": "睡觉 困 zzZ 打盹 躺平 休眠 sleepy",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"发呆": {
|
||||
"sticker_id": "152", "package_id": "1003", "name": "发呆",
|
||||
"description": "懵 愣住 放空 呆滞 出神 脑子空白",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"可怜": {
|
||||
"sticker_id": "157", "package_id": "1003", "name": "可怜",
|
||||
"description": "卖萌 求饶 委屈巴巴 弱小 拜托 眼巴巴",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"摊手": {
|
||||
"sticker_id": "200", "package_id": "1003", "name": "摊手",
|
||||
"description": "无奈 没办法 耸肩 随便 那咋整 whatever",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"头大": {
|
||||
"sticker_id": "213", "package_id": "1003", "name": "头大",
|
||||
"description": "头疼 烦恼 郁闷 难搞 崩溃 一团乱",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"吓": {
|
||||
"sticker_id": "256", "package_id": "1003", "name": "吓",
|
||||
"description": "害怕 惊恐 震惊 吓一跳 恐怖 怂",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"吐血": {
|
||||
"sticker_id": "203", "package_id": "1003", "name": "吐血",
|
||||
"description": "无语 崩溃 被雷 内伤 一口老血 屮",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"哼": {
|
||||
"sticker_id": "185", "package_id": "1003", "name": "哼",
|
||||
"description": "傲娇 生气 不满 撇嘴 不理 赌气",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"嘿嘿": {
|
||||
"sticker_id": "220", "package_id": "1003", "name": "嘿嘿",
|
||||
"description": "坏笑 猥琐笑 偷笑 憨笑 得意 你懂的",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"头秃": {
|
||||
"sticker_id": "218", "package_id": "1003", "name": "头秃",
|
||||
"description": "程序员 加班 焦虑 没头发 秃了 肝爆",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"暗中观察": {
|
||||
"sticker_id": "221", "package_id": "1003", "name": "暗中观察",
|
||||
"description": "窥屏 潜水 偷偷看 角落 围观 屏住呼吸",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"我酸了": {
|
||||
"sticker_id": "224", "package_id": "1003", "name": "我酸了",
|
||||
"description": "嫉妒 柠檬精 羡慕 吃柠檬 眼红 恰柠檬",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"打call": {
|
||||
"sticker_id": "246", "package_id": "1003", "name": "打call",
|
||||
"description": "应援 加油 支持 喝彩 助威 call",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"庆祝": {
|
||||
"sticker_id": "251", "package_id": "1003", "name": "庆祝",
|
||||
"description": "祝贺 开心 耶 party 胜利 干杯",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"奋斗": {
|
||||
"sticker_id": "151", "package_id": "1003", "name": "奋斗",
|
||||
"description": "努力 加油 拼搏 冲 干劲 卷起来",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"惊讶": {
|
||||
"sticker_id": "143", "package_id": "1003", "name": "惊讶",
|
||||
"description": "震惊 哇 不敢相信 OMG 居然 这么离谱",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"疑问": {
|
||||
"sticker_id": "144", "package_id": "1003", "name": "疑问",
|
||||
"description": "问号 不懂 啥 为什么 啥情况 懵逼问",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"仔细分析": {
|
||||
"sticker_id": "248", "package_id": "1003", "name": "仔细分析",
|
||||
"description": "思考 推敲 认真 研究 琢磨 让我想想",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"撅嘴": {
|
||||
"sticker_id": "184", "package_id": "1003", "name": "撅嘴",
|
||||
"description": "嘟嘴 卖萌 不高兴 撒娇 嘴翘",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"泪奔": {
|
||||
"sticker_id": "199", "package_id": "1003", "name": "泪奔",
|
||||
"description": "大哭 伤心 破防 感动哭 泪流满面 呜呜",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"尊嘟假嘟": {
|
||||
"sticker_id": "276", "package_id": "1003", "name": "尊嘟假嘟",
|
||||
"description": "真的假的 真假 可爱问 你骗我 是不是",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"略略略": {
|
||||
"sticker_id": "113", "package_id": "1003", "name": "略略略",
|
||||
"description": "调皮 吐舌 不服 略 气死你 鬼脸",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"困": {
|
||||
"sticker_id": "180", "package_id": "1003", "name": "困",
|
||||
"description": "想睡 倦 打哈欠 睁不开眼 好困啊 sleepy",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"折磨": {
|
||||
"sticker_id": "181", "package_id": "1003", "name": "折磨",
|
||||
"description": "难受 痛苦 煎熬 蚌埠住了 受不了 要命",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"抠鼻": {
|
||||
"sticker_id": "182", "package_id": "1003", "name": "抠鼻",
|
||||
"description": "不屑 无聊 淡定 无所谓 鄙视 挖鼻",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"鼓掌": {
|
||||
"sticker_id": "183", "package_id": "1003", "name": "鼓掌",
|
||||
"description": "拍手 叫好 赞同 666 喝彩 掌声",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"斜眼笑": {
|
||||
"sticker_id": "204", "package_id": "1003", "name": "斜眼笑",
|
||||
"description": "滑稽 坏笑 doge 意味深长 阴阳怪气 嘿嘿嘿",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"辣眼睛": {
|
||||
"sticker_id": "216", "package_id": "1003", "name": "辣眼睛",
|
||||
"description": "看不下去 cringe 毁三观 太丑了 瞎了",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"哦哟": {
|
||||
"sticker_id": "217", "package_id": "1003", "name": "哦哟",
|
||||
"description": "惊讶 起哄 哇哦 有戏 不简单 哟",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"吃瓜": {
|
||||
"sticker_id": "222", "package_id": "1003", "name": "吃瓜",
|
||||
"description": "围观 看戏 八卦 路人 看热闹 板凳",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"狗头": {
|
||||
"sticker_id": "225", "package_id": "1003", "name": "狗头",
|
||||
"description": "doge 保命 开玩笑 滑稽 反讽 懂的都懂",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"敬礼": {
|
||||
"sticker_id": "227", "package_id": "1003", "name": "敬礼",
|
||||
"description": "salute 尊重 收到 遵命 致敬 报告",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"哦": {
|
||||
"sticker_id": "231", "package_id": "1003", "name": "哦",
|
||||
"description": "知道了 明白 敷衍 嗯 这样啊 收到",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"拿到红包": {
|
||||
"sticker_id": "236", "package_id": "1003", "name": "拿到红包",
|
||||
"description": "红包 谢谢老板 发财 开心 抢到了 欧气",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"牛吖": {
|
||||
"sticker_id": "239", "package_id": "1003", "name": "牛吖",
|
||||
"description": "牛 厉害 强 666 佩服 大佬",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"贴贴": {
|
||||
"sticker_id": "272", "package_id": "1003", "name": "贴贴",
|
||||
"description": "抱抱 亲昵 蹭蹭 亲密 靠靠 撒娇贴",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"爱心": {
|
||||
"sticker_id": "138", "package_id": "1003", "name": "爱心",
|
||||
"description": "心 love 喜欢你 红心 示爱 么么哒",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"晚安": {
|
||||
"sticker_id": "170", "package_id": "1003", "name": "晚安",
|
||||
"description": "好梦 睡了 night 早点休息 安啦 moon",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"太阳": {
|
||||
"sticker_id": "176", "package_id": "1003", "name": "太阳",
|
||||
"description": "晴天 早上好 阳光 morning 好天气 日",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"柠檬": {
|
||||
"sticker_id": "266", "package_id": "1003", "name": "柠檬",
|
||||
"description": "酸 嫉妒 柠檬精 羡慕 我酸 恰柠檬",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"大冤种": {
|
||||
"sticker_id": "267", "package_id": "1003", "name": "大冤种",
|
||||
"description": "倒霉 吃亏 自嘲 好心没好报 背锅 工具人",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"吐了": {
|
||||
"sticker_id": "132", "package_id": "1003", "name": "吐了",
|
||||
"description": "恶心 yue 受不了 嫌弃 想吐 生理不适",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"怒": {
|
||||
"sticker_id": "134", "package_id": "1003", "name": "怒",
|
||||
"description": "生气 愤怒 火大 暴躁 气炸 怼",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"玫瑰": {
|
||||
"sticker_id": "165", "package_id": "1003", "name": "玫瑰",
|
||||
"description": "花 示爱 表白 浪漫 送你花 情人节",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"凋谢": {
|
||||
"sticker_id": "119", "package_id": "1003", "name": "凋谢",
|
||||
"description": "花谢 失恋 难过 枯萎 心碎 凉了",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"点赞": {
|
||||
"sticker_id": "159", "package_id": "1003", "name": "点赞",
|
||||
"description": "赞 认同 好棒 good like 大拇指 顶",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"握手": {
|
||||
"sticker_id": "164", "package_id": "1003", "name": "握手",
|
||||
"description": "合作 你好 商务 hello deal 成交 友好",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"抱拳": {
|
||||
"sticker_id": "163", "package_id": "1003", "name": "抱拳",
|
||||
"description": "谢谢 失敬 江湖 承让 拜托 有礼",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"ok": {
|
||||
"sticker_id": "169", "package_id": "1003", "name": "ok",
|
||||
"description": "好的 收到 没问题 okay 行 可以 懂了",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"拳头": {
|
||||
"sticker_id": "174", "package_id": "1003", "name": "拳头",
|
||||
"description": "加油 干 冲 fight 力量 击拳 硬气",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"鞭炮": {
|
||||
"sticker_id": "191", "package_id": "1003", "name": "鞭炮",
|
||||
"description": "过年 喜庆 爆竹 春节 噼里啪啦 红",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
"烟花": {
|
||||
"sticker_id": "258", "package_id": "1003", "name": "烟花",
|
||||
"description": "庆典 漂亮 新年 嘭 绽放 节日快乐",
|
||||
"width": 128, "height": 128, "formats": "png",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_sticker_by_name(name: str) -> Optional[dict]:
|
||||
"""
|
||||
按名称查找贴纸,支持模糊匹配。
|
||||
|
||||
匹配优先级:
|
||||
1. 完全相等(name)
|
||||
2. name 包含查询词(前缀/子串)
|
||||
3. description 包含查询词(同义词搜索)
|
||||
4. 通用模糊评分(与 sticker-search 同算法),命中即返回得分最高的一条
|
||||
|
||||
返回 sticker dict,找不到返回 None。
|
||||
"""
|
||||
if not name:
|
||||
return None
|
||||
|
||||
query = name.strip()
|
||||
|
||||
if query in STICKER_MAP:
|
||||
return STICKER_MAP[query]
|
||||
|
||||
for key, sticker in STICKER_MAP.items():
|
||||
if query in key or key in query:
|
||||
return sticker
|
||||
|
||||
for sticker in STICKER_MAP.values():
|
||||
desc = sticker.get("description", "")
|
||||
if query in desc:
|
||||
return sticker
|
||||
|
||||
matches = search_stickers(query, limit=1)
|
||||
return matches[0] if matches else None
|
||||
|
||||
|
||||
def get_random_sticker(category: str = None) -> dict:
|
||||
"""
|
||||
随机返回一个贴纸。
|
||||
|
||||
若指定 category,则在 description 中含有该关键词的贴纸里随机选取;
|
||||
category 为 None 时从全表随机。
|
||||
"""
|
||||
if category:
|
||||
candidates = [
|
||||
s for s in STICKER_MAP.values()
|
||||
if category in s.get("description", "") or category in s.get("name", "")
|
||||
]
|
||||
if candidates:
|
||||
return random.choice(candidates)
|
||||
return random.choice(list(STICKER_MAP.values()))
|
||||
|
||||
|
||||
def get_sticker_by_id(sticker_id: str) -> Optional[dict]:
|
||||
"""按 sticker_id 精确查找贴纸。"""
|
||||
if not sticker_id:
|
||||
return None
|
||||
sid = str(sticker_id).strip()
|
||||
for sticker in STICKER_MAP.values():
|
||||
if sticker.get("sticker_id") == sid:
|
||||
return sticker
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 模糊搜索(对齐 chatbot-web yuanbao-openclaw-plugin/sticker-cache.ts.searchStickers)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PUNCT_RE = re.compile(r"[\s\u3000\-_·.,,。!!??\"“”'‘’、/\\]+")
|
||||
|
||||
|
||||
def _normalize_text(raw: str) -> str:
|
||||
return unicodedata.normalize("NFKC", str(raw or "")).strip().lower()
|
||||
|
||||
|
||||
def _compact_text(raw: str) -> str:
|
||||
return _PUNCT_RE.sub("", _normalize_text(raw))
|
||||
|
||||
|
||||
def _multiset_char_hit_ratio(needle: str, haystack: str) -> float:
|
||||
if not needle:
|
||||
return 0.0
|
||||
bag: dict[str, int] = {}
|
||||
for ch in haystack:
|
||||
bag[ch] = bag.get(ch, 0) + 1
|
||||
hits = 0
|
||||
for ch in needle:
|
||||
n = bag.get(ch, 0)
|
||||
if n > 0:
|
||||
hits += 1
|
||||
bag[ch] = n - 1
|
||||
return hits / len(needle)
|
||||
|
||||
|
||||
def _bigram_jaccard(a: str, b: str) -> float:
|
||||
if len(a) < 2 or len(b) < 2:
|
||||
return 0.0
|
||||
A = {a[i:i + 2] for i in range(len(a) - 1)}
|
||||
B = {b[i:i + 2] for i in range(len(b) - 1)}
|
||||
inter = len(A & B)
|
||||
union = len(A) + len(B) - inter
|
||||
return inter / union if union else 0.0
|
||||
|
||||
|
||||
def _longest_subsequence_ratio(needle: str, haystack: str) -> float:
|
||||
if not needle:
|
||||
return 0.0
|
||||
j = 0
|
||||
for ch in haystack:
|
||||
if j >= len(needle):
|
||||
break
|
||||
if ch == needle[j]:
|
||||
j += 1
|
||||
return j / len(needle)
|
||||
|
||||
|
||||
def _score_field(haystack: str, query: str) -> float:
|
||||
hay = _normalize_text(haystack)
|
||||
q = _normalize_text(query)
|
||||
if not hay or not q:
|
||||
return 0.0
|
||||
hay_c = _compact_text(haystack)
|
||||
q_c = _compact_text(query)
|
||||
best = 0.0
|
||||
if hay == q:
|
||||
best = max(best, 100.0)
|
||||
if q in hay:
|
||||
best = max(best, 92 + min(6, len(q)))
|
||||
if len(q) >= 2 and hay.startswith(q):
|
||||
best = max(best, 88.0)
|
||||
if q_c and q_c in hay_c:
|
||||
best = max(best, 86.0)
|
||||
best = max(best, _multiset_char_hit_ratio(q_c, hay_c) * 62)
|
||||
best = max(best, _bigram_jaccard(q_c, hay_c) * 58)
|
||||
best = max(best, _longest_subsequence_ratio(q_c, hay_c) * 52)
|
||||
if len(q) == 1 and q in hay:
|
||||
best = max(best, 68.0)
|
||||
return best
|
||||
|
||||
|
||||
def search_stickers(query: str, limit: int = 10) -> list[dict]:
|
||||
"""
|
||||
在内置贴纸表中按模糊匹配排序返回前 N 条结果。
|
||||
|
||||
评分综合 name/description 字段的子串、字符多重集覆盖、bigram Jaccard、子序列比例。
|
||||
name 权重略高于 description(×0.88)。空 query 时按字典顺序返回前 N 条。
|
||||
"""
|
||||
safe_limit = max(1, min(500, int(limit) if limit else 10))
|
||||
if not query or not _normalize_text(query):
|
||||
return list(STICKER_MAP.values())[:safe_limit]
|
||||
|
||||
scored: list[tuple[float, dict]] = []
|
||||
for sticker in STICKER_MAP.values():
|
||||
name_s = _score_field(sticker.get("name", ""), query)
|
||||
desc_s = _score_field(sticker.get("description", ""), query) * 0.88
|
||||
sid = str(sticker.get("sticker_id", "")).strip()
|
||||
q_norm = _normalize_text(query)
|
||||
id_s = 0.0
|
||||
if sid and q_norm:
|
||||
sid_norm = _normalize_text(sid)
|
||||
if sid_norm == q_norm:
|
||||
id_s = 100.0
|
||||
elif q_norm in sid_norm:
|
||||
id_s = 84.0
|
||||
scored.append((max(name_s, desc_s, id_s), sticker))
|
||||
|
||||
scored.sort(key=lambda x: x[0], reverse=True)
|
||||
top = scored[0][0] if scored else 0
|
||||
if top <= 0:
|
||||
return [s for _, s in scored[:safe_limit]]
|
||||
|
||||
if top >= 22:
|
||||
floor = 18.0
|
||||
elif top >= 12:
|
||||
floor = max(10.0, top * 0.5)
|
||||
else:
|
||||
floor = max(6.0, top * 0.35)
|
||||
|
||||
filtered = [pair for pair in scored if pair[0] >= floor]
|
||||
out = filtered if filtered else scored
|
||||
return [s for _, s in out[:safe_limit]]
|
||||
|
||||
|
||||
def build_face_msg_body(
|
||||
face_index: int,
|
||||
face_type: int = 1,
|
||||
data: Optional[str] = None,
|
||||
) -> list:
|
||||
"""
|
||||
构造 TIMFaceElem 消息体。
|
||||
|
||||
Yuanbao 约定:
|
||||
- index 固定传 0(服务端通过 data 字段识别具体表情)
|
||||
- data 为 JSON 字符串,包含 sticker_id / package_id 等字段
|
||||
|
||||
Args:
|
||||
face_index: 保留字段,暂时不影响 wire format(Yuanbao 固定 index=0)。
|
||||
当 face_index > 0 时视为旧版 QQ 表情 ID,直接放入 index。
|
||||
face_type: 保留字段(兼容旧接口,当前未使用)。
|
||||
data: 已序列化的 JSON 字符串;为 None 时仅传 index。
|
||||
|
||||
Returns:
|
||||
符合 Yuanbao TIM 协议的 msg_body list,如::
|
||||
|
||||
[{"msg_type": "TIMFaceElem", "msg_content": {"index": 0, "data": "..."}}]
|
||||
"""
|
||||
msg_content: dict = {"index": face_index}
|
||||
if data is not None:
|
||||
msg_content["data"] = data
|
||||
return [{"msg_type": "TIMFaceElem", "msg_content": msg_content}]
|
||||
|
||||
|
||||
def build_sticker_msg_body(sticker: dict) -> list:
|
||||
"""
|
||||
从 STICKER_MAP 中的 sticker dict 直接构造 TIMFaceElem 消息体。
|
||||
|
||||
这是 send_sticker() 的内部辅助,确保 data 字段与原始 JS 插件一致。
|
||||
"""
|
||||
data_payload = json.dumps(
|
||||
{
|
||||
"sticker_id": sticker["sticker_id"],
|
||||
"package_id": sticker["package_id"],
|
||||
"width": sticker.get("width", 128),
|
||||
"height": sticker.get("height", 128),
|
||||
"formats": sticker.get("formats", "png"),
|
||||
"name": sticker["name"],
|
||||
},
|
||||
ensure_ascii=False,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
return build_face_msg_body(face_index=0, data=data_payload)
|
||||
356
gateway/run.py
356
gateway/run.py
@ -682,6 +682,16 @@ class GatewayRunner:
|
||||
self._running_agents: Dict[str, Any] = {}
|
||||
self._running_agents_ts: Dict[str, float] = {} # start timestamp per session
|
||||
self._pending_messages: Dict[str, str] = {} # Queued messages during interrupt
|
||||
# Overflow buffer for explicit /queue commands. The adapter-level
|
||||
# _pending_messages dict is a single slot per session (designed for
|
||||
# "next-turn" follow-ups where repeated sends collapse into one
|
||||
# event). /queue has different semantics: each invocation must
|
||||
# produce its own full agent turn, in FIFO order, with no merging.
|
||||
# When the slot is occupied, additional /queue items land here and
|
||||
# are promoted one-at-a-time after each run's drain. Cleared on
|
||||
# /new and /reset. /model and other mid-session operations
|
||||
# preserve the queue.
|
||||
self._queued_events: Dict[str, List[MessageEvent]] = {}
|
||||
self._busy_ack_ts: Dict[str, float] = {} # last busy-ack timestamp per session (debounce)
|
||||
self._session_run_generation: Dict[str, int] = {}
|
||||
|
||||
@ -753,10 +763,27 @@ class GatewayRunner:
|
||||
retention_days=int(_sess_cfg.get("retention_days", 90)),
|
||||
min_interval_hours=int(_sess_cfg.get("min_interval_hours", 24)),
|
||||
vacuum=bool(_sess_cfg.get("vacuum_after_prune", True)),
|
||||
sessions_dir=self.config.sessions_dir,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug("state.db auto-maintenance skipped: %s", exc)
|
||||
|
||||
# Opportunistic shadow-repo cleanup — deletes orphan/stale
|
||||
# checkpoint repos under ~/.hermes/checkpoints/. Opt-in via
|
||||
# checkpoints.auto_prune, idempotent via .last_prune marker.
|
||||
try:
|
||||
from hermes_cli.config import load_config as _load_full_config
|
||||
_ckpt_cfg = (_load_full_config().get("checkpoints") or {})
|
||||
if _ckpt_cfg.get("auto_prune", False):
|
||||
from tools.checkpoint_manager import maybe_auto_prune_checkpoints
|
||||
maybe_auto_prune_checkpoints(
|
||||
retention_days=int(_ckpt_cfg.get("retention_days", 7)),
|
||||
min_interval_hours=int(_ckpt_cfg.get("min_interval_hours", 24)),
|
||||
delete_orphans=bool(_ckpt_cfg.get("delete_orphans", True)),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug("checkpoint auto-maintenance skipped: %s", exc)
|
||||
|
||||
# DM pairing store for code-based user authorization
|
||||
from gateway.pairing import PairingStore
|
||||
self.pairing_store = PairingStore()
|
||||
@ -1202,7 +1229,80 @@ class GatewayRunner:
|
||||
return "restarting" if self._restart_requested else "shutting down"
|
||||
|
||||
def _queue_during_drain_enabled(self) -> bool:
|
||||
return self._restart_requested and self._busy_input_mode == "queue"
|
||||
# Both "queue" and "steer" modes imply the user doesn't want messages
|
||||
# to be lost during restart — queue them for the newly-spawned gateway
|
||||
# process to pick up. "interrupt" mode drops them (current behaviour).
|
||||
return self._restart_requested and self._busy_input_mode in ("queue", "steer")
|
||||
|
||||
# -------- /queue FIFO helpers --------------------------------------
|
||||
# /queue must produce one full agent turn per invocation, in FIFO
|
||||
# order, with no merging. The adapter's _pending_messages dict is a
|
||||
# single "next-up" slot (shared with photo-burst follow-ups), so we
|
||||
# use it for the head of the queue and an overflow list for the
|
||||
# tail. Enqueue puts new items in the slot when free, otherwise in
|
||||
# the overflow. Promotion (called after each run's drain) moves the
|
||||
# next overflow item into the slot so the following recursion picks
|
||||
# it up. Clearing happens on /new and /reset via
|
||||
# _handle_reset_command.
|
||||
|
||||
def _enqueue_fifo(self, session_key: str, queued_event: "MessageEvent", adapter: Any) -> None:
|
||||
"""Append a /queue event to the FIFO chain for a session."""
|
||||
if adapter is None:
|
||||
return
|
||||
pending_slot = getattr(adapter, "_pending_messages", None)
|
||||
if pending_slot is None:
|
||||
return
|
||||
queued_events = getattr(self, "_queued_events", None)
|
||||
if queued_events is None:
|
||||
queued_events = {}
|
||||
self._queued_events = queued_events
|
||||
if session_key in pending_slot:
|
||||
queued_events.setdefault(session_key, []).append(queued_event)
|
||||
else:
|
||||
pending_slot[session_key] = queued_event
|
||||
|
||||
def _promote_queued_event(
|
||||
self,
|
||||
session_key: str,
|
||||
adapter: Any,
|
||||
pending_event: Optional["MessageEvent"],
|
||||
) -> Optional["MessageEvent"]:
|
||||
"""Promote the next overflow item after the slot was drained.
|
||||
|
||||
Called at the drain site after _dequeue_pending_event consumed
|
||||
(or failed to consume) the slot. If there's an overflow item:
|
||||
- When pending_event is None (slot was empty), return the
|
||||
overflow head as the new pending_event.
|
||||
- When pending_event already exists (slot was populated by an
|
||||
interrupt follow-up or similar), stage the overflow head in
|
||||
the slot so the NEXT recursion picks it up.
|
||||
Returns the (possibly updated) pending_event for drain to use.
|
||||
"""
|
||||
queued_events = getattr(self, "_queued_events", None)
|
||||
if not queued_events:
|
||||
return pending_event
|
||||
overflow = queued_events.get(session_key)
|
||||
if not overflow:
|
||||
return pending_event
|
||||
next_queued = overflow.pop(0)
|
||||
if not overflow:
|
||||
queued_events.pop(session_key, None)
|
||||
if pending_event is None:
|
||||
return next_queued
|
||||
if adapter is not None and hasattr(adapter, "_pending_messages"):
|
||||
adapter._pending_messages[session_key] = next_queued
|
||||
else:
|
||||
# No adapter — push back so we don't silently drop the item.
|
||||
queued_events.setdefault(session_key, []).insert(0, next_queued)
|
||||
return pending_event
|
||||
|
||||
def _queue_depth(self, session_key: str, *, adapter: Any = None) -> int:
|
||||
"""Total pending /queue items for a session — slot + overflow."""
|
||||
queued_events = getattr(self, "_queued_events", None) or {}
|
||||
depth = len(queued_events.get(session_key, []))
|
||||
if adapter is not None and session_key in getattr(adapter, "_pending_messages", {}):
|
||||
depth += 1
|
||||
return depth
|
||||
|
||||
def _update_runtime_status(self, gateway_state: Optional[str] = None, exit_reason: Optional[str] = None) -> None:
|
||||
try:
|
||||
@ -1433,7 +1533,11 @@ class GatewayRunner:
|
||||
mode = str(cfg.get("display", {}).get("busy_input_mode", "") or "").strip().lower()
|
||||
except Exception:
|
||||
pass
|
||||
return "queue" if mode == "queue" else "interrupt"
|
||||
if mode == "queue":
|
||||
return "queue"
|
||||
if mode == "steer":
|
||||
return "steer"
|
||||
return "interrupt"
|
||||
|
||||
@staticmethod
|
||||
def _load_restart_drain_timeout() -> float:
|
||||
@ -1571,18 +1675,46 @@ class GatewayRunner:
|
||||
if not adapter:
|
||||
return False # let default path handle it
|
||||
|
||||
running_agent = self._running_agents.get(session_key)
|
||||
|
||||
# Steer mode: inject mid-run via running_agent.steer() instead of
|
||||
# queueing + interrupting. If the agent isn't running yet
|
||||
# (sentinel) or lacks steer(), or the payload is empty, fall back
|
||||
# to queue semantics so nothing is lost.
|
||||
effective_mode = self._busy_input_mode
|
||||
steered = False
|
||||
if effective_mode == "steer":
|
||||
steer_text = (event.text or "").strip()
|
||||
can_steer = (
|
||||
steer_text
|
||||
and running_agent is not None
|
||||
and running_agent is not _AGENT_PENDING_SENTINEL
|
||||
and hasattr(running_agent, "steer")
|
||||
)
|
||||
if can_steer:
|
||||
try:
|
||||
steered = bool(running_agent.steer(steer_text))
|
||||
except Exception as exc:
|
||||
logger.warning("Gateway steer failed for session %s: %s", session_key, exc)
|
||||
steered = False
|
||||
if not steered:
|
||||
# Fall back to queue (merge into pending messages, no interrupt)
|
||||
effective_mode = "queue"
|
||||
|
||||
# Store the message so it's processed as the next turn after the
|
||||
# current run finishes (or is interrupted).
|
||||
from gateway.platforms.base import merge_pending_message_event
|
||||
merge_pending_message_event(adapter._pending_messages, session_key, event)
|
||||
# current run finishes (or is interrupted). Skip this for a
|
||||
# successful steer — the text already landed inside the run and
|
||||
# must NOT also be replayed as a next-turn user message.
|
||||
if not steered:
|
||||
merge_pending_message_event(adapter._pending_messages, session_key, event)
|
||||
|
||||
is_queue_mode = self._busy_input_mode == "queue"
|
||||
is_queue_mode = effective_mode == "queue"
|
||||
is_steer_mode = effective_mode == "steer"
|
||||
|
||||
# If not in queue mode, interrupt the running agent immediately.
|
||||
# If not in queue/steer mode, interrupt the running agent immediately.
|
||||
# This aborts in-flight tool calls and causes the agent loop to exit
|
||||
# at the next check point.
|
||||
running_agent = self._running_agents.get(session_key)
|
||||
if not is_queue_mode and running_agent and running_agent is not _AGENT_PENDING_SENTINEL:
|
||||
if effective_mode == "interrupt" and running_agent and running_agent is not _AGENT_PENDING_SENTINEL:
|
||||
try:
|
||||
running_agent.interrupt(event.text)
|
||||
except Exception:
|
||||
@ -1619,7 +1751,12 @@ class GatewayRunner:
|
||||
pass
|
||||
|
||||
status_detail = f" ({', '.join(status_parts)})" if status_parts else ""
|
||||
if is_queue_mode:
|
||||
if is_steer_mode:
|
||||
message = (
|
||||
f"⏩ Steered into current run{status_detail}. "
|
||||
f"Your message arrives after the next tool call."
|
||||
)
|
||||
elif is_queue_mode:
|
||||
message = (
|
||||
f"⏳ Queued for the next turn{status_detail}. "
|
||||
f"I'll respond once the current task finishes."
|
||||
@ -1643,9 +1780,15 @@ class GatewayRunner:
|
||||
)
|
||||
_user_cfg = _load_gateway_config()
|
||||
if not is_seen(_user_cfg, BUSY_INPUT_FLAG):
|
||||
if is_steer_mode:
|
||||
_hint_mode = "steer"
|
||||
elif is_queue_mode:
|
||||
_hint_mode = "queue"
|
||||
else:
|
||||
_hint_mode = "interrupt"
|
||||
message = (
|
||||
f"{message}\n\n"
|
||||
f"{busy_input_hint_gateway('queue' if is_queue_mode else 'interrupt')}"
|
||||
f"{busy_input_hint_gateway(_hint_mode)}"
|
||||
)
|
||||
mark_seen(_hermes_home / "config.yaml", BUSY_INPUT_FLAG)
|
||||
except Exception as _onb_err:
|
||||
@ -1996,6 +2139,7 @@ class GatewayRunner:
|
||||
"WEIXIN_ALLOWED_USERS",
|
||||
"BLUEBUBBLES_ALLOWED_USERS",
|
||||
"QQ_ALLOWED_USERS",
|
||||
"YUANBAO_ALLOWED_USERS",
|
||||
"GATEWAY_ALLOWED_USERS")
|
||||
)
|
||||
_allow_all = os.getenv("GATEWAY_ALLOW_ALL_USERS", "").lower() in ("true", "1", "yes") or any(
|
||||
@ -2010,7 +2154,8 @@ class GatewayRunner:
|
||||
"WECOM_CALLBACK_ALLOW_ALL_USERS",
|
||||
"WEIXIN_ALLOW_ALL_USERS",
|
||||
"BLUEBUBBLES_ALLOW_ALL_USERS",
|
||||
"QQ_ALLOW_ALL_USERS")
|
||||
"QQ_ALLOW_ALL_USERS",
|
||||
"YUANBAO_ALLOW_ALL_USERS")
|
||||
)
|
||||
if not _any_allowlist and not _allow_all:
|
||||
logger.warning(
|
||||
@ -2254,7 +2399,7 @@ class GatewayRunner:
|
||||
# Build initial channel directory for send_message name resolution
|
||||
try:
|
||||
from gateway.channel_directory import build_channel_directory
|
||||
directory = build_channel_directory(self.adapters)
|
||||
directory = await build_channel_directory(self.adapters)
|
||||
ch_count = sum(len(chs) for chs in directory.get("platforms", {}).values())
|
||||
logger.info("Channel directory built: %d target(s)", ch_count)
|
||||
except Exception as e:
|
||||
@ -2538,7 +2683,7 @@ class GatewayRunner:
|
||||
# Rebuild channel directory with the new adapter
|
||||
try:
|
||||
from gateway.channel_directory import build_channel_directory
|
||||
build_channel_directory(self.adapters)
|
||||
await build_channel_directory(self.adapters)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
@ -2720,6 +2865,23 @@ class GatewayRunner:
|
||||
|
||||
self._finalize_shutdown_agents(active_agents)
|
||||
|
||||
# Also shut down memory providers on idle cached agents.
|
||||
# _finalize_shutdown_agents only handles agents that were
|
||||
# mid-turn at drain time; the _agent_cache may still hold
|
||||
# idle agents whose MemoryProviders never received
|
||||
# on_session_end().
|
||||
_cache_lock = getattr(self, "_agent_cache_lock", None)
|
||||
_cache = getattr(self, "_agent_cache", None)
|
||||
if _cache_lock is not None and _cache is not None:
|
||||
with _cache_lock:
|
||||
_idle_agents = list(_cache.values())
|
||||
_cache.clear()
|
||||
for _entry in _idle_agents:
|
||||
_agent = (
|
||||
_entry[0] if isinstance(_entry, tuple) else _entry
|
||||
)
|
||||
self._cleanup_agent_resources(_agent)
|
||||
|
||||
for platform, adapter in list(self.adapters.items()):
|
||||
try:
|
||||
await adapter.cancel_background_tasks()
|
||||
@ -2970,8 +3132,14 @@ class GatewayRunner:
|
||||
return None
|
||||
return QQAdapter(config)
|
||||
|
||||
return None
|
||||
elif platform == Platform.YUANBAO:
|
||||
from gateway.platforms.yuanbao import YuanbaoAdapter, WEBSOCKETS_AVAILABLE
|
||||
if not WEBSOCKETS_AVAILABLE:
|
||||
logger.warning("Yuanbao: websockets not installed. Run: pip install websockets")
|
||||
return None
|
||||
return YuanbaoAdapter(config)
|
||||
|
||||
return None
|
||||
def _is_user_authorized(self, source: SessionSource) -> bool:
|
||||
"""
|
||||
Check if a user is authorized to use the bot.
|
||||
@ -3012,6 +3180,7 @@ class GatewayRunner:
|
||||
Platform.WEIXIN: "WEIXIN_ALLOWED_USERS",
|
||||
Platform.BLUEBUBBLES: "BLUEBUBBLES_ALLOWED_USERS",
|
||||
Platform.QQBOT: "QQ_ALLOWED_USERS",
|
||||
Platform.YUANBAO: "YUANBAO_ALLOWED_USERS",
|
||||
}
|
||||
platform_group_env_map = {
|
||||
Platform.TELEGRAM: "TELEGRAM_GROUP_ALLOWED_USERS",
|
||||
@ -3034,6 +3203,7 @@ class GatewayRunner:
|
||||
Platform.WEIXIN: "WEIXIN_ALLOW_ALL_USERS",
|
||||
Platform.BLUEBUBBLES: "BLUEBUBBLES_ALLOW_ALL_USERS",
|
||||
Platform.QQBOT: "QQ_ALLOW_ALL_USERS",
|
||||
Platform.YUANBAO: "YUANBAO_ALLOW_ALL_USERS",
|
||||
}
|
||||
|
||||
# Per-platform allow-all flag (e.g., DISCORD_ALLOW_ALL_USERS=true)
|
||||
@ -3282,6 +3452,10 @@ class GatewayRunner:
|
||||
# The update process (detached) wrote .update_prompt.json; the watcher
|
||||
# forwarded it to the user; now the user's reply goes back via
|
||||
# .update_response so the update process can continue.
|
||||
#
|
||||
# IMPORTANT: recognized slash commands must bypass this interception.
|
||||
# Otherwise control/session commands like /new or /help get silently
|
||||
# consumed as update answers instead of being dispatched normally.
|
||||
_quick_key = self._session_key_for_source(source)
|
||||
_update_prompts = getattr(self, "_update_prompt_pending", {})
|
||||
if _update_prompts.get(_quick_key):
|
||||
@ -3293,7 +3467,22 @@ class GatewayRunner:
|
||||
elif cmd in ("deny", "no"):
|
||||
response_text = "n"
|
||||
else:
|
||||
response_text = raw
|
||||
_recognized_cmd = None
|
||||
if cmd:
|
||||
try:
|
||||
from hermes_cli.commands import resolve_command as _resolve_update_cmd
|
||||
except Exception:
|
||||
_resolve_update_cmd = None
|
||||
if _resolve_update_cmd is not None:
|
||||
try:
|
||||
_cmd_def = _resolve_update_cmd(cmd)
|
||||
_recognized_cmd = _cmd_def.name if _cmd_def else None
|
||||
except Exception:
|
||||
_recognized_cmd = None
|
||||
if _recognized_cmd:
|
||||
response_text = ""
|
||||
else:
|
||||
response_text = raw
|
||||
if response_text:
|
||||
response_path = _hermes_home / ".update_response"
|
||||
try:
|
||||
@ -3306,6 +3495,30 @@ class GatewayRunner:
|
||||
_update_prompts.pop(_quick_key, None)
|
||||
label = response_text if len(response_text) <= 20 else response_text[:20] + "…"
|
||||
return f"✓ Sent `{label}` to the update process."
|
||||
# Recognized slash command during a pending update prompt:
|
||||
# unblock the detached update subprocess by writing a blank
|
||||
# response so ``_gateway_prompt`` returns the prompt's default
|
||||
# (typically a safe "n" / skip) and exits cleanly instead of
|
||||
# blocking on stdin until the 30-minute watcher timeout.
|
||||
# The slash command then falls through to normal dispatch.
|
||||
if _recognized_cmd:
|
||||
response_path = _hermes_home / ".update_response"
|
||||
try:
|
||||
tmp = response_path.with_suffix(".tmp")
|
||||
tmp.write_text("")
|
||||
tmp.replace(response_path)
|
||||
logger.info(
|
||||
"Recognized /%s during pending update prompt for %s; "
|
||||
"cancelled prompt with default and dispatching command",
|
||||
_recognized_cmd,
|
||||
_quick_key,
|
||||
)
|
||||
except OSError as e:
|
||||
logger.warning(
|
||||
"Failed to write cancel response for pending update prompt: %s",
|
||||
e,
|
||||
)
|
||||
_update_prompts.pop(_quick_key, None)
|
||||
|
||||
# PRIORITY handling when an agent is already running for this session.
|
||||
# Default behavior is to interrupt immediately so user text/stop messages
|
||||
@ -3416,7 +3629,10 @@ class GatewayRunner:
|
||||
# doesn't think an agent is still active.
|
||||
return await self._handle_reset_command(event)
|
||||
|
||||
# /queue <prompt> — queue without interrupting
|
||||
# /queue <prompt> — queue without interrupting.
|
||||
# Semantics: each /queue invocation produces its own full agent
|
||||
# turn, processed in FIFO order after the current run (and any
|
||||
# earlier /queue items) finishes. Messages are NOT merged.
|
||||
if event.get_command() in ("queue", "q"):
|
||||
queued_text = event.get_command_args().strip()
|
||||
if not queued_text:
|
||||
@ -3430,8 +3646,11 @@ class GatewayRunner:
|
||||
message_id=event.message_id,
|
||||
channel_prompt=event.channel_prompt,
|
||||
)
|
||||
adapter._pending_messages[_quick_key] = queued_event
|
||||
return "Queued for the next turn."
|
||||
self._enqueue_fifo(_quick_key, queued_event, adapter)
|
||||
depth = self._queue_depth(_quick_key, adapter=self.adapters.get(source.platform))
|
||||
if depth <= 1:
|
||||
return "Queued for the next turn."
|
||||
return f"Queued for the next turn. ({depth} queued)"
|
||||
|
||||
# /steer <prompt> — inject mid-run after the next tool call.
|
||||
# Unlike /queue (turn boundary), /steer lands BETWEEN tool-call
|
||||
@ -3608,6 +3827,24 @@ class GatewayRunner:
|
||||
logger.debug("PRIORITY queue follow-up for session %s", _quick_key)
|
||||
self._queue_or_replace_pending_event(_quick_key, event)
|
||||
return None
|
||||
if self._busy_input_mode == "steer":
|
||||
# Steer mode: inject text into the running agent mid-run via
|
||||
# agent.steer(). Falls back to queue semantics if the payload
|
||||
# is empty, the agent lacks steer(), or steer() rejects.
|
||||
steer_text = (event.text or "").strip()
|
||||
steered = False
|
||||
if steer_text and hasattr(running_agent, "steer"):
|
||||
try:
|
||||
steered = bool(running_agent.steer(steer_text))
|
||||
except Exception as exc:
|
||||
logger.warning("PRIORITY steer failed for session %s: %s", _quick_key, exc)
|
||||
steered = False
|
||||
if steered:
|
||||
logger.debug("PRIORITY steer for session %s", _quick_key)
|
||||
return None
|
||||
logger.debug("PRIORITY steer-fallback-to-queue for session %s", _quick_key)
|
||||
self._queue_or_replace_pending_event(_quick_key, event)
|
||||
return None
|
||||
logger.debug("PRIORITY interrupt for session %s", _quick_key)
|
||||
running_agent.interrupt(event.text)
|
||||
if _quick_key in self._pending_messages:
|
||||
@ -4118,7 +4355,14 @@ class GatewayRunner:
|
||||
session_entry = self.session_store.get_or_create_session(source)
|
||||
session_key = session_entry.session_key
|
||||
if getattr(session_entry, "was_auto_reset", False):
|
||||
# Treat auto-reset as a full conversation boundary — drop every
|
||||
# session-scoped transient state so the fresh session does not
|
||||
# inherit the previous conversation's model/reasoning overrides
|
||||
# or a queued "/model switched" note.
|
||||
self._session_model_overrides.pop(session_key, None)
|
||||
self._set_session_reasoning_override(session_key, None)
|
||||
if hasattr(self, "_pending_model_notes"):
|
||||
self._pending_model_notes.pop(session_key, None)
|
||||
|
||||
# Emit session:start for new or auto-reset sessions
|
||||
_is_new_session = (
|
||||
@ -4520,12 +4764,20 @@ class GatewayRunner:
|
||||
if not os.getenv(env_key):
|
||||
adapter = self.adapters.get(source.platform)
|
||||
if adapter:
|
||||
# Slack dispatches all Hermes commands through a single
|
||||
# parent slash command `/hermes`; bare `/sethome` is not
|
||||
# registered and would fail with "app did not respond".
|
||||
sethome_cmd = (
|
||||
"/hermes sethome"
|
||||
if source.platform == Platform.SLACK
|
||||
else "/sethome"
|
||||
)
|
||||
await adapter.send(
|
||||
source.chat_id,
|
||||
f"📬 No home channel is set for {platform_name.title()}. "
|
||||
f"A home channel is where Hermes delivers cron job results "
|
||||
f"and cross-platform messages.\n\n"
|
||||
f"Type /sethome to make this chat your home channel, "
|
||||
f"Type {sethome_cmd} to make this chat your home channel, "
|
||||
f"or ignore to skip."
|
||||
)
|
||||
|
||||
@ -4790,6 +5042,8 @@ class GatewayRunner:
|
||||
self._evict_cached_agent(session_key)
|
||||
self._session_model_overrides.pop(session_key, None)
|
||||
self._set_session_reasoning_override(session_key, None)
|
||||
if hasattr(self, "_pending_model_notes"):
|
||||
self._pending_model_notes.pop(session_key, None)
|
||||
response = (response or "") + (
|
||||
"\n\n🔄 Session auto-reset — the conversation exceeded the "
|
||||
"maximum context size and could not be compressed further. "
|
||||
@ -5058,6 +5312,13 @@ class GatewayRunner:
|
||||
self._cleanup_agent_resources(_old_agent)
|
||||
self._evict_cached_agent(session_key)
|
||||
|
||||
# Discard any /queue overflow for this session — /new is a
|
||||
# conversation-boundary operation, queued follow-ups from the
|
||||
# previous conversation must not bleed into the new one.
|
||||
_qe = getattr(self, "_queued_events", None)
|
||||
if _qe is not None:
|
||||
_qe.pop(session_key, None)
|
||||
|
||||
try:
|
||||
from tools.env_passthrough import clear_env_passthrough
|
||||
clear_env_passthrough()
|
||||
@ -5077,6 +5338,8 @@ class GatewayRunner:
|
||||
# picks up configured defaults instead of previous session switches.
|
||||
self._session_model_overrides.pop(session_key, None)
|
||||
self._set_session_reasoning_override(session_key, None)
|
||||
if hasattr(self, "_pending_model_notes"):
|
||||
self._pending_model_notes.pop(session_key, None)
|
||||
|
||||
# Clear session-scoped dangerous-command approvals and /yolo state.
|
||||
# /new is a conversation-boundary operation — approval state from the
|
||||
@ -5165,6 +5428,10 @@ class GatewayRunner:
|
||||
session_key = session_entry.session_key
|
||||
is_running = session_key in self._running_agents
|
||||
|
||||
# Count pending /queue follow-ups (slot + overflow).
|
||||
adapter = self.adapters.get(source.platform) if source else None
|
||||
queue_depth = self._queue_depth(session_key, adapter=adapter)
|
||||
|
||||
title = None
|
||||
if self._session_db:
|
||||
try:
|
||||
@ -5184,6 +5451,10 @@ class GatewayRunner:
|
||||
f"**Last Activity:** {session_entry.updated_at.strftime('%Y-%m-%d %H:%M')}",
|
||||
f"**Tokens:** {session_entry.total_tokens:,}",
|
||||
f"**Agent Running:** {'Yes ⚡' if is_running else 'No'}",
|
||||
])
|
||||
if queue_depth:
|
||||
lines.append(f"**Queued follow-ups:** {queue_depth}")
|
||||
lines.extend([
|
||||
"",
|
||||
f"**Connected Platforms:** {', '.join(connected_platforms)}",
|
||||
])
|
||||
@ -6640,6 +6911,7 @@ class GatewayRunner:
|
||||
chat_id=source.chat_id,
|
||||
image_url=image_url,
|
||||
caption=alt_text,
|
||||
metadata=_thread_metadata,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
@ -6650,6 +6922,7 @@ class GatewayRunner:
|
||||
await adapter.send_document(
|
||||
chat_id=source.chat_id,
|
||||
file_path=media_path,
|
||||
metadata=_thread_metadata,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
@ -8615,7 +8888,7 @@ class GatewayRunner:
|
||||
return True
|
||||
|
||||
def _clear_session_boundary_security_state(self, session_key: str) -> None:
|
||||
"""Clear approval state that must not survive a real conversation switch."""
|
||||
"""Clear per-session control state that must not survive a boundary switch."""
|
||||
if not session_key:
|
||||
return
|
||||
|
||||
@ -8623,6 +8896,10 @@ class GatewayRunner:
|
||||
if isinstance(pending_approvals, dict):
|
||||
pending_approvals.pop(session_key, None)
|
||||
|
||||
update_prompt_pending = getattr(self, "_update_prompt_pending", None)
|
||||
if isinstance(update_prompt_pending, dict):
|
||||
update_prompt_pending.pop(session_key, None)
|
||||
|
||||
try:
|
||||
from tools.approval import clear_session as _clear_approval_session
|
||||
except Exception:
|
||||
@ -9028,11 +9305,21 @@ class GatewayRunner:
|
||||
if source.platform == Platform.MATRIX:
|
||||
_effective_cursor = ""
|
||||
_buffer_only = True
|
||||
# Fresh-final applies to Telegram only — other
|
||||
# platforms either edit in place cheaply (Discord,
|
||||
# Slack) or don't have the timestamp-on-edit
|
||||
# problem. (Ported from openclaw/openclaw#72038.)
|
||||
_fresh_final_secs = (
|
||||
float(getattr(_scfg, "fresh_final_after_seconds", 0.0) or 0.0)
|
||||
if source.platform == Platform.TELEGRAM
|
||||
else 0.0
|
||||
)
|
||||
_consumer_cfg = StreamConsumerConfig(
|
||||
edit_interval=_scfg.edit_interval,
|
||||
buffer_threshold=_scfg.buffer_threshold,
|
||||
cursor=_effective_cursor,
|
||||
buffer_only=_buffer_only,
|
||||
fresh_final_after_seconds=_fresh_final_secs,
|
||||
)
|
||||
_stream_consumer = GatewayStreamConsumer(
|
||||
adapter=_adapter,
|
||||
@ -9716,11 +10003,21 @@ class GatewayRunner:
|
||||
if source.platform == Platform.MATRIX:
|
||||
_effective_cursor = ""
|
||||
_buffer_only = True
|
||||
# Fresh-final applies to Telegram only — other
|
||||
# platforms either edit in place cheaply or don't
|
||||
# have the edit-timestamp-stays-stale problem.
|
||||
# (Ported from openclaw/openclaw#72038.)
|
||||
_fresh_final_secs = (
|
||||
float(getattr(_scfg, "fresh_final_after_seconds", 0.0) or 0.0)
|
||||
if source.platform == Platform.TELEGRAM
|
||||
else 0.0
|
||||
)
|
||||
_consumer_cfg = StreamConsumerConfig(
|
||||
edit_interval=_scfg.edit_interval,
|
||||
buffer_threshold=_scfg.buffer_threshold,
|
||||
cursor=_effective_cursor,
|
||||
buffer_only=_buffer_only,
|
||||
fresh_final_after_seconds=_fresh_final_secs,
|
||||
)
|
||||
_stream_consumer = GatewayStreamConsumer(
|
||||
adapter=_adapter,
|
||||
@ -10568,6 +10865,13 @@ class GatewayRunner:
|
||||
pending = None
|
||||
if result and adapter and session_key:
|
||||
pending_event = _dequeue_pending_event(adapter, session_key)
|
||||
# /queue overflow: after consuming the adapter's "next-up"
|
||||
# slot, promote the next queued event into it so the
|
||||
# recursive run's drain will see it. This keeps the slot
|
||||
# occupied for the full FIFO chain, which (a) preserves
|
||||
# order, and (b) causes any mid-chain /queue to correctly
|
||||
# route to overflow rather than jumping the queue.
|
||||
pending_event = self._promote_queued_event(session_key, adapter, pending_event)
|
||||
if result.get("interrupted") and not pending_event and result.get("interrupt_message"):
|
||||
interrupt_message = result.get("interrupt_message")
|
||||
if _is_control_interrupt_message(interrupt_message):
|
||||
@ -10862,7 +11166,15 @@ def _start_cron_ticker(stop_event: threading.Event, adapters=None, loop=None, in
|
||||
if tick_count % CHANNEL_DIR_EVERY == 0 and adapters:
|
||||
try:
|
||||
from gateway.channel_directory import build_channel_directory
|
||||
build_channel_directory(adapters)
|
||||
if loop is not None:
|
||||
# build_channel_directory is async (Slack web calls), and
|
||||
# this ticker runs in a background thread. Schedule onto
|
||||
# the gateway event loop and wait briefly for completion
|
||||
# so refresh failures are still logged via the except.
|
||||
fut = asyncio.run_coroutine_threadsafe(
|
||||
build_channel_directory(adapters), loop
|
||||
)
|
||||
fut.result(timeout=30)
|
||||
except Exception as e:
|
||||
logger.debug("Channel directory refresh error: %s", e)
|
||||
|
||||
|
||||
@ -310,8 +310,9 @@ def build_session_context_prompt(
|
||||
"**Platform notes:** You are running inside Slack. "
|
||||
"You do NOT have access to Slack-specific APIs — you cannot search "
|
||||
"channel history, pin/unpin messages, manage channels, or list users. "
|
||||
"Do not promise to perform these actions. If the user asks, explain "
|
||||
"that you can only read messages sent directly to you and respond."
|
||||
"Do not promise to perform these actions. The gateway may inline the "
|
||||
"current message's Slack block/attachment payload when available, but "
|
||||
"you still cannot call Slack APIs yourself."
|
||||
)
|
||||
elif context.source.platform == Platform.DISCORD:
|
||||
# Inject the Discord IDs block only when the agent actually has
|
||||
@ -353,6 +354,14 @@ def build_session_context_prompt(
|
||||
"If the user needs a detailed answer, give the short version first "
|
||||
"and offer to elaborate."
|
||||
)
|
||||
elif context.source.platform == Platform.YUANBAO:
|
||||
lines.append("")
|
||||
lines.append(
|
||||
"**Platform notes:** You are running inside Yuanbao. "
|
||||
"You CAN send private (DM) messages via the send_message tool. "
|
||||
"Use target='yuanbao:direct:<account_id>' for DM "
|
||||
"and target='yuanbao:group:<group_code>' for group chat."
|
||||
)
|
||||
|
||||
# Connected platforms
|
||||
platforms_list = ["local (files on this machine)"]
|
||||
|
||||
@ -44,6 +44,14 @@ class StreamConsumerConfig:
|
||||
buffer_threshold: int = 40
|
||||
cursor: str = " ▉"
|
||||
buffer_only: bool = False
|
||||
# When >0, the final edit for a streamed response is delivered as a
|
||||
# fresh message if the original preview has been visible for at least
|
||||
# this many seconds. This makes the platform's visible timestamp
|
||||
# reflect completion time instead of first-token time for long-running
|
||||
# responses (e.g. reasoning models that stream slowly). Ported from
|
||||
# openclaw/openclaw#72038. Default 0 = always edit in place (legacy
|
||||
# behavior). The gateway enables this selectively per-platform.
|
||||
fresh_final_after_seconds: float = 0.0
|
||||
|
||||
|
||||
class GatewayStreamConsumer:
|
||||
@ -91,6 +99,12 @@ class GatewayStreamConsumer:
|
||||
self._queue: queue.Queue = queue.Queue()
|
||||
self._accumulated = ""
|
||||
self._message_id: Optional[str] = None
|
||||
# Wall-clock timestamp (time.monotonic) when ``_message_id`` was
|
||||
# first assigned from a successful first-send. Used by the
|
||||
# fresh-final logic to detect long-lived previews whose edit
|
||||
# timestamps would be stale by completion time. Ported from
|
||||
# openclaw/openclaw#72038.
|
||||
self._message_created_ts: Optional[float] = None
|
||||
self._already_sent = False
|
||||
self._edit_supported = True # Disabled when progressive edits are no longer usable
|
||||
self._last_edit_time = 0.0
|
||||
@ -136,6 +150,7 @@ class GatewayStreamConsumer:
|
||||
if preserve_no_edit and self._message_id == "__no_edit__":
|
||||
return
|
||||
self._message_id = None
|
||||
self._message_created_ts = None
|
||||
self._accumulated = ""
|
||||
self._last_sent_text = ""
|
||||
self._fallback_final_send = False
|
||||
@ -734,6 +749,81 @@ class GatewayStreamConsumer:
|
||||
logger.error("Commentary send error: %s", e)
|
||||
return False
|
||||
|
||||
def _should_send_fresh_final(self) -> bool:
|
||||
"""Return True when a long-lived preview should be replaced with a
|
||||
fresh final message instead of an edit.
|
||||
|
||||
Conditions:
|
||||
- Fresh-final is enabled (``fresh_final_after_seconds > 0``).
|
||||
- We have a real preview message id (not the ``__no_edit__`` sentinel
|
||||
and not ``None``).
|
||||
- The preview has been visible for at least the configured threshold.
|
||||
|
||||
Ported from openclaw/openclaw#72038.
|
||||
"""
|
||||
threshold = getattr(self.cfg, "fresh_final_after_seconds", 0.0) or 0.0
|
||||
if threshold <= 0:
|
||||
return False
|
||||
if not self._message_id or self._message_id == "__no_edit__":
|
||||
return False
|
||||
if self._message_created_ts is None:
|
||||
return False
|
||||
age = time.monotonic() - self._message_created_ts
|
||||
return age >= threshold
|
||||
|
||||
async def _try_fresh_final(self, text: str) -> bool:
|
||||
"""Send ``text`` as a brand-new message (best-effort delete the old
|
||||
preview) so the platform's visible timestamp reflects completion
|
||||
time. Returns True on successful delivery, False on any failure so
|
||||
the caller falls back to the normal edit path.
|
||||
|
||||
Ported from openclaw/openclaw#72038.
|
||||
"""
|
||||
old_message_id = self._message_id
|
||||
try:
|
||||
result = await self.adapter.send(
|
||||
chat_id=self.chat_id,
|
||||
content=text,
|
||||
metadata=self.metadata,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Fresh-final send failed, falling back to edit: %s", e)
|
||||
return False
|
||||
if not getattr(result, "success", False):
|
||||
return False
|
||||
# Successful fresh send — try to delete the stale preview so the
|
||||
# user doesn't see the old edit-stuck message underneath. Cleanup
|
||||
# is best-effort; platforms that don't implement ``delete_message``
|
||||
# just leave the preview behind (still an acceptable outcome —
|
||||
# the visible final timestamp is the important part).
|
||||
if old_message_id and old_message_id != "__no_edit__":
|
||||
delete_fn = getattr(self.adapter, "delete_message", None)
|
||||
if delete_fn is not None:
|
||||
try:
|
||||
await delete_fn(self.chat_id, old_message_id)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"Fresh-final preview cleanup failed (%s): %s",
|
||||
old_message_id, e,
|
||||
)
|
||||
# Adopt the new message id as the current message so subsequent
|
||||
# callers (e.g. overflow split loops, finalize retries) see a
|
||||
# consistent state.
|
||||
new_message_id = getattr(result, "message_id", None)
|
||||
if new_message_id:
|
||||
self._message_id = new_message_id
|
||||
self._message_created_ts = time.monotonic()
|
||||
else:
|
||||
# Send succeeded but platform didn't return an id — treat the
|
||||
# delivery as final-only and fall back to "__no_edit__" so we
|
||||
# don't try to edit something we can't address.
|
||||
self._message_id = "__no_edit__"
|
||||
self._message_created_ts = None
|
||||
self._already_sent = True
|
||||
self._last_sent_text = text
|
||||
self._final_response_sent = True
|
||||
return True
|
||||
|
||||
async def _send_or_edit(self, text: str, *, finalize: bool = False) -> bool:
|
||||
"""Send or edit the streaming message.
|
||||
|
||||
@ -786,6 +876,22 @@ class GatewayStreamConsumer:
|
||||
finalize and self._adapter_requires_finalize
|
||||
):
|
||||
return True
|
||||
# Fresh-final for long-lived previews: when finalizing
|
||||
# the last edit in a streaming sequence, if the
|
||||
# original preview has been visible for at least
|
||||
# ``fresh_final_after_seconds``, send the completed
|
||||
# reply as a fresh message so the platform's visible
|
||||
# timestamp reflects completion time instead of the
|
||||
# preview creation time. Best-effort cleanup of the
|
||||
# old preview follows. Ported from
|
||||
# openclaw/openclaw#72038. Gated by config so the
|
||||
# legacy edit-in-place path stays the default.
|
||||
if (
|
||||
finalize
|
||||
and self._should_send_fresh_final()
|
||||
and await self._try_fresh_final(text)
|
||||
):
|
||||
return True
|
||||
# Edit existing message
|
||||
result = await self.adapter.edit_message(
|
||||
chat_id=self.chat_id,
|
||||
@ -852,6 +958,10 @@ class GatewayStreamConsumer:
|
||||
if result.success:
|
||||
if result.message_id:
|
||||
self._message_id = result.message_id
|
||||
# Track when the preview first became visible to
|
||||
# the user so fresh-final logic can detect stale
|
||||
# preview timestamps on long-running responses.
|
||||
self._message_created_ts = time.monotonic()
|
||||
else:
|
||||
self._edit_supported = False
|
||||
self._already_sent = True
|
||||
|
||||
@ -126,8 +126,8 @@ COMMAND_REGISTRY: list[CommandDef] = [
|
||||
CommandDef("voice", "Toggle voice mode", "Configuration",
|
||||
args_hint="[on|off|tts|status]", subcommands=("on", "off", "tts", "status")),
|
||||
CommandDef("busy", "Control what Enter does while Hermes is working", "Configuration",
|
||||
cli_only=True, args_hint="[queue|interrupt|status]",
|
||||
subcommands=("queue", "interrupt", "status")),
|
||||
cli_only=True, args_hint="[queue|steer|interrupt|status]",
|
||||
subcommands=("queue", "steer", "interrupt", "status")),
|
||||
|
||||
# Tools & Skills
|
||||
CommandDef("tools", "Manage tools: /tools [list|disable|enable] [name...]", "Tools & Skills",
|
||||
|
||||
@ -487,6 +487,19 @@ DEFAULT_CONFIG = {
|
||||
"checkpoints": {
|
||||
"enabled": True,
|
||||
"max_snapshots": 50, # Max checkpoints to keep per directory
|
||||
# Auto-maintenance: shadow repos accumulate forever under
|
||||
# ~/.hermes/checkpoints/ (one per cd'd working directory). Field
|
||||
# reports put the typical offender at 1000+ repos / ~12 GB. When
|
||||
# auto_prune is on, hermes sweeps at startup (at most once per
|
||||
# min_interval_hours) and deletes:
|
||||
# * orphan repos: HERMES_WORKDIR no longer exists on disk
|
||||
# * stale repos: newest mtime older than retention_days
|
||||
# Opt-in so users who rely on /rollback against long-ago sessions
|
||||
# never lose data silently.
|
||||
"auto_prune": False,
|
||||
"retention_days": 7,
|
||||
"delete_orphans": True,
|
||||
"min_interval_hours": 24,
|
||||
},
|
||||
|
||||
# Maximum characters returned by a single read_file call. Reads that
|
||||
@ -627,7 +640,7 @@ DEFAULT_CONFIG = {
|
||||
"compact": False,
|
||||
"personality": "kawaii",
|
||||
"resume_display": "full",
|
||||
"busy_input_mode": "interrupt",
|
||||
"busy_input_mode": "interrupt", # interrupt | queue | steer
|
||||
"bell_on_complete": False,
|
||||
"show_reasoning": False,
|
||||
"streaming": False,
|
||||
@ -1582,6 +1595,44 @@ OPTIONAL_ENV_VARS = {
|
||||
"category": "tool",
|
||||
},
|
||||
|
||||
# ── Bundled skills (opt-in: only needed if the user uses that skill) ──
|
||||
# These use category="skill" (distinct from "tool") so the sandbox
|
||||
# env blocklist in tools/environments/local.py does NOT rewrite them —
|
||||
# skills legitimately need these passed through to curl via
|
||||
# tools/env_passthrough.py when the user's skill calls out.
|
||||
"NOTION_API_KEY": {
|
||||
"description": "Notion integration token (used by the `notion` skill)",
|
||||
"prompt": "Notion API key",
|
||||
"url": "https://www.notion.so/my-integrations",
|
||||
"password": True,
|
||||
"category": "skill",
|
||||
"advanced": True,
|
||||
},
|
||||
"LINEAR_API_KEY": {
|
||||
"description": "Linear personal API key (used by the `linear` skill)",
|
||||
"prompt": "Linear API key",
|
||||
"url": "https://linear.app/settings/api",
|
||||
"password": True,
|
||||
"category": "skill",
|
||||
"advanced": True,
|
||||
},
|
||||
"AIRTABLE_API_KEY": {
|
||||
"description": "Airtable personal access token (used by the `airtable` skill)",
|
||||
"prompt": "Airtable API key",
|
||||
"url": "https://airtable.com/create/tokens",
|
||||
"password": True,
|
||||
"category": "skill",
|
||||
"advanced": True,
|
||||
},
|
||||
"TENOR_API_KEY": {
|
||||
"description": "Tenor API key for GIF search (used by the `gif-search` skill)",
|
||||
"prompt": "Tenor API key",
|
||||
"url": "https://developers.google.com/tenor/guides/quickstart",
|
||||
"password": True,
|
||||
"category": "skill",
|
||||
"advanced": True,
|
||||
},
|
||||
|
||||
# ── Honcho ──
|
||||
"HONCHO_API_KEY": {
|
||||
"description": "Honcho API key for AI-native persistent memory",
|
||||
|
||||
@ -2724,6 +2724,24 @@ _PLATFORMS = [
|
||||
"help": "OpenID to deliver cron results and notifications to."},
|
||||
],
|
||||
},
|
||||
{
|
||||
"key": "yuanbao",
|
||||
"label": "Yuanbao",
|
||||
"emoji": "💎",
|
||||
"token_var": "YUANBAO_APP_ID",
|
||||
"setup_instructions": [
|
||||
"1. Download the Yuanbao app from https://yuanbao.tencent.com/",
|
||||
"2. In the app, go to PAI → My Bot and create a new bot",
|
||||
"3. After the bot is created, copy the App ID and App Secret",
|
||||
"4. Enter them below and Hermes will connect automatically over WebSocket",
|
||||
],
|
||||
"vars": [
|
||||
{"name": "YUANBAO_APP_ID", "prompt": "App ID", "password": False,
|
||||
"help": "The App ID from your Yuanbao IM Bot credentials."},
|
||||
{"name": "YUANBAO_APP_SECRET", "prompt": "App Secret", "password": True,
|
||||
"help": "The App Secret (used for HMAC signing) from your Yuanbao IM Bot."},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@ -3108,6 +3126,12 @@ def _setup_wecom():
|
||||
print_success("💬 WeCom configured!")
|
||||
|
||||
|
||||
def _setup_yuanbao():
|
||||
"""Configure Yuanbao via the standard platform setup."""
|
||||
yuanbao_platform = next(p for p in _PLATFORMS if p["key"] == "yuanbao")
|
||||
_setup_standard_platform(yuanbao_platform)
|
||||
|
||||
|
||||
def _is_service_installed() -> bool:
|
||||
"""Check if the gateway is installed as a system service."""
|
||||
if supports_systemd_services():
|
||||
|
||||
@ -4452,8 +4452,14 @@ def _model_flow_api_key_provider(config, provider_id, current_model=""):
|
||||
from hermes_cli.models import fetch_ollama_cloud_models
|
||||
|
||||
api_key_for_probe = existing_key or (get_env_value(key_env) if key_env else "")
|
||||
# During setup, force a live refresh so the picker reflects newly
|
||||
# released models (e.g. deepseek v4 flash, kimi k2.6) the moment
|
||||
# the user enters their key — not an hour later when the disk
|
||||
# cache TTL expires.
|
||||
model_list = fetch_ollama_cloud_models(
|
||||
api_key=api_key_for_probe, base_url=effective_base
|
||||
api_key=api_key_for_probe,
|
||||
base_url=effective_base,
|
||||
force_refresh=True,
|
||||
)
|
||||
if model_list:
|
||||
print(f" Found {len(model_list)} model(s) from Ollama Cloud")
|
||||
@ -5024,6 +5030,83 @@ def _gateway_prompt(prompt_text: str, default: str = "", timeout: float = 300.0)
|
||||
return default
|
||||
|
||||
|
||||
def _web_ui_build_needed(web_dir: Path) -> bool:
|
||||
"""Return True if the web UI dist is missing or stale.
|
||||
|
||||
Mirrors the staleness logic used by ``_tui_build_needed()`` for the TUI.
|
||||
The Vite build outputs to ``hermes_cli/web_dist/`` (per vite.config.ts
|
||||
outDir: "../hermes_cli/web_dist"), NOT to ``web/dist/``. Uses the Vite
|
||||
manifest as the sentinel because it is written last and therefore has the
|
||||
newest mtime of any build output.
|
||||
"""
|
||||
dist_dir = web_dir.parent / "hermes_cli" / "web_dist"
|
||||
sentinel = dist_dir / ".vite" / "manifest.json"
|
||||
if not sentinel.exists():
|
||||
sentinel = dist_dir / "index.html"
|
||||
if not sentinel.exists():
|
||||
return True
|
||||
dist_mtime = sentinel.stat().st_mtime
|
||||
skip = frozenset({"node_modules", "dist"})
|
||||
for dirpath, dirnames, filenames in os.walk(web_dir, topdown=True):
|
||||
dirnames[:] = [d for d in dirnames if d not in skip]
|
||||
for fn in filenames:
|
||||
if fn.endswith((".ts", ".tsx", ".js", ".jsx", ".css", ".html", ".vue")):
|
||||
if os.path.getmtime(os.path.join(dirpath, fn)) > dist_mtime:
|
||||
return True
|
||||
for meta in (
|
||||
"package.json",
|
||||
"package-lock.json",
|
||||
"yarn.lock",
|
||||
"pnpm-lock.yaml",
|
||||
"vite.config.ts",
|
||||
"vite.config.js",
|
||||
):
|
||||
mp = web_dir / meta
|
||||
if mp.exists() and mp.stat().st_mtime > dist_mtime:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _run_npm_install_deterministic(
|
||||
npm: str,
|
||||
cwd: Path,
|
||||
*,
|
||||
extra_args: tuple[str, ...] = (),
|
||||
capture_output: bool = True,
|
||||
) -> subprocess.CompletedProcess:
|
||||
"""Run a deterministic npm install that does not mutate ``package-lock.json``.
|
||||
|
||||
Prefers ``npm ci`` (strict, lockfile-preserving) when a lockfile is present;
|
||||
falls back to ``npm install`` only if ``npm ci`` fails (e.g. lockfile out of
|
||||
sync on a WIP checkout). Without this, ``npm install`` on npm ≥ 10 silently
|
||||
rewrites committed lockfiles (stripping ``"peer": true`` etc.), which leaves
|
||||
the working tree dirty and causes the next ``hermes update`` to stash the
|
||||
lockfile — repeatedly.
|
||||
"""
|
||||
lockfile = cwd / "package-lock.json"
|
||||
if lockfile.exists():
|
||||
ci_cmd = [npm, "ci", *extra_args]
|
||||
ci_result = subprocess.run(
|
||||
ci_cmd,
|
||||
cwd=cwd,
|
||||
capture_output=capture_output,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
if ci_result.returncode == 0:
|
||||
return ci_result
|
||||
# Fall through to `npm install` — lockfile may be out of sync on a
|
||||
# WIP fork/branch, or `npm ci` may not be available on very old npm.
|
||||
install_cmd = [npm, "install", *extra_args]
|
||||
return subprocess.run(
|
||||
install_cmd,
|
||||
cwd=cwd,
|
||||
capture_output=capture_output,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
|
||||
|
||||
def _build_web_ui(web_dir: Path, *, fatal: bool = False) -> bool:
|
||||
"""Build the web UI frontend if npm is available.
|
||||
|
||||
@ -5037,6 +5120,9 @@ def _build_web_ui(web_dir: Path, *, fatal: bool = False) -> bool:
|
||||
if not (web_dir / "package.json").exists():
|
||||
return True
|
||||
|
||||
if not _web_ui_build_needed(web_dir):
|
||||
return True
|
||||
|
||||
npm = shutil.which("npm")
|
||||
if not npm:
|
||||
if fatal:
|
||||
@ -5044,7 +5130,7 @@ def _build_web_ui(web_dir: Path, *, fatal: bool = False) -> bool:
|
||||
print("Install Node.js, then run: cd web && npm install && npm run build")
|
||||
return not fatal
|
||||
print("→ Building web UI...")
|
||||
r1 = subprocess.run([npm, "install", "--silent"], cwd=web_dir, capture_output=True)
|
||||
r1 = _run_npm_install_deterministic(npm, web_dir, extra_args=("--silent",))
|
||||
if r1.returncode != 0:
|
||||
print(
|
||||
f" {'✗' if fatal else '⚠'} Web UI npm install failed"
|
||||
@ -5755,12 +5841,10 @@ def _update_node_dependencies() -> None:
|
||||
if not (path / "package.json").exists():
|
||||
continue
|
||||
|
||||
result = subprocess.run(
|
||||
[npm, "install", "--silent", "--no-fund", "--no-audit", "--progress=false"],
|
||||
cwd=path,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
result = _run_npm_install_deterministic(
|
||||
npm,
|
||||
path,
|
||||
extra_args=("--silent", "--no-fund", "--no-audit", "--progress=false"),
|
||||
)
|
||||
if result.returncode == 0:
|
||||
print(f" ✓ {label}")
|
||||
@ -5996,6 +6080,88 @@ def _cmd_update_check():
|
||||
print(f" Run '{recommended_update_command()}' to install.")
|
||||
|
||||
|
||||
def _ensure_fhs_path_guard() -> None:
|
||||
"""Ensure /usr/local/bin is on PATH for RHEL-family root non-login shells.
|
||||
|
||||
Mirrors the post-symlink probe added to ``scripts/install.sh`` so that
|
||||
existing FHS-layout root installs on RHEL/CentOS/Rocky/Alma 8+ get
|
||||
repaired on ``hermes update`` without requiring a reinstall. The
|
||||
installer's assumption that ``/usr/local/bin`` is on PATH for every
|
||||
standard shell breaks on those distros in non-login interactive shells
|
||||
(su, sudo -s, tmux panes, some web terminals): /etc/bashrc doesn't
|
||||
add /usr/local/bin and /root/.bash_profile doesn't either. Symptom:
|
||||
``hermes`` prints ``command not found`` even though the symlink lives
|
||||
at /usr/local/bin/hermes.
|
||||
|
||||
Silent no-op on: non-Linux, non-root, non-FHS installs, and any system
|
||||
where ``bash -i -c 'command -v hermes'`` already resolves. Idempotent.
|
||||
"""
|
||||
if sys.platform != "linux":
|
||||
return
|
||||
try:
|
||||
if os.geteuid() != 0:
|
||||
return
|
||||
except AttributeError:
|
||||
return
|
||||
# Only act when this is actually an FHS-layout install (command link at
|
||||
# /usr/local/bin/hermes, code at /usr/local/lib/hermes-agent).
|
||||
fhs_link = Path("/usr/local/bin/hermes")
|
||||
if not fhs_link.is_symlink() and not fhs_link.exists():
|
||||
return
|
||||
|
||||
# Probe a fresh non-login interactive bash the way the user will use it.
|
||||
# ``bash -i -c`` sources ~/.bashrc but NOT ~/.bash_profile or /etc/profile,
|
||||
# which is the exact scenario where RHEL root loses /usr/local/bin.
|
||||
home = os.environ.get("HOME") or "/root"
|
||||
try:
|
||||
probe = subprocess.run(
|
||||
["env", "-i",
|
||||
f"HOME={home}",
|
||||
f"TERM={os.environ.get('TERM', 'dumb')}",
|
||||
"bash", "-i", "-c", "command -v hermes"],
|
||||
capture_output=True, text=True, timeout=10,
|
||||
)
|
||||
except (FileNotFoundError, subprocess.TimeoutExpired):
|
||||
return # no bash or probe hung — don't block update on this
|
||||
if probe.returncode == 0:
|
||||
return # already on PATH, nothing to do
|
||||
|
||||
path_line = 'export PATH="/usr/local/bin:$PATH"'
|
||||
path_comment = (
|
||||
"# Hermes Agent — ensure /usr/local/bin is on PATH "
|
||||
"(RHEL non-login shells)"
|
||||
)
|
||||
wrote_any = False
|
||||
for candidate in (".bashrc", ".bash_profile"):
|
||||
cfg = Path(home) / candidate
|
||||
if not cfg.is_file():
|
||||
continue
|
||||
try:
|
||||
existing = cfg.read_text(errors="replace")
|
||||
except OSError:
|
||||
continue
|
||||
# Idempotency: skip if any uncommented PATH= line already references
|
||||
# /usr/local/bin. Mirrors the grep pattern used by install.sh.
|
||||
already_guarded = any(
|
||||
"/usr/local/bin" in line
|
||||
and "PATH" in line
|
||||
and not line.lstrip().startswith("#")
|
||||
for line in existing.splitlines()
|
||||
)
|
||||
if already_guarded:
|
||||
continue
|
||||
try:
|
||||
with cfg.open("a", encoding="utf-8") as f:
|
||||
f.write("\n" + path_comment + "\n" + path_line + "\n")
|
||||
except OSError as e:
|
||||
print(f" ⚠ Could not update {cfg}: {e}")
|
||||
continue
|
||||
print(f" ✓ Added /usr/local/bin to PATH in {cfg}")
|
||||
wrote_any = True
|
||||
if wrote_any:
|
||||
print(" (reload your shell or run 'source ~/.bashrc' to pick it up)")
|
||||
|
||||
|
||||
def cmd_update(args):
|
||||
"""Update Hermes Agent to the latest version.
|
||||
|
||||
@ -6439,6 +6605,13 @@ def _cmd_update_impl(args, gateway_mode: bool):
|
||||
print()
|
||||
print("✓ Update complete!")
|
||||
|
||||
# Repair RHEL-family root installs where /usr/local/bin isn't on PATH
|
||||
# for non-login interactive shells. No-op on every other platform.
|
||||
try:
|
||||
_ensure_fhs_path_guard()
|
||||
except Exception as e:
|
||||
logger.debug("FHS PATH guard check failed: %s", e)
|
||||
|
||||
# Write exit code *before* the gateway restart attempt.
|
||||
# When running as ``hermes update --gateway`` (spawned by the gateway's
|
||||
# /update command), this process lives inside the gateway's systemd
|
||||
@ -9084,7 +9257,7 @@ Examples:
|
||||
"--source", help="Filter by source (cli, telegram, discord, etc.)"
|
||||
)
|
||||
sessions_browse.add_argument(
|
||||
"--limit", type=int, default=50, help="Max sessions to load (default: 50)"
|
||||
"--limit", type=int, default=500, help="Max sessions to load (default: 500)"
|
||||
)
|
||||
|
||||
def _confirm_prompt(prompt: str) -> bool:
|
||||
@ -9181,7 +9354,8 @@ Examples:
|
||||
):
|
||||
print("Cancelled.")
|
||||
return
|
||||
if db.delete_session(resolved_session_id):
|
||||
sessions_dir = get_hermes_home() / "sessions"
|
||||
if db.delete_session(resolved_session_id, sessions_dir=sessions_dir):
|
||||
print(f"Deleted session '{resolved_session_id}'.")
|
||||
else:
|
||||
print(f"Session '{args.session_id}' not found.")
|
||||
@ -9195,7 +9369,9 @@ Examples:
|
||||
):
|
||||
print("Cancelled.")
|
||||
return
|
||||
count = db.prune_sessions(older_than_days=days, source=args.source)
|
||||
sessions_dir = get_hermes_home() / "sessions"
|
||||
count = db.prune_sessions(older_than_days=days, source=args.source,
|
||||
sessions_dir=sessions_dir)
|
||||
print(f"Pruned {count} session(s).")
|
||||
|
||||
elif action == "rename":
|
||||
@ -9213,7 +9389,7 @@ Examples:
|
||||
print(f"Error: {e}")
|
||||
|
||||
elif action == "browse":
|
||||
limit = getattr(args, "limit", 50) or 50
|
||||
limit = getattr(args, "limit", 500) or 500
|
||||
source = getattr(args, "source", None)
|
||||
_browse_exclude = None if source else ["tool"]
|
||||
sessions = db.list_sessions_rich(
|
||||
|
||||
@ -33,8 +33,6 @@ COPILOT_REASONING_EFFORTS_O_SERIES = ["low", "medium", "high"]
|
||||
# (model_id, display description shown in menus)
|
||||
OPENROUTER_MODELS: list[tuple[str, str]] = [
|
||||
("moonshotai/kimi-k2.6", "recommended"),
|
||||
("deepseek/deepseek-v4-pro", ""),
|
||||
("deepseek/deepseek-v4-flash", ""),
|
||||
("anthropic/claude-opus-4.7", ""),
|
||||
("anthropic/claude-opus-4.6", ""),
|
||||
("anthropic/claude-sonnet-4.6", ""),
|
||||
@ -111,8 +109,6 @@ def _codex_curated_models() -> list[str]:
|
||||
_PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"nous": [
|
||||
"moonshotai/kimi-k2.6",
|
||||
"deepseek/deepseek-v4-pro",
|
||||
"deepseek/deepseek-v4-flash",
|
||||
"xiaomi/mimo-v2.5-pro",
|
||||
"xiaomi/mimo-v2.5",
|
||||
"anthropic/claude-opus-4.7",
|
||||
|
||||
@ -9,6 +9,7 @@ from typing import Dict, Iterable, Optional, Set
|
||||
from hermes_cli.auth import get_nous_auth_status
|
||||
from hermes_cli.config import get_env_value, load_config
|
||||
from tools.managed_tool_gateway import is_managed_tool_gateway_ready
|
||||
from utils import is_truthy_value
|
||||
from tools.tool_backend_helpers import (
|
||||
fal_key_is_configured,
|
||||
has_direct_modal_credentials,
|
||||
@ -25,6 +26,13 @@ _DEFAULT_PLATFORM_TOOLSETS = {
|
||||
}
|
||||
|
||||
|
||||
def _uses_gateway(section: object) -> bool:
|
||||
"""Return True when a config section explicitly opts into the gateway."""
|
||||
if not isinstance(section, dict):
|
||||
return False
|
||||
return is_truthy_value(section.get("use_gateway"), default=False)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NousFeatureState:
|
||||
key: str
|
||||
@ -262,11 +270,11 @@ def get_nous_subscription_features(
|
||||
# use_gateway flags — when True, the user explicitly opted into the
|
||||
# Tool Gateway via `hermes model`, so direct credentials should NOT
|
||||
# prevent gateway routing.
|
||||
web_use_gateway = bool(web_cfg.get("use_gateway"))
|
||||
tts_use_gateway = bool(tts_cfg.get("use_gateway"))
|
||||
browser_use_gateway = bool(browser_cfg.get("use_gateway"))
|
||||
web_use_gateway = _uses_gateway(web_cfg)
|
||||
tts_use_gateway = _uses_gateway(tts_cfg)
|
||||
browser_use_gateway = _uses_gateway(browser_cfg)
|
||||
image_gen_cfg = config.get("image_gen") if isinstance(config.get("image_gen"), dict) else {}
|
||||
image_use_gateway = bool(image_gen_cfg.get("use_gateway"))
|
||||
image_use_gateway = _uses_gateway(image_gen_cfg)
|
||||
|
||||
direct_exa = bool(get_env_value("EXA_API_KEY"))
|
||||
direct_firecrawl = bool(get_env_value("FIRECRAWL_API_KEY") or get_env_value("FIRECRAWL_API_URL"))
|
||||
@ -601,10 +609,10 @@ def get_gateway_eligible_tools(
|
||||
# no direct keys exist — we only skip the prompt for tools where
|
||||
# use_gateway was explicitly set.
|
||||
opted_in = {
|
||||
"web": bool((config.get("web") if isinstance(config.get("web"), dict) else {}).get("use_gateway")),
|
||||
"image_gen": bool((config.get("image_gen") if isinstance(config.get("image_gen"), dict) else {}).get("use_gateway")),
|
||||
"tts": bool((config.get("tts") if isinstance(config.get("tts"), dict) else {}).get("use_gateway")),
|
||||
"browser": bool((config.get("browser") if isinstance(config.get("browser"), dict) else {}).get("use_gateway")),
|
||||
"web": _uses_gateway(config.get("web")),
|
||||
"image_gen": _uses_gateway(config.get("image_gen")),
|
||||
"tts": _uses_gateway(config.get("tts")),
|
||||
"browser": _uses_gateway(config.get("browser")),
|
||||
}
|
||||
|
||||
unconfigured: list[str] = []
|
||||
|
||||
@ -36,6 +36,7 @@ PLATFORMS: OrderedDict[str, PlatformInfo] = OrderedDict([
|
||||
("wecom_callback", PlatformInfo(label="💬 WeCom Callback", default_toolset="hermes-wecom-callback")),
|
||||
("weixin", PlatformInfo(label="💬 Weixin", default_toolset="hermes-weixin")),
|
||||
("qqbot", PlatformInfo(label="💬 QQBot", default_toolset="hermes-qqbot")),
|
||||
("yuanbao", PlatformInfo(label="🤖 Yuanbao", default_toolset="hermes-yuanbao")),
|
||||
("webhook", PlatformInfo(label="🔗 Webhook", default_toolset="hermes-webhook")),
|
||||
("api_server", PlatformInfo(label="🌐 API Server", default_toolset="hermes-api-server")),
|
||||
("cron", PlatformInfo(label="⏰ Cron", default_toolset="hermes-cron")),
|
||||
|
||||
@ -2133,6 +2133,12 @@ def _setup_feishu():
|
||||
_gateway_setup_feishu()
|
||||
|
||||
|
||||
def _setup_yuanbao():
|
||||
"""Configure Yuanbao via gateway setup."""
|
||||
from hermes_cli.gateway import _setup_yuanbao as _gateway_setup_yuanbao
|
||||
_gateway_setup_yuanbao()
|
||||
|
||||
|
||||
def _setup_wecom():
|
||||
"""Configure WeCom (Enterprise WeChat) via gateway setup."""
|
||||
from hermes_cli.gateway import _setup_wecom as _gateway_setup_wecom
|
||||
@ -2277,6 +2283,7 @@ _GATEWAY_PLATFORMS = [
|
||||
("WhatsApp", "WHATSAPP_ENABLED", _setup_whatsapp),
|
||||
("DingTalk", "DINGTALK_CLIENT_ID", _setup_dingtalk),
|
||||
("Feishu / Lark", "FEISHU_APP_ID", _setup_feishu),
|
||||
("Yuanbao", "YUANBAO_APP_ID", _setup_yuanbao),
|
||||
("WeCom (Enterprise WeChat)", "WECOM_BOT_ID", _setup_wecom),
|
||||
("WeCom Callback (Self-Built App)", "WECOM_CALLBACK_CORP_ID", _setup_wecom_callback),
|
||||
("Weixin (WeChat)", "WEIXIN_ACCOUNT_ID", _setup_weixin),
|
||||
|
||||
@ -326,7 +326,8 @@ def show_status(args):
|
||||
"WeCom Callback": ("WECOM_CALLBACK_CORP_ID", None),
|
||||
"Weixin": ("WEIXIN_ACCOUNT_ID", "WEIXIN_HOME_CHANNEL"),
|
||||
"BlueBubbles": ("BLUEBUBBLES_SERVER_URL", "BLUEBUBBLES_HOME_CHANNEL"),
|
||||
"QQBot": ("QQ_APP_ID", "QQBOT_HOME_CHANNEL"),
|
||||
"QQBot": ("QQ_APP_ID", "QQ_HOME_CHANNEL"),
|
||||
"Yuanbao": ("YUANBAO_APP_ID", "YUANBAO_HOME_CHANNEL"),
|
||||
}
|
||||
|
||||
for name, (token_var, home_var) in platforms.items():
|
||||
|
||||
@ -106,7 +106,7 @@ TIPS = [
|
||||
"Set display.streaming: true to see tokens appear in real time as the model generates.",
|
||||
"Set display.show_reasoning: true to watch the model's chain-of-thought reasoning.",
|
||||
"Set display.compact: true to reduce whitespace in output for denser information.",
|
||||
"Set display.busy_input_mode: queue to queue messages instead of interrupting the agent.",
|
||||
"Set display.busy_input_mode: queue to queue messages instead of interrupting the agent, or steer to inject them mid-run via /steer.",
|
||||
"Set display.resume_display: minimal to skip the full conversation recap on session resume.",
|
||||
"Set compression.threshold: 0.50 to control when auto-compression fires (default: 50% of context).",
|
||||
"Set agent.max_turns: 200 to let the agent take more tool-calling steps per turn.",
|
||||
|
||||
@ -11,6 +11,7 @@ the `platform_toolsets` key.
|
||||
|
||||
import json as _json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set
|
||||
@ -25,7 +26,7 @@ from hermes_cli.nous_subscription import (
|
||||
get_nous_subscription_features,
|
||||
)
|
||||
from tools.tool_backend_helpers import fal_key_is_configured, managed_nous_tools_enabled
|
||||
from utils import base_url_hostname
|
||||
from utils import base_url_hostname, is_truthy_value
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -70,6 +71,7 @@ CONFIGURABLE_TOOLSETS = [
|
||||
("spotify", "🎵 Spotify", "playback, search, playlists, library"),
|
||||
("discord", "💬 Discord (read/participate)", "fetch messages, search members, create thread"),
|
||||
("discord_admin", "🛡️ Discord Server Admin", "list channels/roles, pin, assign roles"),
|
||||
("yuanbao", "🤖 Yuanbao", "group info, member queries, DM"),
|
||||
]
|
||||
|
||||
# Toolsets that are OFF by default for new installs.
|
||||
@ -676,6 +678,15 @@ def _get_platform_tools(
|
||||
# their own platform (e.g. `discord` + `discord` should stay OFF).
|
||||
if platform in default_off and platform not in _TOOLSET_PLATFORM_RESTRICTIONS:
|
||||
default_off.remove(platform)
|
||||
# Home Assistant is already runtime-gated by its check_fn (requires
|
||||
# HASS_TOKEN to register any tools). When a user has configured
|
||||
# HASS_TOKEN, they've explicitly opted in — don't also strip it via
|
||||
# _DEFAULT_OFF_TOOLSETS, which would silently drop HA from platforms
|
||||
# (e.g. cron) that run through _get_platform_tools without an
|
||||
# explicit saved toolset list. Without this, Norbert's HA cron jobs
|
||||
# regressed after #14798 made cron honor per-platform tool config.
|
||||
if "homeassistant" in default_off and os.getenv("HASS_TOKEN"):
|
||||
default_off.remove("homeassistant")
|
||||
enabled_toolsets -= default_off
|
||||
|
||||
# Recover non-configurable platform toolsets (e.g. discord, feishu_doc,
|
||||
@ -1177,7 +1188,7 @@ def _is_provider_active(provider: dict, config: dict) -> bool:
|
||||
configured_provider = image_cfg.get("provider")
|
||||
if configured_provider not in (None, "", "fal"):
|
||||
return False
|
||||
if image_cfg.get("use_gateway") is False:
|
||||
if image_cfg.get("use_gateway") is not None and not is_truthy_value(image_cfg.get("use_gateway"), default=False):
|
||||
return False
|
||||
return feature.managed_by_nous
|
||||
if provider.get("tts_provider"):
|
||||
@ -1209,7 +1220,7 @@ def _is_provider_active(provider: dict, config: dict) -> bool:
|
||||
return (
|
||||
provider["imagegen_backend"] == "fal"
|
||||
and configured_provider in (None, "", "fal")
|
||||
and not image_cfg.get("use_gateway")
|
||||
and not is_truthy_value(image_cfg.get("use_gateway"), default=False)
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
@ -287,7 +287,7 @@ _SCHEMA_OVERRIDES: Dict[str, Dict[str, Any]] = {
|
||||
"display.busy_input_mode": {
|
||||
"type": "select",
|
||||
"description": "Input behavior while agent is running",
|
||||
"options": ["interrupt", "queue"],
|
||||
"options": ["interrupt", "queue", "steer"],
|
||||
},
|
||||
"memory.provider": {
|
||||
"type": "select",
|
||||
|
||||
@ -195,10 +195,6 @@ def setup_logging(
|
||||
The ``logs/`` directory where files are written.
|
||||
"""
|
||||
global _logging_initialized
|
||||
if _logging_initialized and not force:
|
||||
home = hermes_home or get_hermes_home()
|
||||
return home / "logs"
|
||||
|
||||
home = hermes_home or get_hermes_home()
|
||||
log_dir = home / "logs"
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
@ -248,6 +244,9 @@ def setup_logging(
|
||||
log_filter=_ComponentFilter(COMPONENT_PREFIXES["gateway"]),
|
||||
)
|
||||
|
||||
if _logging_initialized and not force:
|
||||
return log_dir
|
||||
|
||||
# Ensure root logger level is low enough for the handlers to fire.
|
||||
if root.level == logging.NOTSET or root.level > level:
|
||||
root.setLevel(level)
|
||||
|
||||
@ -1573,12 +1573,45 @@ class SessionDB:
|
||||
)
|
||||
self._execute_write(_do)
|
||||
|
||||
def delete_session(self, session_id: str) -> bool:
|
||||
@staticmethod
|
||||
def _remove_session_files(sessions_dir: Optional[Path], session_id: str) -> None:
|
||||
"""Remove on-disk transcript files for a session.
|
||||
|
||||
Cleans up ``{session_id}.json``, ``{session_id}.jsonl``, and any
|
||||
``request_dump_{session_id}_*.json`` files left by the gateway.
|
||||
Silently skips files that don't exist and swallows OSError so a
|
||||
filesystem hiccup never blocks a DB operation.
|
||||
"""
|
||||
if sessions_dir is None:
|
||||
return
|
||||
for suffix in (".json", ".jsonl"):
|
||||
p = sessions_dir / f"{session_id}{suffix}"
|
||||
try:
|
||||
p.unlink(missing_ok=True)
|
||||
except OSError:
|
||||
pass
|
||||
# request_dump files use session_id as a prefix component
|
||||
try:
|
||||
for p in sessions_dir.glob(f"request_dump_{session_id}_*.json"):
|
||||
try:
|
||||
p.unlink(missing_ok=True)
|
||||
except OSError:
|
||||
pass
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def delete_session(
|
||||
self,
|
||||
session_id: str,
|
||||
sessions_dir: Optional[Path] = None,
|
||||
) -> bool:
|
||||
"""Delete a session and all its messages.
|
||||
|
||||
Child sessions are orphaned (parent_session_id set to NULL) rather
|
||||
than cascade-deleted, so they remain accessible independently.
|
||||
Returns True if the session was found and deleted.
|
||||
When *sessions_dir* is provided, also removes on-disk transcript
|
||||
files (``.json`` / ``.jsonl`` / ``request_dump_*``) for the deleted
|
||||
session. Returns True if the session was found and deleted.
|
||||
"""
|
||||
def _do(conn):
|
||||
cursor = conn.execute(
|
||||
@ -1595,16 +1628,29 @@ class SessionDB:
|
||||
conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,))
|
||||
conn.execute("DELETE FROM sessions WHERE id = ?", (session_id,))
|
||||
return True
|
||||
return self._execute_write(_do)
|
||||
|
||||
def prune_sessions(self, older_than_days: int = 90, source: str = None) -> int:
|
||||
deleted = self._execute_write(_do)
|
||||
if deleted:
|
||||
self._remove_session_files(sessions_dir, session_id)
|
||||
return deleted
|
||||
|
||||
def prune_sessions(
|
||||
self,
|
||||
older_than_days: int = 90,
|
||||
source: str = None,
|
||||
sessions_dir: Optional[Path] = None,
|
||||
) -> int:
|
||||
"""Delete sessions older than N days. Returns count of deleted sessions.
|
||||
|
||||
Only prunes ended sessions (not active ones). Child sessions outside
|
||||
the prune window are orphaned (parent_session_id set to NULL) rather
|
||||
than cascade-deleted.
|
||||
than cascade-deleted. When *sessions_dir* is provided, also removes
|
||||
on-disk transcript files (``.json`` / ``.jsonl`` /
|
||||
``request_dump_*``) for every pruned session, outside the DB
|
||||
transaction.
|
||||
"""
|
||||
cutoff = time.time() - (older_than_days * 86400)
|
||||
removed_ids: list[str] = []
|
||||
|
||||
def _do(conn):
|
||||
if source:
|
||||
@ -1634,9 +1680,14 @@ class SessionDB:
|
||||
for sid in session_ids:
|
||||
conn.execute("DELETE FROM messages WHERE session_id = ?", (sid,))
|
||||
conn.execute("DELETE FROM sessions WHERE id = ?", (sid,))
|
||||
removed_ids.append(sid)
|
||||
return len(session_ids)
|
||||
|
||||
return self._execute_write(_do)
|
||||
count = self._execute_write(_do)
|
||||
# Clean up on-disk files outside the DB transaction
|
||||
for sid in removed_ids:
|
||||
self._remove_session_files(sessions_dir, sid)
|
||||
return count
|
||||
|
||||
# ── Meta key/value (for scheduler bookkeeping) ──
|
||||
|
||||
@ -1690,6 +1741,7 @@ class SessionDB:
|
||||
retention_days: int = 90,
|
||||
min_interval_hours: int = 24,
|
||||
vacuum: bool = True,
|
||||
sessions_dir: Optional[Path] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Idempotent auto-maintenance: prune old sessions + optional VACUUM.
|
||||
|
||||
@ -1697,6 +1749,10 @@ class SessionDB:
|
||||
within ``min_interval_hours`` no-op. Designed to be called once at
|
||||
startup from long-lived entrypoints (CLI, gateway, cron scheduler).
|
||||
|
||||
When *sessions_dir* is provided, on-disk transcript files
|
||||
(``.json`` / ``.jsonl`` / ``request_dump_*``) for pruned sessions
|
||||
are removed as part of the same sweep (issue #3015).
|
||||
|
||||
Never raises. On any failure, logs a warning and returns a dict
|
||||
with ``"error"`` set.
|
||||
|
||||
@ -1720,7 +1776,10 @@ class SessionDB:
|
||||
except (TypeError, ValueError):
|
||||
pass # corrupt meta; treat as no prior run
|
||||
|
||||
pruned = self.prune_sessions(older_than_days=retention_days)
|
||||
pruned = self.prune_sessions(
|
||||
older_than_days=retention_days,
|
||||
sessions_dir=sessions_dir,
|
||||
)
|
||||
result["pruned"] = pruned
|
||||
|
||||
# Only VACUUM if we actually freed rows — VACUUM on a tight DB
|
||||
|
||||
@ -3,7 +3,9 @@
|
||||
Long-term memory with knowledge graph, entity resolution, and multi-strategy
|
||||
retrieval. Supports cloud (API key) and local modes.
|
||||
|
||||
Configurable timeout via HINDSIGHT_TIMEOUT env var or config.json.
|
||||
Configurable request timeout via HINDSIGHT_TIMEOUT env var or config.json.
|
||||
Configurable embedded daemon idle timeout via HINDSIGHT_IDLE_TIMEOUT env var
|
||||
or config.json idle_timeout.
|
||||
|
||||
Original PR #1811 by benfrank241, adapted to MemoryProvider ABC.
|
||||
|
||||
@ -14,6 +16,7 @@ Config via environment variables:
|
||||
HINDSIGHT_API_URL — API endpoint
|
||||
HINDSIGHT_MODE — cloud or local (default: cloud)
|
||||
HINDSIGHT_TIMEOUT — API request timeout in seconds (default: 120)
|
||||
HINDSIGHT_IDLE_TIMEOUT — embedded daemon idle timeout seconds; 0 disables shutdown (default: 300)
|
||||
HINDSIGHT_RETAIN_TAGS — comma-separated tags attached to retained memories
|
||||
HINDSIGHT_RETAIN_SOURCE — metadata source value attached to retained memories
|
||||
HINDSIGHT_RETAIN_USER_PREFIX — label used before user turns in retained transcripts
|
||||
@ -45,6 +48,7 @@ _DEFAULT_API_URL = "https://api.hindsight.vectorize.io"
|
||||
_DEFAULT_LOCAL_URL = "http://localhost:8888"
|
||||
_MIN_CLIENT_VERSION = "0.4.22"
|
||||
_DEFAULT_TIMEOUT = 120 # seconds — cloud API can take 30-40s per request
|
||||
_DEFAULT_IDLE_TIMEOUT = 300 # seconds — Hindsight embedded daemon default
|
||||
_VALID_BUDGETS = {"low", "mid", "high"}
|
||||
_PROVIDER_DEFAULT_MODELS = {
|
||||
"openai": "gpt-4o-mini",
|
||||
@ -59,6 +63,17 @@ _PROVIDER_DEFAULT_MODELS = {
|
||||
}
|
||||
|
||||
|
||||
def _parse_int_setting(value: Any, default: int) -> int:
|
||||
"""Parse an integer config/env value, falling back on invalid input."""
|
||||
if value is None or value == "":
|
||||
return default
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
logger.warning("Invalid integer Hindsight setting %r; using default %s", value, default)
|
||||
return default
|
||||
|
||||
|
||||
def _check_local_runtime() -> tuple[bool, str | None]:
|
||||
"""Return whether local embedded Hindsight imports cleanly.
|
||||
|
||||
@ -203,6 +218,8 @@ def _load_config() -> dict:
|
||||
return {
|
||||
"mode": os.environ.get("HINDSIGHT_MODE", "cloud"),
|
||||
"apiKey": os.environ.get("HINDSIGHT_API_KEY", ""),
|
||||
"timeout": _parse_int_setting(os.environ.get("HINDSIGHT_TIMEOUT"), _DEFAULT_TIMEOUT),
|
||||
"idle_timeout": _parse_int_setting(os.environ.get("HINDSIGHT_IDLE_TIMEOUT"), _DEFAULT_IDLE_TIMEOUT),
|
||||
"retain_tags": os.environ.get("HINDSIGHT_RETAIN_TAGS", ""),
|
||||
"retain_source": os.environ.get("HINDSIGHT_RETAIN_SOURCE", ""),
|
||||
"retain_user_prefix": os.environ.get("HINDSIGHT_RETAIN_USER_PREFIX", "User"),
|
||||
@ -304,6 +321,16 @@ def _build_embedded_profile_env(config: dict[str, Any], *, llm_api_key: str | No
|
||||
}
|
||||
if current_base_url:
|
||||
env_values["HINDSIGHT_API_LLM_BASE_URL"] = str(current_base_url)
|
||||
|
||||
idle_timeout = (
|
||||
config.get("idle_timeout")
|
||||
if config.get("idle_timeout") is not None
|
||||
else os.environ.get("HINDSIGHT_IDLE_TIMEOUT")
|
||||
)
|
||||
if idle_timeout is not None and idle_timeout != "":
|
||||
env_values["HINDSIGHT_EMBED_DAEMON_IDLE_TIMEOUT"] = str(
|
||||
_parse_int_setting(idle_timeout, _DEFAULT_IDLE_TIMEOUT)
|
||||
)
|
||||
return env_values
|
||||
|
||||
|
||||
@ -412,6 +439,7 @@ class HindsightMemoryProvider(MemoryProvider):
|
||||
self._turn_index = 0
|
||||
self._client = None
|
||||
self._timeout = _DEFAULT_TIMEOUT
|
||||
self._idle_timeout = _DEFAULT_IDLE_TIMEOUT
|
||||
self._prefetch_result = ""
|
||||
self._prefetch_lock = threading.Lock()
|
||||
self._prefetch_thread = None
|
||||
@ -592,10 +620,17 @@ class HindsightMemoryProvider(MemoryProvider):
|
||||
sys.stdout.write(" LLM API key: ")
|
||||
sys.stdout.flush()
|
||||
llm_key = getpass.getpass(prompt="") if sys.stdin.isatty() else sys.stdin.readline().strip()
|
||||
# Always write explicitly (including empty) so the provider sees ""
|
||||
# rather than a missing variable. The daemon reads from .env at
|
||||
# startup and fails when HINDSIGHT_LLM_API_KEY is unset.
|
||||
env_writes["HINDSIGHT_LLM_API_KEY"] = llm_key
|
||||
if llm_key:
|
||||
env_writes["HINDSIGHT_LLM_API_KEY"] = llm_key
|
||||
else:
|
||||
env_path = Path(hermes_home) / ".env"
|
||||
existing_llm_key = ""
|
||||
if env_path.exists():
|
||||
for line in env_path.read_text().splitlines():
|
||||
if line.startswith("HINDSIGHT_LLM_API_KEY="):
|
||||
existing_llm_key = line.split("=", 1)[1]
|
||||
break
|
||||
env_writes["HINDSIGHT_LLM_API_KEY"] = existing_llm_key
|
||||
|
||||
# Step 4: Save everything
|
||||
provider_config["bank_id"] = "hermes"
|
||||
@ -605,6 +640,11 @@ class HindsightMemoryProvider(MemoryProvider):
|
||||
timeout_val = existing_timeout if existing_timeout else _DEFAULT_TIMEOUT
|
||||
provider_config["timeout"] = timeout_val
|
||||
env_writes["HINDSIGHT_TIMEOUT"] = str(timeout_val)
|
||||
if mode == "local_embedded":
|
||||
existing_idle_timeout = self._config.get("idle_timeout") if self._config else None
|
||||
idle_timeout_val = existing_idle_timeout if existing_idle_timeout is not None else _DEFAULT_IDLE_TIMEOUT
|
||||
provider_config["idle_timeout"] = idle_timeout_val
|
||||
env_writes["HINDSIGHT_IDLE_TIMEOUT"] = str(idle_timeout_val)
|
||||
config["memory"]["provider"] = "hindsight"
|
||||
save_config(config)
|
||||
|
||||
@ -693,6 +733,7 @@ class HindsightMemoryProvider(MemoryProvider):
|
||||
{"key": "recall_max_input_chars", "description": "Maximum input query length for auto-recall", "default": 800},
|
||||
{"key": "recall_prompt_preamble", "description": "Custom preamble for recalled memories in context"},
|
||||
{"key": "timeout", "description": "API request timeout in seconds", "default": _DEFAULT_TIMEOUT},
|
||||
{"key": "idle_timeout", "description": "Embedded daemon idle timeout in seconds (0 disables auto-shutdown)", "default": _DEFAULT_IDLE_TIMEOUT, "when": {"mode": "local_embedded"}},
|
||||
]
|
||||
|
||||
def _get_client(self):
|
||||
@ -720,6 +761,14 @@ class HindsightMemoryProvider(MemoryProvider):
|
||||
)
|
||||
if self._llm_base_url:
|
||||
kwargs["llm_base_url"] = self._llm_base_url
|
||||
idle_timeout = _parse_int_setting(
|
||||
self._config.get("idle_timeout")
|
||||
if self._config.get("idle_timeout") is not None
|
||||
else os.environ.get("HINDSIGHT_IDLE_TIMEOUT", self._idle_timeout),
|
||||
_DEFAULT_IDLE_TIMEOUT,
|
||||
)
|
||||
self._idle_timeout = idle_timeout
|
||||
kwargs["idle_timeout"] = idle_timeout
|
||||
self._client = HindsightEmbedded(**kwargs)
|
||||
else:
|
||||
from hindsight_client import Hindsight
|
||||
@ -736,6 +785,38 @@ class HindsightMemoryProvider(MemoryProvider):
|
||||
"""Schedule *coro* on the shared loop using the configured timeout."""
|
||||
return _run_sync(coro, timeout=self._timeout)
|
||||
|
||||
def _is_retriable_embedded_connection_error(self, exc: Exception) -> bool:
|
||||
"""Return True for stale embedded-daemon connection failures."""
|
||||
if self._mode != "local_embedded":
|
||||
return False
|
||||
text = f"{type(exc).__name__}: {exc}".lower()
|
||||
return any(
|
||||
marker in text
|
||||
for marker in (
|
||||
"cannot connect to host",
|
||||
"connection refused",
|
||||
"connect call failed",
|
||||
"clientconnectorerror",
|
||||
)
|
||||
)
|
||||
|
||||
def _run_hindsight_operation(self, operation):
|
||||
"""Run an async Hindsight client operation, retrying once after idle shutdown."""
|
||||
client = self._get_client()
|
||||
try:
|
||||
return self._run_sync(operation(client))
|
||||
except Exception as exc:
|
||||
if not self._is_retriable_embedded_connection_error(exc):
|
||||
raise
|
||||
logger.info(
|
||||
"Hindsight embedded daemon appears unreachable; recreating client and retrying once: %s",
|
||||
exc,
|
||||
)
|
||||
self._client = None
|
||||
client = self._get_client()
|
||||
self._client = client
|
||||
return self._run_sync(operation(client))
|
||||
|
||||
def initialize(self, session_id: str, **kwargs) -> None:
|
||||
self._session_id = str(session_id or "").strip()
|
||||
self._parent_session_id = str(kwargs.get("parent_session_id", "") or "").strip()
|
||||
@ -790,7 +871,14 @@ class HindsightMemoryProvider(MemoryProvider):
|
||||
self._session_turns = []
|
||||
self._mode = self._config.get("mode", "cloud")
|
||||
# Read timeout from config or env var, fall back to default
|
||||
self._timeout = self._config.get("timeout") or int(os.environ.get("HINDSIGHT_TIMEOUT", str(_DEFAULT_TIMEOUT)))
|
||||
self._timeout = _parse_int_setting(
|
||||
self._config.get("timeout") if self._config.get("timeout") is not None else os.environ.get("HINDSIGHT_TIMEOUT"),
|
||||
_DEFAULT_TIMEOUT,
|
||||
)
|
||||
self._idle_timeout = _parse_int_setting(
|
||||
self._config.get("idle_timeout") if self._config.get("idle_timeout") is not None else os.environ.get("HINDSIGHT_IDLE_TIMEOUT"),
|
||||
_DEFAULT_IDLE_TIMEOUT,
|
||||
)
|
||||
# "local" is a legacy alias for "local_embedded"
|
||||
if self._mode == "local":
|
||||
self._mode = "local_embedded"
|
||||
@ -981,10 +1069,9 @@ class HindsightMemoryProvider(MemoryProvider):
|
||||
|
||||
def _run():
|
||||
try:
|
||||
client = self._get_client()
|
||||
if self._prefetch_method == "reflect":
|
||||
logger.debug("Prefetch: calling reflect (bank=%s, query_len=%d)", self._bank_id, len(query))
|
||||
resp = self._run_sync(client.areflect(bank_id=self._bank_id, query=query, budget=self._budget))
|
||||
resp = self._run_hindsight_operation(lambda client: client.areflect(bank_id=self._bank_id, query=query, budget=self._budget))
|
||||
text = resp.text or ""
|
||||
else:
|
||||
recall_kwargs: dict = {
|
||||
@ -998,7 +1085,7 @@ class HindsightMemoryProvider(MemoryProvider):
|
||||
recall_kwargs["types"] = self._recall_types
|
||||
logger.debug("Prefetch: calling recall (bank=%s, query_len=%d, budget=%s)",
|
||||
self._bank_id, len(query), self._budget)
|
||||
resp = self._run_sync(client.arecall(**recall_kwargs))
|
||||
resp = self._run_hindsight_operation(lambda client: client.arecall(**recall_kwargs))
|
||||
num_results = len(resp.results) if resp.results else 0
|
||||
logger.debug("Prefetch: recall returned %d results", num_results)
|
||||
text = "\n".join(f"- {r.text}" for r in resp.results if r.text) if resp.results else ""
|
||||
@ -1131,12 +1218,14 @@ class HindsightMemoryProvider(MemoryProvider):
|
||||
item.pop("retain_async", None)
|
||||
logger.debug("Hindsight retain: bank=%s, doc=%s, async=%s, content_len=%d, num_turns=%d",
|
||||
self._bank_id, self._document_id, self._retain_async, len(content), len(self._session_turns))
|
||||
self._run_sync(client.aretain_batch(
|
||||
bank_id=self._bank_id,
|
||||
items=[item],
|
||||
document_id=self._document_id,
|
||||
retain_async=self._retain_async,
|
||||
))
|
||||
self._run_hindsight_operation(
|
||||
lambda client: client.aretain_batch(
|
||||
bank_id=self._bank_id,
|
||||
items=[item],
|
||||
document_id=self._document_id,
|
||||
retain_async=self._retain_async,
|
||||
)
|
||||
)
|
||||
logger.debug("Hindsight retain succeeded")
|
||||
except Exception as e:
|
||||
logger.warning("Hindsight sync failed: %s", e, exc_info=True)
|
||||
@ -1152,12 +1241,6 @@ class HindsightMemoryProvider(MemoryProvider):
|
||||
return [RETAIN_SCHEMA, RECALL_SCHEMA, REFLECT_SCHEMA]
|
||||
|
||||
def handle_tool_call(self, tool_name: str, args: dict, **kwargs) -> str:
|
||||
try:
|
||||
client = self._get_client()
|
||||
except Exception as e:
|
||||
logger.warning("Hindsight client init failed: %s", e)
|
||||
return tool_error(f"Hindsight client unavailable: {e}")
|
||||
|
||||
if tool_name == "hindsight_retain":
|
||||
content = args.get("content", "")
|
||||
if not content:
|
||||
@ -1171,7 +1254,7 @@ class HindsightMemoryProvider(MemoryProvider):
|
||||
)
|
||||
logger.debug("Tool hindsight_retain: bank=%s, content_len=%d, context=%s",
|
||||
self._bank_id, len(content), context)
|
||||
self._run_sync(client.aretain(**retain_kwargs))
|
||||
self._run_hindsight_operation(lambda client: client.aretain(**retain_kwargs))
|
||||
logger.debug("Tool hindsight_retain: success")
|
||||
return json.dumps({"result": "Memory stored successfully."})
|
||||
except Exception as e:
|
||||
@ -1194,7 +1277,7 @@ class HindsightMemoryProvider(MemoryProvider):
|
||||
recall_kwargs["types"] = self._recall_types
|
||||
logger.debug("Tool hindsight_recall: bank=%s, query_len=%d, budget=%s",
|
||||
self._bank_id, len(query), self._budget)
|
||||
resp = self._run_sync(client.arecall(**recall_kwargs))
|
||||
resp = self._run_hindsight_operation(lambda client: client.arecall(**recall_kwargs))
|
||||
num_results = len(resp.results) if resp.results else 0
|
||||
logger.debug("Tool hindsight_recall: %d results", num_results)
|
||||
if not resp.results:
|
||||
@ -1212,9 +1295,11 @@ class HindsightMemoryProvider(MemoryProvider):
|
||||
try:
|
||||
logger.debug("Tool hindsight_reflect: bank=%s, query_len=%d, budget=%s",
|
||||
self._bank_id, len(query), self._budget)
|
||||
resp = self._run_sync(client.areflect(
|
||||
bank_id=self._bank_id, query=query, budget=self._budget
|
||||
))
|
||||
resp = self._run_hindsight_operation(
|
||||
lambda client: client.areflect(
|
||||
bank_id=self._bank_id, query=query, budget=self._budget
|
||||
)
|
||||
)
|
||||
logger.debug("Tool hindsight_reflect: response_len=%d", len(resp.text or ""))
|
||||
return json.dumps({"result": resp.text or "No relevant memories found."})
|
||||
except Exception as e:
|
||||
@ -1231,9 +1316,19 @@ class HindsightMemoryProvider(MemoryProvider):
|
||||
if self._client is not None:
|
||||
try:
|
||||
if self._mode == "local_embedded":
|
||||
# Use the public close() API. The RuntimeError from
|
||||
# aiohttp's "attached to a different loop" is expected
|
||||
# and harmless — the daemon keeps running independently.
|
||||
# HindsightEmbedded.close() delegates to its sync client.close().
|
||||
# When Hermes created/used that client on the shared async loop,
|
||||
# closing it from this thread can raise "attached to a different
|
||||
# loop" before aiohttp releases the session. Close the embedded
|
||||
# inner async client on the shared loop first, then let the
|
||||
# wrapper clean up daemon/UI bookkeeping.
|
||||
inner_client = getattr(self._client, "_client", None)
|
||||
if inner_client is not None and hasattr(inner_client, "aclose"):
|
||||
_run_sync(inner_client.aclose())
|
||||
try:
|
||||
self._client._client = None
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
self._client.close()
|
||||
except RuntimeError:
|
||||
|
||||
15
run_agent.py
15
run_agent.py
@ -3304,10 +3304,19 @@ class AIAgent:
|
||||
logger.warning("Background memory/skill review failed: %s", e)
|
||||
self._emit_auxiliary_failure("background review", e)
|
||||
finally:
|
||||
# Close all resources (httpx client, subprocesses, etc.) so
|
||||
# GC doesn't try to clean them up on a dead asyncio event
|
||||
# loop (which produces "Event loop is closed" errors).
|
||||
# Background review agents can initialize memory providers
|
||||
# (for example Hindsight) that own their own network clients.
|
||||
# Explicitly stop those providers before closing the agent so
|
||||
# their aiohttp sessions do not leak until GC/process exit.
|
||||
# Then close all remaining resources (httpx client,
|
||||
# subprocesses, etc.) so GC doesn't try to clean them up on a
|
||||
# dead asyncio event loop (which produces "Event loop is
|
||||
# closed" errors).
|
||||
if review_agent is not None:
|
||||
try:
|
||||
review_agent.shutdown_memory_provider()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
review_agent.close()
|
||||
except Exception:
|
||||
|
||||
@ -1055,10 +1055,37 @@ setup_path() {
|
||||
return 0
|
||||
fi
|
||||
|
||||
# FHS layout: /usr/local/bin is on PATH for every standard shell, nothing to inject.
|
||||
# FHS layout: /usr/local/bin is normally on PATH for login shells (via
|
||||
# /etc/profile pathmunge), but on RHEL/CentOS/Rocky/Alma 8+ non-login
|
||||
# interactive root shells (su, sudo -s, tmux panes, some web terminals)
|
||||
# only source /etc/bashrc, which does NOT add /usr/local/bin — and
|
||||
# /root/.bash_profile doesn't either. So verify with `command -v` and
|
||||
# fall back to writing a PATH guard into /root/.bashrc when needed.
|
||||
if [ "$ROOT_FHS_LAYOUT" = true ]; then
|
||||
export PATH="$command_link_dir:$PATH"
|
||||
log_info "/usr/local/bin is already on PATH for all shells"
|
||||
# Probe a fresh non-login interactive bash the way the user will use it.
|
||||
# `bash -i -c` sources ~/.bashrc but NOT ~/.bash_profile or /etc/profile,
|
||||
# which is the exact scenario where RHEL root loses /usr/local/bin.
|
||||
if env -i HOME="$HOME" TERM="${TERM:-dumb}" bash -i -c 'command -v hermes' \
|
||||
>/dev/null 2>&1; then
|
||||
log_info "/usr/local/bin is already on PATH for all shells"
|
||||
log_success "hermes command ready"
|
||||
return 0
|
||||
fi
|
||||
|
||||
log_info "hermes not on PATH in non-login shells (common on RHEL-family)"
|
||||
PATH_LINE='export PATH="/usr/local/bin:$PATH"'
|
||||
PATH_COMMENT='# Hermes Agent — ensure /usr/local/bin is on PATH (RHEL non-login shells)'
|
||||
for SHELL_CONFIG in "$HOME/.bashrc" "$HOME/.bash_profile"; do
|
||||
[ -f "$SHELL_CONFIG" ] || continue
|
||||
if ! grep -v '^[[:space:]]*#' "$SHELL_CONFIG" 2>/dev/null \
|
||||
| grep -qE 'PATH=.*(/usr/local/bin|\$command_link_dir)'; then
|
||||
echo "" >> "$SHELL_CONFIG"
|
||||
echo "$PATH_COMMENT" >> "$SHELL_CONFIG"
|
||||
echo "$PATH_LINE" >> "$SHELL_CONFIG"
|
||||
log_success "Added /usr/local/bin to PATH in $SHELL_CONFIG"
|
||||
fi
|
||||
done
|
||||
log_success "hermes command ready"
|
||||
return 0
|
||||
fi
|
||||
|
||||
@ -43,16 +43,22 @@ AUTHOR_MAP = {
|
||||
"teknium1@gmail.com": "teknium1",
|
||||
"teknium@nousresearch.com": "teknium1",
|
||||
"127238744+teknium1@users.noreply.github.com": "teknium1",
|
||||
"johnnncenaaa77@gmail.com": "johnncenae",
|
||||
"focusflow.app.help@gmail.com": "yes999zc",
|
||||
"343873859@qq.com": "DrStrangerUJN",
|
||||
"uzmpsk.dilekakbas@gmail.com": "dlkakbs",
|
||||
"jefferson@heimdallstrategy.com": "Mind-Dragon",
|
||||
"130918800+devorun@users.noreply.github.com": "devorun",
|
||||
"sonoyuncudmr@gmail.com": "Sonoyunchu",
|
||||
"maks.mir@yahoo.com": "say8hi",
|
||||
"web3blind@users.noreply.github.com": "web3blind",
|
||||
"julia@alexland.us": "alexg0bot",
|
||||
"1060770+benjaminsehl@users.noreply.github.com": "benjaminsehl",
|
||||
"nerijusn76@gmail.com": "Nerijusas",
|
||||
"itonov@proton.me": "Ito-69",
|
||||
"glesstech@gmail.com": "georgeglessner",
|
||||
"maxim.smetanin@gmail.com": "maxims-oss",
|
||||
"yoimexex@gmail.com": "Yoimex",
|
||||
# contributors (from noreply pattern)
|
||||
"david.vv@icloud.com": "davidvv",
|
||||
"wangqiang@wangqiangdeMac-mini.local": "xiaoqiang243",
|
||||
@ -118,6 +124,17 @@ AUTHOR_MAP = {
|
||||
"Mibayy@users.noreply.github.com": "Mibayy",
|
||||
"mibayy@users.noreply.github.com": "Mibayy",
|
||||
"135070653+sgaofen@users.noreply.github.com": "sgaofen",
|
||||
"lzy.dev@gmail.com": "zhiyanliu",
|
||||
"me@janstepanovsky.cz": "hhhonzik",
|
||||
"139848623+hhuang91@users.noreply.github.com": "hhuang91",
|
||||
"s.ozaki@ebinou.net": "Satoshi-agi",
|
||||
"10774721+kunlabs@users.noreply.github.com": "kunlabs",
|
||||
"110560187+Wang-tianhao@users.noreply.github.com": "Wang-tianhao",
|
||||
"170458616+ghostmfr@users.noreply.github.com": "ghostmfr",
|
||||
"1848670+mewwts@users.noreply.github.com": "mewwts",
|
||||
"1930707+haru398801@users.noreply.github.com": "haru398801",
|
||||
"rapabelias@gmail.com": "badgerbees",
|
||||
"xnb888@proton.me": "xnbi",
|
||||
"nocoo@users.noreply.github.com": "nocoo",
|
||||
"30841158+n-WN@users.noreply.github.com": "n-WN",
|
||||
"tsuijinglei@gmail.com": "hiddenpuppy",
|
||||
@ -194,6 +211,7 @@ AUTHOR_MAP = {
|
||||
"satelerd@gmail.com": "satelerd",
|
||||
"dan@danlynn.com": "danklynn",
|
||||
"mattmaximo@hotmail.com": "MattMaximo",
|
||||
"MatthewRHardwick@gmail.com": "mrhwick",
|
||||
"149063006+j3ffffff@users.noreply.github.com": "j3ffffff",
|
||||
"A-FdL-Prog@users.noreply.github.com": "A-FdL-Prog",
|
||||
"l0hde@users.noreply.github.com": "l0hde",
|
||||
@ -380,6 +398,17 @@ AUTHOR_MAP = {
|
||||
"zzn+pa@zzn.im": "xinbenlv",
|
||||
"zaynjarvis@gmail.com": "ZaynJarvis",
|
||||
"zhiheng.liu@bytedance.com": "ZaynJarvis",
|
||||
"izhaolongfei@gmail.com": "loongfay",
|
||||
"296659110@qq.com": "lrt4836",
|
||||
"fe.daniel91@gmail.com": "beforeload",
|
||||
"libo1106@foxmail.com": "libo1106",
|
||||
"295367131@qq.com": "295367131",
|
||||
"295367132@qq.com": "IxAres",
|
||||
"danieldliu@tencent.com": "danieldliu",
|
||||
"loongzhao@tencent.com": "loongzhao",
|
||||
"Bartok9@users.noreply.github.com": "Bartok9",
|
||||
"LeonSGP43@users.noreply.github.com": "LeonSGP43",
|
||||
"kshitijk4poor@users.noreply.github.com": "kshitijk4poor",
|
||||
"mbelleau@Michels-MacBook-Pro.local": "malaiwah",
|
||||
"michel.belleau@malaiwah.com": "malaiwah",
|
||||
"gnanasekaran.sekareee@gmail.com": "gnanam1990",
|
||||
|
||||
228
skills/productivity/airtable/SKILL.md
Normal file
228
skills/productivity/airtable/SKILL.md
Normal file
@ -0,0 +1,228 @@
|
||||
---
|
||||
name: airtable
|
||||
description: Airtable REST API via curl. Records CRUD, filters, upserts.
|
||||
version: 1.1.0
|
||||
author: community
|
||||
license: MIT
|
||||
prerequisites:
|
||||
env_vars: [AIRTABLE_API_KEY]
|
||||
commands: [curl]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Airtable, Productivity, Database, API]
|
||||
homepage: https://airtable.com/developers/web/api/introduction
|
||||
---
|
||||
|
||||
# Airtable — Bases, Tables & Records
|
||||
|
||||
Work with Airtable's REST API directly via `curl` using the `terminal` tool. No MCP server, no OAuth flow, no Python SDK — just `curl` and a personal access token.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
1. Create a **Personal Access Token (PAT)** at https://airtable.com/create/tokens (tokens start with `pat...`).
|
||||
2. Grant these scopes (minimum):
|
||||
- `data.records:read` — read rows
|
||||
- `data.records:write` — create / update / delete rows
|
||||
- `schema.bases:read` — list bases and tables
|
||||
3. **Important:** in the same token UI, add each base you want to access to the token's **Access** list. PATs are scoped per-base — a valid token on the wrong base returns `403`.
|
||||
4. Store the token in `~/.hermes/.env` (or via `hermes setup`):
|
||||
```
|
||||
AIRTABLE_API_KEY=pat_your_token_here
|
||||
```
|
||||
|
||||
> Note: legacy `key...` API keys were deprecated Feb 2024. Only PATs and OAuth tokens work now.
|
||||
|
||||
## API Basics
|
||||
|
||||
- **Endpoint:** `https://api.airtable.com/v0`
|
||||
- **Auth header:** `Authorization: Bearer $AIRTABLE_API_KEY`
|
||||
- **All requests** use JSON (`Content-Type: application/json` for any POST/PATCH/PUT body).
|
||||
- **Object IDs:** bases `app...`, tables `tbl...`, records `rec...`, fields `fld...`. IDs never change; names can. Prefer IDs in automations.
|
||||
- **Rate limit:** 5 requests/sec/base. `429` → back off. Burst on a single base will be throttled.
|
||||
|
||||
Base curl pattern:
|
||||
```bash
|
||||
curl -s "https://api.airtable.com/v0/$BASE_ID/$TABLE?maxRecords=5" \
|
||||
-H "Authorization: Bearer $AIRTABLE_API_KEY" | python3 -m json.tool
|
||||
```
|
||||
|
||||
`-s` suppresses curl's progress bar — keep it set for every call so the tool output stays clean for Hermes. Pipe through `python3 -m json.tool` (always present) or `jq` (if installed) for readable JSON.
|
||||
|
||||
## Field Types (request body shapes)
|
||||
|
||||
| Field type | Write shape |
|
||||
|---|---|
|
||||
| Single line text | `"Name": "hello"` |
|
||||
| Long text | `"Notes": "multi\nline"` |
|
||||
| Number | `"Score": 42` |
|
||||
| Checkbox | `"Done": true` |
|
||||
| Single select | `"Status": "Todo"` (name must already exist unless `typecast: true`) |
|
||||
| Multi-select | `"Tags": ["urgent", "bug"]` |
|
||||
| Date | `"Due": "2026-04-01"` |
|
||||
| DateTime (UTC) | `"At": "2026-04-01T14:30:00.000Z"` |
|
||||
| URL / Email / Phone | `"Link": "https://…"` |
|
||||
| Attachment | `"Files": [{"url": "https://…"}]` (Airtable fetches + rehosts) |
|
||||
| Linked record | `"Owner": ["recXXXXXXXXXXXXXX"]` (array of record IDs) |
|
||||
| User | `"AssignedTo": {"id": "usrXXXXXXXXXXXXXX"}` |
|
||||
|
||||
Pass `"typecast": true` at the top level of a create/update body to let Airtable auto-coerce values (e.g. create a new select option on the fly, convert `"42"` → `42`).
|
||||
|
||||
## Common Queries
|
||||
|
||||
### List bases the token can see
|
||||
```bash
|
||||
curl -s "https://api.airtable.com/v0/meta/bases" \
|
||||
-H "Authorization: Bearer $AIRTABLE_API_KEY" | python3 -m json.tool
|
||||
```
|
||||
|
||||
### List tables + schema for a base
|
||||
```bash
|
||||
curl -s "https://api.airtable.com/v0/meta/bases/$BASE_ID/tables" \
|
||||
-H "Authorization: Bearer $AIRTABLE_API_KEY" | python3 -m json.tool
|
||||
```
|
||||
Use this BEFORE mutating — confirms exact field names and IDs, surfaces `options.choices` for select fields, and shows primary-field names.
|
||||
|
||||
### List records (first 10)
|
||||
```bash
|
||||
curl -s "https://api.airtable.com/v0/$BASE_ID/$TABLE?maxRecords=10" \
|
||||
-H "Authorization: Bearer $AIRTABLE_API_KEY" | python3 -m json.tool
|
||||
```
|
||||
|
||||
### Get a single record
|
||||
```bash
|
||||
curl -s "https://api.airtable.com/v0/$BASE_ID/$TABLE/$RECORD_ID" \
|
||||
-H "Authorization: Bearer $AIRTABLE_API_KEY" | python3 -m json.tool
|
||||
```
|
||||
|
||||
### Filter records (filterByFormula)
|
||||
Airtable formulas must be URL-encoded. Let Python stdlib do it — never hand-encode:
|
||||
```bash
|
||||
FORMULA="{Status}='Todo'"
|
||||
ENC=$(python3 -c 'import sys, urllib.parse; print(urllib.parse.quote(sys.argv[1], safe=""))' "$FORMULA")
|
||||
curl -s "https://api.airtable.com/v0/$BASE_ID/$TABLE?filterByFormula=$ENC&maxRecords=20" \
|
||||
-H "Authorization: Bearer $AIRTABLE_API_KEY" | python3 -m json.tool
|
||||
```
|
||||
|
||||
Useful formula patterns:
|
||||
- Exact match: `{Email}='user@example.com'`
|
||||
- Contains: `FIND('bug', LOWER({Title}))`
|
||||
- Multiple conditions: `AND({Status}='Todo', {Priority}='High')`
|
||||
- Or: `OR({Owner}='alice', {Owner}='bob')`
|
||||
- Not empty: `NOT({Assignee}='')`
|
||||
- Date comparison: `IS_AFTER({Due}, TODAY())`
|
||||
|
||||
### Sort + select specific fields
|
||||
```bash
|
||||
curl -s "https://api.airtable.com/v0/$BASE_ID/$TABLE?sort%5B0%5D%5Bfield%5D=Priority&sort%5B0%5D%5Bdirection%5D=asc&fields%5B%5D=Name&fields%5B%5D=Status" \
|
||||
-H "Authorization: Bearer $AIRTABLE_API_KEY" | python3 -m json.tool
|
||||
```
|
||||
Square brackets in query params MUST be URL-encoded (`%5B` / `%5D`).
|
||||
|
||||
### Use a named view
|
||||
```bash
|
||||
curl -s "https://api.airtable.com/v0/$BASE_ID/$TABLE?view=Grid%20view&maxRecords=50" \
|
||||
-H "Authorization: Bearer $AIRTABLE_API_KEY" | python3 -m json.tool
|
||||
```
|
||||
Views apply their saved filter + sort server-side.
|
||||
|
||||
## Common Mutations
|
||||
|
||||
### Create a record
|
||||
```bash
|
||||
curl -s -X POST "https://api.airtable.com/v0/$BASE_ID/$TABLE" \
|
||||
-H "Authorization: Bearer $AIRTABLE_API_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"fields":{"Name":"New task","Status":"Todo","Priority":"High"}}' | python3 -m json.tool
|
||||
```
|
||||
|
||||
### Create up to 10 records in one call
|
||||
```bash
|
||||
curl -s -X POST "https://api.airtable.com/v0/$BASE_ID/$TABLE" \
|
||||
-H "Authorization: Bearer $AIRTABLE_API_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"typecast": true,
|
||||
"records": [
|
||||
{"fields": {"Name": "Task A", "Status": "Todo"}},
|
||||
{"fields": {"Name": "Task B", "Status": "In progress"}}
|
||||
]
|
||||
}' | python3 -m json.tool
|
||||
```
|
||||
Batch endpoints are capped at **10 records per request**. For larger inserts, loop in batches of 10 with a short sleep to respect 5 req/sec/base.
|
||||
|
||||
### Update a record (PATCH — merges, preserves unchanged fields)
|
||||
```bash
|
||||
curl -s -X PATCH "https://api.airtable.com/v0/$BASE_ID/$TABLE/$RECORD_ID" \
|
||||
-H "Authorization: Bearer $AIRTABLE_API_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"fields":{"Status":"Done"}}' | python3 -m json.tool
|
||||
```
|
||||
|
||||
### Upsert by a merge field (no ID needed)
|
||||
```bash
|
||||
curl -s -X PATCH "https://api.airtable.com/v0/$BASE_ID/$TABLE" \
|
||||
-H "Authorization: Bearer $AIRTABLE_API_KEY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"performUpsert": {"fieldsToMergeOn": ["Email"]},
|
||||
"records": [
|
||||
{"fields": {"Email": "user@example.com", "Status": "Active"}}
|
||||
]
|
||||
}' | python3 -m json.tool
|
||||
```
|
||||
`performUpsert` creates records whose merge-field values are new, patches records whose merge-field values already exist. Great for idempotent syncs.
|
||||
|
||||
### Delete a record
|
||||
```bash
|
||||
curl -s -X DELETE "https://api.airtable.com/v0/$BASE_ID/$TABLE/$RECORD_ID" \
|
||||
-H "Authorization: Bearer $AIRTABLE_API_KEY" | python3 -m json.tool
|
||||
```
|
||||
|
||||
### Delete up to 10 records in one call
|
||||
```bash
|
||||
curl -s -X DELETE "https://api.airtable.com/v0/$BASE_ID/$TABLE?records%5B%5D=rec1&records%5B%5D=rec2" \
|
||||
-H "Authorization: Bearer $AIRTABLE_API_KEY" | python3 -m json.tool
|
||||
```
|
||||
|
||||
## Pagination
|
||||
|
||||
List endpoints return at most **100 records per page**. If the response includes `"offset": "..."`, pass it back on the next call. Loop until the field is absent:
|
||||
|
||||
```bash
|
||||
OFFSET=""
|
||||
while :; do
|
||||
URL="https://api.airtable.com/v0/$BASE_ID/$TABLE?pageSize=100"
|
||||
[ -n "$OFFSET" ] && URL="$URL&offset=$OFFSET"
|
||||
RESP=$(curl -s "$URL" -H "Authorization: Bearer $AIRTABLE_API_KEY")
|
||||
echo "$RESP" | python3 -c 'import json,sys; d=json.load(sys.stdin); [print(r["id"], r["fields"].get("Name","")) for r in d["records"]]'
|
||||
OFFSET=$(echo "$RESP" | python3 -c 'import json,sys; d=json.load(sys.stdin); print(d.get("offset",""))')
|
||||
[ -z "$OFFSET" ] && break
|
||||
done
|
||||
```
|
||||
|
||||
## Typical Hermes Workflow
|
||||
|
||||
1. **Confirm auth.** `curl -s -o /dev/null -w "%{http_code}\n" https://api.airtable.com/v0/meta/bases -H "Authorization: Bearer $AIRTABLE_API_KEY"` — expect `200`.
|
||||
2. **Find the base.** List bases (step above) OR ask the user for the `app...` ID directly if the token lacks `schema.bases:read`.
|
||||
3. **Inspect the schema.** `GET /v0/meta/bases/$BASE_ID/tables` — cache the exact field names and primary-field name locally in the session before mutating anything.
|
||||
4. **Read before you write.** For "update X where Y", `filterByFormula` first to resolve the `rec...` ID, then `PATCH /v0/$BASE_ID/$TABLE/$RECORD_ID`. Never guess record IDs.
|
||||
5. **Batch writes.** Combine related creates into one 10-record POST to stay under the 5 req/sec budget.
|
||||
6. **Destructive ops.** Deletions can't be undone via API. If the user says "delete all Xs", echo back the filter + record count and confirm before firing.
|
||||
|
||||
## Pitfalls
|
||||
|
||||
- **`filterByFormula` MUST be URL-encoded.** Field names with spaces or non-ASCII also need encoding (`{My Field}` → `%7BMy%20Field%7D`). Use Python stdlib (pattern above) — never hand-escape.
|
||||
- **Empty fields are omitted from responses.** A missing `"Assignee"` key doesn't mean the field doesn't exist — it means this record's value is empty. Check the schema (step 3) before concluding a field is missing.
|
||||
- **PATCH vs PUT.** `PATCH` merges supplied fields into the record. `PUT` replaces the record entirely and clears any field you didn't include. Default to `PATCH`.
|
||||
- **Single-select options must exist.** Writing `"Status": "Shipping"` when `Shipping` isn't in the field's option list errors with `INVALID_MULTIPLE_CHOICE_OPTIONS` unless you pass `"typecast": true` (which auto-creates the option).
|
||||
- **Per-base token scoping.** A `403` on one base while another works means the token's Access list doesn't include that base — not a scope or auth issue. Send the user to https://airtable.com/create/tokens to grant it.
|
||||
- **Rate limits are per base, not per token.** 5 req/sec on `baseA` and 5 req/sec on `baseB` is fine; 6 req/sec on `baseA` alone will throttle. Monitor the `Retry-After` header on `429`.
|
||||
|
||||
## Important Notes for Hermes
|
||||
|
||||
- **Always use the `terminal` tool with `curl`.** Do NOT use `web_extract` (it can't send auth headers) or `browser_navigate` (needs UI auth and is slow).
|
||||
- **`AIRTABLE_API_KEY` flows from `~/.hermes/.env` into the subprocess automatically** when this skill is loaded — no need to re-export it before each `curl` call.
|
||||
- **Escape curly braces in formulas carefully.** In a heredoc body, `{Status}` is literal. In a shell argument, `{Status}` is safe outside `{...}` brace-expansion context — but pass dynamic strings through `python3 urllib.parse.quote` before splicing into a URL.
|
||||
- **Pretty-print with `python3 -m json.tool`** (always present) rather than `jq` (optional). Only reach for `jq` when you need filtering/projection.
|
||||
- **Pagination is per-page, not global.** Airtable's 100-record cap is a hard limit; there is no way to bump it. Loop with `offset` until the field is absent.
|
||||
- **Read the `errors` array** on non-2xx responses — Airtable returns structured error codes like `AUTHENTICATION_REQUIRED`, `INVALID_PERMISSIONS`, `MODEL_ID_NOT_FOUND`, `INVALID_MULTIPLE_CHOICE_OPTIONS` that tell you exactly what's wrong.
|
||||
@ -926,13 +926,18 @@ def cmd_timezone(args):
|
||||
os_ = offset_info.get("seconds", 0)
|
||||
sign = "+" if oh >= 0 else "-"
|
||||
utc_offset = f"{sign}{abs(oh):02d}:{om:02d}"
|
||||
if os_:
|
||||
utc_offset = f"{utc_offset}:{os_:02d}"
|
||||
elif tz_data.get("standardUtcOffset"):
|
||||
offset_info2 = tz_data["standardUtcOffset"]
|
||||
if isinstance(offset_info2, dict):
|
||||
oh = offset_info2.get("hours", 0)
|
||||
om = abs(offset_info2.get("minutes", 0))
|
||||
os_ = offset_info2.get("seconds", 0)
|
||||
sign = "+" if oh >= 0 else "-"
|
||||
utc_offset = f"{sign}{abs(oh):02d}:{om:02d}"
|
||||
if os_:
|
||||
utc_offset = f"{utc_offset}:{os_:02d}"
|
||||
timezone_src = "timeapi.io"
|
||||
except (RuntimeError, KeyError, TypeError):
|
||||
pass # API may be down; continue to fallback
|
||||
|
||||
107
skills/yuanbao/SKILL.md
Normal file
107
skills/yuanbao/SKILL.md
Normal file
@ -0,0 +1,107 @@
|
||||
---
|
||||
name: yuanbao
|
||||
description: Yuanbao (元宝) group interaction — @mention users, query group info and members
|
||||
version: 1.0.0
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [yuanbao, mention, at, group, members, 元宝, 派, 艾特]
|
||||
related_skills: []
|
||||
---
|
||||
|
||||
# Yuanbao Group Interaction
|
||||
|
||||
## CRITICAL: How Messaging Works
|
||||
|
||||
**Your text reply IS the message sent to the group/user.** The gateway automatically delivers your response text to the chat. You do NOT need any special "send message" tool — just reply normally and it gets sent.
|
||||
|
||||
When you include `@nickname` in your reply text, the gateway automatically converts it into a real @mention that notifies the user. This is built-in — you have full @mention capability.
|
||||
|
||||
**NEVER say you cannot send messages or @mention users. NEVER suggest the user do it manually. NEVER add disclaimers about permissions. Just reply with the text you want sent.**
|
||||
|
||||
## Available Tools
|
||||
|
||||
| Tool | When to use |
|
||||
|------|------------|
|
||||
| `yb_query_group_info` | Query group name, owner, member count |
|
||||
| `yb_query_group_members` | Find a user, list bots, list all members, or get nickname for @mention |
|
||||
| `yb_send_dm` | Send a private/direct message (DM / 私信) to a user, with optional media files |
|
||||
|
||||
## @Mention Workflow
|
||||
|
||||
When you need to @mention / 艾特 someone:
|
||||
|
||||
1. Call `yb_query_group_members` with `action="find"`, `name="<target name>"`, `mention=true`
|
||||
2. Get the exact nickname from the response
|
||||
3. Include `@nickname` in your reply text — the gateway handles the rest
|
||||
|
||||
Example: user says "帮我艾特元宝"
|
||||
|
||||
Step 1 — tool call:
|
||||
```json
|
||||
{ "group_code": "328306697", "action": "find", "name": "元宝", "mention": true }
|
||||
```
|
||||
|
||||
Step 2 — your reply (this gets sent to the group with a working @mention):
|
||||
```
|
||||
@元宝 你好,有人找你!
|
||||
```
|
||||
|
||||
**That's it.** No extra explanation needed. Keep it short and natural.
|
||||
|
||||
**Rules:**
|
||||
- Call `yb_query_group_members` first to get the exact nickname — do NOT guess
|
||||
- The @mention format: `@nickname` with a space before the @ sign
|
||||
- Your reply text IS the message — it WILL be sent and the @mention WILL work
|
||||
- Be concise. Do NOT explain how @mention works to the user.
|
||||
|
||||
## Send DM (Private Message) Workflow
|
||||
|
||||
When someone asks to send a private message / 私信 / DM to a user:
|
||||
|
||||
1. Call `yb_send_dm` with `group_code`, `name` (target user's name), and `message`
|
||||
2. The tool automatically finds the user and sends the DM
|
||||
3. Report the result to the user
|
||||
|
||||
Example: user says "给 @用户aea3 私信发一个 hello"
|
||||
|
||||
```json
|
||||
yb_send_dm({ "group_code": "535168412", "name": "用户aea3", "message": "hello" })
|
||||
```
|
||||
|
||||
Example with media: user says "给 @用户aea3 私信发一张图片"
|
||||
|
||||
```json
|
||||
yb_send_dm({
|
||||
"group_code": "535168412",
|
||||
"name": "用户aea3",
|
||||
"message": "Here is the image",
|
||||
"media_files": [{"path": "/tmp/photo.jpg"}]
|
||||
})
|
||||
```
|
||||
|
||||
**Rules:**
|
||||
- Extract `group_code` from the current chat_id (e.g. `group:535168412` → `535168412`)
|
||||
- If you already know the user_id, pass it directly via the `user_id` parameter to skip lookup
|
||||
- If multiple users match the name, the tool returns candidates — ask the user to clarify
|
||||
- Do NOT use `send_message` tool for Yuanbao DMs — use `yb_send_dm` instead
|
||||
- Supports media: images (.jpg/.png/.gif/.webp/.bmp) sent as image messages, other files as documents
|
||||
|
||||
## Query Group Info
|
||||
|
||||
```json
|
||||
yb_query_group_info({ "group_code": "328306697" })
|
||||
```
|
||||
|
||||
## Query Members
|
||||
|
||||
| Action | Description |
|
||||
|--------|-------------|
|
||||
| `find` | Search by name (partial match, case-insensitive) |
|
||||
| `list_bots` | List bots and Yuanbao AI assistants |
|
||||
| `list_all` | List all members |
|
||||
|
||||
## Notes
|
||||
|
||||
- `group_code` comes from chat_id: `group:328306697` → `328306697`
|
||||
- Groups are called "派 (Pai)" in the Yuanbao app
|
||||
- Member roles: `user`, `yuanbao_ai`, `bot`
|
||||
@ -117,6 +117,12 @@ class TestHintMessages:
|
||||
assert "/busy interrupt" in msg
|
||||
assert "queued" in msg.lower()
|
||||
|
||||
def test_busy_input_hint_gateway_steer(self):
|
||||
msg = busy_input_hint_gateway("steer")
|
||||
assert "/busy interrupt" in msg
|
||||
assert "/busy queue" in msg
|
||||
assert "steer" in msg.lower()
|
||||
|
||||
def test_busy_input_hint_cli_interrupt(self):
|
||||
msg = busy_input_hint_cli("interrupt")
|
||||
assert "/busy queue" in msg
|
||||
@ -125,6 +131,12 @@ class TestHintMessages:
|
||||
msg = busy_input_hint_cli("queue")
|
||||
assert "/busy interrupt" in msg
|
||||
|
||||
def test_busy_input_hint_cli_steer(self):
|
||||
msg = busy_input_hint_cli("steer")
|
||||
assert "/busy interrupt" in msg
|
||||
assert "/busy queue" in msg
|
||||
assert "steer" in msg.lower()
|
||||
|
||||
def test_tool_progress_hints_mention_verbose(self):
|
||||
assert "/verbose" in tool_progress_hint_gateway()
|
||||
assert "/verbose" in tool_progress_hint_cli()
|
||||
@ -133,8 +145,10 @@ class TestHintMessages:
|
||||
for hint in (
|
||||
busy_input_hint_gateway("queue"),
|
||||
busy_input_hint_gateway("interrupt"),
|
||||
busy_input_hint_gateway("steer"),
|
||||
busy_input_hint_cli("queue"),
|
||||
busy_input_hint_cli("interrupt"),
|
||||
busy_input_hint_cli("steer"),
|
||||
tool_progress_hint_gateway(),
|
||||
tool_progress_hint_cli(),
|
||||
):
|
||||
|
||||
@ -65,6 +65,35 @@ class TestHandleBusyCommand(unittest.TestCase):
|
||||
self.assertEqual(stub.busy_input_mode, "interrupt")
|
||||
mock_save.assert_called_once_with("display.busy_input_mode", "interrupt")
|
||||
|
||||
def test_steer_argument_sets_steer_mode_and_saves(self):
|
||||
cli_mod = _import_cli()
|
||||
stub = self._make_cli("interrupt")
|
||||
with (
|
||||
patch.object(cli_mod, "_cprint") as mock_cprint,
|
||||
patch.object(cli_mod, "save_config_value", return_value=True) as mock_save,
|
||||
):
|
||||
cli_mod.HermesCLI._handle_busy_command(stub, "/busy steer")
|
||||
|
||||
self.assertEqual(stub.busy_input_mode, "steer")
|
||||
mock_save.assert_called_once_with("display.busy_input_mode", "steer")
|
||||
printed = " ".join(str(c) for c in mock_cprint.call_args_list)
|
||||
self.assertIn("steer", printed.lower())
|
||||
|
||||
def test_status_reports_steer_behavior(self):
|
||||
cli_mod = _import_cli()
|
||||
stub = self._make_cli("steer")
|
||||
with (
|
||||
patch.object(cli_mod, "_cprint") as mock_cprint,
|
||||
patch.object(cli_mod, "save_config_value") as mock_save,
|
||||
):
|
||||
cli_mod.HermesCLI._handle_busy_command(stub, "/busy status")
|
||||
|
||||
mock_save.assert_not_called()
|
||||
printed = " ".join(str(c) for c in mock_cprint.call_args_list)
|
||||
self.assertIn("steer", printed.lower())
|
||||
# The usage line should also advertise the steer option
|
||||
self.assertIn("steer", printed)
|
||||
|
||||
def test_invalid_argument_prints_usage(self):
|
||||
cli_mod = _import_cli()
|
||||
stub = self._make_cli()
|
||||
@ -90,5 +119,5 @@ class TestBusyCommandRegistry(unittest.TestCase):
|
||||
from hermes_cli.commands import COMMAND_REGISTRY
|
||||
|
||||
busy = next(c for c in COMMAND_REGISTRY if c.name == "busy")
|
||||
assert busy.args_hint == "[queue|interrupt|status]"
|
||||
assert busy.args_hint == "[queue|steer|interrupt|status]"
|
||||
assert busy.category == "Configuration"
|
||||
|
||||
@ -31,6 +31,40 @@ def _make_cli_stub():
|
||||
return cli
|
||||
|
||||
|
||||
def _make_background_cli_stub():
|
||||
cli = _make_cli_stub()
|
||||
cli._background_task_counter = 0
|
||||
cli._background_tasks = {}
|
||||
cli._ensure_runtime_credentials = MagicMock(return_value=True)
|
||||
cli._resolve_turn_agent_config = MagicMock(return_value={
|
||||
"model": "test-model",
|
||||
"runtime": {
|
||||
"api_key": "test-key",
|
||||
"base_url": "https://example.test/v1",
|
||||
"provider": "test",
|
||||
"api_mode": "chat_completions",
|
||||
},
|
||||
"request_overrides": None,
|
||||
})
|
||||
cli.max_turns = 90
|
||||
cli.enabled_toolsets = []
|
||||
cli._session_db = None
|
||||
cli.reasoning_config = {}
|
||||
cli.service_tier = None
|
||||
cli._providers_only = None
|
||||
cli._providers_ignore = None
|
||||
cli._providers_order = None
|
||||
cli._provider_sort = None
|
||||
cli._provider_require_params = None
|
||||
cli._provider_data_collection = None
|
||||
cli._fallback_model = None
|
||||
cli._agent_running = False
|
||||
cli._spinner_text = ""
|
||||
cli.bell_on_complete = False
|
||||
cli.final_response_markdown = "strip"
|
||||
return cli
|
||||
|
||||
|
||||
class TestCliApprovalUi:
|
||||
def test_sudo_prompt_restores_existing_draft_after_response(self):
|
||||
cli = _make_cli_stub()
|
||||
@ -255,6 +289,54 @@ class TestCliApprovalUi:
|
||||
# Command got truncated with a marker.
|
||||
assert "(command truncated" in rendered
|
||||
|
||||
def test_background_task_registers_thread_local_approval_callbacks(self):
|
||||
"""Background /btw tasks must use the prompt_toolkit approval UI.
|
||||
|
||||
The foreground chat path registers dangerous-command callbacks inside
|
||||
its worker thread because tools.terminal_tool stores them in
|
||||
threading.local(). /background used to skip that, so dangerous commands
|
||||
fell back to raw input() in a background thread and timed out under
|
||||
prompt_toolkit.
|
||||
"""
|
||||
cli = _make_background_cli_stub()
|
||||
seen = {}
|
||||
|
||||
class FakeAgent:
|
||||
def __init__(self, **kwargs):
|
||||
self._print_fn = None
|
||||
self.thinking_callback = None
|
||||
|
||||
def run_conversation(self, **kwargs):
|
||||
from tools.terminal_tool import (
|
||||
_get_approval_callback,
|
||||
_get_sudo_password_callback,
|
||||
)
|
||||
|
||||
seen["approval"] = _get_approval_callback()
|
||||
seen["sudo"] = _get_sudo_password_callback()
|
||||
return {
|
||||
"final_response": "done",
|
||||
"messages": [],
|
||||
"completed": True,
|
||||
"failed": False,
|
||||
}
|
||||
|
||||
with patch.object(cli_module, "AIAgent", FakeAgent), \
|
||||
patch.object(cli_module, "_cprint"), \
|
||||
patch.object(cli_module, "ChatConsole") as chat_console:
|
||||
chat_console.return_value.print = MagicMock()
|
||||
cli._handle_background_command("/btw check weather")
|
||||
|
||||
deadline = time.time() + 2
|
||||
while cli._background_tasks and time.time() < deadline:
|
||||
time.sleep(0.01)
|
||||
|
||||
assert seen["approval"].__self__ is cli
|
||||
assert seen["approval"].__func__ is HermesCLI._approval_callback
|
||||
assert seen["sudo"].__self__ is cli
|
||||
assert seen["sudo"].__func__ is HermesCLI._sudo_password_callback
|
||||
assert not cli._background_tasks
|
||||
|
||||
|
||||
class TestApprovalCallbackThreadLocalWiring:
|
||||
"""Regression guard for the thread-local callback freeze (#13617 / #13618).
|
||||
|
||||
102
tests/cli/test_save_conversation_location.py
Normal file
102
tests/cli/test_save_conversation_location.py
Normal file
@ -0,0 +1,102 @@
|
||||
"""Tests for /save — the conversation snapshot slash command.
|
||||
|
||||
Regression: the old implementation wrote ``hermes_conversation_<ts>.json``
|
||||
to the current working directory (CWD). Users who ran /save expected the
|
||||
file to be discoverable via ``hermes sessions browse``, but CWD-resident
|
||||
snapshots are not indexed in the state DB and are generally invisible.
|
||||
The fix writes snapshots under ``~/.hermes/sessions/saved/`` and prints
|
||||
the absolute path plus the resume hint for the live session.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hermes_home(tmp_path, monkeypatch):
|
||||
home = tmp_path / ".hermes"
|
||||
home.mkdir()
|
||||
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
||||
monkeypatch.setenv("HERMES_HOME", str(home))
|
||||
# Clear any cached hermes_home computation
|
||||
import hermes_constants
|
||||
if hasattr(hermes_constants, "_hermes_home_cache"):
|
||||
hermes_constants._hermes_home_cache = None
|
||||
return home
|
||||
|
||||
|
||||
def _make_stub_cli(history):
|
||||
"""Build a minimal object exposing just what save_conversation uses."""
|
||||
return SimpleNamespace(
|
||||
conversation_history=history,
|
||||
model="test-model",
|
||||
session_id="20260101_120000_abc123",
|
||||
session_start=datetime(2026, 1, 1, 12, 0, 0),
|
||||
)
|
||||
|
||||
|
||||
def test_save_conversation_writes_under_hermes_home(hermes_home, tmp_path, monkeypatch, capsys):
|
||||
"""Snapshot must land under ~/.hermes/sessions/saved/, not CWD."""
|
||||
# Change CWD to a different directory to prove the file does NOT go there.
|
||||
work = tmp_path / "somewhere-else"
|
||||
work.mkdir()
|
||||
monkeypatch.chdir(work)
|
||||
|
||||
# Import fresh to pick up the HERMES_HOME fixture
|
||||
for mod in [m for m in sys.modules if m.startswith("cli") or m == "hermes_constants"]:
|
||||
sys.modules.pop(mod, None)
|
||||
|
||||
import cli # noqa: F401 (module under test)
|
||||
|
||||
stub = _make_stub_cli([
|
||||
{"role": "user", "content": "hi"},
|
||||
{"role": "assistant", "content": "hello"},
|
||||
])
|
||||
|
||||
# Call the unbound method against our stub.
|
||||
cli.HermesCLI.save_conversation(stub)
|
||||
|
||||
# File must NOT be in CWD
|
||||
cwd_leak = list(work.glob("hermes_conversation_*.json"))
|
||||
assert not cwd_leak, f"snapshot leaked to CWD: {cwd_leak}"
|
||||
|
||||
# File MUST be under ~/.hermes/sessions/saved/
|
||||
saved_dir = hermes_home / "sessions" / "saved"
|
||||
assert saved_dir.is_dir(), "expected saved/ subdirectory to be created"
|
||||
files = list(saved_dir.glob("hermes_conversation_*.json"))
|
||||
assert len(files) == 1, files
|
||||
|
||||
payload = json.loads(files[0].read_text())
|
||||
assert payload["model"] == "test-model"
|
||||
assert payload["session_id"] == "20260101_120000_abc123"
|
||||
assert payload["messages"] == [
|
||||
{"role": "user", "content": "hi"},
|
||||
{"role": "assistant", "content": "hello"},
|
||||
]
|
||||
|
||||
# User-facing message must include the absolute path AND the resume hint.
|
||||
out = capsys.readouterr().out
|
||||
assert str(files[0]) in out, out
|
||||
assert "hermes --resume 20260101_120000_abc123" in out, out
|
||||
|
||||
|
||||
def test_save_conversation_empty_history_does_nothing(hermes_home, capsys):
|
||||
for mod in [m for m in sys.modules if m.startswith("cli") or m == "hermes_constants"]:
|
||||
sys.modules.pop(mod, None)
|
||||
import cli
|
||||
|
||||
stub = _make_stub_cli([])
|
||||
cli.HermesCLI.save_conversation(stub)
|
||||
|
||||
saved_dir = hermes_home / "sessions" / "saved"
|
||||
assert not saved_dir.exists() or not list(saved_dir.iterdir())
|
||||
out = capsys.readouterr().out
|
||||
assert "No conversation to save" in out
|
||||
@ -211,6 +211,21 @@ _HERMES_BEHAVIORAL_VARS = frozenset({
|
||||
"SIGNAL_ALLOW_ALL_USERS",
|
||||
"EMAIL_ALLOW_ALL_USERS",
|
||||
"SMS_ALLOW_ALL_USERS",
|
||||
# Platform gating — set by load_gateway_config() as a side effect when
|
||||
# a config.yaml is present, so individual test bodies that call the
|
||||
# loader leak these values into later tests on the same xdist worker.
|
||||
# Force-clear on every test setup so the leak can't happen.
|
||||
"SLACK_REQUIRE_MENTION",
|
||||
"SLACK_STRICT_MENTION",
|
||||
"SLACK_FREE_RESPONSE_CHANNELS",
|
||||
"SLACK_ALLOW_BOTS",
|
||||
"SLACK_REACTIONS",
|
||||
"DISCORD_REQUIRE_MENTION",
|
||||
"DISCORD_FREE_RESPONSE_CHANNELS",
|
||||
"TELEGRAM_REQUIRE_MENTION",
|
||||
"WHATSAPP_REQUIRE_MENTION",
|
||||
"DINGTALK_REQUIRE_MENTION",
|
||||
"MATRIX_REQUIRE_MENTION",
|
||||
})
|
||||
|
||||
|
||||
|
||||
@ -186,6 +186,91 @@ class TestBusySessionAck:
|
||||
assert "respond once the current task finishes" in content
|
||||
assert "Interrupting" not in content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_steer_mode_calls_agent_steer_no_interrupt_no_queue(self):
|
||||
"""busy_input_mode='steer' injects via agent.steer() and skips queueing."""
|
||||
runner, sentinel = _make_runner()
|
||||
runner._busy_input_mode = "steer"
|
||||
adapter = _make_adapter()
|
||||
|
||||
event = _make_event(text="also check the tests")
|
||||
sk = build_session_key(event.source)
|
||||
runner.adapters[event.source.platform] = adapter
|
||||
|
||||
agent = MagicMock()
|
||||
agent.steer = MagicMock(return_value=True)
|
||||
runner._running_agents[sk] = agent
|
||||
|
||||
with patch("gateway.run.merge_pending_message_event") as mock_merge:
|
||||
await runner._handle_active_session_busy_message(event, sk)
|
||||
|
||||
# VERIFY: Agent was steered, NOT interrupted
|
||||
agent.steer.assert_called_once_with("also check the tests")
|
||||
agent.interrupt.assert_not_called()
|
||||
|
||||
# VERIFY: No queueing — successful steer must NOT replay as next turn
|
||||
mock_merge.assert_not_called()
|
||||
|
||||
# VERIFY: Ack mentions steer wording
|
||||
adapter._send_with_retry.assert_called_once()
|
||||
call_kwargs = adapter._send_with_retry.call_args
|
||||
content = call_kwargs.kwargs.get("content") or call_kwargs[1].get("content", "")
|
||||
assert "Steered" in content or "steer" in content.lower()
|
||||
assert "Interrupting" not in content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_steer_mode_falls_back_to_queue_when_agent_rejects(self):
|
||||
"""If agent.steer() returns False, fall back to queue behavior."""
|
||||
runner, sentinel = _make_runner()
|
||||
runner._busy_input_mode = "steer"
|
||||
adapter = _make_adapter()
|
||||
|
||||
event = _make_event(text="empty or rejected")
|
||||
sk = build_session_key(event.source)
|
||||
runner.adapters[event.source.platform] = adapter
|
||||
|
||||
agent = MagicMock()
|
||||
agent.steer = MagicMock(return_value=False) # rejected
|
||||
runner._running_agents[sk] = agent
|
||||
|
||||
with patch("gateway.run.merge_pending_message_event") as mock_merge:
|
||||
await runner._handle_active_session_busy_message(event, sk)
|
||||
|
||||
agent.steer.assert_called_once()
|
||||
agent.interrupt.assert_not_called()
|
||||
# Fell back to queue semantics: event was merged into pending messages
|
||||
mock_merge.assert_called_once()
|
||||
|
||||
# Ack uses queue-mode wording (not steer, not interrupt)
|
||||
call_kwargs = adapter._send_with_retry.call_args
|
||||
content = call_kwargs.kwargs.get("content") or call_kwargs[1].get("content", "")
|
||||
assert "Queued for the next turn" in content
|
||||
assert "Steered" not in content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_steer_mode_falls_back_to_queue_when_agent_pending(self):
|
||||
"""If agent is still starting (sentinel), steer mode falls back to queue."""
|
||||
runner, sentinel = _make_runner()
|
||||
runner._busy_input_mode = "steer"
|
||||
adapter = _make_adapter()
|
||||
|
||||
event = _make_event(text="arrived too early")
|
||||
sk = build_session_key(event.source)
|
||||
runner.adapters[event.source.platform] = adapter
|
||||
|
||||
# Agent is still being set up — sentinel in place
|
||||
runner._running_agents[sk] = sentinel
|
||||
|
||||
with patch("gateway.run.merge_pending_message_event") as mock_merge:
|
||||
await runner._handle_active_session_busy_message(event, sk)
|
||||
|
||||
# Event was queued instead of steered
|
||||
mock_merge.assert_called_once()
|
||||
|
||||
call_kwargs = adapter._send_with_retry.call_args
|
||||
content = call_kwargs.kwargs.get("content") or call_kwargs[1].get("content", "")
|
||||
assert "Queued for the next turn" in content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_debounce_suppresses_rapid_acks(self):
|
||||
"""Second message within 30s should NOT send another ack."""
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
"""Tests for gateway/channel_directory.py — channel resolution and display."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from gateway.channel_directory import (
|
||||
build_channel_directory,
|
||||
@ -12,6 +14,7 @@ from gateway.channel_directory import (
|
||||
format_directory_for_display,
|
||||
load_directory,
|
||||
_build_from_sessions,
|
||||
_build_slack,
|
||||
DIRECTORY_PATH,
|
||||
)
|
||||
|
||||
@ -62,7 +65,7 @@ class TestBuildChannelDirectoryWrites:
|
||||
monkeypatch.setattr(json, "dump", broken_dump)
|
||||
|
||||
with patch("gateway.channel_directory.DIRECTORY_PATH", cache_file):
|
||||
build_channel_directory({})
|
||||
asyncio.run(build_channel_directory({}))
|
||||
result = load_directory()
|
||||
|
||||
assert result == previous
|
||||
@ -142,6 +145,21 @@ class TestResolveChannelName:
|
||||
with self._setup(tmp_path, platforms):
|
||||
assert resolve_channel_name("telegram", "Coaching Chat / topic 17585") == "-1001:17585"
|
||||
|
||||
def test_id_match_takes_precedence_over_name(self, tmp_path):
|
||||
"""A raw channel ID resolves to itself, even when a different
|
||||
channel happens to be named the same string. Case-sensitive: Slack
|
||||
IDs are uppercase and must not be normalized away."""
|
||||
platforms = {
|
||||
"slack": [
|
||||
{"id": "C0B0QV5434G", "name": "engineering", "type": "channel"},
|
||||
{"id": "C99", "name": "c0b0qv5434g", "type": "channel"},
|
||||
]
|
||||
}
|
||||
with self._setup(tmp_path, platforms):
|
||||
assert resolve_channel_name("slack", "C0B0QV5434G") == "C0B0QV5434G"
|
||||
# Lowercase still falls through to name matching (case-insensitive)
|
||||
assert resolve_channel_name("slack", "c0b0qv5434g") == "C99"
|
||||
|
||||
def test_display_label_with_type_suffix_resolves(self, tmp_path):
|
||||
platforms = {
|
||||
"telegram": [
|
||||
@ -332,3 +350,135 @@ class TestLookupChannelType:
|
||||
}
|
||||
with self._setup(tmp_path, platforms):
|
||||
assert lookup_channel_type("discord", "300") is None
|
||||
|
||||
|
||||
def _make_slack_adapter(team_clients):
|
||||
"""Build a stand-in for SlackAdapter exposing only ``_team_clients``."""
|
||||
return SimpleNamespace(_team_clients=team_clients)
|
||||
|
||||
|
||||
def _make_slack_client(pages):
|
||||
"""Build an AsyncWebClient mock whose ``users_conversations`` returns pages."""
|
||||
client = MagicMock()
|
||||
client.users_conversations = AsyncMock(side_effect=pages)
|
||||
return client
|
||||
|
||||
|
||||
class TestBuildSlack:
|
||||
"""_build_slack actually calls users.conversations on each workspace client."""
|
||||
|
||||
def test_no_team_clients_falls_back_to_sessions(self, tmp_path):
|
||||
sessions_path = tmp_path / "sessions" / "sessions.json"
|
||||
sessions_path.parent.mkdir(parents=True)
|
||||
sessions_path.write_text(json.dumps({
|
||||
"s1": {"origin": {"platform": "slack", "chat_id": "D123", "chat_name": "Alice"}},
|
||||
}))
|
||||
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
entries = asyncio.run(_build_slack(_make_slack_adapter({})))
|
||||
|
||||
assert len(entries) == 1
|
||||
assert entries[0]["id"] == "D123"
|
||||
|
||||
def test_lists_channels_from_users_conversations(self, tmp_path):
|
||||
client = _make_slack_client([
|
||||
{
|
||||
"ok": True,
|
||||
"channels": [
|
||||
{"id": "C0B0QV5434G", "name": "engineering", "is_private": False},
|
||||
{"id": "G123ABCDEF", "name": "secret-chat", "is_private": True},
|
||||
],
|
||||
"response_metadata": {},
|
||||
},
|
||||
])
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
entries = asyncio.run(_build_slack(_make_slack_adapter({"T1": client})))
|
||||
|
||||
ids = {e["id"] for e in entries}
|
||||
assert ids == {"C0B0QV5434G", "G123ABCDEF"}
|
||||
types = {e["id"]: e["type"] for e in entries}
|
||||
assert types["C0B0QV5434G"] == "channel"
|
||||
assert types["G123ABCDEF"] == "private"
|
||||
client.users_conversations.assert_awaited_once()
|
||||
|
||||
def test_paginates_via_response_metadata_cursor(self, tmp_path):
|
||||
client = _make_slack_client([
|
||||
{
|
||||
"ok": True,
|
||||
"channels": [{"id": "C001", "name": "first", "is_private": False}],
|
||||
"response_metadata": {"next_cursor": "cur1"},
|
||||
},
|
||||
{
|
||||
"ok": True,
|
||||
"channels": [{"id": "C002", "name": "second", "is_private": False}],
|
||||
"response_metadata": {"next_cursor": ""},
|
||||
},
|
||||
])
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
entries = asyncio.run(_build_slack(_make_slack_adapter({"T1": client})))
|
||||
|
||||
assert {e["id"] for e in entries} == {"C001", "C002"}
|
||||
assert client.users_conversations.await_count == 2
|
||||
|
||||
def test_per_workspace_error_does_not_block_others(self, tmp_path):
|
||||
bad = MagicMock()
|
||||
bad.users_conversations = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
good = _make_slack_client([
|
||||
{
|
||||
"ok": True,
|
||||
"channels": [{"id": "C999", "name": "ok-channel", "is_private": False}],
|
||||
"response_metadata": {},
|
||||
},
|
||||
])
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
entries = asyncio.run(_build_slack(_make_slack_adapter({"BAD": bad, "GOOD": good})))
|
||||
|
||||
assert {e["id"] for e in entries} == {"C999"}
|
||||
|
||||
def test_session_dms_merged_when_not_in_api_results(self, tmp_path):
|
||||
sessions_path = tmp_path / "sessions" / "sessions.json"
|
||||
sessions_path.parent.mkdir(parents=True)
|
||||
sessions_path.write_text(json.dumps({
|
||||
"s1": {"origin": {"platform": "slack", "chat_id": "D456", "chat_name": "Bob"}},
|
||||
"dup": {"origin": {"platform": "slack", "chat_id": "C001", "chat_name": "first"}},
|
||||
}))
|
||||
client = _make_slack_client([
|
||||
{
|
||||
"ok": True,
|
||||
"channels": [{"id": "C001", "name": "first", "is_private": False}],
|
||||
"response_metadata": {},
|
||||
},
|
||||
])
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
entries = asyncio.run(_build_slack(_make_slack_adapter({"T1": client})))
|
||||
|
||||
ids = {e["id"] for e in entries}
|
||||
assert "C001" in ids and "D456" in ids
|
||||
# Channel ID from API should not be duplicated by the session merge
|
||||
assert sum(1 for e in entries if e["id"] == "C001") == 1
|
||||
|
||||
def test_skips_channels_with_no_id_or_name(self, tmp_path):
|
||||
client = _make_slack_client([
|
||||
{
|
||||
"ok": True,
|
||||
"channels": [
|
||||
{"id": "C001", "name": "good", "is_private": False},
|
||||
{"id": "", "name": "no-id"},
|
||||
{"id": "C002"}, # no name (e.g. IM)
|
||||
],
|
||||
"response_metadata": {},
|
||||
},
|
||||
])
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
entries = asyncio.run(_build_slack(_make_slack_adapter({"T1": client})))
|
||||
|
||||
assert {e["id"] for e in entries} == {"C001"}
|
||||
|
||||
def test_response_not_ok_breaks_pagination_for_that_workspace(self, tmp_path):
|
||||
client = _make_slack_client([
|
||||
{"ok": False, "error": "missing_scope"},
|
||||
])
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(tmp_path)}):
|
||||
entries = asyncio.run(_build_slack(_make_slack_adapter({"T1": client})))
|
||||
|
||||
assert entries == []
|
||||
|
||||
@ -186,12 +186,18 @@ class TestPlatformDefaults:
|
||||
assert resolve_display_setting({}, plat, "tool_progress") == "all", plat
|
||||
|
||||
def test_medium_tier_platforms(self):
|
||||
"""Slack, Mattermost, Matrix default to 'new' tool progress."""
|
||||
"""Mattermost, Matrix, Feishu, WhatsApp default to 'new' tool progress."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
for plat in ("slack", "mattermost", "matrix", "feishu", "whatsapp"):
|
||||
for plat in ("mattermost", "matrix", "feishu", "whatsapp"):
|
||||
assert resolve_display_setting({}, plat, "tool_progress") == "new", plat
|
||||
|
||||
def test_slack_defaults_tool_progress_off(self):
|
||||
"""Slack defaults to quiet tool progress (permanent chat noise otherwise)."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
|
||||
assert resolve_display_setting({}, "slack", "tool_progress") == "off"
|
||||
|
||||
def test_low_tier_platforms(self):
|
||||
"""Signal, BlueBubbles, etc. default to 'off' tool progress."""
|
||||
from gateway.display_config import resolve_display_setting
|
||||
@ -241,7 +247,7 @@ class TestConfigMigration:
|
||||
},
|
||||
},
|
||||
}
|
||||
config_path.write_text(yaml.dump(config))
|
||||
config_path.write_text(yaml.dump(config), encoding="utf-8")
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
# Re-import to pick up the new HERMES_HOME
|
||||
@ -251,7 +257,7 @@ class TestConfigMigration:
|
||||
|
||||
result = cfg_mod.migrate_config(interactive=False, quiet=True)
|
||||
# Re-read config
|
||||
updated = yaml.safe_load(config_path.read_text())
|
||||
updated = yaml.safe_load(config_path.read_text(encoding="utf-8"))
|
||||
platforms = updated.get("display", {}).get("platforms", {})
|
||||
assert platforms.get("signal", {}).get("tool_progress") == "off"
|
||||
assert platforms.get("telegram", {}).get("tool_progress") == "all"
|
||||
@ -268,7 +274,7 @@ class TestConfigMigration:
|
||||
"platforms": {"telegram": {"tool_progress": "verbose"}},
|
||||
},
|
||||
}
|
||||
config_path.write_text(yaml.dump(config))
|
||||
config_path.write_text(yaml.dump(config), encoding="utf-8")
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
import importlib
|
||||
@ -276,7 +282,7 @@ class TestConfigMigration:
|
||||
importlib.reload(cfg_mod)
|
||||
|
||||
cfg_mod.migrate_config(interactive=False, quiet=True)
|
||||
updated = yaml.safe_load(config_path.read_text())
|
||||
updated = yaml.safe_load(config_path.read_text(encoding="utf-8"))
|
||||
# Existing "verbose" should NOT be overwritten by legacy "off"
|
||||
assert updated["display"]["platforms"]["telegram"]["tool_progress"] == "verbose"
|
||||
|
||||
|
||||
@ -540,7 +540,7 @@ from gateway.config import Platform, PlatformConfig # noqa: E402
|
||||
|
||||
|
||||
def _make_slack_adapter():
|
||||
config = PlatformConfig(enabled=True, token="xoxb-fake-token")
|
||||
config = PlatformConfig(enabled=True, token="***")
|
||||
adapter = SlackAdapter(config)
|
||||
adapter._app = MagicMock()
|
||||
adapter._app.client = AsyncMock()
|
||||
@ -549,6 +549,39 @@ def _make_slack_adapter():
|
||||
return adapter
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SlackAdapter diagnostics helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSlackAttachmentDiagnostics:
|
||||
def test_missing_scope_error_returns_actionable_notice(self):
|
||||
"""_describe_slack_api_error translates a missing_scope response into
|
||||
a user-facing notice mentioning the needed scope and the reinstall
|
||||
step. This is the helper used by every files.info call site (Slack
|
||||
Connect stubs + post-download failures) to surface scope problems
|
||||
without making an extra probe call per attachment.
|
||||
"""
|
||||
adapter = _make_slack_adapter()
|
||||
|
||||
response = {
|
||||
"error": "missing_scope",
|
||||
"needed": "files:read",
|
||||
"provided": "chat:write,files:write",
|
||||
}
|
||||
detail = adapter._describe_slack_api_error(response, file_obj={"id": "F123", "name": "photo.jpg"})
|
||||
assert detail is not None
|
||||
assert "files:read" in detail
|
||||
assert "reinstall" in detail.lower()
|
||||
assert "chat:write,files:write" in detail
|
||||
|
||||
def test_download_failure_403_returns_permission_notice(self):
|
||||
adapter = _make_slack_adapter()
|
||||
exc = _make_http_status_error(403)
|
||||
detail = adapter._describe_slack_download_failure(exc, file_obj={"name": "report.pdf"})
|
||||
assert "403" in detail
|
||||
assert "permission or scope" in detail
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SlackAdapter._download_slack_file
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -702,6 +735,7 @@ class TestSlackDownloadSlackFileBytes:
|
||||
fake_response = MagicMock()
|
||||
fake_response.content = b"raw bytes here"
|
||||
fake_response.raise_for_status = MagicMock()
|
||||
fake_response.headers = {"content-type": "application/pdf"}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=fake_response)
|
||||
@ -717,6 +751,29 @@ class TestSlackDownloadSlackFileBytes:
|
||||
result = asyncio.run(run())
|
||||
assert result == b"raw bytes here"
|
||||
|
||||
def test_rejects_html_response(self):
|
||||
"""Slack HTML sign-in pages should not be accepted as file bytes."""
|
||||
adapter = _make_slack_adapter()
|
||||
|
||||
fake_response = MagicMock()
|
||||
fake_response.content = b"<!DOCTYPE html><html><title>Slack</title></html>"
|
||||
fake_response.raise_for_status = MagicMock()
|
||||
fake_response.headers = {"content-type": "text/html; charset=utf-8"}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=fake_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
await adapter._download_slack_file_bytes(
|
||||
"https://files.slack.com/file.bin"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="HTML instead of file bytes"):
|
||||
asyncio.run(run())
|
||||
|
||||
def test_retries_on_429_then_succeeds(self):
|
||||
"""429 on first attempt is retried; raw bytes returned on second."""
|
||||
adapter = _make_slack_adapter()
|
||||
@ -724,6 +781,7 @@ class TestSlackDownloadSlackFileBytes:
|
||||
ok_response = MagicMock()
|
||||
ok_response.content = b"final bytes"
|
||||
ok_response.raise_for_status = MagicMock()
|
||||
ok_response.headers = {"content-type": "application/pdf"}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(
|
||||
|
||||
@ -77,6 +77,19 @@ class TestMessageDeduplicatorTTL:
|
||||
assert "old-0" not in dedup._seen
|
||||
assert "new-0" in dedup._seen
|
||||
|
||||
def test_max_size_eviction_caps_fresh_entries(self):
|
||||
"""Fresh entries must still be capped to max_size on overflow."""
|
||||
dedup = MessageDeduplicator(max_size=2, ttl_seconds=60)
|
||||
|
||||
dedup.is_duplicate("msg-1")
|
||||
dedup.is_duplicate("msg-2")
|
||||
dedup.is_duplicate("msg-3")
|
||||
|
||||
assert len(dedup._seen) == 2
|
||||
assert "msg-1" not in dedup._seen
|
||||
assert "msg-2" in dedup._seen
|
||||
assert "msg-3" in dedup._seen
|
||||
|
||||
def test_ttl_zero_means_no_dedup(self):
|
||||
"""With TTL=0, all entries expire immediately."""
|
||||
dedup = MessageDeduplicator(ttl_seconds=0)
|
||||
|
||||
@ -77,6 +77,46 @@ class TestFindSessionId:
|
||||
|
||||
assert result == "sess_topic_a"
|
||||
|
||||
def test_user_id_disambiguates_same_group_chat(self, tmp_path):
|
||||
sessions_dir, index_file = _setup_sessions(tmp_path, {
|
||||
"alice": {
|
||||
"session_id": "sess_alice",
|
||||
"origin": {"platform": "telegram", "chat_id": "-1001", "user_id": "alice"},
|
||||
"updated_at": "2026-01-01T00:00:00",
|
||||
},
|
||||
"bob": {
|
||||
"session_id": "sess_bob",
|
||||
"origin": {"platform": "telegram", "chat_id": "-1001", "user_id": "bob"},
|
||||
"updated_at": "2026-02-01T00:00:00",
|
||||
},
|
||||
})
|
||||
|
||||
with patch.object(mirror_mod, "_SESSIONS_DIR", sessions_dir), \
|
||||
patch.object(mirror_mod, "_SESSIONS_INDEX", index_file):
|
||||
result = _find_session_id("telegram", "-1001", user_id="alice")
|
||||
|
||||
assert result == "sess_alice"
|
||||
|
||||
def test_ambiguous_same_group_chat_without_user_id_returns_none(self, tmp_path):
|
||||
sessions_dir, index_file = _setup_sessions(tmp_path, {
|
||||
"alice": {
|
||||
"session_id": "sess_alice",
|
||||
"origin": {"platform": "telegram", "chat_id": "-1001", "user_id": "alice"},
|
||||
"updated_at": "2026-01-01T00:00:00",
|
||||
},
|
||||
"bob": {
|
||||
"session_id": "sess_bob",
|
||||
"origin": {"platform": "telegram", "chat_id": "-1001", "user_id": "bob"},
|
||||
"updated_at": "2026-02-01T00:00:00",
|
||||
},
|
||||
})
|
||||
|
||||
with patch.object(mirror_mod, "_SESSIONS_DIR", sessions_dir), \
|
||||
patch.object(mirror_mod, "_SESSIONS_INDEX", index_file):
|
||||
result = _find_session_id("telegram", "-1001")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_no_match_returns_none(self, tmp_path):
|
||||
sessions_dir, index_file = _setup_sessions(tmp_path, {
|
||||
"sess": {
|
||||
@ -189,6 +229,35 @@ class TestMirrorToSession:
|
||||
assert (sessions_dir / "sess_topic_a.jsonl").exists()
|
||||
assert not (sessions_dir / "sess_topic_b.jsonl").exists()
|
||||
|
||||
def test_successful_mirror_uses_user_id_for_group_session(self, tmp_path):
|
||||
sessions_dir, index_file = _setup_sessions(tmp_path, {
|
||||
"alice": {
|
||||
"session_id": "sess_alice",
|
||||
"origin": {"platform": "telegram", "chat_id": "-1001", "user_id": "alice"},
|
||||
"updated_at": "2026-01-01T00:00:00",
|
||||
},
|
||||
"bob": {
|
||||
"session_id": "sess_bob",
|
||||
"origin": {"platform": "telegram", "chat_id": "-1001", "user_id": "bob"},
|
||||
"updated_at": "2026-02-01T00:00:00",
|
||||
},
|
||||
})
|
||||
|
||||
with patch.object(mirror_mod, "_SESSIONS_DIR", sessions_dir), \
|
||||
patch.object(mirror_mod, "_SESSIONS_INDEX", index_file), \
|
||||
patch("gateway.mirror._append_to_sqlite"):
|
||||
result = mirror_to_session(
|
||||
"telegram",
|
||||
"-1001",
|
||||
"Hello group!",
|
||||
source_label="cli",
|
||||
user_id="alice",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
assert (sessions_dir / "sess_alice.jsonl").exists()
|
||||
assert not (sessions_dir / "sess_bob.jsonl").exists()
|
||||
|
||||
def test_no_matching_session(self, tmp_path):
|
||||
sessions_dir, index_file = _setup_sessions(tmp_path, {})
|
||||
|
||||
|
||||
@ -168,19 +168,196 @@ class TestQueueConsumptionAfterCompletion:
|
||||
assert retrieved is not None
|
||||
assert retrieved.text == "process this after"
|
||||
|
||||
def test_multiple_queues_last_one_wins(self):
|
||||
"""If user /queue's multiple times, last message overwrites."""
|
||||
def test_multiple_queues_overflow_fifo(self):
|
||||
"""Multiple /queue commands must stack in FIFO order, no merging.
|
||||
|
||||
The adapter's _pending_messages dict has a single slot per session,
|
||||
but GatewayRunner layers an overflow buffer on top so repeated
|
||||
/queue invocations all get their own turn in order.
|
||||
"""
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = GatewayRunner.__new__(GatewayRunner)
|
||||
runner._queued_events = {}
|
||||
adapter = _StubAdapter()
|
||||
session_key = "telegram:user:123"
|
||||
|
||||
for text in ["first", "second", "third"]:
|
||||
event = MessageEvent(
|
||||
events = [
|
||||
MessageEvent(
|
||||
text=text,
|
||||
message_type=MessageType.TEXT,
|
||||
source=MagicMock(),
|
||||
source=MagicMock(chat_id="123", platform=Platform.TELEGRAM),
|
||||
message_id=f"q-{text}",
|
||||
)
|
||||
adapter._pending_messages[session_key] = event
|
||||
for text in ("first", "second", "third")
|
||||
]
|
||||
|
||||
retrieved = adapter.get_pending_message(session_key)
|
||||
assert retrieved.text == "third"
|
||||
for ev in events:
|
||||
runner._enqueue_fifo(session_key, ev, adapter)
|
||||
|
||||
# Slot holds head; overflow holds the tail in order.
|
||||
assert adapter._pending_messages[session_key].text == "first"
|
||||
assert [e.text for e in runner._queued_events[session_key]] == ["second", "third"]
|
||||
assert runner._queue_depth(session_key, adapter=adapter) == 3
|
||||
|
||||
def test_promote_advances_queue_fifo(self):
|
||||
"""After the slot drains, the next overflow item is promoted."""
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = GatewayRunner.__new__(GatewayRunner)
|
||||
runner._queued_events = {}
|
||||
adapter = _StubAdapter()
|
||||
session_key = "telegram:user:123"
|
||||
|
||||
for text in ("A", "B", "C"):
|
||||
runner._enqueue_fifo(
|
||||
session_key,
|
||||
MessageEvent(
|
||||
text=text,
|
||||
message_type=MessageType.TEXT,
|
||||
source=MagicMock(),
|
||||
message_id=f"q-{text}",
|
||||
),
|
||||
adapter,
|
||||
)
|
||||
|
||||
# Simulate turn 1 drain: consume slot, promote next.
|
||||
pending_event = _dequeue_pending_event(adapter, session_key)
|
||||
pending_event = runner._promote_queued_event(session_key, adapter, pending_event)
|
||||
assert pending_event is not None and pending_event.text == "A"
|
||||
assert adapter._pending_messages[session_key].text == "B"
|
||||
assert runner._queue_depth(session_key, adapter=adapter) == 2
|
||||
|
||||
# Simulate turn 2 drain.
|
||||
pending_event = _dequeue_pending_event(adapter, session_key)
|
||||
pending_event = runner._promote_queued_event(session_key, adapter, pending_event)
|
||||
assert pending_event.text == "B"
|
||||
assert adapter._pending_messages[session_key].text == "C"
|
||||
assert session_key not in runner._queued_events # overflow emptied
|
||||
|
||||
# Simulate turn 3 drain.
|
||||
pending_event = _dequeue_pending_event(adapter, session_key)
|
||||
pending_event = runner._promote_queued_event(session_key, adapter, pending_event)
|
||||
assert pending_event.text == "C"
|
||||
assert session_key not in adapter._pending_messages
|
||||
assert runner._queue_depth(session_key, adapter=adapter) == 0
|
||||
|
||||
# Turn 4: nothing pending.
|
||||
pending_event = _dequeue_pending_event(adapter, session_key)
|
||||
pending_event = runner._promote_queued_event(session_key, adapter, pending_event)
|
||||
assert pending_event is None
|
||||
|
||||
def test_promote_stages_overflow_when_slot_already_populated(self):
|
||||
"""If the slot was re-populated (e.g. by an interrupt follow-up),
|
||||
promotion must stage the overflow head without clobbering it."""
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = GatewayRunner.__new__(GatewayRunner)
|
||||
runner._queued_events = {}
|
||||
adapter = _StubAdapter()
|
||||
session_key = "telegram:user:123"
|
||||
|
||||
# /queue once — lands in slot. Second /queue — overflow.
|
||||
for text in ("Q1", "Q2"):
|
||||
runner._enqueue_fifo(
|
||||
session_key,
|
||||
MessageEvent(
|
||||
text=text,
|
||||
message_type=MessageType.TEXT,
|
||||
source=MagicMock(),
|
||||
message_id=f"q-{text}",
|
||||
),
|
||||
adapter,
|
||||
)
|
||||
|
||||
# Drain consumes Q1.
|
||||
pending_event = _dequeue_pending_event(adapter, session_key)
|
||||
assert pending_event.text == "Q1"
|
||||
|
||||
# Someone else (interrupt path) re-populates the slot.
|
||||
interrupt_follow_up = MessageEvent(
|
||||
text="urgent",
|
||||
message_type=MessageType.TEXT,
|
||||
source=MagicMock(),
|
||||
message_id="m-urg",
|
||||
)
|
||||
adapter._pending_messages[session_key] = interrupt_follow_up
|
||||
|
||||
# Promotion must NOT overwrite the interrupt follow-up; Q2 should
|
||||
# move into a position that runs AFTER it. In the current design
|
||||
# the overflow head is staged in the slot AFTER the interrupt
|
||||
# follow-up's turn runs — so here, the slot keeps the interrupt
|
||||
# and Q2 stays queued. Verify we return the interrupt event and
|
||||
# Q2 is positioned to run next.
|
||||
returned = runner._promote_queued_event(session_key, adapter, interrupt_follow_up)
|
||||
assert returned is interrupt_follow_up
|
||||
# Q2 was moved into the slot, evicting the interrupt? No —
|
||||
# current implementation puts Q2 in the slot unconditionally,
|
||||
# overwriting the interrupt. This is an acceptable edge-case
|
||||
# trade-off: /queue items always run after the currently-staged
|
||||
# pending_event (which is what `returned` is), and the slot
|
||||
# gets the next-in-line item.
|
||||
assert adapter._pending_messages[session_key].text == "Q2"
|
||||
|
||||
def test_queue_depth_counts_slot_plus_overflow(self):
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = GatewayRunner.__new__(GatewayRunner)
|
||||
runner._queued_events = {}
|
||||
adapter = _StubAdapter()
|
||||
session_key = "telegram:user:depth"
|
||||
|
||||
assert runner._queue_depth(session_key, adapter=adapter) == 0
|
||||
|
||||
runner._enqueue_fifo(
|
||||
session_key,
|
||||
MessageEvent(
|
||||
text="one",
|
||||
message_type=MessageType.TEXT,
|
||||
source=MagicMock(),
|
||||
message_id="q1",
|
||||
),
|
||||
adapter,
|
||||
)
|
||||
assert runner._queue_depth(session_key, adapter=adapter) == 1
|
||||
|
||||
for text in ("two", "three"):
|
||||
runner._enqueue_fifo(
|
||||
session_key,
|
||||
MessageEvent(
|
||||
text=text,
|
||||
message_type=MessageType.TEXT,
|
||||
source=MagicMock(),
|
||||
message_id=f"q-{text}",
|
||||
),
|
||||
adapter,
|
||||
)
|
||||
assert runner._queue_depth(session_key, adapter=adapter) == 3
|
||||
|
||||
def test_enqueue_preserves_text_no_merging(self):
|
||||
"""Each /queue item keeps its own text — never merged with neighbors."""
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = GatewayRunner.__new__(GatewayRunner)
|
||||
runner._queued_events = {}
|
||||
adapter = _StubAdapter()
|
||||
session_key = "telegram:user:nomerge"
|
||||
|
||||
texts = ["deploy the branch", "then run tests", "finally push"]
|
||||
for text in texts:
|
||||
runner._enqueue_fifo(
|
||||
session_key,
|
||||
MessageEvent(
|
||||
text=text,
|
||||
message_type=MessageType.TEXT,
|
||||
source=MagicMock(),
|
||||
message_id=f"q-{text[:4]}",
|
||||
),
|
||||
adapter,
|
||||
)
|
||||
|
||||
# Slot + overflow contain exactly the three texts, unmodified.
|
||||
collected = [adapter._pending_messages[session_key].text] + [
|
||||
e.text for e in runner._queued_events[session_key]
|
||||
]
|
||||
assert collected == texts
|
||||
|
||||
@ -90,9 +90,21 @@ def test_load_busy_input_mode_prefers_env_then_config_then_default(tmp_path, mon
|
||||
)
|
||||
assert gateway_run.GatewayRunner._load_busy_input_mode() == "queue"
|
||||
|
||||
(tmp_path / "config.yaml").write_text(
|
||||
"display:\n busy_input_mode: steer\n", encoding="utf-8"
|
||||
)
|
||||
assert gateway_run.GatewayRunner._load_busy_input_mode() == "steer"
|
||||
|
||||
monkeypatch.setenv("HERMES_GATEWAY_BUSY_INPUT_MODE", "interrupt")
|
||||
assert gateway_run.GatewayRunner._load_busy_input_mode() == "interrupt"
|
||||
|
||||
monkeypatch.setenv("HERMES_GATEWAY_BUSY_INPUT_MODE", "steer")
|
||||
assert gateway_run.GatewayRunner._load_busy_input_mode() == "steer"
|
||||
|
||||
# Unknown values fall through to the safe default
|
||||
monkeypatch.setenv("HERMES_GATEWAY_BUSY_INPUT_MODE", "bogus")
|
||||
assert gateway_run.GatewayRunner._load_busy_input_mode() == "interrupt"
|
||||
|
||||
|
||||
def test_load_restart_drain_timeout_prefers_env_then_config_then_default(
|
||||
tmp_path, monkeypatch, caplog
|
||||
|
||||
@ -245,6 +245,7 @@ class TestBuildSessionContextPrompt:
|
||||
assert "Slack" in prompt
|
||||
assert "cannot search" in prompt.lower()
|
||||
assert "pin" in prompt.lower()
|
||||
assert "current message's slack block/attachment payload" in prompt.lower()
|
||||
|
||||
def test_discord_prompt_with_channel_topic(self):
|
||||
"""Channel topic should appear in the session context prompt."""
|
||||
|
||||
@ -76,6 +76,7 @@ def _make_resume_runner():
|
||||
runner._running_agents_ts = {}
|
||||
runner._busy_ack_ts = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._update_prompt_pending = {}
|
||||
runner._agent_cache_lock = None
|
||||
runner.session_store = MagicMock()
|
||||
runner.session_store.get_or_create_session.return_value = current_entry
|
||||
@ -102,6 +103,7 @@ def _make_branch_runner():
|
||||
runner._running_agents_ts = {}
|
||||
runner._busy_ack_ts = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._update_prompt_pending = {}
|
||||
runner._agent_cache_lock = None
|
||||
runner.session_store = MagicMock()
|
||||
runner.session_store.get_or_create_session.return_value = current_entry
|
||||
@ -127,6 +129,8 @@ async def test_resume_clears_session_scoped_approval_and_yolo_state():
|
||||
enable_session_yolo(other_key)
|
||||
runner._pending_approvals[session_key] = {"command": "rm -rf /tmp/demo"}
|
||||
runner._pending_approvals[other_key] = {"command": "rm -rf /tmp/other"}
|
||||
runner._update_prompt_pending[session_key] = True
|
||||
runner._update_prompt_pending[other_key] = True
|
||||
|
||||
result = await runner._handle_resume_command(_make_event("/resume Resumed Work"))
|
||||
|
||||
@ -134,9 +138,11 @@ async def test_resume_clears_session_scoped_approval_and_yolo_state():
|
||||
assert is_approved(session_key, "recursive delete") is False
|
||||
assert is_session_yolo_enabled(session_key) is False
|
||||
assert session_key not in runner._pending_approvals
|
||||
assert session_key not in runner._update_prompt_pending
|
||||
assert is_approved(other_key, "recursive delete") is True
|
||||
assert is_session_yolo_enabled(other_key) is True
|
||||
assert other_key in runner._pending_approvals
|
||||
assert other_key in runner._update_prompt_pending
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -150,6 +156,8 @@ async def test_branch_clears_session_scoped_approval_and_yolo_state():
|
||||
enable_session_yolo(other_key)
|
||||
runner._pending_approvals[session_key] = {"command": "rm -rf /tmp/demo"}
|
||||
runner._pending_approvals[other_key] = {"command": "rm -rf /tmp/other"}
|
||||
runner._update_prompt_pending[session_key] = True
|
||||
runner._update_prompt_pending[other_key] = True
|
||||
|
||||
result = await runner._handle_branch_command(_make_event("/branch"))
|
||||
|
||||
@ -157,9 +165,11 @@ async def test_branch_clears_session_scoped_approval_and_yolo_state():
|
||||
assert is_approved(session_key, "recursive delete") is False
|
||||
assert is_session_yolo_enabled(session_key) is False
|
||||
assert session_key not in runner._pending_approvals
|
||||
assert session_key not in runner._update_prompt_pending
|
||||
assert is_approved(other_key, "recursive delete") is True
|
||||
assert is_session_yolo_enabled(other_key) is True
|
||||
assert other_key in runner._pending_approvals
|
||||
assert other_key in runner._update_prompt_pending
|
||||
|
||||
|
||||
def test_clear_session_boundary_security_state_is_scoped():
|
||||
@ -172,6 +182,7 @@ def test_clear_session_boundary_security_state_is_scoped():
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner._pending_approvals = {}
|
||||
runner._update_prompt_pending = {}
|
||||
|
||||
source = _make_source()
|
||||
session_key = build_session_key(source)
|
||||
@ -183,6 +194,8 @@ def test_clear_session_boundary_security_state_is_scoped():
|
||||
enable_session_yolo(other_key)
|
||||
runner._pending_approvals[session_key] = {"command": "rm -rf /tmp/demo"}
|
||||
runner._pending_approvals[other_key] = {"command": "rm -rf /tmp/other"}
|
||||
runner._update_prompt_pending[session_key] = True
|
||||
runner._update_prompt_pending[other_key] = True
|
||||
|
||||
runner._clear_session_boundary_security_state(session_key)
|
||||
|
||||
@ -190,11 +203,14 @@ def test_clear_session_boundary_security_state_is_scoped():
|
||||
assert is_approved(session_key, "recursive delete") is False
|
||||
assert is_session_yolo_enabled(session_key) is False
|
||||
assert session_key not in runner._pending_approvals
|
||||
assert session_key not in runner._update_prompt_pending
|
||||
# Other session untouched
|
||||
assert is_approved(other_key, "recursive delete") is True
|
||||
assert is_session_yolo_enabled(other_key) is True
|
||||
assert other_key in runner._pending_approvals
|
||||
assert other_key in runner._update_prompt_pending
|
||||
|
||||
# Empty session_key is a no-op
|
||||
runner._clear_session_boundary_security_state("")
|
||||
assert is_approved(other_key, "recursive delete") is True
|
||||
assert other_key in runner._update_prompt_pending
|
||||
|
||||
@ -1,11 +1,16 @@
|
||||
"""Regression tests for the TUI gateway's ``session.list`` handler.
|
||||
|
||||
Reported during TUI v2 blitz retest: the ``/resume`` modal inside a TUI
|
||||
session only surfaced ``tui``/``cli`` rows, hiding telegram sessions users
|
||||
could still resume directly via ``hermes --tui --resume <id>``.
|
||||
|
||||
The fix widens the picker to a curated allowlist of user-facing sources
|
||||
(tui/cli + chat adapters) while still filtering internal/system sources.
|
||||
History:
|
||||
- The original implementation hardcoded an allow-list of known gateway
|
||||
sources (``tui, cli, telegram, discord, slack, ...``). New or unlisted
|
||||
sources (``acp``, ``webhook``, user-defined ``HERMES_SESSION_SOURCE``
|
||||
values, newly-added platforms) were silently dropped from the resume
|
||||
picker — users reported "lots of sessions are missing from browse
|
||||
but exist in .hermes/sessions."
|
||||
- The handler now deny-lists only the internal/noisy source ``tool``
|
||||
(sub-agent runs) and surfaces every other source to the picker.
|
||||
- The default ``limit`` raised from 20 to 200 so longer-running users
|
||||
can scroll through their history without hitting an artificial cap.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@ -23,42 +28,64 @@ class _StubDB:
|
||||
return list(self.rows)
|
||||
|
||||
|
||||
def _call(limit: int = 20):
|
||||
def _call(limit: int | None = None):
|
||||
params: dict = {}
|
||||
if limit is not None:
|
||||
params["limit"] = limit
|
||||
return server.handle_request({
|
||||
"id": "1",
|
||||
"method": "session.list",
|
||||
"params": {"limit": limit},
|
||||
"params": params,
|
||||
})
|
||||
|
||||
|
||||
def test_session_list_includes_telegram_but_filters_internal_sources(monkeypatch):
|
||||
def test_session_list_surfaces_all_user_facing_sources(monkeypatch):
|
||||
"""acp / webhook / custom sources should all appear; only ``tool`` is hidden."""
|
||||
rows = [
|
||||
{"id": "tui-1", "source": "tui", "started_at": 9},
|
||||
{"id": "tool-1", "source": "tool", "started_at": 8},
|
||||
{"id": "tg-1", "source": "telegram", "started_at": 7},
|
||||
{"id": "acp-1", "source": "acp", "started_at": 6},
|
||||
{"id": "cli-1", "source": "cli", "started_at": 5},
|
||||
{"id": "webhook-1", "source": "webhook", "started_at": 4},
|
||||
{"id": "custom-1", "source": "my-custom-source", "started_at": 3},
|
||||
]
|
||||
db = _StubDB(rows)
|
||||
monkeypatch.setattr(server, "_get_db", lambda: db)
|
||||
|
||||
resp = _call(limit=10)
|
||||
sessions = resp["result"]["sessions"]
|
||||
ids = [s["id"] for s in sessions]
|
||||
ids = [s["id"] for s in resp["result"]["sessions"]]
|
||||
|
||||
assert "tg-1" in ids and "tui-1" in ids and "cli-1" in ids, ids
|
||||
assert "tool-1" not in ids and "acp-1" not in ids, ids
|
||||
# Every human-facing source — including previously-hidden acp, webhook,
|
||||
# and custom sources — must surface in the picker now.
|
||||
assert "tg-1" in ids
|
||||
assert "tui-1" in ids
|
||||
assert "cli-1" in ids
|
||||
assert "acp-1" in ids, "acp sessions were being hidden by the old allow-list"
|
||||
assert "webhook-1" in ids, "webhook sessions were being hidden by the old allow-list"
|
||||
assert "custom-1" in ids, "custom HERMES_SESSION_SOURCE values were being hidden"
|
||||
|
||||
# Only internal sub-agent runs stay hidden.
|
||||
assert "tool-1" not in ids
|
||||
|
||||
|
||||
def test_session_list_fetches_wider_window_before_filtering(monkeypatch):
|
||||
def test_session_list_default_limit_is_200(monkeypatch):
|
||||
"""Default limit should be wide enough for long-running users."""
|
||||
db = _StubDB([{"id": "x", "source": "cli", "started_at": 1}])
|
||||
monkeypatch.setattr(server, "_get_db", lambda: db)
|
||||
|
||||
_call() # no explicit limit
|
||||
# fetch_limit = max(limit * 2, 200); limit defaults to 200, so 400.
|
||||
assert db.calls[0].get("limit") == 400, db.calls[0]
|
||||
|
||||
|
||||
def test_session_list_respects_explicit_limit(monkeypatch):
|
||||
db = _StubDB([{"id": "x", "source": "cli", "started_at": 1}])
|
||||
monkeypatch.setattr(server, "_get_db", lambda: db)
|
||||
|
||||
_call(limit=10)
|
||||
|
||||
assert len(db.calls) == 1
|
||||
assert db.calls[0].get("source") is None, db.calls[0]
|
||||
assert db.calls[0].get("limit") == 100, db.calls[0]
|
||||
# fetch_limit = max(limit * 2, 200) = 200 when limit is small.
|
||||
assert db.calls[0].get("limit") == 200, db.calls[0]
|
||||
|
||||
|
||||
def test_session_list_preserves_ordering_after_filter(monkeypatch):
|
||||
@ -66,6 +93,7 @@ def test_session_list_preserves_ordering_after_filter(monkeypatch):
|
||||
{"id": "newest", "source": "telegram", "started_at": 5},
|
||||
{"id": "internal", "source": "tool", "started_at": 4},
|
||||
{"id": "middle", "source": "tui", "started_at": 3},
|
||||
{"id": "also-visible", "source": "webhook", "started_at": 2},
|
||||
{"id": "oldest", "source": "discord", "started_at": 1},
|
||||
]
|
||||
monkeypatch.setattr(server, "_get_db", lambda: _StubDB(rows))
|
||||
@ -73,4 +101,4 @@ def test_session_list_preserves_ordering_after_filter(monkeypatch):
|
||||
resp = _call()
|
||||
ids = [s["id"] for s in resp["result"]["sessions"]]
|
||||
|
||||
assert ids == ["newest", "middle", "oldest"]
|
||||
assert ids == ["newest", "middle", "also-visible", "oldest"]
|
||||
|
||||
@ -81,11 +81,13 @@ async def test_new_command_clears_session_model_override():
|
||||
"api_mode": "openai",
|
||||
}
|
||||
runner._session_reasoning_overrides[session_key] = {"enabled": True, "effort": "high"}
|
||||
runner._pending_model_notes[session_key] = "[Note: switched to gpt-4o.]"
|
||||
|
||||
await runner._handle_reset_command(_make_event("/new"))
|
||||
|
||||
assert session_key not in runner._session_model_overrides
|
||||
assert session_key not in runner._session_reasoning_overrides
|
||||
assert session_key not in runner._pending_model_notes
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -126,6 +128,8 @@ async def test_new_command_only_clears_own_session():
|
||||
}
|
||||
runner._session_reasoning_overrides[session_key] = {"enabled": True, "effort": "high"}
|
||||
runner._session_reasoning_overrides[other_key] = {"enabled": True, "effort": "low"}
|
||||
runner._pending_model_notes[session_key] = "[Note: switched to gpt-4o.]"
|
||||
runner._pending_model_notes[other_key] = "[Note: switched to claude-sonnet-4-6.]"
|
||||
|
||||
await runner._handle_reset_command(_make_event("/new"))
|
||||
|
||||
@ -133,3 +137,5 @@ async def test_new_command_only_clears_own_session():
|
||||
assert other_key in runner._session_model_overrides
|
||||
assert session_key not in runner._session_reasoning_overrides
|
||||
assert other_key in runner._session_reasoning_overrides
|
||||
assert session_key not in runner._pending_model_notes
|
||||
assert other_key in runner._pending_model_notes
|
||||
|
||||
210
tests/gateway/test_shutdown_cache_cleanup.py
Normal file
210
tests/gateway/test_shutdown_cache_cleanup.py
Normal file
@ -0,0 +1,210 @@
|
||||
"""Regression tests for gateway shutdown cleaning up cached agent memory providers (issue #11205).
|
||||
|
||||
When the gateway shuts down, ``stop()`` called ``_finalize_shutdown_agents()``
|
||||
which only drained agents in ``_running_agents``. Idle agents sitting in
|
||||
``_agent_cache`` (LRU cache) were never cleaned up, so their
|
||||
``MemoryProvider.on_session_end()`` hooks never fired.
|
||||
|
||||
The fix adds an explicit sweep of ``_agent_cache`` after
|
||||
``_finalize_shutdown_agents`` in the ``_stop_impl`` coroutine.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# Import the module (not the class) to reach stop() and helpers
|
||||
import gateway.run as gw_mod
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _FakeGateway:
|
||||
"""Minimal stand-in with just enough state for ``stop()`` to run."""
|
||||
|
||||
def __init__(self):
|
||||
self._running = True
|
||||
self._draining = False
|
||||
self._restart_requested = False
|
||||
self._restart_detached = False
|
||||
self._restart_via_service = False
|
||||
self._stop_task = None
|
||||
self._exit_cleanly = False
|
||||
self._exit_with_failure = False
|
||||
self._exit_reason = None
|
||||
self._exit_code = None
|
||||
self._restart_drain_timeout = 0.01
|
||||
self._running_agents = {}
|
||||
self._running_agents_ts = {}
|
||||
self._agent_cache = OrderedDict()
|
||||
self._agent_cache_lock = threading.Lock()
|
||||
self.adapters = {}
|
||||
self._background_tasks = set()
|
||||
self._failed_platforms = []
|
||||
self._shutdown_event = asyncio.Event()
|
||||
self._pending_messages = {}
|
||||
self._pending_approvals = {}
|
||||
self._busy_ack_ts = {}
|
||||
|
||||
def _running_agent_count(self):
|
||||
return len(self._running_agents)
|
||||
|
||||
def _update_runtime_status(self, *_a, **_kw):
|
||||
pass
|
||||
|
||||
async def _notify_active_sessions_of_shutdown(self):
|
||||
pass
|
||||
|
||||
async def _drain_active_agents(self, timeout):
|
||||
return {}, False
|
||||
|
||||
def _finalize_shutdown_agents(self, agents):
|
||||
for agent in agents.values():
|
||||
self._cleanup_agent_resources(agent)
|
||||
|
||||
def _cleanup_agent_resources(self, agent):
|
||||
if agent is None:
|
||||
return
|
||||
try:
|
||||
if hasattr(agent, "shutdown_memory_provider"):
|
||||
agent.shutdown_memory_provider()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if hasattr(agent, "close"):
|
||||
agent.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _evict_cached_agent(self, key):
|
||||
pass
|
||||
|
||||
|
||||
def _make_mock_agent():
|
||||
a = MagicMock()
|
||||
a.shutdown_memory_provider = MagicMock()
|
||||
a.close = MagicMock()
|
||||
return a
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCachedAgentCleanupOnShutdown:
|
||||
"""Verify that ``stop()`` calls ``_cleanup_agent_resources`` on idle
|
||||
cached agents, triggering ``shutdown_memory_provider()`` (which calls
|
||||
``on_session_end``)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cached_agent_memory_provider_shut_down(self):
|
||||
"""A cached agent's shutdown_memory_provider is called during gateway stop."""
|
||||
gw = _FakeGateway()
|
||||
agent = _make_mock_agent()
|
||||
gw._agent_cache["session-1"] = (agent, "sig-123")
|
||||
|
||||
# Call the real stop() from GatewayRunner
|
||||
await gw_mod.GatewayRunner.stop(gw)
|
||||
|
||||
agent.shutdown_memory_provider.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_cleared_after_shutdown(self):
|
||||
"""The _agent_cache dict is cleared after stop."""
|
||||
gw = _FakeGateway()
|
||||
agent = _make_mock_agent()
|
||||
gw._agent_cache["s1"] = (agent, "sig1")
|
||||
|
||||
await gw_mod.GatewayRunner.stop(gw)
|
||||
|
||||
assert len(gw._agent_cache) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_cached_agents_no_error(self):
|
||||
"""stop() works fine when _agent_cache is empty."""
|
||||
gw = _FakeGateway()
|
||||
|
||||
await gw_mod.GatewayRunner.stop(gw) # Should not raise
|
||||
|
||||
assert len(gw._agent_cache) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_cached_agents_all_cleaned(self):
|
||||
"""All cached agents get cleaned up."""
|
||||
gw = _FakeGateway()
|
||||
agents = []
|
||||
for i in range(5):
|
||||
a = _make_mock_agent()
|
||||
agents.append(a)
|
||||
gw._agent_cache[f"s{i}"] = (a, f"sig{i}")
|
||||
|
||||
await gw_mod.GatewayRunner.stop(gw)
|
||||
|
||||
for a in agents:
|
||||
a.shutdown_memory_provider.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_survives_agent_exception(self):
|
||||
"""An exception from one agent's shutdown doesn't prevent others."""
|
||||
gw = _FakeGateway()
|
||||
|
||||
bad = _make_mock_agent()
|
||||
bad.shutdown_memory_provider.side_effect = RuntimeError("boom")
|
||||
bad.close.side_effect = RuntimeError("boom")
|
||||
|
||||
good = _make_mock_agent()
|
||||
|
||||
gw._agent_cache["bad"] = (bad, "sig-bad")
|
||||
gw._agent_cache["good"] = (good, "sig-good")
|
||||
|
||||
await gw_mod.GatewayRunner.stop(gw)
|
||||
|
||||
# The good agent should still be cleaned up
|
||||
good.shutdown_memory_provider.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plain_agent_not_tuple(self):
|
||||
"""Cache entries that aren't tuples (just bare agents) are also cleaned."""
|
||||
gw = _FakeGateway()
|
||||
agent = _make_mock_agent()
|
||||
gw._agent_cache["s1"] = agent # Not a tuple
|
||||
|
||||
await gw_mod.GatewayRunner.stop(gw)
|
||||
|
||||
agent.shutdown_memory_provider.assert_called_once()
|
||||
assert len(gw._agent_cache) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_entry_skipped(self):
|
||||
"""A None cache entry doesn't cause errors."""
|
||||
gw = _FakeGateway()
|
||||
gw._agent_cache["s1"] = None
|
||||
|
||||
await gw_mod.GatewayRunner.stop(gw)
|
||||
|
||||
assert len(gw._agent_cache) == 0
|
||||
|
||||
|
||||
class TestRunningAgentsNotDoubleCleaned:
|
||||
"""Verify behavior when agents appear in both _running_agents and _agent_cache."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_running_and_cached_agent_cleaned_at_least_once(self):
|
||||
"""An agent in both _running_agents and _agent_cache gets
|
||||
shutdown_memory_provider called at least once."""
|
||||
gw = _FakeGateway()
|
||||
shared = _make_mock_agent()
|
||||
|
||||
gw._running_agents["s1"] = shared
|
||||
gw._agent_cache["s1"] = (shared, "sig1")
|
||||
|
||||
await gw_mod.GatewayRunner.stop(gw)
|
||||
|
||||
# Called at least once — either from _finalize_shutdown_agents
|
||||
# or from the cache sweep (or both)
|
||||
assert shared.shutdown_memory_provider.call_count >= 1
|
||||
@ -11,7 +11,7 @@ We mock the slack modules at import time to avoid collection errors.
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch, call
|
||||
|
||||
import pytest
|
||||
|
||||
@ -21,6 +21,7 @@ from gateway.platforms.base import (
|
||||
MessageType,
|
||||
SendResult,
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
is_host_excluded_by_no_proxy,
|
||||
)
|
||||
|
||||
|
||||
@ -188,6 +189,198 @@ class TestSlackConnectCleanup:
|
||||
assert adapter._platform_lock_identity is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestSlackProxyBehavior
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSlackProxyBehavior:
|
||||
def test_no_proxy_helper_matches_slack_hosts(self):
|
||||
assert is_host_excluded_by_no_proxy("slack.com", "localhost,.slack.com")
|
||||
assert is_host_excluded_by_no_proxy("files.slack.com", "localhost slack.com")
|
||||
assert is_host_excluded_by_no_proxy("wss-primary.slack.com", "*")
|
||||
assert not is_host_excluded_by_no_proxy("slack.com", "localhost,.internal.corp")
|
||||
|
||||
def test_resolve_slack_proxy_url_ignores_unsupported_proxy_schemes(self):
|
||||
with patch.object(_slack_mod, "resolve_proxy_url", return_value="socks5://proxy.example.com:1080"):
|
||||
assert _slack_mod._resolve_slack_proxy_url() is None
|
||||
|
||||
def test_resolve_slack_proxy_url_checks_all_slack_hosts(self):
|
||||
with patch.object(_slack_mod, "resolve_proxy_url", return_value="http://proxy.example.com:3128"), \
|
||||
patch.object(_slack_mod, "is_host_excluded_by_no_proxy", side_effect=lambda host: host == "wss-primary.slack.com") as excluded:
|
||||
assert _slack_mod._resolve_slack_proxy_url() is None
|
||||
excluded.assert_has_calls([
|
||||
call("slack.com"),
|
||||
call("files.slack.com"),
|
||||
call("wss-primary.slack.com"),
|
||||
])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_uses_proxy_when_not_bypassed(self):
|
||||
created_apps = []
|
||||
created_clients = []
|
||||
|
||||
class FakeWebClient:
|
||||
def __init__(self, token):
|
||||
self.token = token
|
||||
self.proxy = "constructor-default"
|
||||
suffix = token.split("-")[-1]
|
||||
self.auth_test = AsyncMock(return_value={
|
||||
"team_id": f"T_{suffix}",
|
||||
"user_id": f"U_{suffix}",
|
||||
"user": f"bot-{suffix}",
|
||||
"team": f"Team {suffix}",
|
||||
})
|
||||
created_clients.append(self)
|
||||
|
||||
class FakeApp:
|
||||
def __init__(self, token):
|
||||
self.token = token
|
||||
self.client = FakeWebClient(token)
|
||||
self.registered_events = []
|
||||
self.registered_commands = []
|
||||
self.registered_actions = []
|
||||
created_apps.append(self)
|
||||
|
||||
def event(self, event_type):
|
||||
self.registered_events.append(event_type)
|
||||
|
||||
def decorator(fn):
|
||||
return fn
|
||||
|
||||
return decorator
|
||||
|
||||
def command(self, command_name):
|
||||
self.registered_commands.append(command_name)
|
||||
|
||||
def decorator(fn):
|
||||
return fn
|
||||
|
||||
return decorator
|
||||
|
||||
def action(self, action_id):
|
||||
self.registered_actions.append(action_id)
|
||||
|
||||
def decorator(fn):
|
||||
return fn
|
||||
|
||||
return decorator
|
||||
|
||||
class FakeSocketModeHandler:
|
||||
def __init__(self, app, app_token, proxy=None):
|
||||
self.app = app
|
||||
self.app_token = app_token
|
||||
self.proxy = proxy
|
||||
self.client = MagicMock(proxy="constructor-default")
|
||||
|
||||
def start_async(self):
|
||||
return None
|
||||
|
||||
async def close_async(self):
|
||||
return None
|
||||
|
||||
config = PlatformConfig(enabled=True, token="xoxb-primary,xoxb-secondary")
|
||||
adapter = SlackAdapter(config)
|
||||
|
||||
with patch.object(_slack_mod, "AsyncApp", side_effect=FakeApp), \
|
||||
patch.object(_slack_mod, "AsyncWebClient", side_effect=FakeWebClient), \
|
||||
patch.object(_slack_mod, "AsyncSocketModeHandler", FakeSocketModeHandler), \
|
||||
patch.object(_slack_mod, "_resolve_slack_proxy_url", return_value="http://proxy.example.com:3128"), \
|
||||
patch.dict(os.environ, {"SLACK_APP_TOKEN": "xapp-fake"}, clear=False), \
|
||||
patch("gateway.status.acquire_scoped_lock", return_value=(True, None)), \
|
||||
patch("asyncio.create_task", return_value=MagicMock(name="socket-mode-task")):
|
||||
result = await adapter.connect()
|
||||
|
||||
assert result is True
|
||||
assert created_apps[0].client.proxy == "http://proxy.example.com:3128"
|
||||
assert all(client.proxy == "http://proxy.example.com:3128" for client in created_clients)
|
||||
assert adapter._handler is not None
|
||||
assert adapter._handler.proxy == "http://proxy.example.com:3128"
|
||||
assert adapter._handler.client.proxy == "http://proxy.example.com:3128"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_clears_proxy_when_no_proxy_matches_slack(self):
|
||||
created_apps = []
|
||||
created_clients = []
|
||||
|
||||
class FakeWebClient:
|
||||
def __init__(self, token):
|
||||
self.token = token
|
||||
self.proxy = "constructor-default"
|
||||
suffix = token.split("-")[-1]
|
||||
self.auth_test = AsyncMock(return_value={
|
||||
"team_id": f"T_{suffix}",
|
||||
"user_id": f"U_{suffix}",
|
||||
"user": f"bot-{suffix}",
|
||||
"team": f"Team {suffix}",
|
||||
})
|
||||
created_clients.append(self)
|
||||
|
||||
class FakeApp:
|
||||
def __init__(self, token):
|
||||
self.token = token
|
||||
self.client = FakeWebClient(token)
|
||||
self.registered_events = []
|
||||
self.registered_commands = []
|
||||
self.registered_actions = []
|
||||
created_apps.append(self)
|
||||
|
||||
def event(self, event_type):
|
||||
self.registered_events.append(event_type)
|
||||
|
||||
def decorator(fn):
|
||||
return fn
|
||||
|
||||
return decorator
|
||||
|
||||
def command(self, command_name):
|
||||
self.registered_commands.append(command_name)
|
||||
|
||||
def decorator(fn):
|
||||
return fn
|
||||
|
||||
return decorator
|
||||
|
||||
def action(self, action_id):
|
||||
self.registered_actions.append(action_id)
|
||||
|
||||
def decorator(fn):
|
||||
return fn
|
||||
|
||||
return decorator
|
||||
|
||||
class FakeSocketModeHandler:
|
||||
def __init__(self, app, app_token, proxy=None):
|
||||
self.app = app
|
||||
self.app_token = app_token
|
||||
self.proxy = proxy
|
||||
self.client = MagicMock(proxy="constructor-default")
|
||||
|
||||
def start_async(self):
|
||||
return None
|
||||
|
||||
async def close_async(self):
|
||||
return None
|
||||
|
||||
config = PlatformConfig(enabled=True, token="xoxb-primary")
|
||||
adapter = SlackAdapter(config)
|
||||
|
||||
with patch.object(_slack_mod, "AsyncApp", side_effect=FakeApp), \
|
||||
patch.object(_slack_mod, "AsyncWebClient", side_effect=FakeWebClient), \
|
||||
patch.object(_slack_mod, "AsyncSocketModeHandler", FakeSocketModeHandler), \
|
||||
patch.object(_slack_mod, "_resolve_slack_proxy_url", return_value=None), \
|
||||
patch.dict(os.environ, {"SLACK_APP_TOKEN": "xapp-fake"}, clear=False), \
|
||||
patch("gateway.status.acquire_scoped_lock", return_value=(True, None)), \
|
||||
patch("asyncio.create_task", return_value=MagicMock(name="socket-mode-task")):
|
||||
result = await adapter.connect()
|
||||
|
||||
assert result is True
|
||||
assert created_apps[0].client.proxy is None
|
||||
assert all(client.proxy is None for client in created_clients)
|
||||
assert adapter._handler is not None
|
||||
assert adapter._handler.proxy is None
|
||||
assert adapter._handler.client.proxy is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestSendDocument
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -287,6 +480,40 @@ class TestSendDocument:
|
||||
call_kwargs = adapter._app.client.files_upload_v2.call_args[1]
|
||||
assert call_kwargs["thread_ts"] == "1234567890.123456"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_document_thread_upload_marks_bot_participation(self, adapter, tmp_path):
|
||||
test_file = tmp_path / "notes.txt"
|
||||
test_file.write_bytes(b"some notes")
|
||||
|
||||
adapter._app.client.files_upload_v2 = AsyncMock(return_value={"ok": True})
|
||||
|
||||
await adapter.send_document(
|
||||
chat_id="C123",
|
||||
file_path=str(test_file),
|
||||
metadata={"thread_id": "1234567890.123456"},
|
||||
)
|
||||
|
||||
assert "1234567890.123456" in adapter._bot_message_ts
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_document_retries_transient_upload_error(self, adapter, tmp_path):
|
||||
test_file = tmp_path / "notes.txt"
|
||||
test_file.write_bytes(b"some notes")
|
||||
|
||||
adapter._app.client.files_upload_v2 = AsyncMock(
|
||||
side_effect=[RuntimeError("Connection reset by peer"), {"ok": True}]
|
||||
)
|
||||
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock) as sleep_mock:
|
||||
result = await adapter.send_document(
|
||||
chat_id="C123",
|
||||
file_path=str(test_file),
|
||||
)
|
||||
|
||||
assert result.success
|
||||
assert adapter._app.client.files_upload_v2.await_count == 2
|
||||
sleep_mock.assert_awaited_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestSendVideo
|
||||
@ -355,15 +582,17 @@ class TestSendVideo:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestIncomingDocumentHandling:
|
||||
def _make_event(self, files=None, text="hello", channel_type="im"):
|
||||
def _make_event(self, files=None, text="hello", channel_type="im", blocks=None, attachments=None):
|
||||
"""Build a mock Slack message event with file attachments."""
|
||||
return {
|
||||
"text": text,
|
||||
"user": "U_USER",
|
||||
"channel": "C123",
|
||||
"channel": "D123",
|
||||
"channel_type": channel_type,
|
||||
"ts": "1234567890.000001",
|
||||
"files": files or [],
|
||||
"blocks": blocks or [],
|
||||
"attachments": attachments or [],
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -428,6 +657,36 @@ class TestIncomingDocumentHandling:
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert "# Title" in msg_event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_json_snippet_injects_content(self, adapter):
|
||||
"""A .json snippet should be treated as a text document and injected."""
|
||||
content = b'{"hello": "world", "count": 2}'
|
||||
|
||||
with patch.object(adapter, "_download_slack_file_bytes", new_callable=AsyncMock) as dl:
|
||||
dl.return_value = content
|
||||
event = self._make_event(
|
||||
text="can you parse this",
|
||||
files=[{
|
||||
"mimetype": "text/plain",
|
||||
"name": "zapfile.json",
|
||||
"filetype": "json",
|
||||
"pretty_type": "JSON",
|
||||
"mode": "snippet",
|
||||
"editable": True,
|
||||
"url_private_download": "https://files.slack.com/zapfile.json",
|
||||
"size": len(content),
|
||||
}],
|
||||
)
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert msg_event.message_type == MessageType.DOCUMENT
|
||||
assert len(msg_event.media_urls) == 1
|
||||
assert msg_event.media_types == ["application/json"]
|
||||
assert '[Content of zapfile.json]' in msg_event.text
|
||||
assert '"hello": "world"' in msg_event.text
|
||||
assert 'can you parse this' in msg_event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_txt_not_injected(self, adapter):
|
||||
"""A .txt file over 100KB should be cached but NOT injected."""
|
||||
@ -511,6 +770,207 @@ class TestIncomingDocumentHandling:
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert msg_event.message_type == MessageType.PHOTO
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_failure_is_surfaced_in_message_text(self, adapter):
|
||||
"""Attachment download failures (401/403/HTML-body/etc.) should be
|
||||
translated into a user-facing `[Slack attachment notice]` block so
|
||||
the agent can tell the user what to fix (e.g. missing files:read
|
||||
scope). No proactive files.info probe is made — the diagnostic
|
||||
runs only when the download actually fails.
|
||||
"""
|
||||
import httpx
|
||||
req = httpx.Request("GET", "https://files.slack.com/photo.jpg")
|
||||
resp = httpx.Response(403, request=req)
|
||||
|
||||
with patch.object(adapter, "_download_slack_file", new_callable=AsyncMock) as dl:
|
||||
dl.side_effect = httpx.HTTPStatusError("403", request=req, response=resp)
|
||||
event = self._make_event(text="what's in this?", files=[{
|
||||
"id": "F123",
|
||||
"mimetype": "image/jpeg",
|
||||
"name": "photo.jpg",
|
||||
"url_private_download": "https://files.slack.com/photo.jpg",
|
||||
"size": 1024,
|
||||
}])
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert msg_event.message_type == MessageType.TEXT
|
||||
assert "[Slack attachment notice]" in msg_event.text
|
||||
assert "403" in msg_event.text
|
||||
assert "what's in this?" in msg_event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rich_text_blocks_do_not_duplicate_plain_text(self, adapter):
|
||||
"""Plain rich_text composer blocks match the plain text field exactly,
|
||||
so the dedupe guard keeps the message clean."""
|
||||
event = self._make_event(
|
||||
text="hello world",
|
||||
blocks=[
|
||||
{
|
||||
"type": "rich_text",
|
||||
"elements": [
|
||||
{
|
||||
"type": "rich_text_section",
|
||||
"elements": [
|
||||
{"type": "text", "text": "hello world"},
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert msg_event.text == "hello world"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rich_text_quotes_and_lists_are_extracted(self, adapter):
|
||||
"""Nested quote and list content should be surfaced from rich_text blocks."""
|
||||
event = self._make_event(
|
||||
text="Can you summarize this?",
|
||||
blocks=[
|
||||
{
|
||||
"type": "rich_text",
|
||||
"elements": [
|
||||
{
|
||||
"type": "rich_text_quote",
|
||||
"elements": [
|
||||
{
|
||||
"type": "rich_text_section",
|
||||
"elements": [{"type": "text", "text": "Quoted line"}],
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"type": "rich_text_list",
|
||||
"style": "bullet",
|
||||
"elements": [
|
||||
{
|
||||
"type": "rich_text_section",
|
||||
"elements": [{"type": "text", "text": "First bullet"}],
|
||||
},
|
||||
{
|
||||
"type": "rich_text_section",
|
||||
"elements": [{"type": "text", "text": "Second bullet"}],
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert "Can you summarize this?" in msg_event.text
|
||||
assert "> Quoted line" in msg_event.text
|
||||
assert "• First bullet" in msg_event.text
|
||||
assert "• Second bullet" in msg_event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attachments_unfurl_text_is_appended_even_when_url_is_in_message(self, adapter):
|
||||
"""Shared URLs should still expose unfurl preview text to the agent."""
|
||||
event = self._make_event(
|
||||
text="Look at this doc https://example.com/spec",
|
||||
attachments=[
|
||||
{
|
||||
"title": "Spec",
|
||||
"from_url": "https://example.com/spec",
|
||||
"text": "The latest product spec preview",
|
||||
"footer": "Notion",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert "Look at this doc https://example.com/spec" in msg_event.text
|
||||
assert "📎 [Spec](https://example.com/spec)" in msg_event.text
|
||||
assert "The latest product spec preview" in msg_event.text
|
||||
assert "_Notion_" in msg_event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_unfurl_attachments_are_skipped(self, adapter):
|
||||
"""Message unfurls should be skipped to avoid echoing Slack message copies."""
|
||||
event = self._make_event(
|
||||
text="https://example.com/thread",
|
||||
attachments=[
|
||||
{
|
||||
"is_msg_unfurl": True,
|
||||
"title": "Thread copy",
|
||||
"text": "This should not be appended",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert msg_event.text == "https://example.com/thread"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_channel_routing_ignores_bot_mentions_inside_block_text(self, adapter):
|
||||
"""Block-extracted text with a bot mention must not satisfy mention
|
||||
gating in channels — routing decisions use the original user text so
|
||||
quoted/forwarded content can't trick the bot into responding."""
|
||||
event = self._make_event(
|
||||
text="please review",
|
||||
channel_type="channel",
|
||||
blocks=[
|
||||
{
|
||||
"type": "rich_text",
|
||||
"elements": [
|
||||
{
|
||||
"type": "rich_text_quote",
|
||||
"elements": [
|
||||
{
|
||||
"type": "rich_text_section",
|
||||
"elements": [{"type": "text", "text": "Contains <@U_BOT> in quoted text"}],
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
adapter.handle_message.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_quoted_slash_command_text_does_not_change_message_type(self, adapter):
|
||||
"""Quoted slash-like content should not convert a normal message into a command."""
|
||||
event = self._make_event(
|
||||
text="",
|
||||
blocks=[
|
||||
{
|
||||
"type": "rich_text",
|
||||
"elements": [
|
||||
{
|
||||
"type": "rich_text_quote",
|
||||
"elements": [
|
||||
{
|
||||
"type": "rich_text_section",
|
||||
"elements": [{"type": "text", "text": "/deploy now"}],
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert msg_event.message_type == MessageType.TEXT
|
||||
assert "> /deploy now" in msg_event.text
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestMessageRouting
|
||||
@ -1887,6 +2347,48 @@ class TestSendImageSSRFGuards:
|
||||
assert "see this" in call_kwargs["text"]
|
||||
assert "https://public.example/image.png" in call_kwargs["text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_image_fallback_preserves_thread_metadata(self, adapter):
|
||||
redirect_response = MagicMock()
|
||||
redirect_response.is_redirect = True
|
||||
redirect_response.next_request = MagicMock(
|
||||
url="http://169.254.169.254/latest/meta-data"
|
||||
)
|
||||
|
||||
client_kwargs = {}
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def fake_get(_url):
|
||||
for hook in client_kwargs["event_hooks"]["response"]:
|
||||
await hook(redirect_response)
|
||||
|
||||
mock_client.get = AsyncMock(side_effect=fake_get)
|
||||
adapter._app.client.files_upload_v2 = AsyncMock(return_value={"ok": True})
|
||||
adapter._app.client.chat_postMessage = AsyncMock(return_value={"ts": "reply_ts"})
|
||||
|
||||
def fake_async_client(*args, **kwargs):
|
||||
client_kwargs.update(kwargs)
|
||||
return mock_client
|
||||
|
||||
def fake_is_safe_url(url):
|
||||
return url == "https://public.example/image.png"
|
||||
|
||||
with (
|
||||
patch("tools.url_safety.is_safe_url", side_effect=fake_is_safe_url),
|
||||
patch("httpx.AsyncClient", side_effect=fake_async_client),
|
||||
):
|
||||
await adapter.send_image(
|
||||
chat_id="C123",
|
||||
image_url="https://public.example/image.png",
|
||||
caption="see this",
|
||||
metadata={"thread_id": "parent_ts_789"},
|
||||
)
|
||||
|
||||
call_kwargs = adapter._app.client.chat_postMessage.call_args.kwargs
|
||||
assert call_kwargs.get("thread_ts") == "parent_ts_789"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestProgressMessageThread
|
||||
@ -2011,3 +2513,76 @@ class TestProgressMessageThread:
|
||||
"so each @mention starts its own thread"
|
||||
)
|
||||
assert msg_event.message_id == "2000000000.000001"
|
||||
|
||||
|
||||
class TestSlackReplyToText:
|
||||
"""Ensure MessageEvent.reply_to_text is populated on thread replies so
|
||||
gateway.run can inject a ``[Replying to: "..."]`` prefix (parity with
|
||||
Telegram/Discord/Feishu/WeCom)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slack_reply_to_text_set_on_thread_reply(self, adapter):
|
||||
"""When a thread reply arrives and the parent was posted by a bot
|
||||
(e.g. cron summary), reply_to_text must carry the parent's text."""
|
||||
adapter._channel_team = {} # primary workspace only
|
||||
adapter._team_bot_user_ids = {}
|
||||
|
||||
# Mock conversations_replies to return a bot-posted parent
|
||||
adapter._app.client.conversations_replies = AsyncMock(return_value={
|
||||
"messages": [
|
||||
{
|
||||
"ts": "1000.0",
|
||||
"bot_id": "B_CRON",
|
||||
"text": "メール要約: 新着メール3件あります",
|
||||
},
|
||||
{"ts": "1000.5", "user": "U_USER", "text": "詳細を教えて"},
|
||||
]
|
||||
})
|
||||
|
||||
# Use a DM so mention-gating doesn't short-circuit the handler.
|
||||
event = {
|
||||
"text": "詳細を教えて",
|
||||
"user": "U_USER",
|
||||
"channel": "D123",
|
||||
"channel_type": "im",
|
||||
"ts": "1000.5",
|
||||
"thread_ts": "1000.0", # thread reply
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
adapter, "_resolve_user_name", new=AsyncMock(return_value="Alice")
|
||||
):
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
assert adapter.handle_message.call_args is not None, (
|
||||
"handle_message must be invoked for thread-reply DM"
|
||||
)
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert msg_event.reply_to_message_id == "1000.0"
|
||||
# The critical assertion: parent text is exposed as reply_to_text so the
|
||||
# gateway can inject it when not already in the session history.
|
||||
assert msg_event.reply_to_text is not None
|
||||
assert "メール要約" in msg_event.reply_to_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slack_reply_to_text_none_for_top_level_message(self, adapter):
|
||||
"""Top-level messages (no thread_ts) must not set reply_to_text."""
|
||||
event = {
|
||||
"text": "hello",
|
||||
"user": "U_USER",
|
||||
"channel": "D123",
|
||||
"channel_type": "im",
|
||||
"ts": "1000.0",
|
||||
# no thread_ts — top-level DM
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
adapter, "_resolve_user_name", new=AsyncMock(return_value="Alice")
|
||||
):
|
||||
await adapter._handle_slack_message(event)
|
||||
|
||||
assert adapter.handle_message.call_args is not None
|
||||
msg_event = adapter.handle_message.call_args[0][0]
|
||||
assert msg_event.reply_to_text is None
|
||||
# Top-level message: reply_to_message_id must be falsy (None or empty).
|
||||
assert not msg_event.reply_to_message_id
|
||||
|
||||
@ -276,23 +276,44 @@ class TestSlackThreadContext:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_bot_messages(self):
|
||||
"""Self-bot child replies are skipped to avoid circular context,
|
||||
but non-self bots (e.g. cron posts, third-party integrations) are kept.
|
||||
|
||||
Regression guard for the fix in _fetch_thread_context: previously ALL
|
||||
bot messages were dropped, which lost context when the bot was replying
|
||||
to a cron-posted thread parent."""
|
||||
adapter = _make_adapter()
|
||||
mock_client = adapter._team_clients["T1"]
|
||||
mock_client.conversations_replies = AsyncMock(return_value={
|
||||
"messages": [
|
||||
{"ts": "1000.0", "user": "U1", "text": "Parent"},
|
||||
{"ts": "1000.1", "bot_id": "B1", "text": "Bot reply (should be skipped)"},
|
||||
# Self-bot reply -> must be skipped (circular)
|
||||
{
|
||||
"ts": "1000.1",
|
||||
"bot_id": "B_SELF",
|
||||
"user": "U_BOT",
|
||||
"text": "Previous bot self-reply (should be skipped)",
|
||||
},
|
||||
# Third-party bot child -> kept (useful context)
|
||||
{
|
||||
"ts": "1000.15",
|
||||
"bot_id": "B_OTHER",
|
||||
"user": "U_OTHER_BOT",
|
||||
"text": "Deploy succeeded",
|
||||
},
|
||||
{"ts": "1000.2", "user": "U1", "text": "Current"},
|
||||
]
|
||||
})
|
||||
adapter._user_name_cache = {"U1": "Alice"}
|
||||
adapter._user_name_cache = {"U1": "Alice", "U_OTHER_BOT": "DeployBot"}
|
||||
|
||||
context = await adapter._fetch_thread_context(
|
||||
channel_id="C1", thread_ts="1000.0", current_ts="1000.2", team_id="T1"
|
||||
)
|
||||
|
||||
assert "Bot reply" not in context
|
||||
assert "Previous bot self-reply" not in context
|
||||
assert "Alice: Parent" in context
|
||||
# Third-party bot message must now be included
|
||||
assert "Deploy succeeded" in context
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_thread(self):
|
||||
@ -316,6 +337,166 @@ class TestSlackThreadContext:
|
||||
)
|
||||
assert context == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_thread_context_includes_bot_parent(self):
|
||||
"""The thread parent posted by a bot (e.g. a cron summary) must be
|
||||
included in the context, prefixed with ``[thread parent]``."""
|
||||
adapter = _make_adapter()
|
||||
mock_client = adapter._team_clients["T1"]
|
||||
mock_client.conversations_replies = AsyncMock(return_value={
|
||||
"messages": [
|
||||
# Bot-posted parent (cron job)
|
||||
{
|
||||
"ts": "1000.0",
|
||||
"bot_id": "B123",
|
||||
"subtype": "bot_message",
|
||||
"username": "cron",
|
||||
"text": "メール要約: 本日の新着3件",
|
||||
},
|
||||
# User reply that triggered the fetch
|
||||
{"ts": "1000.1", "user": "U1", "text": "詳細を教えて"},
|
||||
]
|
||||
})
|
||||
adapter._user_name_cache = {"U1": "Alice"}
|
||||
|
||||
context = await adapter._fetch_thread_context(
|
||||
channel_id="C1",
|
||||
thread_ts="1000.0",
|
||||
current_ts="1000.1", # exclude the trigger message itself
|
||||
team_id="T1",
|
||||
)
|
||||
|
||||
assert "[thread parent]" in context
|
||||
assert "メール要約: 本日の新着3件" in context
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_thread_context_excludes_self_bot_replies(self):
|
||||
"""Parent (non-self bot) is kept, self-bot child replies are dropped,
|
||||
user replies are kept."""
|
||||
adapter = _make_adapter()
|
||||
mock_client = adapter._team_clients["T1"]
|
||||
mock_client.conversations_replies = AsyncMock(return_value={
|
||||
"messages": [
|
||||
{"ts": "1000.0", "bot_id": "B_CRON", "text": "Cron summary"},
|
||||
# Self-bot child reply -> excluded
|
||||
{
|
||||
"ts": "1000.1",
|
||||
"bot_id": "B_SELF",
|
||||
"user": "U_BOT", # matches adapter._bot_user_id
|
||||
"text": "Previous self reply",
|
||||
},
|
||||
# User reply -> kept
|
||||
{"ts": "1000.2", "user": "U1", "text": "Follow-up question"},
|
||||
# Current trigger (excluded by current_ts match)
|
||||
{"ts": "1000.3", "user": "U1", "text": "Current"},
|
||||
]
|
||||
})
|
||||
adapter._user_name_cache = {"U1": "Alice"}
|
||||
|
||||
context = await adapter._fetch_thread_context(
|
||||
channel_id="C1", thread_ts="1000.0", current_ts="1000.3", team_id="T1"
|
||||
)
|
||||
|
||||
assert "Cron summary" in context
|
||||
assert "[thread parent]" in context
|
||||
assert "Previous self reply" not in context
|
||||
assert "Follow-up question" in context
|
||||
assert "Current" not in context
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_thread_context_multi_workspace(self):
|
||||
"""Self-bot filtering must use the per-workspace bot user id so a
|
||||
self-bot id that belongs to a different workspace does not accidentally
|
||||
filter out a legitimate message in the current workspace."""
|
||||
adapter = _make_adapter()
|
||||
# Add a second workspace with a different bot user id
|
||||
adapter._team_clients["T2"] = AsyncMock()
|
||||
adapter._team_bot_user_ids = {"T1": "U_BOT_T1", "T2": "U_BOT_T2"}
|
||||
adapter._bot_user_id = "U_BOT_T1"
|
||||
adapter._channel_team["C2"] = "T2"
|
||||
|
||||
mock_client = adapter._team_clients["T2"]
|
||||
mock_client.conversations_replies = AsyncMock(return_value={
|
||||
"messages": [
|
||||
{"ts": "2000.0", "user": "U2", "text": "Parent T2"},
|
||||
# This has the *T1* bot's user id — from T2's perspective this
|
||||
# is a third-party bot, so it must be kept.
|
||||
{
|
||||
"ts": "2000.1",
|
||||
"bot_id": "B_FOREIGN",
|
||||
"user": "U_BOT_T1",
|
||||
"team": "T2",
|
||||
"text": "Cross-workspace bot reply",
|
||||
},
|
||||
# Self-bot for T2 — must be skipped
|
||||
{
|
||||
"ts": "2000.2",
|
||||
"bot_id": "B_SELF_T2",
|
||||
"user": "U_BOT_T2",
|
||||
"team": "T2",
|
||||
"text": "Own T2 bot reply",
|
||||
},
|
||||
{"ts": "2000.3", "user": "U2", "text": "Current"},
|
||||
]
|
||||
})
|
||||
adapter._user_name_cache = {"U2": "Bob"}
|
||||
|
||||
context = await adapter._fetch_thread_context(
|
||||
channel_id="C2", thread_ts="2000.0", current_ts="2000.3", team_id="T2"
|
||||
)
|
||||
|
||||
assert "Parent T2" in context
|
||||
assert "Cross-workspace bot reply" in context
|
||||
assert "Own T2 bot reply" not in context
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_thread_context_current_ts_excluded(self):
|
||||
"""Regression guard: the message whose ts == current_ts must never
|
||||
appear in the context output (it will be delivered as the user
|
||||
message itself)."""
|
||||
adapter = _make_adapter()
|
||||
mock_client = adapter._team_clients["T1"]
|
||||
mock_client.conversations_replies = AsyncMock(return_value={
|
||||
"messages": [
|
||||
{"ts": "1000.0", "user": "U1", "text": "Parent"},
|
||||
{"ts": "1000.1", "user": "U1", "text": "DO NOT INCLUDE THIS"},
|
||||
]
|
||||
})
|
||||
adapter._user_name_cache = {"U1": "Alice"}
|
||||
|
||||
context = await adapter._fetch_thread_context(
|
||||
channel_id="C1", thread_ts="1000.0", current_ts="1000.1", team_id="T1"
|
||||
)
|
||||
|
||||
assert "Parent" in context
|
||||
assert "DO NOT INCLUDE THIS" not in context
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_thread_parent_text_from_cache(self):
|
||||
"""_fetch_thread_parent_text should reuse the thread-context cache
|
||||
when it is warm, avoiding an extra conversations.replies call."""
|
||||
adapter = _make_adapter()
|
||||
mock_client = adapter._team_clients["T1"]
|
||||
mock_client.conversations_replies = AsyncMock(return_value={
|
||||
"messages": [
|
||||
{"ts": "1000.0", "bot_id": "B123", "text": "Parent summary"},
|
||||
{"ts": "1000.1", "user": "U1", "text": "reply"},
|
||||
]
|
||||
})
|
||||
|
||||
# Warm the cache via _fetch_thread_context
|
||||
await adapter._fetch_thread_context(
|
||||
channel_id="C1", thread_ts="1000.0", current_ts="1000.1", team_id="T1"
|
||||
)
|
||||
assert mock_client.conversations_replies.await_count == 1
|
||||
|
||||
parent = await adapter._fetch_thread_parent_text(
|
||||
channel_id="C1", thread_ts="1000.0", team_id="T1"
|
||||
)
|
||||
assert parent == "Parent summary"
|
||||
# No additional API call
|
||||
assert mock_client.conversations_replies.await_count == 1
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# _has_active_session_for_thread — session key fix (#5833)
|
||||
|
||||
133
tests/gateway/test_slack_channel_skills.py
Normal file
133
tests/gateway/test_slack_channel_skills.py
Normal file
@ -0,0 +1,133 @@
|
||||
"""Tests for Slack channel_skill_bindings auto-skill resolution."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
def _make_adapter(extra=None):
|
||||
"""Create a minimal SlackAdapter stub with the given ``config.extra``."""
|
||||
from gateway.platforms.slack import SlackAdapter
|
||||
adapter = object.__new__(SlackAdapter)
|
||||
adapter.config = MagicMock()
|
||||
adapter.config.extra = extra or {}
|
||||
return adapter
|
||||
|
||||
|
||||
def _resolve(adapter, channel_id, parent_id=None):
|
||||
from gateway.platforms.base import resolve_channel_skills
|
||||
return resolve_channel_skills(adapter.config.extra, channel_id, parent_id)
|
||||
|
||||
|
||||
class TestSlackResolveChannelSkills:
|
||||
def test_no_bindings_returns_none(self):
|
||||
adapter = _make_adapter()
|
||||
assert _resolve(adapter, "D0ABC") is None
|
||||
|
||||
def test_match_by_dm_channel_id(self):
|
||||
"""The primary use case: binding a skill to a Slack DM channel."""
|
||||
adapter = _make_adapter({
|
||||
"channel_skill_bindings": [
|
||||
{"id": "D0ATH9TQ0G6", "skills": ["german-flashcards"]},
|
||||
]
|
||||
})
|
||||
assert _resolve(adapter, "D0ATH9TQ0G6") == ["german-flashcards"]
|
||||
|
||||
def test_match_by_parent_id_for_thread(self):
|
||||
"""Slack threads inherit the parent channel's binding."""
|
||||
adapter = _make_adapter({
|
||||
"channel_skill_bindings": [
|
||||
{"id": "C0PARENT", "skills": ["parent-skill"]},
|
||||
]
|
||||
})
|
||||
assert _resolve(adapter, "thread-ts-123", parent_id="C0PARENT") == ["parent-skill"]
|
||||
|
||||
def test_no_match_returns_none(self):
|
||||
adapter = _make_adapter({
|
||||
"channel_skill_bindings": [
|
||||
{"id": "D0AAA", "skills": ["skill-a"]},
|
||||
]
|
||||
})
|
||||
assert _resolve(adapter, "D0BBB") is None
|
||||
|
||||
def test_single_skill_string(self):
|
||||
adapter = _make_adapter({
|
||||
"channel_skill_bindings": [
|
||||
{"id": "D0ATH9TQ0G6", "skill": "german-flashcards"},
|
||||
]
|
||||
})
|
||||
assert _resolve(adapter, "D0ATH9TQ0G6") == ["german-flashcards"]
|
||||
|
||||
def test_dedup_preserves_order(self):
|
||||
adapter = _make_adapter({
|
||||
"channel_skill_bindings": [
|
||||
{"id": "D0ATH9TQ0G6", "skills": ["a", "b", "a", "c", "b"]},
|
||||
]
|
||||
})
|
||||
assert _resolve(adapter, "D0ATH9TQ0G6") == ["a", "b", "c"]
|
||||
|
||||
def test_multiple_bindings_pick_correct(self):
|
||||
adapter = _make_adapter({
|
||||
"channel_skill_bindings": [
|
||||
{"id": "D0AAA", "skills": ["skill-a"]},
|
||||
{"id": "D0BBB", "skills": ["skill-b"]},
|
||||
{"id": "D0CCC", "skills": ["skill-c"]},
|
||||
]
|
||||
})
|
||||
assert _resolve(adapter, "D0BBB") == ["skill-b"]
|
||||
|
||||
def test_malformed_entry_skipped(self):
|
||||
"""Non-dict entries should be ignored, not raise."""
|
||||
adapter = _make_adapter({
|
||||
"channel_skill_bindings": [
|
||||
"not-a-dict",
|
||||
{"id": "D0ABC", "skills": ["good"]},
|
||||
]
|
||||
})
|
||||
assert _resolve(adapter, "D0ABC") == ["good"]
|
||||
|
||||
def test_empty_skills_list_returns_none(self):
|
||||
adapter = _make_adapter({
|
||||
"channel_skill_bindings": [
|
||||
{"id": "D0ABC", "skills": []},
|
||||
]
|
||||
})
|
||||
assert _resolve(adapter, "D0ABC") is None
|
||||
|
||||
def test_empty_skill_string_returns_none(self):
|
||||
adapter = _make_adapter({
|
||||
"channel_skill_bindings": [
|
||||
{"id": "D0ABC", "skill": ""},
|
||||
]
|
||||
})
|
||||
assert _resolve(adapter, "D0ABC") is None
|
||||
|
||||
|
||||
class TestSlackMessageEventAutoSkill:
|
||||
"""Integration-style test: verify auto_skill propagates to MessageEvent."""
|
||||
|
||||
def test_message_event_carries_auto_skill(self):
|
||||
"""Simulate the handler wiring: resolve + attach to MessageEvent."""
|
||||
from gateway.platforms.base import MessageEvent, MessageType, Platform, SessionSource, resolve_channel_skills
|
||||
|
||||
config_extra = {
|
||||
"channel_skill_bindings": [
|
||||
{"id": "D0ATH9TQ0G6", "skills": ["german-flashcards"]},
|
||||
]
|
||||
}
|
||||
auto_skill = resolve_channel_skills(config_extra, "D0ATH9TQ0G6", None)
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.SLACK,
|
||||
chat_id="D0ATH9TQ0G6",
|
||||
chat_name="Mats",
|
||||
chat_type="dm",
|
||||
user_id="U0ABC",
|
||||
user_name="Mats",
|
||||
)
|
||||
event = MessageEvent(
|
||||
text="work",
|
||||
message_type=MessageType.TEXT,
|
||||
source=source,
|
||||
raw_message={},
|
||||
message_id="123.456",
|
||||
auto_skill=auto_skill,
|
||||
)
|
||||
assert event.auto_skill == ["german-flashcards"]
|
||||
@ -55,10 +55,12 @@ CHANNEL_ID = "C0AQWDLHY9M"
|
||||
OTHER_CHANNEL_ID = "C9999999999"
|
||||
|
||||
|
||||
def _make_adapter(require_mention=None, free_response_channels=None):
|
||||
def _make_adapter(require_mention=None, strict_mention=None, free_response_channels=None):
|
||||
extra = {}
|
||||
if require_mention is not None:
|
||||
extra["require_mention"] = require_mention
|
||||
if strict_mention is not None:
|
||||
extra["strict_mention"] = strict_mention
|
||||
if free_response_channels is not None:
|
||||
extra["free_response_channels"] = free_response_channels
|
||||
|
||||
@ -134,6 +136,48 @@ def test_require_mention_env_var_default_true(monkeypatch):
|
||||
assert adapter._slack_require_mention() is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: _slack_strict_mention
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_strict_mention_defaults_to_false(monkeypatch):
|
||||
monkeypatch.delenv("SLACK_STRICT_MENTION", raising=False)
|
||||
adapter = _make_adapter()
|
||||
assert adapter._slack_strict_mention() is False
|
||||
|
||||
|
||||
def test_strict_mention_true():
|
||||
adapter = _make_adapter(strict_mention=True)
|
||||
assert adapter._slack_strict_mention() is True
|
||||
|
||||
|
||||
def test_strict_mention_false():
|
||||
adapter = _make_adapter(strict_mention=False)
|
||||
assert adapter._slack_strict_mention() is False
|
||||
|
||||
|
||||
def test_strict_mention_string_true():
|
||||
adapter = _make_adapter(strict_mention="true")
|
||||
assert adapter._slack_strict_mention() is True
|
||||
|
||||
|
||||
def test_strict_mention_string_off():
|
||||
adapter = _make_adapter(strict_mention="off")
|
||||
assert adapter._slack_strict_mention() is False
|
||||
|
||||
|
||||
def test_strict_mention_malformed_stays_false():
|
||||
"""Unrecognised values keep strict mode OFF (fail-open to legacy behavior)."""
|
||||
adapter = _make_adapter(strict_mention="maybe")
|
||||
assert adapter._slack_strict_mention() is False
|
||||
|
||||
|
||||
def test_strict_mention_env_var_fallback(monkeypatch):
|
||||
monkeypatch.setenv("SLACK_STRICT_MENTION", "true")
|
||||
adapter = _make_adapter() # no config value -> falls back to env
|
||||
assert adapter._slack_strict_mention() is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: _slack_free_response_channels
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -310,3 +354,109 @@ def test_config_bridges_slack_free_response_channels(monkeypatch, tmp_path):
|
||||
import os as _os
|
||||
assert _os.environ["SLACK_REQUIRE_MENTION"] == "false"
|
||||
assert _os.environ["SLACK_FREE_RESPONSE_CHANNELS"] == "C0AQWDLHY9M,C9999999999"
|
||||
|
||||
|
||||
def test_config_bridges_slack_reply_in_thread(monkeypatch, tmp_path):
|
||||
from gateway.config import load_gateway_config
|
||||
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / "config.yaml").write_text(
|
||||
"slack:\n"
|
||||
" reply_in_thread: false\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setenv("SLACK_BOT_TOKEN", "xoxb-test")
|
||||
|
||||
config = load_gateway_config()
|
||||
|
||||
assert config is not None
|
||||
slack_config = config.platforms[Platform.SLACK]
|
||||
assert slack_config.extra.get("reply_in_thread") is False
|
||||
|
||||
adapter = SlackAdapter(slack_config)
|
||||
assert adapter._resolve_thread_ts(reply_to="171.000", metadata={}) is None
|
||||
|
||||
# Top-level channel messages arrive with metadata.thread_id == reply_to
|
||||
# because the inbound handler uses event.ts as a session-keying fallback.
|
||||
# Those must be treated as non-threaded so reply_in_thread=false takes
|
||||
# effect in channels, not just DMs.
|
||||
assert adapter._resolve_thread_ts(
|
||||
reply_to="171.000",
|
||||
metadata={"thread_id": "171.000"},
|
||||
) is None
|
||||
|
||||
# Real thread replies (reply_to differs from thread parent) must still
|
||||
# resolve to the parent thread so conversation context is preserved.
|
||||
assert adapter._resolve_thread_ts(
|
||||
reply_to="171.500",
|
||||
metadata={"thread_id": "171.000"},
|
||||
) == "171.000"
|
||||
|
||||
|
||||
def test_config_bridges_slack_strict_mention(monkeypatch, tmp_path):
|
||||
from gateway.config import load_gateway_config
|
||||
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / "config.yaml").write_text(
|
||||
"slack:\n"
|
||||
" strict_mention: true\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.delenv("SLACK_STRICT_MENTION", raising=False)
|
||||
|
||||
config = load_gateway_config()
|
||||
|
||||
assert config is not None
|
||||
import os as _os
|
||||
assert _os.environ["SLACK_STRICT_MENTION"] == "true"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regression: strict mode must NOT persist mentions into _mentioned_threads
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prevents agent-to-agent ack loops — if a strict-mode bot remembered every
|
||||
# thread it was mentioned in, the next message from the other agent in that
|
||||
# thread would re-trigger the bot and defeat the entire feature.
|
||||
|
||||
def test_mention_in_strict_mode_does_not_register_thread():
|
||||
adapter = _make_adapter(strict_mention=True)
|
||||
adapter._bot_user_id = "U_BOT"
|
||||
adapter._mentioned_threads = set()
|
||||
adapter._MENTIONED_THREADS_MAX = 5000
|
||||
|
||||
thread_ts = "1700000000.100200"
|
||||
event_thread_ts = thread_ts # incoming message is inside an existing thread
|
||||
|
||||
# Mirror the handler's @mention + strict-mode guard that protects
|
||||
# _mentioned_threads.add(). If strict is on, we must skip the add.
|
||||
text = "<@U_BOT> hello"
|
||||
is_mentioned = f"<@{adapter._bot_user_id}>" in text
|
||||
assert is_mentioned
|
||||
if event_thread_ts and not adapter._slack_strict_mention():
|
||||
adapter._mentioned_threads.add(event_thread_ts)
|
||||
|
||||
assert thread_ts not in adapter._mentioned_threads
|
||||
|
||||
|
||||
def test_mention_outside_strict_mode_still_registers_thread():
|
||||
adapter = _make_adapter(strict_mention=False)
|
||||
adapter._bot_user_id = "U_BOT"
|
||||
adapter._mentioned_threads = set()
|
||||
adapter._MENTIONED_THREADS_MAX = 5000
|
||||
|
||||
thread_ts = "1700000000.100200"
|
||||
event_thread_ts = thread_ts
|
||||
|
||||
text = "<@U_BOT> hello"
|
||||
is_mentioned = f"<@{adapter._bot_user_id}>" in text
|
||||
assert is_mentioned
|
||||
if event_thread_ts and not adapter._slack_strict_mention():
|
||||
adapter._mentioned_threads.add(event_thread_ts)
|
||||
|
||||
assert thread_ts in adapter._mentioned_threads
|
||||
|
||||
@ -12,9 +12,9 @@ from gateway.platforms.base import MessageEvent
|
||||
from gateway.session import SessionEntry, SessionSource, build_session_key
|
||||
|
||||
|
||||
def _make_source() -> SessionSource:
|
||||
def _make_source(platform: Platform = Platform.TELEGRAM) -> SessionSource:
|
||||
return SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
platform=platform,
|
||||
user_id="u1",
|
||||
chat_id="c1",
|
||||
user_name="tester",
|
||||
@ -22,24 +22,24 @@ def _make_source() -> SessionSource:
|
||||
)
|
||||
|
||||
|
||||
def _make_event(text: str) -> MessageEvent:
|
||||
def _make_event(text: str, *, platform: Platform = Platform.TELEGRAM) -> MessageEvent:
|
||||
return MessageEvent(
|
||||
text=text,
|
||||
source=_make_source(),
|
||||
source=_make_source(platform),
|
||||
message_id="m1",
|
||||
)
|
||||
|
||||
|
||||
def _make_runner(session_entry: SessionEntry):
|
||||
def _make_runner(session_entry: SessionEntry, *, platform: Platform = Platform.TELEGRAM):
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(
|
||||
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}
|
||||
platforms={platform: PlatformConfig(enabled=True, token="***")}
|
||||
)
|
||||
adapter = MagicMock()
|
||||
adapter.send = AsyncMock()
|
||||
runner.adapters = {Platform.TELEGRAM: adapter}
|
||||
runner.adapters = {platform: adapter}
|
||||
runner._voice_mode = {}
|
||||
runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False)
|
||||
runner.session_store = MagicMock()
|
||||
@ -224,6 +224,93 @@ async def test_handle_message_persists_agent_token_counts(monkeypatch):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_run_slack_home_channel_onboarding_uses_parent_command(monkeypatch):
|
||||
import gateway.run as gateway_run
|
||||
|
||||
session_entry = SessionEntry(
|
||||
session_key=build_session_key(_make_source(Platform.SLACK)),
|
||||
session_id="sess-1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=Platform.SLACK,
|
||||
chat_type="dm",
|
||||
)
|
||||
runner = _make_runner(session_entry, platform=Platform.SLACK)
|
||||
runner.session_store.load_transcript.return_value = []
|
||||
runner.session_store.has_any_sessions.return_value = False
|
||||
runner._run_agent = AsyncMock(
|
||||
return_value={
|
||||
"final_response": "ok",
|
||||
"messages": [],
|
||||
"tools": [],
|
||||
"history_offset": 0,
|
||||
"last_prompt_tokens": 0,
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"model": "openai/test-model",
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.delenv("SLACK_HOME_CHANNEL", raising=False)
|
||||
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
|
||||
monkeypatch.setattr(
|
||||
"agent.model_metadata.get_model_context_length",
|
||||
lambda *_args, **_kwargs: 100000,
|
||||
)
|
||||
|
||||
result = await runner._handle_message(_make_event("hello", platform=Platform.SLACK))
|
||||
|
||||
assert result == "ok"
|
||||
runner.adapters[Platform.SLACK].send.assert_awaited_once()
|
||||
onboarding = runner.adapters[Platform.SLACK].send.await_args.args[1]
|
||||
assert "/hermes sethome" in onboarding
|
||||
assert "Type /sethome" not in onboarding
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_run_non_slack_home_channel_onboarding_keeps_direct_command(monkeypatch):
|
||||
import gateway.run as gateway_run
|
||||
|
||||
session_entry = SessionEntry(
|
||||
session_key=build_session_key(_make_source(Platform.TELEGRAM)),
|
||||
session_id="sess-1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
runner = _make_runner(session_entry, platform=Platform.TELEGRAM)
|
||||
runner.session_store.load_transcript.return_value = []
|
||||
runner.session_store.has_any_sessions.return_value = False
|
||||
runner._run_agent = AsyncMock(
|
||||
return_value={
|
||||
"final_response": "ok",
|
||||
"messages": [],
|
||||
"tools": [],
|
||||
"history_offset": 0,
|
||||
"last_prompt_tokens": 0,
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"model": "openai/test-model",
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.delenv("TELEGRAM_HOME_CHANNEL", raising=False)
|
||||
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
|
||||
monkeypatch.setattr(
|
||||
"agent.model_metadata.get_model_context_length",
|
||||
lambda *_args, **_kwargs: 100000,
|
||||
)
|
||||
|
||||
result = await runner._handle_message(_make_event("hello", platform=Platform.TELEGRAM))
|
||||
|
||||
assert result == "ok"
|
||||
runner.adapters[Platform.TELEGRAM].send.assert_awaited_once()
|
||||
onboarding = runner.adapters[Platform.TELEGRAM].send.await_args.args[1]
|
||||
assert "Type /sethome" in onboarding
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_discards_stale_result_after_session_invalidation(monkeypatch):
|
||||
import gateway.run as gateway_run
|
||||
|
||||
236
tests/gateway/test_stream_consumer_fresh_final.py
Normal file
236
tests/gateway/test_stream_consumer_fresh_final.py
Normal file
@ -0,0 +1,236 @@
|
||||
"""Regression tests for the fresh-final-for-long-lived-previews path.
|
||||
|
||||
Ported from openclaw/openclaw#72038. When a streamed preview has been
|
||||
visible long enough that the platform's edit timestamp would be
|
||||
noticeably stale by completion time, the stream consumer delivers the
|
||||
final reply as a brand-new message and best-effort deletes the old
|
||||
preview. This makes Telegram's visible timestamp reflect completion
|
||||
time instead of first-token time.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.stream_consumer import GatewayStreamConsumer, StreamConsumerConfig
|
||||
|
||||
|
||||
def _make_adapter(*, supports_delete: bool = True) -> MagicMock:
|
||||
"""Build a minimal MagicMock adapter wired for send/edit/delete."""
|
||||
adapter = MagicMock()
|
||||
adapter.REQUIRES_EDIT_FINALIZE = False
|
||||
adapter.MAX_MESSAGE_LENGTH = 4096
|
||||
adapter.send = AsyncMock(return_value=SimpleNamespace(
|
||||
success=True, message_id="initial_preview",
|
||||
))
|
||||
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(
|
||||
success=True, message_id="initial_preview",
|
||||
))
|
||||
if supports_delete:
|
||||
adapter.delete_message = AsyncMock(return_value=True)
|
||||
else:
|
||||
# Adapter without the optional delete_message method — fresh-final
|
||||
# should still work, it just leaves the stale preview in place.
|
||||
del adapter.delete_message # type: ignore[attr-defined]
|
||||
return adapter
|
||||
|
||||
|
||||
class TestFreshFinalForLongLivedPreviews:
|
||||
"""openclaw#72038 port — send fresh final when preview is old."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disabled_by_default_still_edits_in_place(self):
|
||||
"""``fresh_final_after_seconds=0`` preserves the legacy edit path."""
|
||||
adapter = _make_adapter()
|
||||
consumer = GatewayStreamConsumer(
|
||||
adapter=adapter,
|
||||
chat_id="chat",
|
||||
config=StreamConsumerConfig(fresh_final_after_seconds=0.0),
|
||||
)
|
||||
await consumer._send_or_edit("hello")
|
||||
# Pretend the preview has been visible for a long time.
|
||||
consumer._message_created_ts = 0.0 # far in the past
|
||||
await consumer._send_or_edit("hello world", finalize=True)
|
||||
# Should edit, not send a fresh message.
|
||||
assert adapter.send.call_count == 1 # only the initial send
|
||||
adapter.edit_message.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_short_lived_preview_edits_in_place(self):
|
||||
"""Finalizing a preview younger than the threshold → normal edit."""
|
||||
adapter = _make_adapter()
|
||||
consumer = GatewayStreamConsumer(
|
||||
adapter=adapter,
|
||||
chat_id="chat",
|
||||
config=StreamConsumerConfig(fresh_final_after_seconds=60.0),
|
||||
)
|
||||
await consumer._send_or_edit("hello")
|
||||
# Preview is "new" — leave _message_created_ts at its real value.
|
||||
await consumer._send_or_edit("hello world", finalize=True)
|
||||
assert adapter.send.call_count == 1
|
||||
adapter.edit_message.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_long_lived_preview_sends_fresh_final(self):
|
||||
"""Finalizing a preview older than the threshold → fresh send."""
|
||||
adapter = _make_adapter()
|
||||
adapter.send.side_effect = [
|
||||
SimpleNamespace(success=True, message_id="initial_preview"),
|
||||
SimpleNamespace(success=True, message_id="fresh_final"),
|
||||
]
|
||||
consumer = GatewayStreamConsumer(
|
||||
adapter=adapter,
|
||||
chat_id="chat",
|
||||
config=StreamConsumerConfig(fresh_final_after_seconds=60.0),
|
||||
)
|
||||
await consumer._send_or_edit("hello")
|
||||
# Force the preview to look stale (visible for > 60s).
|
||||
consumer._message_created_ts = 0.0 # zero = ~uptime seconds old
|
||||
await consumer._send_or_edit("hello world", finalize=True)
|
||||
# Fresh send happened; no edit of the old preview.
|
||||
assert adapter.send.call_count == 2
|
||||
adapter.edit_message.assert_not_called()
|
||||
# The old preview was deleted as cleanup.
|
||||
adapter.delete_message.assert_awaited_once_with("chat", "initial_preview")
|
||||
# State was updated to the new message id.
|
||||
assert consumer._message_id == "fresh_final"
|
||||
assert consumer._final_response_sent is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fresh_final_without_delete_support_is_best_effort(self):
|
||||
"""Adapter lacking ``delete_message`` still gets the fresh send."""
|
||||
adapter = _make_adapter(supports_delete=False)
|
||||
adapter.send.side_effect = [
|
||||
SimpleNamespace(success=True, message_id="initial_preview"),
|
||||
SimpleNamespace(success=True, message_id="fresh_final"),
|
||||
]
|
||||
consumer = GatewayStreamConsumer(
|
||||
adapter=adapter,
|
||||
chat_id="chat",
|
||||
config=StreamConsumerConfig(fresh_final_after_seconds=60.0),
|
||||
)
|
||||
await consumer._send_or_edit("hello")
|
||||
consumer._message_created_ts = 0.0
|
||||
await consumer._send_or_edit("hello world", finalize=True)
|
||||
assert adapter.send.call_count == 2
|
||||
adapter.edit_message.assert_not_called()
|
||||
# No delete attempt — just the fresh send.
|
||||
assert consumer._message_id == "fresh_final"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fresh_final_fallback_to_edit_on_send_failure(self):
|
||||
"""If the fresh send fails, fall back to the normal edit path."""
|
||||
adapter = _make_adapter()
|
||||
adapter.send.side_effect = [
|
||||
SimpleNamespace(success=True, message_id="initial_preview"),
|
||||
SimpleNamespace(success=False, error="network"),
|
||||
]
|
||||
consumer = GatewayStreamConsumer(
|
||||
adapter=adapter,
|
||||
chat_id="chat",
|
||||
config=StreamConsumerConfig(fresh_final_after_seconds=60.0),
|
||||
)
|
||||
await consumer._send_or_edit("hello")
|
||||
consumer._message_created_ts = 0.0
|
||||
ok = await consumer._send_or_edit("hello world", finalize=True)
|
||||
# Fresh send was attempted and failed → edit happened instead.
|
||||
assert adapter.send.call_count == 2
|
||||
adapter.edit_message.assert_called_once()
|
||||
assert ok is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_only_finalize_triggers_fresh_final(self):
|
||||
"""Intermediate edits (``finalize=False``) never switch to fresh send."""
|
||||
adapter = _make_adapter()
|
||||
consumer = GatewayStreamConsumer(
|
||||
adapter=adapter,
|
||||
chat_id="chat",
|
||||
config=StreamConsumerConfig(fresh_final_after_seconds=60.0),
|
||||
)
|
||||
await consumer._send_or_edit("hello")
|
||||
consumer._message_created_ts = 0.0 # stale
|
||||
await consumer._send_or_edit("hello partial") # no finalize
|
||||
assert adapter.send.call_count == 1
|
||||
adapter.edit_message.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_edit_sentinel_is_not_affected(self):
|
||||
"""Platforms with the ``__no_edit__`` sentinel never go fresh-final."""
|
||||
adapter = _make_adapter()
|
||||
adapter.send.return_value = SimpleNamespace(success=True, message_id=None)
|
||||
consumer = GatewayStreamConsumer(
|
||||
adapter=adapter,
|
||||
chat_id="chat",
|
||||
config=StreamConsumerConfig(fresh_final_after_seconds=60.0),
|
||||
)
|
||||
await consumer._send_or_edit("hello")
|
||||
assert consumer._message_id == "__no_edit__"
|
||||
assert consumer._message_created_ts is None
|
||||
# Even with finalize=True, no fresh send — the sentinel gates it.
|
||||
assert consumer._should_send_fresh_final() is False
|
||||
|
||||
|
||||
class TestStreamConsumerConfigFreshFinalField:
|
||||
"""The dataclass field must exist and default to 0 (disabled)."""
|
||||
|
||||
def test_default_is_disabled(self):
|
||||
cfg = StreamConsumerConfig()
|
||||
assert cfg.fresh_final_after_seconds == 0.0
|
||||
|
||||
def test_field_is_configurable(self):
|
||||
cfg = StreamConsumerConfig(fresh_final_after_seconds=120.0)
|
||||
assert cfg.fresh_final_after_seconds == 120.0
|
||||
|
||||
|
||||
class TestStreamingConfigFreshFinalField:
|
||||
"""The gateway-level StreamingConfig carries the setting."""
|
||||
|
||||
def test_default_enables_with_60s(self):
|
||||
from gateway.config import StreamingConfig
|
||||
cfg = StreamingConfig()
|
||||
assert cfg.fresh_final_after_seconds == 60.0
|
||||
|
||||
def test_from_dict_uses_default_when_missing(self):
|
||||
from gateway.config import StreamingConfig
|
||||
cfg = StreamingConfig.from_dict({"enabled": True})
|
||||
assert cfg.fresh_final_after_seconds == 60.0
|
||||
|
||||
def test_from_dict_respects_explicit_zero(self):
|
||||
from gateway.config import StreamingConfig
|
||||
cfg = StreamingConfig.from_dict({
|
||||
"enabled": True,
|
||||
"fresh_final_after_seconds": 0,
|
||||
})
|
||||
assert cfg.fresh_final_after_seconds == 0.0
|
||||
|
||||
def test_to_dict_round_trip(self):
|
||||
from gateway.config import StreamingConfig
|
||||
original = StreamingConfig(fresh_final_after_seconds=90.0)
|
||||
restored = StreamingConfig.from_dict(original.to_dict())
|
||||
assert restored.fresh_final_after_seconds == 90.0
|
||||
|
||||
|
||||
class TestTelegramAdapterDeleteMessage:
|
||||
"""Contract: Telegram adapter implements ``delete_message``."""
|
||||
|
||||
def test_delete_message_method_exists(self):
|
||||
telegram = pytest.importorskip("gateway.platforms.telegram")
|
||||
import inspect
|
||||
cls = telegram.TelegramAdapter
|
||||
assert hasattr(cls, "delete_message"), (
|
||||
"TelegramAdapter.delete_message is required for the fresh-final "
|
||||
"cleanup path (openclaw/openclaw#72038 port)."
|
||||
)
|
||||
sig = inspect.signature(cls.delete_message)
|
||||
params = list(sig.parameters)
|
||||
assert params[:3] == ["self", "chat_id", "message_id"]
|
||||
|
||||
def test_base_adapter_default_returns_false(self):
|
||||
"""BasePlatformAdapter.delete_message default = no-op returning False."""
|
||||
from gateway.platforms.base import BasePlatformAdapter
|
||||
import inspect
|
||||
sig = inspect.signature(BasePlatformAdapter.delete_message)
|
||||
assert list(sig.parameters)[:3] == ["self", "chat_id", "message_id"]
|
||||
@ -251,7 +251,7 @@ class TestWatchUpdateProgress:
|
||||
"session_key": "agent:main:telegram:dm:111"}
|
||||
(hermes_home / ".update_pending.json").write_text(json.dumps(pending))
|
||||
# Write output
|
||||
(hermes_home / ".update_output.txt").write_text("→ Fetching updates...\n")
|
||||
(hermes_home / ".update_output.txt").write_text("→ Fetching updates...\n", encoding="utf-8")
|
||||
|
||||
mock_adapter = AsyncMock()
|
||||
runner.adapters = {Platform.TELEGRAM: mock_adapter}
|
||||
@ -261,7 +261,7 @@ class TestWatchUpdateProgress:
|
||||
await asyncio.sleep(0.3)
|
||||
(hermes_home / ".update_output.txt").write_text(
|
||||
"→ Fetching updates...\n✓ Code updated!\n"
|
||||
)
|
||||
, encoding="utf-8")
|
||||
(hermes_home / ".update_exit_code").write_text("0")
|
||||
|
||||
with patch("gateway.run._hermes_home", hermes_home):
|
||||
@ -489,6 +489,63 @@ class TestUpdatePromptInterception:
|
||||
# Should clear the pending flag
|
||||
assert session_key not in runner._update_prompt_pending
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recognized_slash_command_bypasses_pending_update_prompt(self, tmp_path):
|
||||
"""Known slash commands must dispatch normally instead of being consumed.
|
||||
|
||||
The update subprocess is still blocked on stdin waiting for
|
||||
``.update_response``, so the gateway writes a blank response to
|
||||
unblock it (``_gateway_prompt`` returns the prompt's default on
|
||||
empty) before falling through to normal command dispatch.
|
||||
"""
|
||||
runner = _make_runner()
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
|
||||
event = _make_event(text="/new", chat_id="67890")
|
||||
session_key = "agent:main:telegram:dm:67890"
|
||||
runner._update_prompt_pending[session_key] = True
|
||||
runner._is_user_authorized = MagicMock(return_value=True)
|
||||
runner._session_key_for_source = MagicMock(return_value=session_key)
|
||||
runner._handle_reset_command = AsyncMock(return_value="reset ok")
|
||||
|
||||
with patch("gateway.run._hermes_home", hermes_home):
|
||||
result = await runner._handle_message(event)
|
||||
|
||||
assert result == "reset ok"
|
||||
runner._handle_reset_command.assert_awaited_once_with(event)
|
||||
# .update_response was written (empty) to unblock the update
|
||||
# subprocess; _gateway_prompt will read "", strip to "", and
|
||||
# return the prompt's default.
|
||||
response_path = hermes_home / ".update_response"
|
||||
assert response_path.exists()
|
||||
assert response_path.read_text() == ""
|
||||
# Pending flag is cleared so stray future input won't be
|
||||
# re-intercepted for a prompt that is no longer outstanding.
|
||||
assert session_key not in runner._update_prompt_pending
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unrecognized_slash_command_still_consumed_as_response(self, tmp_path):
|
||||
"""Unknown /foo is written verbatim to .update_response (legacy behavior)."""
|
||||
runner = _make_runner()
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
|
||||
event = _make_event(text="/foobarbaz", chat_id="67890")
|
||||
session_key = "agent:main:telegram:dm:67890"
|
||||
runner._update_prompt_pending[session_key] = True
|
||||
runner._is_user_authorized = MagicMock(return_value=True)
|
||||
runner._session_key_for_source = MagicMock(return_value=session_key)
|
||||
|
||||
with patch("gateway.run._hermes_home", hermes_home):
|
||||
result = await runner._handle_message(event)
|
||||
|
||||
response_path = hermes_home / ".update_response"
|
||||
assert response_path.exists()
|
||||
assert response_path.read_text() == "/foobarbaz"
|
||||
assert "Sent" in (result or "")
|
||||
assert session_key not in runner._update_prompt_pending
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normal_message_when_no_prompt_pending(self, tmp_path):
|
||||
"""Messages pass through normally when no prompt is pending."""
|
||||
|
||||
@ -134,7 +134,7 @@ class TestVerboseCommand:
|
||||
"""Cycling /verbose on Telegram doesn't change Slack's setting.
|
||||
|
||||
Without a global tool_progress, each platform uses its built-in
|
||||
default: Telegram = 'all' (high tier), Slack = 'new' (medium tier).
|
||||
default: Telegram = 'all' (high tier), Slack = 'off' (quiet Slack default).
|
||||
"""
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
@ -161,8 +161,8 @@ class TestVerboseCommand:
|
||||
platforms = saved["display"]["platforms"]
|
||||
# Telegram: all -> verbose (high tier default = all)
|
||||
assert platforms["telegram"]["tool_progress"] == "verbose"
|
||||
# Slack: new -> all (medium tier default = new, cycle to all)
|
||||
assert platforms["slack"]["tool_progress"] == "all"
|
||||
# Slack: off -> new (first /verbose cycle from quiet default)
|
||||
assert platforms["slack"]["tool_progress"] == "new"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_config_file_returns_disabled(self, tmp_path, monkeypatch):
|
||||
|
||||
@ -149,3 +149,46 @@ def test_get_nous_subscription_features_requires_agent_browser_for_browserbase(m
|
||||
assert features.browser.active is False
|
||||
assert features.browser.managed_by_nous is False
|
||||
assert features.browser.current_provider == "Browserbase"
|
||||
|
||||
|
||||
def test_get_nous_subscription_features_does_not_treat_quoted_false_as_gateway_opt_in(monkeypatch):
|
||||
env = {"EXA_API_KEY": "exa-test"}
|
||||
|
||||
monkeypatch.setattr(ns, "get_env_value", lambda name: env.get(name, ""))
|
||||
monkeypatch.setattr(ns, "get_nous_auth_status", lambda: {"logged_in": True})
|
||||
monkeypatch.setattr(ns, "managed_nous_tools_enabled", lambda: True)
|
||||
monkeypatch.setattr(ns, "_toolset_enabled", lambda config, key: key == "web")
|
||||
monkeypatch.setattr(ns, "_has_agent_browser", lambda: False)
|
||||
monkeypatch.setattr(ns, "resolve_openai_audio_api_key", lambda: "")
|
||||
monkeypatch.setattr(ns, "has_direct_modal_credentials", lambda: False)
|
||||
monkeypatch.setattr(ns, "is_managed_tool_gateway_ready", lambda vendor: vendor == "firecrawl")
|
||||
|
||||
features = ns.get_nous_subscription_features(
|
||||
{"web": {"backend": "exa", "use_gateway": "false"}}
|
||||
)
|
||||
|
||||
assert features.web.available is True
|
||||
assert features.web.active is True
|
||||
assert features.web.managed_by_nous is False
|
||||
assert features.web.direct_override is True
|
||||
assert features.web.current_provider == "exa"
|
||||
|
||||
|
||||
def test_get_gateway_eligible_tools_ignores_quoted_false_opt_in(monkeypatch):
|
||||
monkeypatch.setattr(ns, "managed_nous_tools_enabled", lambda: True)
|
||||
monkeypatch.setattr(
|
||||
ns,
|
||||
"_get_gateway_direct_credentials",
|
||||
lambda: {"web": True, "image_gen": False, "tts": False, "browser": False},
|
||||
)
|
||||
|
||||
unconfigured, has_direct, already_managed = ns.get_gateway_eligible_tools(
|
||||
{
|
||||
"model": {"provider": "nous"},
|
||||
"web": {"use_gateway": "false"},
|
||||
}
|
||||
)
|
||||
|
||||
assert "web" in has_direct
|
||||
assert "web" not in already_managed
|
||||
assert set(unconfigured) == {"image_gen", "tts", "browser"}
|
||||
|
||||
@ -401,14 +401,21 @@ class TestSessionBrowseArgparse:
|
||||
from hermes_cli.main import _session_browse_picker
|
||||
assert callable(_session_browse_picker)
|
||||
|
||||
def test_browse_default_limit_is_50(self):
|
||||
"""The default --limit for browse should be 50."""
|
||||
# This test verifies at the argparse level
|
||||
# We test by running the parse on "sessions browse" args
|
||||
# Since we can't easily extract the subparser, verify via the
|
||||
# _session_browse_picker accepting large lists
|
||||
sessions = _make_sessions(50)
|
||||
assert len(sessions) == 50
|
||||
def test_browse_default_limit_is_500(self):
|
||||
"""The default --limit for browse should be 500."""
|
||||
# Build the same argparse tree cmd_sessions uses and verify the default.
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
subparsers = parser.add_subparsers(dest="sessions_action")
|
||||
browse = subparsers.add_parser("browse")
|
||||
browse.add_argument("--source")
|
||||
browse.add_argument("--limit", type=int, default=500)
|
||||
|
||||
args = parser.parse_args(["browse"])
|
||||
assert args.limit == 500
|
||||
|
||||
args = parser.parse_args(["browse", "--limit", "42"])
|
||||
assert args.limit == 42
|
||||
|
||||
|
||||
# ─── Integration: cmd_sessions browse action ────────────────────────────────
|
||||
|
||||
@ -12,7 +12,7 @@ def test_sessions_delete_accepts_unique_id_prefix(monkeypatch, capsys):
|
||||
captured["resolved_from"] = session_id
|
||||
return "20260315_092437_c9a6ff"
|
||||
|
||||
def delete_session(self, session_id):
|
||||
def delete_session(self, session_id, **kwargs):
|
||||
captured["deleted"] = session_id
|
||||
return True
|
||||
|
||||
@ -45,7 +45,7 @@ def test_sessions_delete_reports_not_found_when_prefix_is_unknown(monkeypatch, c
|
||||
def resolve_session_id(self, session_id):
|
||||
return None
|
||||
|
||||
def delete_session(self, session_id):
|
||||
def delete_session(self, session_id, **kwargs):
|
||||
raise AssertionError("delete_session should not be called when resolution fails")
|
||||
|
||||
def close(self):
|
||||
@ -73,7 +73,7 @@ def test_sessions_delete_handles_eoferror_on_confirm(monkeypatch, capsys):
|
||||
def resolve_session_id(self, session_id):
|
||||
return "20260315_092437_c9a6ff"
|
||||
|
||||
def delete_session(self, session_id):
|
||||
def delete_session(self, session_id, **kwargs):
|
||||
raise AssertionError("delete_session should not be called when cancelled")
|
||||
|
||||
def close(self):
|
||||
|
||||
30
tests/hermes_cli/test_setup_ollama_cloud_force_refresh.py
Normal file
30
tests/hermes_cli/test_setup_ollama_cloud_force_refresh.py
Normal file
@ -0,0 +1,30 @@
|
||||
"""Regression: ``hermes setup`` for the ollama-cloud provider must force-refresh
|
||||
the model cache after the user supplies a key, otherwise the picker keeps
|
||||
serving a stale cache (models.dev only, no live API probe) for up to an hour.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
def test_setup_ollama_cloud_passes_force_refresh(monkeypatch):
|
||||
"""The provider-setup model-fetch for ollama-cloud must pass ``force_refresh=True``."""
|
||||
import hermes_cli.main as main_mod
|
||||
import inspect
|
||||
|
||||
src = inspect.getsource(main_mod)
|
||||
|
||||
# Locate the ollama-cloud branch in the provider setup flow.
|
||||
marker = 'provider_id == "ollama-cloud"'
|
||||
assert marker in src, "ollama-cloud branch missing from provider setup"
|
||||
idx = src.index(marker)
|
||||
# The call to fetch_ollama_cloud_models should be within the next ~2000 chars.
|
||||
snippet = src[idx:idx + 2000]
|
||||
assert "fetch_ollama_cloud_models(" in snippet, snippet[:500]
|
||||
assert "force_refresh=True" in snippet, (
|
||||
"ollama-cloud setup must pass force_refresh=True so newly released "
|
||||
"models (e.g. deepseek v4 flash, kimi k2.6) appear the moment the "
|
||||
"user enters their key, not an hour later when the cache TTL expires. "
|
||||
f"Snippet: {snippet[:500]}"
|
||||
)
|
||||
@ -41,6 +41,36 @@ def test_get_platform_tools_homeassistant_platform_keeps_homeassistant_toolset()
|
||||
assert "homeassistant" in enabled
|
||||
|
||||
|
||||
def test_get_platform_tools_homeassistant_toolset_enabled_for_cron_when_hass_token_set(monkeypatch):
|
||||
"""HA toolset is runtime-gated by check_fn (requires HASS_TOKEN).
|
||||
|
||||
When HASS_TOKEN is set, the user has explicitly opted in — _DEFAULT_OFF_TOOLSETS
|
||||
shouldn't also strip HA from platforms (like cron) that run through
|
||||
_get_platform_tools without an explicit saved toolset list.
|
||||
|
||||
Regression guard for Norbert's HA cron breakage after #14798 made cron
|
||||
honor per-platform tool config.
|
||||
"""
|
||||
monkeypatch.setenv("HASS_TOKEN", "fake-test-token")
|
||||
|
||||
cron_enabled = _get_platform_tools({}, "cron")
|
||||
assert "homeassistant" in cron_enabled
|
||||
# moa must stay off — the original goal of #14798
|
||||
assert "moa" not in cron_enabled
|
||||
|
||||
cli_enabled = _get_platform_tools({}, "cli")
|
||||
assert "homeassistant" in cli_enabled
|
||||
|
||||
|
||||
def test_get_platform_tools_homeassistant_toolset_off_for_cron_when_hass_token_missing(monkeypatch):
|
||||
"""Without HASS_TOKEN, HA stays off by default — preserves #14798's behavior
|
||||
for users who never configured HA."""
|
||||
monkeypatch.delenv("HASS_TOKEN", raising=False)
|
||||
|
||||
cron_enabled = _get_platform_tools({}, "cron")
|
||||
assert "homeassistant" not in cron_enabled
|
||||
|
||||
|
||||
def test_get_platform_tools_preserves_explicit_empty_selection():
|
||||
config = {"platform_toolsets": {"cli": []}}
|
||||
|
||||
|
||||
121
tests/hermes_cli/test_web_ui_build.py
Normal file
121
tests/hermes_cli/test_web_ui_build.py
Normal file
@ -0,0 +1,121 @@
|
||||
"""Tests for _web_ui_build_needed — staleness check for the web UI dist.
|
||||
|
||||
Critical invariant: the Vite build outputs to hermes_cli/web_dist/
|
||||
(vite.config.ts: outDir: "../hermes_cli/web_dist"), NOT web/dist/.
|
||||
The sentinel must be checked in the correct output directory or the
|
||||
freshness check is a no-op and the OOM rebuild always runs.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from hermes_cli.main import _web_ui_build_needed, _build_web_ui
|
||||
|
||||
|
||||
def _touch(path: Path, offset: float = 0.0) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.touch()
|
||||
if offset:
|
||||
t = time.time() + offset
|
||||
os.utime(path, (t, t))
|
||||
|
||||
|
||||
def _make_web_dir(tmp_path: Path) -> tuple[Path, Path]:
|
||||
"""Return (web_dir, dist_dir) matching real repo layout."""
|
||||
web_dir = tmp_path / "web"
|
||||
web_dir.mkdir()
|
||||
(web_dir / "package.json").touch()
|
||||
dist_dir = tmp_path / "hermes_cli" / "web_dist"
|
||||
return web_dir, dist_dir
|
||||
|
||||
|
||||
class TestWebUIBuildNeeded:
|
||||
|
||||
def test_returns_true_when_dist_missing(self, tmp_path):
|
||||
web_dir, _ = _make_web_dir(tmp_path)
|
||||
assert _web_ui_build_needed(web_dir) is True
|
||||
|
||||
def test_returns_false_when_vite_manifest_fresh(self, tmp_path):
|
||||
web_dir, dist_dir = _make_web_dir(tmp_path)
|
||||
_touch(web_dir / "src" / "App.tsx", offset=-10)
|
||||
_touch(dist_dir / ".vite" / "manifest.json")
|
||||
assert _web_ui_build_needed(web_dir) is False
|
||||
|
||||
def test_returns_true_when_source_newer_than_manifest(self, tmp_path):
|
||||
web_dir, dist_dir = _make_web_dir(tmp_path)
|
||||
_touch(dist_dir / ".vite" / "manifest.json", offset=-10)
|
||||
_touch(web_dir / "src" / "App.tsx")
|
||||
assert _web_ui_build_needed(web_dir) is True
|
||||
|
||||
def test_falls_back_to_index_html_when_manifest_missing(self, tmp_path):
|
||||
web_dir, dist_dir = _make_web_dir(tmp_path)
|
||||
_touch(web_dir / "src" / "main.ts", offset=-10)
|
||||
_touch(dist_dir / "index.html")
|
||||
assert _web_ui_build_needed(web_dir) is False
|
||||
|
||||
def test_web_dist_dir_not_web_dist_subdir(self, tmp_path):
|
||||
"""Regression: sentinel must be in hermes_cli/web_dist/, NOT web/dist/."""
|
||||
web_dir, dist_dir = _make_web_dir(tmp_path)
|
||||
_touch(web_dir / "src" / "App.tsx", offset=-10)
|
||||
# Place manifest in wrong location (web/dist/) — should NOT count as fresh
|
||||
wrong_dist = web_dir / "dist" / ".vite" / "manifest.json"
|
||||
_touch(wrong_dist)
|
||||
# Correct location is empty → still needs build
|
||||
assert _web_ui_build_needed(web_dir) is True
|
||||
|
||||
def test_returns_true_when_package_lock_newer_than_dist(self, tmp_path):
|
||||
web_dir, dist_dir = _make_web_dir(tmp_path)
|
||||
_touch(dist_dir / ".vite" / "manifest.json", offset=-10)
|
||||
_touch(web_dir / "package-lock.json")
|
||||
assert _web_ui_build_needed(web_dir) is True
|
||||
|
||||
def test_returns_true_when_vite_config_newer_than_dist(self, tmp_path):
|
||||
web_dir, dist_dir = _make_web_dir(tmp_path)
|
||||
_touch(dist_dir / ".vite" / "manifest.json", offset=-10)
|
||||
_touch(web_dir / "vite.config.ts")
|
||||
assert _web_ui_build_needed(web_dir) is True
|
||||
|
||||
def test_ignores_node_modules(self, tmp_path):
|
||||
web_dir, dist_dir = _make_web_dir(tmp_path)
|
||||
# package.json older than manifest; only node_modules file is newer
|
||||
_touch(web_dir / "package.json", offset=-20)
|
||||
_touch(dist_dir / ".vite" / "manifest.json", offset=-10)
|
||||
_touch(web_dir / "node_modules" / "react" / "index.js")
|
||||
assert _web_ui_build_needed(web_dir) is False
|
||||
|
||||
def test_ignores_dist_subdir_under_web(self, tmp_path):
|
||||
web_dir, dist_dir = _make_web_dir(tmp_path)
|
||||
# package.json older than manifest; only web/dist file is newer
|
||||
_touch(web_dir / "package.json", offset=-20)
|
||||
_touch(dist_dir / ".vite" / "manifest.json", offset=-10)
|
||||
_touch(web_dir / "dist" / "assets" / "index.js")
|
||||
assert _web_ui_build_needed(web_dir) is False
|
||||
|
||||
|
||||
class TestBuildWebUISkipsWhenFresh:
|
||||
|
||||
def test_skips_npm_when_dist_is_fresh(self, tmp_path):
|
||||
web_dir, dist_dir = _make_web_dir(tmp_path)
|
||||
_touch(dist_dir / ".vite" / "manifest.json")
|
||||
|
||||
with patch("hermes_cli.main.shutil.which", return_value="/usr/bin/npm"), \
|
||||
patch("hermes_cli.main.subprocess.run") as mock_run:
|
||||
result = _build_web_ui(web_dir)
|
||||
|
||||
assert result is True
|
||||
mock_run.assert_not_called()
|
||||
|
||||
def test_runs_npm_when_dist_missing(self, tmp_path):
|
||||
web_dir, _ = _make_web_dir(tmp_path)
|
||||
|
||||
mock_cp = __import__("subprocess").CompletedProcess([], 0, stdout=b"", stderr=b"")
|
||||
with patch("hermes_cli.main.shutil.which", return_value="/usr/bin/npm"), \
|
||||
patch("hermes_cli.main.subprocess.run", return_value=mock_cp) as mock_run:
|
||||
result = _build_web_ui(web_dir)
|
||||
|
||||
assert result is True
|
||||
assert mock_run.call_count == 2 # npm install + npm run build
|
||||
@ -7,6 +7,7 @@ turn counting, tags), and schema completeness.
|
||||
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
@ -18,6 +19,7 @@ from plugins.memory.hindsight import (
|
||||
REFLECT_SCHEMA,
|
||||
RETAIN_SCHEMA,
|
||||
_load_config,
|
||||
_build_embedded_profile_env,
|
||||
_normalize_retain_tags,
|
||||
_resolve_bank_id_template,
|
||||
_sanitize_bank_segment,
|
||||
@ -34,7 +36,8 @@ def _clean_env(monkeypatch):
|
||||
"""Ensure no stale env vars leak between tests."""
|
||||
for key in (
|
||||
"HINDSIGHT_API_KEY", "HINDSIGHT_API_URL", "HINDSIGHT_BANK_ID",
|
||||
"HINDSIGHT_BUDGET", "HINDSIGHT_MODE", "HINDSIGHT_LLM_API_KEY",
|
||||
"HINDSIGHT_BUDGET", "HINDSIGHT_MODE", "HINDSIGHT_TIMEOUT",
|
||||
"HINDSIGHT_IDLE_TIMEOUT", "HINDSIGHT_LLM_API_KEY",
|
||||
"HINDSIGHT_RETAIN_TAGS", "HINDSIGHT_RETAIN_SOURCE",
|
||||
"HINDSIGHT_RETAIN_USER_PREFIX", "HINDSIGHT_RETAIN_ASSISTANT_PREFIX",
|
||||
):
|
||||
@ -251,6 +254,51 @@ class TestConfig:
|
||||
assert cfg["banks"]["hermes"]["bankId"] == "env-bank"
|
||||
assert cfg["banks"]["hermes"]["budget"] == "high"
|
||||
|
||||
def test_embedded_profile_env_includes_idle_timeout_from_config(self):
|
||||
env = _build_embedded_profile_env({
|
||||
"llm_provider": "openai",
|
||||
"llm_model": "gpt-4o-mini",
|
||||
"idle_timeout": 0,
|
||||
})
|
||||
|
||||
assert env["HINDSIGHT_EMBED_DAEMON_IDLE_TIMEOUT"] == "0"
|
||||
|
||||
def test_embedded_profile_env_includes_idle_timeout_from_env(self, monkeypatch):
|
||||
monkeypatch.setenv("HINDSIGHT_IDLE_TIMEOUT", "42")
|
||||
|
||||
env = _build_embedded_profile_env({
|
||||
"llm_provider": "openai",
|
||||
"llm_model": "gpt-4o-mini",
|
||||
})
|
||||
|
||||
assert env["HINDSIGHT_EMBED_DAEMON_IDLE_TIMEOUT"] == "42"
|
||||
|
||||
def test_get_client_passes_idle_timeout_to_hindsight_embedded(self, monkeypatch):
|
||||
captured = {}
|
||||
|
||||
class FakeHindsightEmbedded:
|
||||
def __init__(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
|
||||
monkeypatch.setitem(sys.modules, "hindsight", SimpleNamespace(HindsightEmbedded=FakeHindsightEmbedded))
|
||||
monkeypatch.setattr("plugins.memory.hindsight._check_local_runtime", lambda: (True, ""))
|
||||
|
||||
p = HindsightMemoryProvider()
|
||||
p._mode = "local_embedded"
|
||||
p._config = {
|
||||
"profile": "hermes",
|
||||
"llm_provider": "openai_compatible",
|
||||
"llm_api_key": "test-key",
|
||||
"llm_model": "test-model",
|
||||
"idle_timeout": 0,
|
||||
}
|
||||
p._llm_base_url = "http://localhost:8060/v1"
|
||||
|
||||
p._get_client()
|
||||
|
||||
assert captured["idle_timeout"] == 0
|
||||
assert captured["llm_provider"] == "openai"
|
||||
|
||||
|
||||
class TestPostSetup:
|
||||
def test_local_embedded_setup_materializes_profile_env(self, tmp_path, monkeypatch):
|
||||
@ -272,7 +320,10 @@ class TestPostSetup:
|
||||
provider.post_setup(str(hermes_home), {"memory": {}})
|
||||
|
||||
assert saved_configs[-1]["memory"]["provider"] == "hindsight"
|
||||
assert (hermes_home / ".env").read_text() == "HINDSIGHT_LLM_API_KEY=sk-local-test\nHINDSIGHT_TIMEOUT=120\n"
|
||||
env_text = (hermes_home / ".env").read_text()
|
||||
assert "HINDSIGHT_LLM_API_KEY=sk-local-test\n" in env_text
|
||||
assert "HINDSIGHT_TIMEOUT=120\n" in env_text
|
||||
assert "HINDSIGHT_IDLE_TIMEOUT=300\n" in env_text
|
||||
|
||||
profile_env = user_home / ".hindsight" / "profiles" / "hermes.env"
|
||||
assert profile_env.exists()
|
||||
@ -281,6 +332,7 @@ class TestPostSetup:
|
||||
"HINDSIGHT_API_LLM_API_KEY=sk-local-test\n"
|
||||
"HINDSIGHT_API_LLM_MODEL=gpt-4o-mini\n"
|
||||
"HINDSIGHT_API_LOG_LEVEL=info\n"
|
||||
"HINDSIGHT_EMBED_DAEMON_IDLE_TIMEOUT=300\n"
|
||||
)
|
||||
|
||||
def test_local_embedded_setup_respects_existing_profile_name(self, tmp_path, monkeypatch):
|
||||
@ -446,6 +498,28 @@ class TestToolHandlers:
|
||||
))
|
||||
assert "error" in result
|
||||
|
||||
def test_local_embedded_recall_reconnects_after_idle_shutdown(self, provider, monkeypatch):
|
||||
first_client = _make_mock_client()
|
||||
first_client.arecall.side_effect = RuntimeError("Cannot connect to host 127.0.0.1:8888")
|
||||
second_client = _make_mock_client()
|
||||
second_client.arecall.return_value = SimpleNamespace(
|
||||
results=[SimpleNamespace(text="Recovered memory")]
|
||||
)
|
||||
clients = iter([first_client, second_client])
|
||||
|
||||
provider._mode = "local_embedded"
|
||||
provider._client = first_client
|
||||
monkeypatch.setattr(provider, "_get_client", lambda: next(clients))
|
||||
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"hindsight_recall", {"query": "test"}
|
||||
))
|
||||
|
||||
assert result["result"] == "1. Recovered memory"
|
||||
assert provider._client is second_client
|
||||
first_client.arecall.assert_called_once()
|
||||
second_client.arecall.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prefetch tests
|
||||
@ -1102,3 +1176,22 @@ class TestSharedEventLoopLifecycle:
|
||||
|
||||
mock_client.aclose.assert_called_once()
|
||||
assert provider._client is None
|
||||
|
||||
|
||||
class TestShutdown:
|
||||
def test_local_embedded_shutdown_closes_inner_async_client_on_shared_loop(self, provider):
|
||||
inner_client = _make_mock_client()
|
||||
embedded = MagicMock()
|
||||
embedded._client = inner_client
|
||||
embedded.close = MagicMock()
|
||||
|
||||
provider._mode = "local_embedded"
|
||||
provider._client = embedded
|
||||
|
||||
provider.shutdown()
|
||||
|
||||
inner_client.aclose.assert_awaited_once()
|
||||
embedded.close.assert_called_once()
|
||||
assert embedded._client is None
|
||||
assert provider._client is None
|
||||
|
||||
|
||||
73
tests/run_agent/test_background_review.py
Normal file
73
tests/run_agent/test_background_review.py
Normal file
@ -0,0 +1,73 @@
|
||||
"""Regression tests for background review agent cleanup."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import run_agent as run_agent_module
|
||||
from run_agent import AIAgent
|
||||
|
||||
|
||||
def _bare_agent() -> AIAgent:
|
||||
agent = object.__new__(AIAgent)
|
||||
agent.model = "fake-model"
|
||||
agent.platform = "telegram"
|
||||
agent.provider = "openai"
|
||||
agent.base_url = ""
|
||||
agent.api_key = ""
|
||||
agent.api_mode = ""
|
||||
agent.session_id = "test-session"
|
||||
agent._parent_session_id = ""
|
||||
agent._credential_pool = None
|
||||
agent._memory_store = object()
|
||||
agent._memory_enabled = True
|
||||
agent._user_profile_enabled = False
|
||||
agent._MEMORY_REVIEW_PROMPT = "review memory"
|
||||
agent._SKILL_REVIEW_PROMPT = "review skills"
|
||||
agent._COMBINED_REVIEW_PROMPT = "review both"
|
||||
agent.background_review_callback = None
|
||||
agent.status_callback = None
|
||||
agent._safe_print = lambda *_args, **_kwargs: None
|
||||
return agent
|
||||
|
||||
|
||||
class ImmediateThread:
|
||||
def __init__(self, *, target, daemon=None, name=None):
|
||||
self._target = target
|
||||
|
||||
def start(self):
|
||||
self._target()
|
||||
|
||||
|
||||
def test_background_review_shuts_down_memory_provider_before_close(monkeypatch):
|
||||
events = []
|
||||
|
||||
class FakeReviewAgent:
|
||||
def __init__(self, **kwargs):
|
||||
events.append(("init", kwargs))
|
||||
self._session_messages = []
|
||||
|
||||
def run_conversation(self, **kwargs):
|
||||
events.append(("run_conversation", kwargs))
|
||||
|
||||
def shutdown_memory_provider(self):
|
||||
events.append(("shutdown_memory_provider", None))
|
||||
|
||||
def close(self):
|
||||
events.append(("close", None))
|
||||
|
||||
monkeypatch.setattr(run_agent_module, "AIAgent", FakeReviewAgent)
|
||||
monkeypatch.setattr(run_agent_module.threading, "Thread", ImmediateThread)
|
||||
|
||||
agent = _bare_agent()
|
||||
|
||||
AIAgent._spawn_background_review(
|
||||
agent,
|
||||
messages_snapshot=[{"role": "user", "content": "hello"}],
|
||||
review_memory=True,
|
||||
)
|
||||
|
||||
assert [name for name, _payload in events] == [
|
||||
"init",
|
||||
"run_conversation",
|
||||
"shutdown_memory_provider",
|
||||
"close",
|
||||
]
|
||||
@ -261,6 +261,42 @@ class TestGatewayMode:
|
||||
]
|
||||
assert len(gw_handlers) == 0
|
||||
|
||||
def test_gateway_log_created_after_cli_init(self, hermes_home):
|
||||
"""Gateway mode attaches gateway.log even after earlier CLI init."""
|
||||
hermes_logging.setup_logging(hermes_home=hermes_home, mode="cli")
|
||||
hermes_logging.setup_logging(hermes_home=hermes_home, mode="gateway")
|
||||
|
||||
root = logging.getLogger()
|
||||
gw_handlers = [
|
||||
h for h in root.handlers
|
||||
if isinstance(h, RotatingFileHandler)
|
||||
and "gateway.log" in getattr(h, "baseFilename", "")
|
||||
]
|
||||
assert len(gw_handlers) == 1
|
||||
|
||||
logging.getLogger("gateway.run").info("gateway connected after cli init")
|
||||
|
||||
for h in root.handlers:
|
||||
h.flush()
|
||||
|
||||
gw_log = hermes_home / "logs" / "gateway.log"
|
||||
assert gw_log.exists()
|
||||
assert "gateway connected after cli init" in gw_log.read_text()
|
||||
|
||||
def test_gateway_log_created_after_cli_init_without_duplicate_handlers(self, hermes_home):
|
||||
"""Repeated gateway setup calls do not attach duplicate gateway handlers."""
|
||||
hermes_logging.setup_logging(hermes_home=hermes_home, mode="cli")
|
||||
hermes_logging.setup_logging(hermes_home=hermes_home, mode="gateway")
|
||||
hermes_logging.setup_logging(hermes_home=hermes_home, mode="gateway")
|
||||
|
||||
root = logging.getLogger()
|
||||
gw_handlers = [
|
||||
h for h in root.handlers
|
||||
if isinstance(h, RotatingFileHandler)
|
||||
and "gateway.log" in getattr(h, "baseFilename", "")
|
||||
]
|
||||
assert len(gw_handlers) == 1
|
||||
|
||||
def test_gateway_log_receives_gateway_records(self, hermes_home):
|
||||
"""gateway.log captures records from gateway.* loggers."""
|
||||
hermes_logging.setup_logging(hermes_home=hermes_home, mode="gateway")
|
||||
|
||||
@ -2010,3 +2010,58 @@ class TestAutoMaintenance:
|
||||
# Should parse as a float timestamp close to now.
|
||||
assert abs(float(marker) - time.time()) < 60
|
||||
|
||||
def test_auto_prune_deletes_transcript_files(self, db, tmp_path):
|
||||
"""Issue #3015: auto-prune must also delete on-disk transcript files."""
|
||||
sessions_dir = tmp_path / "sessions"
|
||||
sessions_dir.mkdir()
|
||||
|
||||
self._make_old_ended(db, "old1", days_old=100)
|
||||
self._make_old_ended(db, "old2", days_old=100)
|
||||
db.create_session(session_id="new", source="cli") # active
|
||||
|
||||
# Transcript files mimicking real gateway/CLI layout
|
||||
(sessions_dir / "old1.json").write_text("{}")
|
||||
(sessions_dir / "old1.jsonl").write_text("{}\n")
|
||||
(sessions_dir / "old2.jsonl").write_text("{}\n")
|
||||
(sessions_dir / "request_dump_old1_001.json").write_text("{}")
|
||||
(sessions_dir / "new.jsonl").write_text("{}\n") # active, must survive
|
||||
|
||||
result = db.maybe_auto_prune_and_vacuum(
|
||||
retention_days=90, sessions_dir=sessions_dir
|
||||
)
|
||||
assert result["pruned"] == 2
|
||||
|
||||
# Pruned transcript files are gone
|
||||
assert not (sessions_dir / "old1.json").exists()
|
||||
assert not (sessions_dir / "old1.jsonl").exists()
|
||||
assert not (sessions_dir / "old2.jsonl").exists()
|
||||
assert not (sessions_dir / "request_dump_old1_001.json").exists()
|
||||
# Active session's transcript is untouched
|
||||
assert (sessions_dir / "new.jsonl").exists()
|
||||
|
||||
def test_auto_prune_without_sessions_dir_preserves_files(self, db, tmp_path):
|
||||
"""Backward-compat: no sessions_dir = DB-only cleanup (legacy behavior)."""
|
||||
sessions_dir = tmp_path / "sessions"
|
||||
sessions_dir.mkdir()
|
||||
self._make_old_ended(db, "old", days_old=100)
|
||||
(sessions_dir / "old.jsonl").write_text("{}\n")
|
||||
|
||||
result = db.maybe_auto_prune_and_vacuum(retention_days=90)
|
||||
assert result["pruned"] == 1
|
||||
# File stays — caller didn't opt in
|
||||
assert (sessions_dir / "old.jsonl").exists()
|
||||
|
||||
def test_prune_sessions_deletes_files_for_pruned_only(self, db, tmp_path):
|
||||
"""Active-session transcripts must never be deleted by prune."""
|
||||
sessions_dir = tmp_path / "sessions"
|
||||
sessions_dir.mkdir()
|
||||
self._make_old_ended(db, "old", days_old=100)
|
||||
db.create_session(session_id="active", source="cli") # not ended
|
||||
(sessions_dir / "old.jsonl").write_text("{}\n")
|
||||
(sessions_dir / "active.jsonl").write_text("{}\n")
|
||||
|
||||
count = db.prune_sessions(older_than_days=90, sessions_dir=sessions_dir)
|
||||
assert count == 1
|
||||
assert not (sessions_dir / "old.jsonl").exists()
|
||||
assert (sessions_dir / "active.jsonl").exists()
|
||||
|
||||
|
||||
416
tests/test_yuanbao_integration.py
Normal file
416
tests/test_yuanbao_integration.py
Normal file
@ -0,0 +1,416 @@
|
||||
"""
|
||||
test_yuanbao_integration.py - Yuanbao 模块集成测试
|
||||
|
||||
验证各模块能正确组装和交互:
|
||||
- YuanbaoAdapter 初始化
|
||||
- Config / Platform 枚举
|
||||
- get_connected_platforms 逻辑
|
||||
- Proto 编解码 round-trip
|
||||
- Markdown 分块
|
||||
- API / Media 模块 import
|
||||
- Toolset 注册
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 确保 hermes-agent 根目录在 sys.path 中
|
||||
_REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
if _REPO_ROOT not in sys.path:
|
||||
sys.path.insert(0, _REPO_ROOT)
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from gateway.config import Platform, PlatformConfig, GatewayConfig
|
||||
from gateway.platforms.yuanbao import YuanbaoAdapter
|
||||
|
||||
|
||||
def make_config(**kwargs):
|
||||
extra = kwargs.pop("extra", {})
|
||||
extra.setdefault("app_id", "test_key")
|
||||
extra.setdefault("app_secret", "test_secret")
|
||||
extra.setdefault("ws_url", "wss://test.example.com/ws")
|
||||
extra.setdefault("api_domain", "https://test.example.com")
|
||||
return PlatformConfig(
|
||||
extra=extra,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================
|
||||
# 1. Adapter 初始化
|
||||
# ===========================================================
|
||||
|
||||
class TestYuanbaoAdapterInit:
|
||||
def test_create_adapter(self):
|
||||
config = make_config()
|
||||
adapter = YuanbaoAdapter(config)
|
||||
assert adapter is not None
|
||||
assert adapter.PLATFORM == Platform.YUANBAO
|
||||
|
||||
def test_initial_state(self):
|
||||
config = make_config()
|
||||
adapter = YuanbaoAdapter(config)
|
||||
status = adapter.get_status()
|
||||
assert status["connected"] == False
|
||||
assert status["bot_id"] is None
|
||||
|
||||
|
||||
# ===========================================================
|
||||
# 2. Config / Platform 枚举
|
||||
# ===========================================================
|
||||
|
||||
class TestYuanbaoConfig:
|
||||
def test_platform_enum(self):
|
||||
assert Platform.YUANBAO.value == "yuanbao"
|
||||
|
||||
def test_config_fields(self):
|
||||
config = make_config()
|
||||
assert config.extra["app_id"] == "test_key"
|
||||
assert config.extra["app_secret"] == "test_secret"
|
||||
|
||||
def test_get_connected_platforms_requires_key_and_secret(self):
|
||||
# Only key, no secret → not in connected list
|
||||
gw_only_key = GatewayConfig(
|
||||
platforms={
|
||||
Platform.YUANBAO: PlatformConfig(
|
||||
enabled=True,
|
||||
extra={"app_id": "key"},
|
||||
)
|
||||
}
|
||||
)
|
||||
platforms = gw_only_key.get_connected_platforms()
|
||||
assert Platform.YUANBAO not in platforms
|
||||
|
||||
# key + secret both present → in connected list
|
||||
gw_full = GatewayConfig(
|
||||
platforms={
|
||||
Platform.YUANBAO: PlatformConfig(
|
||||
enabled=True,
|
||||
extra={"app_id": "key", "app_secret": "secret"},
|
||||
)
|
||||
}
|
||||
)
|
||||
platforms2 = gw_full.get_connected_platforms()
|
||||
assert Platform.YUANBAO in platforms2
|
||||
|
||||
|
||||
# ===========================================================
|
||||
# 3. GatewayRunner 注册
|
||||
# ===========================================================
|
||||
|
||||
class TestGatewayRunnerRegistration:
|
||||
def test_yuanbao_in_platform_enum(self):
|
||||
"""Platform 枚举包含 YUANBAO"""
|
||||
assert hasattr(Platform, "YUANBAO")
|
||||
assert Platform.YUANBAO.value == "yuanbao"
|
||||
|
||||
def _make_minimal_runner(self, config):
|
||||
"""通过 __new__ + 最小初始化绕过 run.py 的模块级 dotenv/ssl 副作用"""
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Stub out heavy dependencies if not already present
|
||||
stubs = [
|
||||
"dotenv",
|
||||
"hermes_cli.env_loader",
|
||||
"hermes_cli.config",
|
||||
"hermes_constants",
|
||||
]
|
||||
_orig = {}
|
||||
for mod in stubs:
|
||||
if mod not in sys.modules:
|
||||
_orig[mod] = None
|
||||
sys.modules[mod] = MagicMock()
|
||||
|
||||
try:
|
||||
from gateway.run import GatewayRunner
|
||||
finally:
|
||||
# Restore only the ones we injected
|
||||
for mod, orig in _orig.items():
|
||||
if orig is None:
|
||||
sys.modules.pop(mod, None)
|
||||
|
||||
runner = GatewayRunner.__new__(GatewayRunner)
|
||||
runner.config = config
|
||||
runner.adapters = {}
|
||||
runner._failed_platforms = {}
|
||||
runner._session_model_overrides = {}
|
||||
return runner, GatewayRunner
|
||||
|
||||
def test_runner_creates_yuanbao_adapter(self):
|
||||
"""GatewayRunner._create_adapter 能为 YUANBAO 返回 YuanbaoAdapter 实例"""
|
||||
from gateway.config import GatewayConfig
|
||||
from unittest.mock import patch
|
||||
config = make_config(enabled=True)
|
||||
gw_config = GatewayConfig(platforms={Platform.YUANBAO: config})
|
||||
|
||||
try:
|
||||
runner, _ = self._make_minimal_runner(gw_config)
|
||||
# websockets 在测试环境可能未安装,mock 掉 WEBSOCKETS_AVAILABLE
|
||||
with patch("gateway.platforms.yuanbao.WEBSOCKETS_AVAILABLE", True):
|
||||
adapter = runner._create_adapter(Platform.YUANBAO, config)
|
||||
except ImportError as e:
|
||||
pytest.skip(f"run.py import unavailable in test env: {e}")
|
||||
|
||||
assert adapter is not None
|
||||
assert isinstance(adapter, YuanbaoAdapter)
|
||||
|
||||
def test_runner_adapter_platform_attr(self):
|
||||
"""创建的 adapter.PLATFORM 为 Platform.YUANBAO"""
|
||||
from gateway.config import GatewayConfig
|
||||
from unittest.mock import patch
|
||||
config = make_config(enabled=True)
|
||||
gw_config = GatewayConfig(platforms={Platform.YUANBAO: config})
|
||||
|
||||
try:
|
||||
runner, _ = self._make_minimal_runner(gw_config)
|
||||
with patch("gateway.platforms.yuanbao.WEBSOCKETS_AVAILABLE", True):
|
||||
adapter = runner._create_adapter(Platform.YUANBAO, config)
|
||||
except ImportError as e:
|
||||
pytest.skip(f"run.py import unavailable in test env: {e}")
|
||||
|
||||
assert adapter is not None
|
||||
assert adapter.PLATFORM == Platform.YUANBAO
|
||||
|
||||
|
||||
# ===========================================================
|
||||
# 4. Proto round-trip
|
||||
# ===========================================================
|
||||
|
||||
class TestProtoRoundTrip:
|
||||
"""验证 proto 编解码基本功能"""
|
||||
|
||||
def test_conn_msg_roundtrip(self):
|
||||
from gateway.platforms.yuanbao_proto import encode_conn_msg, decode_conn_msg
|
||||
encoded = encode_conn_msg(msg_type=1, seq_no=42, data=b"hello")
|
||||
decoded = decode_conn_msg(encoded)
|
||||
assert decoded["seq_no"] == 42
|
||||
assert decoded["data"] == b"hello"
|
||||
|
||||
def test_text_elem_encoding(self):
|
||||
from gateway.platforms.yuanbao_proto import encode_send_c2c_message
|
||||
msg = encode_send_c2c_message(
|
||||
to_account="user123",
|
||||
msg_body=[{"msg_type": "TIMTextElem", "msg_content": {"text": "hello"}}],
|
||||
from_account="bot456",
|
||||
)
|
||||
assert isinstance(msg, bytes)
|
||||
assert len(msg) > 0
|
||||
|
||||
|
||||
# ===========================================================
|
||||
# 5. Markdown 分块
|
||||
# ===========================================================
|
||||
|
||||
class TestMarkdownChunking:
|
||||
def test_chunks_are_sent_separately(self):
|
||||
from gateway.platforms.yuanbao import MarkdownProcessor
|
||||
long_text = "paragraph\n\n" * 100
|
||||
chunks = MarkdownProcessor.chunk_markdown_text(long_text, 200)
|
||||
assert len(chunks) > 1
|
||||
for c in chunks:
|
||||
# 段落原子块允许轻微超限,仅验证不崩溃
|
||||
assert isinstance(c, str)
|
||||
assert len(c) > 0
|
||||
|
||||
def test_chunk_short_text_no_split(self):
|
||||
from gateway.platforms.yuanbao import MarkdownProcessor
|
||||
text = "hello world"
|
||||
chunks = MarkdownProcessor.chunk_markdown_text(text, 3000)
|
||||
assert chunks == [text]
|
||||
|
||||
|
||||
# ===========================================================
|
||||
# 6. Sign Token 模块
|
||||
# ===========================================================
|
||||
|
||||
class TestSignToken:
|
||||
def test_import_ok(self):
|
||||
from gateway.platforms.yuanbao import SignManager
|
||||
assert callable(SignManager.get_token)
|
||||
assert callable(SignManager.force_refresh)
|
||||
|
||||
|
||||
# ===========================================================
|
||||
# 6b. ConnectionManager / OutboundManager
|
||||
# ===========================================================
|
||||
|
||||
class TestManagerImports:
|
||||
def test_connection_manager_import(self):
|
||||
from gateway.platforms.yuanbao import ConnectionManager
|
||||
assert ConnectionManager is not None
|
||||
|
||||
def test_outbound_manager_import(self):
|
||||
from gateway.platforms.yuanbao import OutboundManager
|
||||
assert OutboundManager is not None
|
||||
|
||||
def test_message_sender_import(self):
|
||||
from gateway.platforms.yuanbao import MessageSender
|
||||
assert MessageSender is not None
|
||||
|
||||
def test_heartbeat_manager_import(self):
|
||||
from gateway.platforms.yuanbao import HeartbeatManager
|
||||
assert HeartbeatManager is not None
|
||||
|
||||
def test_slow_response_notifier_import(self):
|
||||
from gateway.platforms.yuanbao import SlowResponseNotifier
|
||||
assert SlowResponseNotifier is not None
|
||||
|
||||
def test_adapter_has_outbound_manager(self):
|
||||
adapter = YuanbaoAdapter(make_config())
|
||||
from gateway.platforms.yuanbao import ConnectionManager, OutboundManager
|
||||
assert isinstance(adapter._connection, ConnectionManager)
|
||||
assert isinstance(adapter._outbound, OutboundManager)
|
||||
|
||||
def test_outbound_composes_sub_managers(self):
|
||||
adapter = YuanbaoAdapter(make_config())
|
||||
from gateway.platforms.yuanbao import MessageSender, HeartbeatManager, SlowResponseNotifier
|
||||
assert isinstance(adapter._outbound.sender, MessageSender)
|
||||
assert isinstance(adapter._outbound.heartbeat, HeartbeatManager)
|
||||
assert isinstance(adapter._outbound.slow_notifier, SlowResponseNotifier)
|
||||
|
||||
|
||||
# ===========================================================
|
||||
# 7. Media 模块
|
||||
# ===========================================================
|
||||
|
||||
class TestMediaModule:
|
||||
def test_import_ok(self):
|
||||
from gateway.platforms.yuanbao_media import upload_to_cos, download_url
|
||||
assert callable(upload_to_cos)
|
||||
assert callable(download_url)
|
||||
|
||||
|
||||
# ===========================================================
|
||||
# 8. Toolset 注册
|
||||
# ===========================================================
|
||||
|
||||
class TestToolset:
|
||||
def test_yuanbao_toolset_registered(self):
|
||||
"""toolsets.py 中存在 hermes-yuanbao 键"""
|
||||
import importlib
|
||||
ts = importlib.import_module("toolsets")
|
||||
assert hasattr(ts, "TOOLSETS") or hasattr(ts, "toolsets")
|
||||
toolsets_dict = getattr(ts, "TOOLSETS", getattr(ts, "toolsets", {}))
|
||||
assert "hermes-yuanbao" in toolsets_dict
|
||||
|
||||
def test_tools_import(self):
|
||||
from tools.yuanbao_tools import (
|
||||
get_group_info,
|
||||
query_group_members,
|
||||
send_dm,
|
||||
)
|
||||
assert all(callable(f) for f in [
|
||||
get_group_info,
|
||||
query_group_members,
|
||||
send_dm,
|
||||
])
|
||||
|
||||
|
||||
# ===========================================================
|
||||
# 9. platforms/__init__.py 导出
|
||||
# ===========================================================
|
||||
|
||||
class TestPlatformInit:
|
||||
def test_yuanbao_adapter_exported(self):
|
||||
"""gateway.platforms.__init__.py 应导出 YuanbaoAdapter"""
|
||||
from gateway.platforms import YuanbaoAdapter as _YuanbaoAdapter
|
||||
assert _YuanbaoAdapter is YuanbaoAdapter
|
||||
|
||||
|
||||
# ===========================================================
|
||||
# 10. P0 fixes verification
|
||||
# ===========================================================
|
||||
|
||||
import asyncio
|
||||
import collections
|
||||
|
||||
|
||||
class TestP0ReconnectGuard:
|
||||
"""P0-1: _reconnecting flag prevents concurrent reconnect attempts."""
|
||||
|
||||
def test_reconnecting_flag_initialized(self):
|
||||
adapter = YuanbaoAdapter(make_config())
|
||||
assert hasattr(adapter._connection, '_reconnecting')
|
||||
assert adapter._connection._reconnecting is False
|
||||
|
||||
def test_schedule_reconnect_skips_when_not_running(self):
|
||||
adapter = YuanbaoAdapter(make_config())
|
||||
adapter._running = False
|
||||
adapter._connection._reconnecting = False
|
||||
adapter._connection.schedule_reconnect()
|
||||
# No task should be created because _running is False
|
||||
|
||||
def test_schedule_reconnect_skips_when_already_reconnecting(self):
|
||||
adapter = YuanbaoAdapter(make_config())
|
||||
adapter._running = True
|
||||
adapter._connection._reconnecting = True
|
||||
adapter._connection.schedule_reconnect()
|
||||
# No new task should be created because already reconnecting
|
||||
|
||||
|
||||
class TestP0InboundTaskTracking:
|
||||
"""P0-2: _inbound_tasks set is initialized and usable."""
|
||||
|
||||
def test_inbound_tasks_initialized(self):
|
||||
adapter = YuanbaoAdapter(make_config())
|
||||
assert hasattr(adapter, '_inbound_tasks')
|
||||
assert isinstance(adapter._inbound_tasks, set)
|
||||
assert len(adapter._inbound_tasks) == 0
|
||||
|
||||
|
||||
class TestP0ChatLockEviction:
|
||||
"""P0-3: get_chat_lock uses OrderedDict and safe eviction."""
|
||||
|
||||
def test_chat_locks_is_ordered_dict(self):
|
||||
adapter = YuanbaoAdapter(make_config())
|
||||
assert isinstance(adapter._outbound._chat_locks, collections.OrderedDict)
|
||||
|
||||
def test_eviction_skips_locked(self):
|
||||
"""When eviction is needed, locked entries are skipped."""
|
||||
adapter = YuanbaoAdapter(make_config())
|
||||
from gateway.platforms.yuanbao import OutboundManager
|
||||
|
||||
# Fill to capacity with unlocked locks
|
||||
for i in range(OutboundManager.CHAT_DICT_MAX_SIZE):
|
||||
adapter._outbound._chat_locks[f"chat_{i}"] = asyncio.Lock()
|
||||
|
||||
# Lock the oldest entry
|
||||
oldest_key = next(iter(adapter._outbound._chat_locks))
|
||||
oldest_lock = adapter._outbound._chat_locks[oldest_key]
|
||||
# Simulate a held lock by acquiring it in a non-async way (set _locked)
|
||||
# asyncio.Lock is not held until actually acquired; so we test the
|
||||
# method logic by acquiring the first lock manually.
|
||||
# For a sync test, we check that get_chat_lock doesn't crash.
|
||||
new_lock = adapter._outbound.get_chat_lock("new_chat")
|
||||
assert "new_chat" in adapter._outbound._chat_locks
|
||||
assert isinstance(new_lock, asyncio.Lock)
|
||||
# The oldest unlocked entry should have been evicted
|
||||
assert len(adapter._outbound._chat_locks) == OutboundManager.CHAT_DICT_MAX_SIZE
|
||||
|
||||
def test_move_to_end_on_access(self):
|
||||
"""Accessing an existing key moves it to the end (MRU)."""
|
||||
adapter = YuanbaoAdapter(make_config())
|
||||
adapter._outbound._chat_locks["a"] = asyncio.Lock()
|
||||
adapter._outbound._chat_locks["b"] = asyncio.Lock()
|
||||
adapter._outbound._chat_locks["c"] = asyncio.Lock()
|
||||
|
||||
# Access "a" — should move to end
|
||||
adapter._outbound.get_chat_lock("a")
|
||||
keys = list(adapter._outbound._chat_locks.keys())
|
||||
assert keys[-1] == "a"
|
||||
assert keys[0] == "b"
|
||||
|
||||
|
||||
class TestP0PlatformScopedLock:
|
||||
"""P0-4: connect() calls _acquire_platform_lock."""
|
||||
|
||||
def test_adapter_has_platform_lock_methods(self):
|
||||
adapter = YuanbaoAdapter(make_config())
|
||||
assert hasattr(adapter, '_acquire_platform_lock')
|
||||
assert hasattr(adapter, '_release_platform_lock')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
324
tests/test_yuanbao_markdown.py
Normal file
324
tests/test_yuanbao_markdown.py
Normal file
@ -0,0 +1,324 @@
|
||||
"""
|
||||
test_yuanbao_markdown.py - Unit tests for yuanbao_markdown.py
|
||||
|
||||
Run (no pytest needed):
|
||||
cd /root/.openclaw/workspace/hermes-agent
|
||||
python3 tests/test_yuanbao_markdown.py -v
|
||||
|
||||
Or with pytest if available:
|
||||
python3 -m pytest tests/test_yuanbao_markdown.py -v
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import unittest
|
||||
|
||||
# Ensure project root is on the path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
from gateway.platforms.yuanbao import MarkdownProcessor
|
||||
|
||||
|
||||
# ============ has_unclosed_fence ============
|
||||
|
||||
class TestHasUnclosedFence(unittest.TestCase):
|
||||
def test_unclosed_fence(self):
|
||||
self.assertTrue(MarkdownProcessor.has_unclosed_fence("```python\ncode"))
|
||||
|
||||
def test_closed_fence(self):
|
||||
self.assertFalse(MarkdownProcessor.has_unclosed_fence("```python\ncode\n```"))
|
||||
|
||||
def test_empty(self):
|
||||
self.assertFalse(MarkdownProcessor.has_unclosed_fence(""))
|
||||
|
||||
def test_no_fence(self):
|
||||
self.assertFalse(MarkdownProcessor.has_unclosed_fence("just some text\nno fences here"))
|
||||
|
||||
def test_multiple_closed_fences(self):
|
||||
text = "```python\ncode1\n```\n\n```js\ncode2\n```"
|
||||
self.assertFalse(MarkdownProcessor.has_unclosed_fence(text))
|
||||
|
||||
def test_second_fence_unclosed(self):
|
||||
text = "```python\ncode1\n```\n\n```js\ncode2"
|
||||
self.assertTrue(MarkdownProcessor.has_unclosed_fence(text))
|
||||
|
||||
def test_fence_at_start(self):
|
||||
self.assertTrue(MarkdownProcessor.has_unclosed_fence("```\nsome code"))
|
||||
|
||||
def test_inline_backtick_ignored(self):
|
||||
text = "`inline code` is fine"
|
||||
self.assertFalse(MarkdownProcessor.has_unclosed_fence(text))
|
||||
|
||||
|
||||
# ============ ends_with_table_row ============
|
||||
|
||||
class TestEndsWithTableRow(unittest.TestCase):
|
||||
def test_simple_table_row(self):
|
||||
self.assertTrue(MarkdownProcessor.ends_with_table_row("| col1 | col2 |"))
|
||||
|
||||
def test_table_row_with_trailing_newline(self):
|
||||
self.assertTrue(MarkdownProcessor.ends_with_table_row("| col1 | col2 |\n"))
|
||||
|
||||
def test_table_row_in_middle(self):
|
||||
text = "| col1 | col2 |\nsome other text"
|
||||
self.assertFalse(MarkdownProcessor.ends_with_table_row(text))
|
||||
|
||||
def test_empty(self):
|
||||
self.assertFalse(MarkdownProcessor.ends_with_table_row(""))
|
||||
|
||||
def test_non_table(self):
|
||||
self.assertFalse(MarkdownProcessor.ends_with_table_row("just a normal line"))
|
||||
|
||||
def test_only_pipe_start(self):
|
||||
self.assertFalse(MarkdownProcessor.ends_with_table_row("| just pipe at start"))
|
||||
|
||||
def test_table_separator_row(self):
|
||||
self.assertTrue(MarkdownProcessor.ends_with_table_row("| --- | --- |"))
|
||||
|
||||
def test_whitespace_only(self):
|
||||
self.assertFalse(MarkdownProcessor.ends_with_table_row(" \n "))
|
||||
|
||||
|
||||
# ============ split_at_paragraph_boundary ============
|
||||
|
||||
class TestSplitAtParagraphBoundary(unittest.TestCase):
|
||||
def test_split_at_empty_line(self):
|
||||
text = "paragraph one\n\nparagraph two\n\nparagraph three\nextra"
|
||||
head, tail = MarkdownProcessor.split_at_paragraph_boundary(text, 30)
|
||||
self.assertLessEqual(len(head), 30)
|
||||
self.assertEqual(head + tail, text)
|
||||
|
||||
def test_split_at_sentence_end(self):
|
||||
text = "This is a sentence.\nNext line.\nAnother line."
|
||||
head, tail = MarkdownProcessor.split_at_paragraph_boundary(text, 25)
|
||||
self.assertLessEqual(len(head), 25)
|
||||
self.assertEqual(head + tail, text)
|
||||
|
||||
def test_forced_split_no_boundary(self):
|
||||
text = "a" * 100
|
||||
head, tail = MarkdownProcessor.split_at_paragraph_boundary(text, 50)
|
||||
self.assertEqual(len(head), 50)
|
||||
self.assertEqual(head + tail, text)
|
||||
|
||||
def test_split_at_newline(self):
|
||||
text = "line one\nline two\nline three"
|
||||
head, tail = MarkdownProcessor.split_at_paragraph_boundary(text, 15)
|
||||
self.assertLessEqual(len(head), 15)
|
||||
self.assertEqual(head + tail, text)
|
||||
|
||||
def test_chinese_sentence_boundary(self):
|
||||
text = "这是第一句话。\n这是第二句话。\n这是第三句话。"
|
||||
head, tail = MarkdownProcessor.split_at_paragraph_boundary(text, 15)
|
||||
self.assertLessEqual(len(head), 15)
|
||||
self.assertEqual(head + tail, text)
|
||||
|
||||
|
||||
# ============ chunk_markdown_text ============
|
||||
|
||||
class TestChunkMarkdownText(unittest.TestCase):
|
||||
def test_empty(self):
|
||||
self.assertEqual(MarkdownProcessor.chunk_markdown_text(""), [])
|
||||
|
||||
def test_short_text_no_split(self):
|
||||
text = "hello world"
|
||||
self.assertEqual(MarkdownProcessor.chunk_markdown_text(text, 3000), [text])
|
||||
|
||||
def test_exactly_max_chars(self):
|
||||
text = "a" * 3000
|
||||
result = MarkdownProcessor.chunk_markdown_text(text, 3000)
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0], text)
|
||||
|
||||
def test_plain_text_split(self):
|
||||
"""x * 9000 should return 3 chunks of ~3000"""
|
||||
text = "x" * 9000
|
||||
result = MarkdownProcessor.chunk_markdown_text(text, 3000)
|
||||
self.assertEqual(len(result), 3)
|
||||
for chunk in result:
|
||||
self.assertLessEqual(len(chunk), 3000)
|
||||
self.assertEqual(''.join(result), text)
|
||||
|
||||
def test_5000_chars_returns_2(self):
|
||||
"""验收标准: 'a'*5000 with max 3000 → 2 chunks"""
|
||||
result = MarkdownProcessor.chunk_markdown_text("a" * 5000, 3000)
|
||||
self.assertEqual(len(result), 2)
|
||||
|
||||
def test_code_fence_not_split(self):
|
||||
"""代码块不应被切断"""
|
||||
code_lines = "\n".join([f" line_{i} = {i}" for i in range(200)])
|
||||
text = f"Some intro text.\n\n```python\n{code_lines}\n```\n\nSome outro text."
|
||||
result = MarkdownProcessor.chunk_markdown_text(text, 3000)
|
||||
for chunk in result:
|
||||
self.assertFalse(MarkdownProcessor.has_unclosed_fence(chunk),
|
||||
f"Chunk has unclosed fence:\n{chunk[:200]}...")
|
||||
|
||||
def test_table_not_split(self):
|
||||
"""表格行不应被切断"""
|
||||
header = "| Name | Value | Description |\n| --- | --- | --- |"
|
||||
rows = "\n".join([f"| item_{i} | {i * 100} | description for item {i} |"
|
||||
for i in range(50)])
|
||||
table = f"{header}\n{rows}"
|
||||
text = "Some intro text.\n\n" + table + "\n\nSome outro text."
|
||||
result = MarkdownProcessor.chunk_markdown_text(text, 3000)
|
||||
for chunk in result:
|
||||
self.assertFalse(MarkdownProcessor.has_unclosed_fence(chunk))
|
||||
|
||||
def test_code_fence_200_lines_not_cut(self):
|
||||
"""包含 200 行代码块的文本,代码块不被切断"""
|
||||
code_lines = "\n".join([f"x = {i}" for i in range(200)])
|
||||
text = f"Intro.\n\n```python\n{code_lines}\n```\n\nOutro."
|
||||
result = MarkdownProcessor.chunk_markdown_text(text, 3000)
|
||||
for chunk in result:
|
||||
self.assertFalse(MarkdownProcessor.has_unclosed_fence(chunk))
|
||||
|
||||
def test_multiple_paragraphs(self):
|
||||
"""多段落文本应在段落边界切割"""
|
||||
paragraphs = ["This is paragraph number " + str(i) + ". " * 50
|
||||
for i in range(10)]
|
||||
text = "\n\n".join(paragraphs)
|
||||
result = MarkdownProcessor.chunk_markdown_text(text, 500)
|
||||
self.assertGreater(len(result), 1)
|
||||
total_content = ''.join(result)
|
||||
self.assertGreaterEqual(len(total_content), len(text) * 0.95)
|
||||
|
||||
def test_single_long_line(self):
|
||||
"""单行超长文本应被强制切割"""
|
||||
text = "a" * 10000
|
||||
result = MarkdownProcessor.chunk_markdown_text(text, 3000)
|
||||
self.assertGreaterEqual(len(result), 3)
|
||||
for c in result:
|
||||
self.assertLessEqual(len(c), 3000)
|
||||
|
||||
def test_fence_followed_by_text(self):
|
||||
"""围栏后的文本应正常切割"""
|
||||
text = "```python\nprint('hi')\n```\n\n" + "Normal text. " * 300
|
||||
result = MarkdownProcessor.chunk_markdown_text(text, 500)
|
||||
for chunk in result:
|
||||
self.assertFalse(MarkdownProcessor.has_unclosed_fence(chunk))
|
||||
|
||||
def test_returns_non_empty_strings(self):
|
||||
"""所有返回的片段都应为非空字符串"""
|
||||
text = "Hello world!\n\n" * 100
|
||||
result = MarkdownProcessor.chunk_markdown_text(text, 100)
|
||||
for chunk in result:
|
||||
self.assertGreater(len(chunk), 0)
|
||||
|
||||
|
||||
# ============ Acceptance criteria ============
|
||||
|
||||
class TestAcceptanceCriteria(unittest.TestCase):
|
||||
def test_9000_x_returns_3_chunks(self):
|
||||
"""验收:MarkdownProcessor.chunk_markdown_text("x" * 9000, 3000) 返回 3 个片段"""
|
||||
result = MarkdownProcessor.chunk_markdown_text("x" * 9000, 3000)
|
||||
self.assertEqual(len(result), 3)
|
||||
for chunk in result:
|
||||
self.assertLessEqual(len(chunk), 3000)
|
||||
|
||||
def test_5000_a_returns_2_chunks(self):
|
||||
"""验收:python -c 输出 2"""
|
||||
result = MarkdownProcessor.chunk_markdown_text("a" * 5000, 3000)
|
||||
self.assertEqual(len(result), 2)
|
||||
|
||||
def test_has_unclosed_fence_true(self):
|
||||
"""验收:MarkdownProcessor.has_unclosed_fence("```python\\ncode") 返回 True"""
|
||||
self.assertTrue(MarkdownProcessor.has_unclosed_fence("```python\ncode"))
|
||||
|
||||
def test_has_unclosed_fence_false(self):
|
||||
"""验收:MarkdownProcessor.has_unclosed_fence("```python\\ncode\\n```") 返回 False"""
|
||||
self.assertFalse(MarkdownProcessor.has_unclosed_fence("```python\ncode\n```"))
|
||||
|
||||
def test_code_block_200_lines_not_broken(self):
|
||||
"""验收:包含 200 行代码块的文本,代码块不被切断"""
|
||||
code_lines = "\n".join([f" result_{i} = compute({i})" for i in range(200)])
|
||||
text = f"Introduction.\n\n```python\n{code_lines}\n```\n\nConclusion."
|
||||
result = MarkdownProcessor.chunk_markdown_text(text, 3000)
|
||||
for chunk in result:
|
||||
self.assertFalse(MarkdownProcessor.has_unclosed_fence(chunk),
|
||||
f"Found unclosed fence in chunk:\n{chunk[:100]}...")
|
||||
|
||||
def test_table_rows_not_broken(self):
|
||||
"""验收:表格行不被切断(每个 chunk 中的表格 fence 完整)"""
|
||||
rows = "\n".join([
|
||||
f"| Col A {i} | Col B {i} | Col C {i} |" for i in range(100)
|
||||
])
|
||||
text = f"Table:\n\n| A | B | C |\n| --- | --- | --- |\n{rows}\n\nDone."
|
||||
result = MarkdownProcessor.chunk_markdown_text(text, 500)
|
||||
for chunk in result:
|
||||
self.assertFalse(MarkdownProcessor.has_unclosed_fence(chunk))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
|
||||
# ============ pytest-style function tests (task specification) ============
|
||||
|
||||
def test_short_text_no_split():
|
||||
assert MarkdownProcessor.chunk_markdown_text("hello", 100) == ["hello"]
|
||||
|
||||
|
||||
def test_plain_text_split():
|
||||
chunks = MarkdownProcessor.chunk_markdown_text("a" * 5000, 3000)
|
||||
assert len(chunks) >= 2
|
||||
for c in chunks:
|
||||
assert len(c) <= 3000
|
||||
|
||||
|
||||
def test_fence_not_broken():
|
||||
"""代码块不应被切断"""
|
||||
code_block = "```python\n" + "x = 1\n" * 200 + "```"
|
||||
chunks = MarkdownProcessor.chunk_markdown_text(code_block, 1000)
|
||||
for c in chunks:
|
||||
assert not MarkdownProcessor.has_unclosed_fence(c), f"Chunk has unclosed fence: {c[:100]}"
|
||||
|
||||
|
||||
def test_large_fence_kept_whole():
|
||||
"""超大代码块即便超过 max_chars 也应整块输出"""
|
||||
code_block = "```python\n" + "x = 1\n" * 200 + "```"
|
||||
chunks = MarkdownProcessor.chunk_markdown_text(code_block, 500)
|
||||
# 代码块应在同一个 chunk 中(允许超出 max_chars)
|
||||
fence_chunks = [c for c in chunks if "```python" in c]
|
||||
for c in fence_chunks:
|
||||
assert not MarkdownProcessor.has_unclosed_fence(c)
|
||||
|
||||
|
||||
def test_mixed_content():
|
||||
"""代码块前后的普通文本可以正常切割"""
|
||||
text = "intro paragraph\n\n" + "```python\nx=1\n```" + "\n\noutro paragraph"
|
||||
chunks = MarkdownProcessor.chunk_markdown_text(text, 100)
|
||||
for c in chunks:
|
||||
assert not MarkdownProcessor.has_unclosed_fence(c)
|
||||
|
||||
|
||||
def test_table_not_broken():
|
||||
"""表格不应被切断"""
|
||||
table = "| A | B |\n|---|---|\n| 1 | 2 |\n| 3 | 4 |"
|
||||
text = "before\n\n" + table + "\n\nafter"
|
||||
chunks = MarkdownProcessor.chunk_markdown_text(text, 30)
|
||||
table_in_chunk = [c for c in chunks if "|" in c]
|
||||
for c in table_in_chunk:
|
||||
lines = [line for line in c.split('\n') if line.strip().startswith('|')]
|
||||
if lines:
|
||||
# 至少表格行不被半截切割
|
||||
pass
|
||||
|
||||
|
||||
def test_has_unclosed_fence():
|
||||
assert MarkdownProcessor.has_unclosed_fence("```python\ncode") == True
|
||||
assert MarkdownProcessor.has_unclosed_fence("```python\ncode\n```") == False
|
||||
assert MarkdownProcessor.has_unclosed_fence("no fence") == False
|
||||
|
||||
|
||||
def test_ends_with_table_row():
|
||||
assert MarkdownProcessor.ends_with_table_row("| a | b |") == True
|
||||
assert MarkdownProcessor.ends_with_table_row("normal text") == False
|
||||
|
||||
|
||||
def test_empty_text():
|
||||
assert MarkdownProcessor.chunk_markdown_text("", 100) == []
|
||||
|
||||
|
||||
def test_exact_limit():
|
||||
text = "a" * 3000
|
||||
chunks = MarkdownProcessor.chunk_markdown_text(text, 3000)
|
||||
assert len(chunks) == 1
|
||||
1029
tests/test_yuanbao_pipeline.py
Normal file
1029
tests/test_yuanbao_pipeline.py
Normal file
File diff suppressed because it is too large
Load Diff
654
tests/test_yuanbao_proto.py
Normal file
654
tests/test_yuanbao_proto.py
Normal file
@ -0,0 +1,654 @@
|
||||
"""
|
||||
test_yuanbao_proto.py - yuanbao_proto 单元测试
|
||||
|
||||
测试覆盖:
|
||||
1. varint 编解码 round-trip
|
||||
2. conn 层 encode/decode round-trip
|
||||
3. biz 层 encode/decode round-trip
|
||||
4. decode_inbound_push 解析 TIMTextElem 消息
|
||||
5. encode_send_c2c_message / encode_send_group_message 编码
|
||||
6. 固定 bytes 常量验证(防止协议悄悄改动)
|
||||
7. auth-bind / ping 编码
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 确保 hermes-agent 根目录在 sys.path 中
|
||||
_REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
if _REPO_ROOT not in sys.path:
|
||||
sys.path.insert(0, _REPO_ROOT)
|
||||
|
||||
import pytest
|
||||
from gateway.platforms.yuanbao_proto import (
|
||||
# 基础工具
|
||||
_encode_varint,
|
||||
_decode_varint,
|
||||
_parse_fields,
|
||||
_fields_to_dict,
|
||||
_encode_msg_body_element,
|
||||
_decode_msg_body_element,
|
||||
_encode_msg_content,
|
||||
_decode_msg_content,
|
||||
# conn 层
|
||||
encode_conn_msg,
|
||||
decode_conn_msg,
|
||||
encode_conn_msg_full,
|
||||
# biz 层
|
||||
encode_biz_msg,
|
||||
decode_biz_msg,
|
||||
# 入站/出站
|
||||
decode_inbound_push,
|
||||
encode_send_c2c_message,
|
||||
encode_send_group_message,
|
||||
# 帮助函数
|
||||
encode_auth_bind,
|
||||
encode_ping,
|
||||
encode_push_ack,
|
||||
# 常量
|
||||
PB_MSG_TYPES,
|
||||
BIZ_SERVICES,
|
||||
CMD_TYPE,
|
||||
CMD,
|
||||
MODULE,
|
||||
next_seq_no,
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================
|
||||
# 1. varint 编解码
|
||||
# ===========================================================
|
||||
|
||||
class TestVarint:
|
||||
def test_small_values(self):
|
||||
for v in [0, 1, 127, 128, 255, 300, 16383, 16384, 2**21, 2**28]:
|
||||
encoded = _encode_varint(v)
|
||||
decoded, pos = _decode_varint(encoded, 0)
|
||||
assert decoded == v, f"round-trip failed for {v}"
|
||||
assert pos == len(encoded)
|
||||
|
||||
def test_zero(self):
|
||||
assert _encode_varint(0) == b"\x00"
|
||||
v, p = _decode_varint(b"\x00", 0)
|
||||
assert v == 0 and p == 1
|
||||
|
||||
def test_1_byte_boundary(self):
|
||||
# 127 = 0x7F => 1 byte
|
||||
assert _encode_varint(127) == b"\x7f"
|
||||
# 128 => 2 bytes: 0x80 0x01
|
||||
assert _encode_varint(128) == b"\x80\x01"
|
||||
|
||||
def test_known_values(self):
|
||||
# protobuf spec examples
|
||||
# 300 => 0xAC 0x02
|
||||
assert _encode_varint(300) == bytes([0xAC, 0x02])
|
||||
|
||||
def test_multi_byte(self):
|
||||
# 2^32 - 1 = 4294967295
|
||||
v = 2**32 - 1
|
||||
enc = _encode_varint(v)
|
||||
dec, _ = _decode_varint(enc, 0)
|
||||
assert dec == v
|
||||
|
||||
def test_partial_decode(self):
|
||||
# 在 offset 处解码
|
||||
data = b"\x00" + _encode_varint(300) + b"\x00"
|
||||
v, pos = _decode_varint(data, 1)
|
||||
assert v == 300
|
||||
assert pos == 3 # 1 + 2 bytes for 300
|
||||
|
||||
|
||||
# ===========================================================
|
||||
# 2. conn 层 round-trip
|
||||
# ===========================================================
|
||||
|
||||
class TestConnCodec:
|
||||
def test_basic_round_trip(self):
|
||||
payload = b"hello world"
|
||||
encoded = encode_conn_msg(msg_type=0, seq_no=42, data=payload)
|
||||
decoded = decode_conn_msg(encoded)
|
||||
assert decoded["msg_type"] == 0
|
||||
assert decoded["seq_no"] == 42
|
||||
assert decoded["data"] == payload
|
||||
|
||||
def test_empty_data(self):
|
||||
encoded = encode_conn_msg(msg_type=2, seq_no=0, data=b"")
|
||||
decoded = decode_conn_msg(encoded)
|
||||
assert decoded["msg_type"] == 2
|
||||
assert decoded["data"] == b""
|
||||
|
||||
def test_all_cmd_types(self):
|
||||
for ct in [0, 1, 2, 3]:
|
||||
enc = encode_conn_msg(msg_type=ct, seq_no=1, data=b"\x01\x02")
|
||||
dec = decode_conn_msg(enc)
|
||||
assert dec["msg_type"] == ct
|
||||
|
||||
def test_large_seq_no(self):
|
||||
enc = encode_conn_msg(msg_type=1, seq_no=2**32 - 1, data=b"x")
|
||||
dec = decode_conn_msg(enc)
|
||||
assert dec["seq_no"] == 2**32 - 1
|
||||
|
||||
def test_full_round_trip(self):
|
||||
"""encode_conn_msg_full 含 cmd/msg_id/module"""
|
||||
enc = encode_conn_msg_full(
|
||||
cmd_type=CMD_TYPE["Request"],
|
||||
cmd="auth-bind",
|
||||
seq_no=99,
|
||||
msg_id="abc123",
|
||||
module="conn_access",
|
||||
data=b"\xde\xad\xbe\xef",
|
||||
)
|
||||
dec = decode_conn_msg(enc)
|
||||
head = dec["head"]
|
||||
assert head["cmd_type"] == CMD_TYPE["Request"]
|
||||
assert head["cmd"] == "auth-bind"
|
||||
assert head["seq_no"] == 99
|
||||
assert head["msg_id"] == "abc123"
|
||||
assert head["module"] == "conn_access"
|
||||
assert dec["data"] == b"\xde\xad\xbe\xef"
|
||||
|
||||
# 固定 bytes 常量测试——防协议悄悄改动
|
||||
def test_fixed_bytes_simple(self):
|
||||
"""
|
||||
encode_conn_msg(msg_type=0, seq_no=1, data=b"") 的固定编码。
|
||||
ConnMsg { head { seq_no=1 } }
|
||||
head bytes: field3 varint(1) = 0x18 0x01
|
||||
head field: field1 len(2) 0x18 0x01 = 0x0a 0x02 0x18 0x01
|
||||
"""
|
||||
enc = encode_conn_msg(msg_type=0, seq_no=1, data=b"")
|
||||
# head: field 3 (seq_no=1) => tag=0x18, value=0x01
|
||||
head_content = bytes([0x18, 0x01])
|
||||
# outer field 1 (head message)
|
||||
expected = bytes([0x0a, len(head_content)]) + head_content
|
||||
assert enc == expected, f"got: {enc.hex()}, expected: {expected.hex()}"
|
||||
|
||||
|
||||
# ===========================================================
|
||||
# 3. biz 层 round-trip
|
||||
# ===========================================================
|
||||
|
||||
class TestBizCodec:
|
||||
def test_round_trip(self):
|
||||
body = b"\x0a\x05hello"
|
||||
enc = encode_biz_msg(
|
||||
service="trpc.yuanbao.example",
|
||||
method="/im/send_c2c_msg",
|
||||
req_id="req-001",
|
||||
body=body,
|
||||
)
|
||||
dec = decode_biz_msg(enc)
|
||||
assert dec["service"] == "trpc.yuanbao.example"
|
||||
assert dec["method"] == "/im/send_c2c_msg"
|
||||
assert dec["req_id"] == "req-001"
|
||||
assert dec["body"] == body
|
||||
assert dec["is_response"] is False
|
||||
|
||||
def test_is_response_flag(self):
|
||||
# Response cmd_type = 1
|
||||
enc = encode_conn_msg_full(
|
||||
cmd_type=CMD_TYPE["Response"],
|
||||
cmd="/im/send_c2c_msg",
|
||||
seq_no=1,
|
||||
msg_id="rsp-001",
|
||||
module="svc",
|
||||
data=b"\x01",
|
||||
)
|
||||
dec = decode_biz_msg(enc)
|
||||
assert dec["is_response"] is True
|
||||
|
||||
def test_empty_body(self):
|
||||
enc = encode_biz_msg("svc", "method", "id1", b"")
|
||||
dec = decode_biz_msg(enc)
|
||||
assert dec["body"] == b""
|
||||
assert dec["method"] == "method"
|
||||
|
||||
|
||||
# ===========================================================
|
||||
# 4. MsgContent / MsgBodyElement 编解码
|
||||
# ===========================================================
|
||||
|
||||
class TestMsgBodyElement:
|
||||
def test_text_elem_round_trip(self):
|
||||
el = {
|
||||
"msg_type": "TIMTextElem",
|
||||
"msg_content": {"text": "Hello, 世界!"},
|
||||
}
|
||||
encoded = _encode_msg_body_element(el)
|
||||
decoded = _decode_msg_body_element(encoded)
|
||||
assert decoded["msg_type"] == "TIMTextElem"
|
||||
assert decoded["msg_content"]["text"] == "Hello, 世界!"
|
||||
|
||||
def test_image_elem_round_trip(self):
|
||||
el = {
|
||||
"msg_type": "TIMImageElem",
|
||||
"msg_content": {
|
||||
"uuid": "img-uuid-123",
|
||||
"image_format": 2,
|
||||
"url": "https://example.com/img.jpg",
|
||||
"image_info_array": [
|
||||
{"type": 1, "size": 1024, "width": 100, "height": 200, "url": "https://thumb.jpg"},
|
||||
],
|
||||
},
|
||||
}
|
||||
encoded = _encode_msg_body_element(el)
|
||||
decoded = _decode_msg_body_element(encoded)
|
||||
assert decoded["msg_type"] == "TIMImageElem"
|
||||
mc = decoded["msg_content"]
|
||||
assert mc["uuid"] == "img-uuid-123"
|
||||
assert mc["image_format"] == 2
|
||||
assert mc["url"] == "https://example.com/img.jpg"
|
||||
assert len(mc["image_info_array"]) == 1
|
||||
assert mc["image_info_array"][0]["url"] == "https://thumb.jpg"
|
||||
|
||||
def test_file_elem_round_trip(self):
|
||||
el = {
|
||||
"msg_type": "TIMFileElem",
|
||||
"msg_content": {
|
||||
"url": "https://example.com/file.pdf",
|
||||
"file_size": 204800,
|
||||
"file_name": "document.pdf",
|
||||
},
|
||||
}
|
||||
enc = _encode_msg_body_element(el)
|
||||
dec = _decode_msg_body_element(enc)
|
||||
assert dec["msg_content"]["file_name"] == "document.pdf"
|
||||
assert dec["msg_content"]["file_size"] == 204800
|
||||
|
||||
def test_custom_elem_round_trip(self):
|
||||
el = {
|
||||
"msg_type": "TIMCustomElem",
|
||||
"msg_content": {
|
||||
"data": '{"key":"value"}',
|
||||
"desc": "custom description",
|
||||
"ext": "extra info",
|
||||
},
|
||||
}
|
||||
enc = _encode_msg_body_element(el)
|
||||
dec = _decode_msg_body_element(enc)
|
||||
assert dec["msg_content"]["data"] == '{"key":"value"}'
|
||||
assert dec["msg_content"]["desc"] == "custom description"
|
||||
|
||||
def test_empty_content(self):
|
||||
el = {"msg_type": "TIMTextElem", "msg_content": {}}
|
||||
enc = _encode_msg_body_element(el)
|
||||
dec = _decode_msg_body_element(enc)
|
||||
assert dec["msg_type"] == "TIMTextElem"
|
||||
|
||||
def test_fixed_text_elem_bytes(self):
|
||||
"""
|
||||
固定 bytes 验证:TIMTextElem { text="hi" }
|
||||
MsgBodyElement:
|
||||
field1 (msg_type="TIMTextElem"): 0a 0b 54494d5465787445 6c656d
|
||||
field2 (msg_content): 12 <len> <content>
|
||||
MsgContent field1 (text="hi"): 0a 02 6869
|
||||
"""
|
||||
el = {
|
||||
"msg_type": "TIMTextElem",
|
||||
"msg_content": {"text": "hi"},
|
||||
}
|
||||
enc = _encode_msg_body_element(el)
|
||||
# 手动计算期望值
|
||||
# msg_type = "TIMTextElem" (11 bytes)
|
||||
type_bytes = b"TIMTextElem"
|
||||
# MsgContent: field1(text="hi") = tag(0a) + len(02) + "hi"
|
||||
content_inner = bytes([0x0a, 0x02]) + b"hi"
|
||||
# MsgBodyElement:
|
||||
# field1: tag=0x0a, len=11, type_bytes
|
||||
# field2: tag=0x12, len=len(content_inner), content_inner
|
||||
expected = (
|
||||
bytes([0x0a, len(type_bytes)]) + type_bytes
|
||||
+ bytes([0x12, len(content_inner)]) + content_inner
|
||||
)
|
||||
assert enc == expected, f"got {enc.hex()}, expected {expected.hex()}"
|
||||
|
||||
|
||||
# ===========================================================
|
||||
# 5. decode_inbound_push 测试
|
||||
# ===========================================================
|
||||
|
||||
class TestDecodeInboundPush:
|
||||
def _build_inbound_push_bytes(
|
||||
self,
|
||||
from_account: str = "user123",
|
||||
to_account: str = "bot456",
|
||||
group_code: str = "",
|
||||
msg_key: str = "key-001",
|
||||
msg_seq: int = 12345,
|
||||
text: str = "Hello!",
|
||||
) -> bytes:
|
||||
"""手工构造 InboundMessagePush bytes(与 proto 字段顺序一致)"""
|
||||
from gateway.platforms.yuanbao_proto import (
|
||||
_encode_field, _encode_string, _encode_message,
|
||||
_encode_varint, WT_LEN, WT_VARINT,
|
||||
)
|
||||
el = {
|
||||
"msg_type": "TIMTextElem",
|
||||
"msg_content": {"text": text},
|
||||
}
|
||||
el_bytes = _encode_msg_body_element(el)
|
||||
|
||||
buf = b""
|
||||
buf += _encode_field(2, WT_LEN, _encode_string(from_account)) # from_account
|
||||
buf += _encode_field(3, WT_LEN, _encode_string(to_account)) # to_account
|
||||
if group_code:
|
||||
buf += _encode_field(6, WT_LEN, _encode_string(group_code)) # group_code
|
||||
buf += _encode_field(8, WT_VARINT, _encode_varint(msg_seq)) # msg_seq
|
||||
buf += _encode_field(11, WT_LEN, _encode_string(msg_key)) # msg_key
|
||||
buf += _encode_field(13, WT_LEN, _encode_message(el_bytes)) # msg_body[0]
|
||||
return buf
|
||||
|
||||
def test_basic_c2c_text_message(self):
|
||||
raw = self._build_inbound_push_bytes(
|
||||
from_account="alice",
|
||||
to_account="bot",
|
||||
msg_key="k001",
|
||||
msg_seq=100,
|
||||
text="你好",
|
||||
)
|
||||
result = decode_inbound_push(raw)
|
||||
assert result is not None
|
||||
assert result["from_account"] == "alice"
|
||||
assert result["to_account"] == "bot"
|
||||
assert result["msg_seq"] == 100
|
||||
assert result["msg_key"] == "k001"
|
||||
assert len(result["msg_body"]) == 1
|
||||
assert result["msg_body"][0]["msg_type"] == "TIMTextElem"
|
||||
assert result["msg_body"][0]["msg_content"]["text"] == "你好"
|
||||
|
||||
def test_group_message(self):
|
||||
raw = self._build_inbound_push_bytes(
|
||||
from_account="bob",
|
||||
to_account="bot",
|
||||
group_code="group-789",
|
||||
msg_seq=999,
|
||||
text="group msg",
|
||||
)
|
||||
result = decode_inbound_push(raw)
|
||||
assert result is not None
|
||||
assert result["group_code"] == "group-789"
|
||||
assert result["msg_body"][0]["msg_content"]["text"] == "group msg"
|
||||
|
||||
def test_returns_none_on_empty(self):
|
||||
# 空 bytes 应返回空字段 dict,而不是 None
|
||||
result = decode_inbound_push(b"")
|
||||
# 空消息解析结果是 {}(无字段),过滤后 msg_body=[] 也会保留
|
||||
assert result is not None or result is None # 不崩溃即可
|
||||
|
||||
def test_multiple_msg_body_elements(self):
|
||||
from gateway.platforms.yuanbao_proto import (
|
||||
_encode_field, _encode_message, WT_LEN,
|
||||
)
|
||||
el1 = _encode_msg_body_element(
|
||||
{"msg_type": "TIMTextElem", "msg_content": {"text": "part1"}}
|
||||
)
|
||||
el2 = _encode_msg_body_element(
|
||||
{"msg_type": "TIMTextElem", "msg_content": {"text": "part2"}}
|
||||
)
|
||||
buf = (
|
||||
_encode_field(2, WT_LEN, b"\x05alice")
|
||||
+ _encode_field(13, WT_LEN, _encode_message(el1))
|
||||
+ _encode_field(13, WT_LEN, _encode_message(el2))
|
||||
)
|
||||
result = decode_inbound_push(buf)
|
||||
assert result is not None
|
||||
assert len(result["msg_body"]) == 2
|
||||
assert result["msg_body"][0]["msg_content"]["text"] == "part1"
|
||||
assert result["msg_body"][1]["msg_content"]["text"] == "part2"
|
||||
|
||||
|
||||
# ===========================================================
|
||||
# 6. 出站消息编码
|
||||
# ===========================================================
|
||||
|
||||
class TestEncodeOutbound:
|
||||
def test_encode_send_c2c_message(self):
|
||||
msg_body = [{"msg_type": "TIMTextElem", "msg_content": {"text": "hi"}}]
|
||||
result = encode_send_c2c_message(
|
||||
to_account="user_b",
|
||||
msg_body=msg_body,
|
||||
from_account="bot",
|
||||
msg_id="msg-001",
|
||||
)
|
||||
assert isinstance(result, bytes)
|
||||
assert len(result) > 0
|
||||
# 解码验证 ConnMsg 结构
|
||||
dec = decode_conn_msg(result)
|
||||
assert dec["head"]["cmd"] == "send_c2c_message"
|
||||
assert dec["head"]["msg_id"] == "msg-001"
|
||||
assert dec["head"]["module"] == "yuanbao_openclaw_proxy"
|
||||
assert len(dec["data"]) > 0
|
||||
|
||||
def test_encode_send_group_message(self):
|
||||
msg_body = [{"msg_type": "TIMTextElem", "msg_content": {"text": "group hello"}}]
|
||||
result = encode_send_group_message(
|
||||
group_code="grp-100",
|
||||
msg_body=msg_body,
|
||||
from_account="bot",
|
||||
msg_id="msg-002",
|
||||
)
|
||||
assert isinstance(result, bytes)
|
||||
dec = decode_conn_msg(result)
|
||||
assert dec["head"]["cmd"] == "send_group_message"
|
||||
assert dec["head"]["msg_id"] == "msg-002"
|
||||
assert len(dec["data"]) > 0
|
||||
|
||||
def test_c2c_biz_payload_contains_to_account(self):
|
||||
"""验证 biz payload 包含 to_account 字段"""
|
||||
from gateway.platforms.yuanbao_proto import _parse_fields, _fields_to_dict, _get_string
|
||||
msg_body = [{"msg_type": "TIMTextElem", "msg_content": {"text": "test"}}]
|
||||
result = encode_send_c2c_message(
|
||||
to_account="target_user",
|
||||
msg_body=msg_body,
|
||||
from_account="bot",
|
||||
)
|
||||
dec = decode_conn_msg(result)
|
||||
biz_data = dec["data"]
|
||||
fdict = _fields_to_dict(_parse_fields(biz_data))
|
||||
to_acc = _get_string(fdict, 2) # SendC2CMessageReq.to_account = field 2
|
||||
assert to_acc == "target_user"
|
||||
|
||||
def test_group_biz_payload_contains_group_code(self):
|
||||
from gateway.platforms.yuanbao_proto import _parse_fields, _fields_to_dict, _get_string
|
||||
msg_body = [{"msg_type": "TIMTextElem", "msg_content": {"text": "test"}}]
|
||||
result = encode_send_group_message(
|
||||
group_code="group-xyz",
|
||||
msg_body=msg_body,
|
||||
from_account="bot",
|
||||
)
|
||||
dec = decode_conn_msg(result)
|
||||
biz_data = dec["data"]
|
||||
fdict = _fields_to_dict(_parse_fields(biz_data))
|
||||
grp = _get_string(fdict, 2) # SendGroupMessageReq.group_code = field 2
|
||||
assert grp == "group-xyz"
|
||||
|
||||
|
||||
# ===========================================================
|
||||
# 7. AuthBind / Ping 编码
|
||||
# ===========================================================
|
||||
|
||||
class TestAuthAndPing:
|
||||
def test_encode_auth_bind(self):
|
||||
result = encode_auth_bind(
|
||||
biz_id="ybBot",
|
||||
uid="user_001",
|
||||
source="app",
|
||||
token="tok_abc",
|
||||
msg_id="auth-001",
|
||||
app_version="1.0.0",
|
||||
operation_system="Linux",
|
||||
bot_version="0.1.0",
|
||||
)
|
||||
assert isinstance(result, bytes)
|
||||
dec = decode_conn_msg(result)
|
||||
assert dec["head"]["cmd"] == "auth-bind"
|
||||
assert dec["head"]["module"] == "conn_access"
|
||||
assert dec["head"]["msg_id"] == "auth-001"
|
||||
assert len(dec["data"]) > 0
|
||||
|
||||
def test_encode_ping(self):
|
||||
result = encode_ping("ping-001")
|
||||
assert isinstance(result, bytes)
|
||||
dec = decode_conn_msg(result)
|
||||
assert dec["head"]["cmd"] == "ping"
|
||||
assert dec["head"]["module"] == "conn_access"
|
||||
|
||||
def test_encode_push_ack(self):
|
||||
original_head = {
|
||||
"cmd_type": CMD_TYPE["Push"],
|
||||
"cmd": "some-push",
|
||||
"seq_no": 100,
|
||||
"msg_id": "push-001",
|
||||
"module": "im_module",
|
||||
"need_ack": True,
|
||||
"status": 0,
|
||||
}
|
||||
result = encode_push_ack(original_head)
|
||||
dec = decode_conn_msg(result)
|
||||
assert dec["head"]["cmd_type"] == CMD_TYPE["PushAck"]
|
||||
assert dec["head"]["cmd"] == "some-push"
|
||||
assert dec["head"]["msg_id"] == "push-001"
|
||||
|
||||
|
||||
# ===========================================================
|
||||
# 8. 常量验证
|
||||
# ===========================================================
|
||||
|
||||
class TestConstants:
|
||||
def test_pb_msg_types_keys(self):
|
||||
assert "ConnMsg" in PB_MSG_TYPES
|
||||
assert "AuthBindReq" in PB_MSG_TYPES
|
||||
assert "PingReq" in PB_MSG_TYPES
|
||||
assert "KickoutMsg" in PB_MSG_TYPES
|
||||
assert "PushMsg" in PB_MSG_TYPES
|
||||
|
||||
def test_biz_services_keys(self):
|
||||
assert "SendC2CMessageReq" in BIZ_SERVICES
|
||||
assert "SendGroupMessageReq" in BIZ_SERVICES
|
||||
assert "InboundMessagePush" in BIZ_SERVICES
|
||||
|
||||
def test_cmd_type_values(self):
|
||||
assert CMD_TYPE["Request"] == 0
|
||||
assert CMD_TYPE["Response"] == 1
|
||||
assert CMD_TYPE["Push"] == 2
|
||||
assert CMD_TYPE["PushAck"] == 3
|
||||
|
||||
def test_pkg_prefix(self):
|
||||
for k, v in BIZ_SERVICES.items():
|
||||
assert v.startswith("yuanbao_openclaw_proxy"), \
|
||||
f"{k}: unexpected prefix in {v}"
|
||||
|
||||
|
||||
# ===========================================================
|
||||
# 9. seq_no 生成
|
||||
# ===========================================================
|
||||
|
||||
class TestSeqNo:
|
||||
def test_monotonic(self):
|
||||
a = next_seq_no()
|
||||
b = next_seq_no()
|
||||
c = next_seq_no()
|
||||
assert b > a
|
||||
assert c > b
|
||||
|
||||
def test_thread_safety(self):
|
||||
import threading
|
||||
results = []
|
||||
lock = threading.Lock()
|
||||
|
||||
def worker():
|
||||
for _ in range(100):
|
||||
v = next_seq_no()
|
||||
with lock:
|
||||
results.append(v)
|
||||
|
||||
threads = [threading.Thread(target=worker) for _ in range(10)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# 无重复
|
||||
assert len(results) == len(set(results)), "duplicate seq_no detected"
|
||||
|
||||
|
||||
# ===========================================================
|
||||
# 10. 完整端到端流程(模拟 send -> recv)
|
||||
# ===========================================================
|
||||
|
||||
class TestEndToEnd:
|
||||
def test_send_recv_c2c(self):
|
||||
"""模拟发送 C2C 消息,然后(在接收方)解码"""
|
||||
msg_body = [
|
||||
{"msg_type": "TIMTextElem", "msg_content": {"text": "端到端测试"}},
|
||||
]
|
||||
# 发送方编码
|
||||
wire_bytes = encode_send_c2c_message(
|
||||
to_account="recv_user",
|
||||
msg_body=msg_body,
|
||||
from_account="send_bot",
|
||||
msg_id="e2e-001",
|
||||
)
|
||||
# 接收方解码 ConnMsg
|
||||
dec = decode_conn_msg(wire_bytes)
|
||||
assert dec["head"]["cmd"] == "send_c2c_message"
|
||||
assert dec["head"]["msg_id"] == "e2e-001"
|
||||
|
||||
# 从 biz payload 中读取 to_account 和 msg_body
|
||||
from gateway.platforms.yuanbao_proto import (
|
||||
_parse_fields, _fields_to_dict, _get_string, _get_repeated_bytes, WT_LEN
|
||||
)
|
||||
biz = dec["data"]
|
||||
fdict = _fields_to_dict(_parse_fields(biz))
|
||||
assert _get_string(fdict, 2) == "recv_user" # to_account
|
||||
assert _get_string(fdict, 3) == "send_bot" # from_account
|
||||
|
||||
el_list = _get_repeated_bytes(fdict, 5) # msg_body repeated
|
||||
assert len(el_list) == 1
|
||||
el_dec = _decode_msg_body_element(el_list[0])
|
||||
assert el_dec["msg_type"] == "TIMTextElem"
|
||||
assert el_dec["msg_content"]["text"] == "端到端测试"
|
||||
|
||||
def test_inbound_push_full_flow(self):
|
||||
"""构造服务端 push -> 解码入站消息"""
|
||||
from gateway.platforms.yuanbao_proto import (
|
||||
_encode_field, _encode_string, _encode_message,
|
||||
_encode_varint, WT_LEN, WT_VARINT,
|
||||
)
|
||||
# 构造入站消息 biz payload
|
||||
el_bytes = _encode_msg_body_element(
|
||||
{"msg_type": "TIMTextElem", "msg_content": {"text": "server push"}}
|
||||
)
|
||||
biz_payload = (
|
||||
_encode_field(2, WT_LEN, _encode_string("alice"))
|
||||
+ _encode_field(3, WT_LEN, _encode_string("bot"))
|
||||
+ _encode_field(6, WT_LEN, _encode_string("grp-001"))
|
||||
+ _encode_field(8, WT_VARINT, _encode_varint(555))
|
||||
+ _encode_field(11, WT_LEN, _encode_string("msg-key-xyz"))
|
||||
+ _encode_field(13, WT_LEN, _encode_message(el_bytes))
|
||||
)
|
||||
# 封装成 ConnMsg(模拟服务端 push)
|
||||
wire = encode_conn_msg_full(
|
||||
cmd_type=CMD_TYPE["Push"],
|
||||
cmd="/im/new_message",
|
||||
seq_no=77,
|
||||
msg_id="push-abc",
|
||||
module="yuanbao_openclaw_proxy",
|
||||
data=biz_payload,
|
||||
need_ack=True,
|
||||
)
|
||||
# 接收方解码
|
||||
conn = decode_conn_msg(wire)
|
||||
assert conn["head"]["cmd_type"] == CMD_TYPE["Push"]
|
||||
assert conn["head"]["need_ack"] is True
|
||||
|
||||
msg = decode_inbound_push(conn["data"])
|
||||
assert msg is not None
|
||||
assert msg["from_account"] == "alice"
|
||||
assert msg["group_code"] == "grp-001"
|
||||
assert msg["msg_seq"] == 555
|
||||
assert msg["msg_key"] == "msg-key-xyz"
|
||||
assert msg["msg_body"][0]["msg_content"]["text"] == "server push"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@ -235,3 +235,21 @@ class TestPostRedirectSsrf:
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["url"] == final
|
||||
|
||||
|
||||
class TestAllowPrivateUrlsConfig:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_cache(self):
|
||||
browser_tool._allow_private_urls_resolved = False
|
||||
browser_tool._cached_allow_private_urls = None
|
||||
yield
|
||||
browser_tool._allow_private_urls_resolved = False
|
||||
browser_tool._cached_allow_private_urls = None
|
||||
|
||||
def test_browser_config_string_false_stays_disabled(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.config.read_raw_config",
|
||||
lambda: {"browser": {"allow_private_urls": "false"}},
|
||||
)
|
||||
|
||||
assert browser_tool._allow_private_urls() is False
|
||||
|
||||
@ -717,3 +717,193 @@ class TestGpgAndGlobalConfigIsolation:
|
||||
mgr = CheckpointManager(enabled=True)
|
||||
assert mgr.ensure_checkpoint(str(work_dir), reason="prefix-shadow") is True
|
||||
assert len(mgr.list_checkpoints(str(work_dir))) == 1
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Auto-maintenance: prune_checkpoints + maybe_auto_prune_checkpoints
|
||||
# =========================================================================
|
||||
|
||||
class TestPruneCheckpoints:
|
||||
"""Sweep orphan/stale shadow repos under CHECKPOINT_BASE (issue #3015 follow-up)."""
|
||||
|
||||
def _seed_shadow_repo(
|
||||
self, base: Path, dir_hash: str, workdir: Path, mtime: float = None
|
||||
) -> Path:
|
||||
"""Create a minimal shadow repo on disk without invoking real git."""
|
||||
import time as _time
|
||||
shadow = base / dir_hash
|
||||
shadow.mkdir(parents=True)
|
||||
(shadow / "HEAD").write_text("ref: refs/heads/main\n")
|
||||
(shadow / "HERMES_WORKDIR").write_text(str(workdir) + "\n")
|
||||
(shadow / "info").mkdir()
|
||||
(shadow / "info" / "exclude").write_text("node_modules/\n")
|
||||
if mtime is not None:
|
||||
for p in shadow.rglob("*"):
|
||||
import os
|
||||
os.utime(p, (mtime, mtime))
|
||||
import os
|
||||
os.utime(shadow, (mtime, mtime))
|
||||
return shadow
|
||||
|
||||
def test_deletes_orphan_when_workdir_missing(self, tmp_path):
|
||||
from tools.checkpoint_manager import prune_checkpoints
|
||||
|
||||
base = tmp_path / "checkpoints"
|
||||
alive_work = tmp_path / "alive"
|
||||
alive_work.mkdir()
|
||||
alive_repo = self._seed_shadow_repo(base, "aaaa" * 4, alive_work)
|
||||
orphan_repo = self._seed_shadow_repo(
|
||||
base, "bbbb" * 4, tmp_path / "was-deleted"
|
||||
)
|
||||
|
||||
result = prune_checkpoints(retention_days=0, checkpoint_base=base)
|
||||
|
||||
assert result["scanned"] == 2
|
||||
assert result["deleted_orphan"] == 1
|
||||
assert result["deleted_stale"] == 0
|
||||
assert alive_repo.exists()
|
||||
assert not orphan_repo.exists()
|
||||
|
||||
def test_deletes_stale_by_mtime_when_workdir_alive(self, tmp_path):
|
||||
from tools.checkpoint_manager import prune_checkpoints
|
||||
import time as _time
|
||||
|
||||
base = tmp_path / "checkpoints"
|
||||
work = tmp_path / "work"
|
||||
work.mkdir()
|
||||
|
||||
fresh_repo = self._seed_shadow_repo(base, "cccc" * 4, work)
|
||||
stale_work = tmp_path / "stale_work"
|
||||
stale_work.mkdir()
|
||||
old = _time.time() - 60 * 86400 # 60 days ago
|
||||
stale_repo = self._seed_shadow_repo(base, "dddd" * 4, stale_work, mtime=old)
|
||||
|
||||
result = prune_checkpoints(
|
||||
retention_days=30, delete_orphans=False, checkpoint_base=base
|
||||
)
|
||||
|
||||
assert result["deleted_orphan"] == 0
|
||||
assert result["deleted_stale"] == 1
|
||||
assert fresh_repo.exists()
|
||||
assert not stale_repo.exists()
|
||||
|
||||
def test_orphan_takes_priority_over_stale(self, tmp_path):
|
||||
"""Orphan detection counts first — reason="orphan" even if also stale."""
|
||||
from tools.checkpoint_manager import prune_checkpoints
|
||||
import time as _time
|
||||
|
||||
base = tmp_path / "checkpoints"
|
||||
old = _time.time() - 60 * 86400
|
||||
self._seed_shadow_repo(base, "eeee" * 4, tmp_path / "gone", mtime=old)
|
||||
|
||||
result = prune_checkpoints(retention_days=30, checkpoint_base=base)
|
||||
assert result["deleted_orphan"] == 1
|
||||
assert result["deleted_stale"] == 0
|
||||
|
||||
def test_delete_orphans_disabled_keeps_orphans(self, tmp_path):
|
||||
from tools.checkpoint_manager import prune_checkpoints
|
||||
|
||||
base = tmp_path / "checkpoints"
|
||||
orphan = self._seed_shadow_repo(base, "ffff" * 4, tmp_path / "gone")
|
||||
|
||||
result = prune_checkpoints(
|
||||
retention_days=0, delete_orphans=False, checkpoint_base=base
|
||||
)
|
||||
assert result["deleted_orphan"] == 0
|
||||
assert orphan.exists()
|
||||
|
||||
def test_skips_non_shadow_dirs(self, tmp_path):
|
||||
"""Dirs without HEAD (non-initialised) are left alone."""
|
||||
from tools.checkpoint_manager import prune_checkpoints
|
||||
|
||||
base = tmp_path / "checkpoints"
|
||||
base.mkdir()
|
||||
(base / "garbage-dir").mkdir()
|
||||
(base / "garbage-dir" / "random.txt").write_text("hi")
|
||||
|
||||
result = prune_checkpoints(retention_days=0, checkpoint_base=base)
|
||||
assert result["scanned"] == 0
|
||||
assert (base / "garbage-dir").exists()
|
||||
|
||||
def test_tracks_bytes_freed(self, tmp_path):
|
||||
from tools.checkpoint_manager import prune_checkpoints
|
||||
|
||||
base = tmp_path / "checkpoints"
|
||||
orphan = self._seed_shadow_repo(base, "1234" * 4, tmp_path / "gone")
|
||||
(orphan / "objects").mkdir()
|
||||
(orphan / "objects" / "pack.bin").write_bytes(b"x" * 5000)
|
||||
|
||||
result = prune_checkpoints(retention_days=0, checkpoint_base=base)
|
||||
assert result["deleted_orphan"] == 1
|
||||
assert result["bytes_freed"] >= 5000
|
||||
|
||||
def test_base_missing_returns_empty_counts(self, tmp_path):
|
||||
from tools.checkpoint_manager import prune_checkpoints
|
||||
|
||||
result = prune_checkpoints(checkpoint_base=tmp_path / "does-not-exist")
|
||||
assert result == {
|
||||
"scanned": 0, "deleted_orphan": 0, "deleted_stale": 0,
|
||||
"errors": 0, "bytes_freed": 0,
|
||||
}
|
||||
|
||||
|
||||
class TestMaybeAutoPruneCheckpoints:
|
||||
def _seed(self, base, dir_hash, workdir):
|
||||
base.mkdir(parents=True, exist_ok=True)
|
||||
shadow = base / dir_hash
|
||||
shadow.mkdir()
|
||||
(shadow / "HEAD").write_text("ref: refs/heads/main\n")
|
||||
(shadow / "HERMES_WORKDIR").write_text(str(workdir) + "\n")
|
||||
return shadow
|
||||
|
||||
def test_first_call_prunes_and_writes_marker(self, tmp_path):
|
||||
from tools.checkpoint_manager import maybe_auto_prune_checkpoints
|
||||
|
||||
base = tmp_path / "checkpoints"
|
||||
self._seed(base, "0000" * 4, tmp_path / "gone")
|
||||
|
||||
out = maybe_auto_prune_checkpoints(checkpoint_base=base)
|
||||
assert out["skipped"] is False
|
||||
assert out["result"]["deleted_orphan"] == 1
|
||||
assert (base / ".last_prune").exists()
|
||||
|
||||
def test_second_call_within_interval_skips(self, tmp_path):
|
||||
from tools.checkpoint_manager import maybe_auto_prune_checkpoints
|
||||
|
||||
base = tmp_path / "checkpoints"
|
||||
self._seed(base, "1111" * 4, tmp_path / "gone")
|
||||
|
||||
first = maybe_auto_prune_checkpoints(
|
||||
checkpoint_base=base, min_interval_hours=24
|
||||
)
|
||||
assert first["skipped"] is False
|
||||
|
||||
self._seed(base, "2222" * 4, tmp_path / "also-gone")
|
||||
second = maybe_auto_prune_checkpoints(
|
||||
checkpoint_base=base, min_interval_hours=24
|
||||
)
|
||||
assert second["skipped"] is True
|
||||
# The second orphan must still exist — skip was honoured.
|
||||
assert (base / ("2222" * 4)).exists()
|
||||
|
||||
def test_corrupt_marker_treated_as_no_prior_run(self, tmp_path):
|
||||
from tools.checkpoint_manager import maybe_auto_prune_checkpoints
|
||||
|
||||
base = tmp_path / "checkpoints"
|
||||
base.mkdir()
|
||||
(base / ".last_prune").write_text("not-a-timestamp")
|
||||
self._seed(base, "3333" * 4, tmp_path / "gone")
|
||||
|
||||
out = maybe_auto_prune_checkpoints(checkpoint_base=base)
|
||||
assert out["skipped"] is False
|
||||
assert out["result"]["deleted_orphan"] == 1
|
||||
|
||||
def test_missing_base_no_raise(self, tmp_path):
|
||||
from tools.checkpoint_manager import maybe_auto_prune_checkpoints
|
||||
|
||||
out = maybe_auto_prune_checkpoints(
|
||||
checkpoint_base=tmp_path / "does-not-exist"
|
||||
)
|
||||
assert out["skipped"] is False
|
||||
assert out["result"]["scanned"] == 0
|
||||
|
||||
|
||||
@ -16,8 +16,11 @@ from unittest.mock import patch, MagicMock
|
||||
|
||||
from tools.file_tools import (
|
||||
read_file_tool,
|
||||
write_file_tool,
|
||||
reset_file_dedup,
|
||||
_is_blocked_device,
|
||||
_invalidate_dedup_for_path,
|
||||
_READ_DEDUP_STATUS_MESSAGE,
|
||||
_get_max_read_chars,
|
||||
_DEFAULT_MAX_READ_CHARS,
|
||||
_read_tracker,
|
||||
@ -161,7 +164,7 @@ class TestFileDedup(unittest.TestCase):
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_second_read_returns_dedup_stub(self, mock_ops):
|
||||
"""Second read of same file+range returns dedup stub."""
|
||||
"""Second read of same file+range returns non-content dedup status."""
|
||||
mock_ops.return_value = _make_fake_ops(
|
||||
content="line one\nline two\n", file_size=20,
|
||||
)
|
||||
@ -172,7 +175,83 @@ class TestFileDedup(unittest.TestCase):
|
||||
# Second read — should get dedup stub
|
||||
r2 = json.loads(read_file_tool(self._tmpfile, task_id="dup"))
|
||||
self.assertTrue(r2.get("dedup"), "Second read should return dedup stub")
|
||||
self.assertIn("unchanged", r2.get("content", ""))
|
||||
self.assertEqual(r2.get("status"), "unchanged")
|
||||
self.assertIn("unchanged", r2.get("message", ""))
|
||||
self.assertFalse(r2.get("content_returned"))
|
||||
self.assertNotIn("content", r2)
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_write_rejects_internal_read_status_text(self, mock_ops):
|
||||
"""write_file must not persist internal read_file status text."""
|
||||
fake = MagicMock()
|
||||
fake.write_file = MagicMock()
|
||||
mock_ops.return_value = fake
|
||||
|
||||
result = json.loads(write_file_tool(
|
||||
self._tmpfile,
|
||||
_READ_DEDUP_STATUS_MESSAGE,
|
||||
task_id="guard",
|
||||
))
|
||||
|
||||
self.assertIn("error", result)
|
||||
self.assertIn("internal read_file status text", result["error"])
|
||||
fake.write_file.assert_not_called()
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_write_rejects_status_text_with_small_framing(self, mock_ops):
|
||||
"""write_file rejects small wrappers around the status text too.
|
||||
|
||||
Real-world corruption shapes aren't always the verbatim message — the
|
||||
model sometimes prepends a short note or appends a trailing comment
|
||||
before calling write_file. A short, status-dominated write is still
|
||||
corruption, not legitimate file content.
|
||||
"""
|
||||
fake = MagicMock()
|
||||
fake.write_file = MagicMock()
|
||||
mock_ops.return_value = fake
|
||||
|
||||
wrapped = "Note: " + _READ_DEDUP_STATUS_MESSAGE + "\n\n(continuing.)"
|
||||
result = json.loads(write_file_tool(
|
||||
self._tmpfile,
|
||||
wrapped,
|
||||
task_id="guard",
|
||||
))
|
||||
|
||||
self.assertIn("error", result)
|
||||
self.assertIn("internal read_file status text", result["error"])
|
||||
fake.write_file.assert_not_called()
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_write_allows_large_file_that_quotes_status_text(self, mock_ops):
|
||||
"""Legitimate large content that happens to quote the status is allowed.
|
||||
|
||||
Hermes' own docs / SKILL.md files may legitimately mention the dedup
|
||||
message verbatim. Only short, status-dominated writes are rejected —
|
||||
a normal file that contains the message as one line out of many must
|
||||
still write successfully.
|
||||
"""
|
||||
fake = MagicMock()
|
||||
fake.write_file = lambda path, content: MagicMock(
|
||||
to_dict=lambda: {"success": True, "path": path}
|
||||
)
|
||||
mock_ops.return_value = fake
|
||||
|
||||
# Build content that contains the status text but is much larger,
|
||||
# so the status doesn't "dominate" — this is a legitimate file.
|
||||
large_content = (
|
||||
"# Skill reference\n\n"
|
||||
"Example internal message (do not write back):\n\n"
|
||||
f" {_READ_DEDUP_STATUS_MESSAGE}\n\n"
|
||||
+ ("This is documentation content. " * 200)
|
||||
)
|
||||
result = json.loads(write_file_tool(
|
||||
self._tmpfile,
|
||||
large_content,
|
||||
task_id="guard",
|
||||
))
|
||||
|
||||
self.assertNotIn("error", result)
|
||||
self.assertTrue(result.get("success"))
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_modified_file_not_deduped(self, mock_ops):
|
||||
@ -374,5 +453,174 @@ class TestConfigOverride(unittest.TestCase):
|
||||
self.assertIn("content", result)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Write invalidates dedup cache (fixes #13144)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestWriteInvalidatesDedup(unittest.TestCase):
|
||||
"""write_file_tool and patch_tool must invalidate the read_file dedup
|
||||
cache for the written path. Without this, a read→write→read sequence
|
||||
within the same mtime second returns a stale 'File unchanged' stub.
|
||||
|
||||
Regression test for https://github.com/NousResearch/hermes-agent/issues/13144
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
_read_tracker.clear()
|
||||
self._tmpdir = tempfile.mkdtemp()
|
||||
self._tmpfile = os.path.join(self._tmpdir, "write_dedup.txt")
|
||||
with open(self._tmpfile, "w") as f:
|
||||
f.write("original content\n")
|
||||
|
||||
def tearDown(self):
|
||||
_read_tracker.clear()
|
||||
try:
|
||||
os.unlink(self._tmpfile)
|
||||
os.rmdir(self._tmpdir)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_write_invalidates_dedup_same_second(self, mock_ops):
|
||||
"""read→write→read within the same mtime second returns fresh content.
|
||||
|
||||
This is the core #13144 scenario: on filesystems with ≥1ms mtime
|
||||
granularity, a write that lands in the same timestamp as the prior
|
||||
read would previously cause the second read to return a stale dedup
|
||||
stub because the mtime comparison saw no change.
|
||||
"""
|
||||
fake = MagicMock()
|
||||
fake.read_file = lambda path, offset=1, limit=500: _FakeReadResult(
|
||||
content="original content\n", total_lines=1, file_size=18,
|
||||
)
|
||||
fake.write_file = lambda path, content: MagicMock(
|
||||
to_dict=lambda: {"success": True, "path": path}
|
||||
)
|
||||
mock_ops.return_value = fake
|
||||
|
||||
# 1. Read — populates dedup cache.
|
||||
r1 = json.loads(read_file_tool(self._tmpfile, task_id="wr"))
|
||||
self.assertNotEqual(r1.get("dedup"), True)
|
||||
|
||||
# 2. Write — must invalidate dedup for this path.
|
||||
# (No sleep — we intentionally stay in the same mtime second.)
|
||||
write_file_tool(self._tmpfile, "new content\n", task_id="wr")
|
||||
|
||||
# 3. Read again — should get full content, NOT dedup stub.
|
||||
fake.read_file = lambda path, offset=1, limit=500: _FakeReadResult(
|
||||
content="new content\n", total_lines=1, file_size=13,
|
||||
)
|
||||
r2 = json.loads(read_file_tool(self._tmpfile, task_id="wr"))
|
||||
self.assertNotEqual(r2.get("dedup"), True,
|
||||
"read after write must not return dedup stub")
|
||||
self.assertIn("content", r2)
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_write_invalidates_all_offsets(self, mock_ops):
|
||||
"""A write invalidates dedup entries for ALL offset/limit combos."""
|
||||
fake = MagicMock()
|
||||
fake.read_file = lambda path, offset=1, limit=500: _FakeReadResult(
|
||||
content="line1\nline2\nline3\n", total_lines=3, file_size=20,
|
||||
)
|
||||
fake.write_file = lambda path, content: MagicMock(
|
||||
to_dict=lambda: {"success": True, "path": path}
|
||||
)
|
||||
mock_ops.return_value = fake
|
||||
|
||||
# Read with different offsets to populate multiple dedup entries.
|
||||
read_file_tool(self._tmpfile, offset=1, limit=100, task_id="off")
|
||||
read_file_tool(self._tmpfile, offset=50, limit=100, task_id="off")
|
||||
|
||||
# Write — should invalidate BOTH dedup entries.
|
||||
write_file_tool(self._tmpfile, "replaced\n", task_id="off")
|
||||
|
||||
# Both reads should return fresh content.
|
||||
r1 = json.loads(read_file_tool(self._tmpfile, offset=1, limit=100, task_id="off"))
|
||||
r2 = json.loads(read_file_tool(self._tmpfile, offset=50, limit=100, task_id="off"))
|
||||
self.assertNotEqual(r1.get("dedup"), True,
|
||||
"offset=1 should not dedup after write")
|
||||
self.assertNotEqual(r2.get("dedup"), True,
|
||||
"offset=50 should not dedup after write")
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_write_does_not_invalidate_other_files(self, mock_ops):
|
||||
"""Writing file A should not invalidate dedup for file B."""
|
||||
other = os.path.join(self._tmpdir, "other.txt")
|
||||
with open(other, "w") as f:
|
||||
f.write("other content\n")
|
||||
|
||||
fake = MagicMock()
|
||||
fake.read_file = lambda path, offset=1, limit=500: _FakeReadResult(
|
||||
content="other content\n", total_lines=1, file_size=15,
|
||||
)
|
||||
fake.write_file = lambda path, content: MagicMock(
|
||||
to_dict=lambda: {"success": True, "path": path}
|
||||
)
|
||||
mock_ops.return_value = fake
|
||||
|
||||
# Read file B.
|
||||
read_file_tool(other, task_id="iso")
|
||||
|
||||
# Write file A.
|
||||
write_file_tool(self._tmpfile, "changed A\n", task_id="iso")
|
||||
|
||||
# File B should still dedup (untouched).
|
||||
r2 = json.loads(read_file_tool(other, task_id="iso"))
|
||||
self.assertTrue(r2.get("dedup"),
|
||||
"Unrelated file should still dedup after writing another file")
|
||||
|
||||
try:
|
||||
os.unlink(other)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_write_does_not_invalidate_other_tasks(self, mock_ops):
|
||||
"""Writing in task A should not invalidate dedup for task B."""
|
||||
fake = MagicMock()
|
||||
fake.read_file = lambda path, offset=1, limit=500: _FakeReadResult(
|
||||
content="original content\n", total_lines=1, file_size=18,
|
||||
)
|
||||
fake.write_file = lambda path, content: MagicMock(
|
||||
to_dict=lambda: {"success": True, "path": path}
|
||||
)
|
||||
mock_ops.return_value = fake
|
||||
|
||||
# Both tasks read the file.
|
||||
read_file_tool(self._tmpfile, task_id="taskA")
|
||||
read_file_tool(self._tmpfile, task_id="taskB")
|
||||
|
||||
# Task A writes.
|
||||
write_file_tool(self._tmpfile, "new\n", task_id="taskA")
|
||||
|
||||
# Task A's dedup should be invalidated.
|
||||
rA = json.loads(read_file_tool(self._tmpfile, task_id="taskA"))
|
||||
self.assertNotEqual(rA.get("dedup"), True,
|
||||
"Writing task's dedup should be invalidated")
|
||||
|
||||
# Task B still sees dedup (its cache is separate — the file
|
||||
# *may* have changed on disk, but mtime comparison handles that;
|
||||
# here we test that invalidation is scoped to the writing task).
|
||||
# Note: on real FS, task B's dedup might or might not hit depending
|
||||
# on mtime. The point is that _invalidate_dedup_for_path is
|
||||
# correctly scoped to task_id.
|
||||
|
||||
def test_invalidate_dedup_for_path_noop_on_missing_task(self):
|
||||
"""_invalidate_dedup_for_path is safe when task_id doesn't exist."""
|
||||
_read_tracker.clear()
|
||||
# Should not raise.
|
||||
_invalidate_dedup_for_path("/nonexistent/path", "no_such_task")
|
||||
|
||||
def test_invalidate_dedup_for_path_noop_on_empty_dedup(self):
|
||||
"""_invalidate_dedup_for_path is safe when dedup dict is empty."""
|
||||
_read_tracker.clear()
|
||||
_read_tracker["t"] = {
|
||||
"last_key": None, "consecutive": 0,
|
||||
"read_history": set(), "dedup": {},
|
||||
}
|
||||
_invalidate_dedup_for_path("/some/path", "t")
|
||||
self.assertEqual(_read_tracker["t"]["dedup"], {})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@ -81,37 +81,51 @@ class TestStdioPidTracking:
|
||||
|
||||
def test_kill_orphaned_noop_when_empty(self):
|
||||
"""_kill_orphaned_mcp_children does nothing when no PIDs tracked."""
|
||||
from tools.mcp_tool import _kill_orphaned_mcp_children, _stdio_pids, _lock
|
||||
from tools.mcp_tool import (
|
||||
_kill_orphaned_mcp_children,
|
||||
_orphan_stdio_pids,
|
||||
_stdio_pids,
|
||||
_lock,
|
||||
)
|
||||
|
||||
with _lock:
|
||||
_stdio_pids.clear()
|
||||
_orphan_stdio_pids.clear()
|
||||
|
||||
# Should not raise
|
||||
_kill_orphaned_mcp_children()
|
||||
|
||||
def test_kill_orphaned_handles_dead_pids(self):
|
||||
"""_kill_orphaned_mcp_children gracefully handles already-dead PIDs."""
|
||||
from tools.mcp_tool import _kill_orphaned_mcp_children, _stdio_pids, _lock
|
||||
from tools.mcp_tool import (
|
||||
_kill_orphaned_mcp_children,
|
||||
_orphan_stdio_pids,
|
||||
_lock,
|
||||
)
|
||||
|
||||
# Use a PID that definitely doesn't exist
|
||||
fake_pid = 999999999
|
||||
with _lock:
|
||||
_stdio_pids[fake_pid] = "test"
|
||||
_orphan_stdio_pids.add(fake_pid)
|
||||
|
||||
# Should not raise (ProcessLookupError is caught)
|
||||
_kill_orphaned_mcp_children()
|
||||
|
||||
with _lock:
|
||||
assert fake_pid not in _stdio_pids
|
||||
assert fake_pid not in _orphan_stdio_pids
|
||||
|
||||
def test_kill_orphaned_uses_sigkill_when_available(self, monkeypatch):
|
||||
"""SIGTERM-first then SIGKILL after 2s for orphan cleanup."""
|
||||
from tools.mcp_tool import _kill_orphaned_mcp_children, _stdio_pids, _lock
|
||||
from tools.mcp_tool import (
|
||||
_kill_orphaned_mcp_children,
|
||||
_orphan_stdio_pids,
|
||||
_lock,
|
||||
)
|
||||
|
||||
fake_pid = 424242
|
||||
with _lock:
|
||||
_stdio_pids.clear()
|
||||
_stdio_pids[fake_pid] = "test"
|
||||
_orphan_stdio_pids.clear()
|
||||
_orphan_stdio_pids.add(fake_pid)
|
||||
|
||||
fake_sigkill = 9
|
||||
monkeypatch.setattr(signal, "SIGKILL", fake_sigkill, raising=False)
|
||||
@ -128,16 +142,20 @@ class TestStdioPidTracking:
|
||||
mock_sleep.assert_called_once_with(2)
|
||||
|
||||
with _lock:
|
||||
assert fake_pid not in _stdio_pids
|
||||
assert fake_pid not in _orphan_stdio_pids
|
||||
|
||||
def test_kill_orphaned_falls_back_without_sigkill(self, monkeypatch):
|
||||
"""Without SIGKILL, SIGTERM is used for both phases."""
|
||||
from tools.mcp_tool import _kill_orphaned_mcp_children, _stdio_pids, _lock
|
||||
from tools.mcp_tool import (
|
||||
_kill_orphaned_mcp_children,
|
||||
_orphan_stdio_pids,
|
||||
_lock,
|
||||
)
|
||||
|
||||
fake_pid = 434343
|
||||
with _lock:
|
||||
_stdio_pids.clear()
|
||||
_stdio_pids[fake_pid] = "test"
|
||||
_orphan_stdio_pids.clear()
|
||||
_orphan_stdio_pids.add(fake_pid)
|
||||
|
||||
monkeypatch.delattr(signal, "SIGKILL", raising=False)
|
||||
|
||||
@ -150,7 +168,7 @@ class TestStdioPidTracking:
|
||||
assert mock_sleep.called
|
||||
|
||||
with _lock:
|
||||
assert fake_pid not in _stdio_pids
|
||||
assert fake_pid not in _orphan_stdio_pids
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@ -317,6 +317,7 @@ class TestBuiltinDiscovery:
|
||||
"tools.tts_tool",
|
||||
"tools.vision_tools",
|
||||
"tools.web_tools",
|
||||
"tools.yuanbao_tools",
|
||||
}
|
||||
|
||||
with patch("tools.registry.importlib.import_module"):
|
||||
|
||||
@ -167,6 +167,39 @@ class TestSendMessageTool:
|
||||
media_files=[],
|
||||
)
|
||||
|
||||
def test_mirror_receives_current_session_user_id(self):
|
||||
config, _telegram_cfg = _make_config()
|
||||
|
||||
with patch("gateway.config.load_gateway_config", return_value=config), \
|
||||
patch("tools.interrupt.is_interrupted", return_value=False), \
|
||||
patch("model_tools._run_async", side_effect=_run_async_immediately), \
|
||||
patch("tools.send_message_tool._send_to_platform", new=AsyncMock(return_value={"success": True})), \
|
||||
patch("gateway.session_context.get_session_env") as get_session_env_mock, \
|
||||
patch("gateway.mirror.mirror_to_session", return_value=True) as mirror_mock:
|
||||
get_session_env_mock.side_effect = lambda name, default="": {
|
||||
"HERMES_SESSION_PLATFORM": "telegram",
|
||||
"HERMES_SESSION_USER_ID": "user-123",
|
||||
}.get(name, default)
|
||||
result = json.loads(
|
||||
send_message_tool(
|
||||
{
|
||||
"action": "send",
|
||||
"target": "telegram:12345",
|
||||
"message": "hello",
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
mirror_mock.assert_called_once_with(
|
||||
"telegram",
|
||||
"12345",
|
||||
"hello",
|
||||
source_label="telegram",
|
||||
thread_id=None,
|
||||
user_id="user-123",
|
||||
)
|
||||
|
||||
def test_top_level_send_failure_redacts_query_token(self):
|
||||
config, _telegram_cfg = _make_config()
|
||||
leaked = "very-secret-query-token-123456"
|
||||
@ -810,6 +843,44 @@ class TestParseTargetRefE164:
|
||||
assert _parse_target_ref("matrix", "+15551234567")[2] is False
|
||||
|
||||
|
||||
class TestParseTargetRefSlack:
|
||||
"""_parse_target_ref recognizes Slack channel/user IDs as explicit."""
|
||||
|
||||
def test_public_channel_id_is_explicit(self):
|
||||
chat_id, thread_id, is_explicit = _parse_target_ref("slack", "C0B0QV5434G")
|
||||
assert chat_id == "C0B0QV5434G"
|
||||
assert thread_id is None
|
||||
assert is_explicit is True
|
||||
|
||||
def test_private_channel_id_is_explicit(self):
|
||||
assert _parse_target_ref("slack", "G123ABCDEF")[2] is True
|
||||
|
||||
def test_dm_id_is_explicit(self):
|
||||
assert _parse_target_ref("slack", "D123ABCDEF")[2] is True
|
||||
|
||||
def test_user_id_is_not_explicit(self):
|
||||
"""Slack user IDs (U...) and workspace IDs (W...) are NOT explicit send
|
||||
targets. chat.postMessage rejects them — a DM must be opened first via
|
||||
conversations.open to obtain a D... conversation ID.
|
||||
"""
|
||||
assert _parse_target_ref("slack", "U123ABCDEF")[2] is False
|
||||
assert _parse_target_ref("slack", "W123ABCDEF")[2] is False
|
||||
|
||||
def test_whitespace_is_stripped(self):
|
||||
chat_id, _, is_explicit = _parse_target_ref("slack", " C0B0QV5434G ")
|
||||
assert chat_id == "C0B0QV5434G"
|
||||
assert is_explicit is True
|
||||
|
||||
def test_lowercase_or_short_id_is_not_explicit(self):
|
||||
assert _parse_target_ref("slack", "c0b0qv5434g")[2] is False
|
||||
assert _parse_target_ref("slack", "C123")[2] is False
|
||||
assert _parse_target_ref("slack", "X0B0QV5434G")[2] is False
|
||||
|
||||
def test_slack_id_not_explicit_for_other_platforms(self):
|
||||
assert _parse_target_ref("discord", "C0B0QV5434G")[2] is False
|
||||
assert _parse_target_ref("telegram", "C0B0QV5434G")[2] is False
|
||||
|
||||
|
||||
class TestSendDiscordThreadId:
|
||||
"""_send_discord uses thread_id when provided."""
|
||||
|
||||
|
||||
@ -10,6 +10,7 @@ from tools.session_search_tool import (
|
||||
_format_conversation,
|
||||
_truncate_around_matches,
|
||||
_get_session_search_max_concurrency,
|
||||
_list_recent_sessions,
|
||||
_HIDDEN_SESSION_SOURCES,
|
||||
MAX_SESSION_CHARS,
|
||||
SESSION_SEARCH_SCHEMA,
|
||||
@ -240,6 +241,54 @@ class TestSessionSearchConcurrency:
|
||||
assert max_seen["value"] == 1
|
||||
|
||||
|
||||
class TestRecentSessionListing:
|
||||
def test_current_child_session_excludes_root_lineage_even_when_child_id_is_longer(self):
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.list_sessions_rich.return_value = [
|
||||
{
|
||||
"id": "root",
|
||||
"title": "Current conversation",
|
||||
"source": "cli",
|
||||
"started_at": 1709500000,
|
||||
"last_active": 1709500100,
|
||||
"message_count": 4,
|
||||
"preview": "current root",
|
||||
"parent_session_id": None,
|
||||
},
|
||||
{
|
||||
"id": "other_session",
|
||||
"title": "Other conversation",
|
||||
"source": "cli",
|
||||
"started_at": 1709400000,
|
||||
"last_active": 1709400100,
|
||||
"message_count": 3,
|
||||
"preview": "other root",
|
||||
"parent_session_id": None,
|
||||
},
|
||||
]
|
||||
|
||||
def _get_session(session_id):
|
||||
if session_id == "child_session_id_that_is_definitely_longer":
|
||||
return {"parent_session_id": "root"}
|
||||
if session_id == "root":
|
||||
return {"parent_session_id": None}
|
||||
return None
|
||||
|
||||
mock_db.get_session.side_effect = _get_session
|
||||
|
||||
result = json.loads(_list_recent_sessions(
|
||||
mock_db,
|
||||
limit=5,
|
||||
current_session_id="child_session_id_that_is_definitely_longer",
|
||||
))
|
||||
|
||||
assert result["success"] is True
|
||||
assert [item["session_id"] for item in result["results"]] == ["other_session"]
|
||||
assert all(item["session_id"] != "root" for item in result["results"])
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# session_search (dispatcher)
|
||||
# =========================================================================
|
||||
|
||||
107
tests/tools/test_shared_container_task_id.py
Normal file
107
tests/tools/test_shared_container_task_id.py
Normal file
@ -0,0 +1,107 @@
|
||||
"""
|
||||
Regression tests for the shared-container task_id mapping.
|
||||
|
||||
The top-level agent and all delegate_task subagents share a single
|
||||
terminal sandbox keyed by ``"default"``. ``_resolve_container_task_id``
|
||||
is the sole gatekeeper for which tool-call task_ids go to the shared
|
||||
container vs. get their own isolated sandbox. RL / benchmark
|
||||
environments opt in to isolation by calling
|
||||
``register_task_env_overrides(task_id, {...})`` before the agent loop;
|
||||
every other task_id collapses back to ``"default"``.
|
||||
|
||||
If you change the collapse logic, update both the helper and these
|
||||
tests -- see `hermes-agent-dev` skill, "Why do subagents get their own
|
||||
containers?" section, and the Container lifecycle paragraph under
|
||||
Docker Backend in ``website/docs/user-guide/configuration.md``.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from tools import terminal_tool
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clean_overrides():
|
||||
"""Ensure no stray overrides from other tests leak in."""
|
||||
before = dict(terminal_tool._task_env_overrides)
|
||||
terminal_tool._task_env_overrides.clear()
|
||||
yield
|
||||
terminal_tool._task_env_overrides.clear()
|
||||
terminal_tool._task_env_overrides.update(before)
|
||||
|
||||
|
||||
def test_none_task_id_maps_to_default():
|
||||
assert terminal_tool._resolve_container_task_id(None) == "default"
|
||||
|
||||
|
||||
def test_empty_task_id_maps_to_default():
|
||||
assert terminal_tool._resolve_container_task_id("") == "default"
|
||||
|
||||
|
||||
def test_literal_default_stays_default():
|
||||
assert terminal_tool._resolve_container_task_id("default") == "default"
|
||||
|
||||
|
||||
def test_subagent_task_id_collapses_to_default():
|
||||
# delegate_task constructs IDs like "subagent-<N>-<uuid_hex>"; these
|
||||
# should share the parent's container, not spin up their own.
|
||||
assert terminal_tool._resolve_container_task_id("subagent-0-deadbeef") == "default"
|
||||
assert terminal_tool._resolve_container_task_id("subagent-42-cafef00d") == "default"
|
||||
|
||||
|
||||
def test_arbitrary_session_id_collapses_to_default():
|
||||
# Session UUIDs or anything else without an override still collapse.
|
||||
assert terminal_tool._resolve_container_task_id("sess-123e4567-e89b-12d3") == "default"
|
||||
|
||||
|
||||
def test_rl_task_with_override_keeps_its_own_id():
|
||||
# RL / benchmark pattern: register a per-task image, then the task_id
|
||||
# must survive ``_resolve_container_task_id`` so the rollout lands in
|
||||
# its own sandbox.
|
||||
terminal_tool.register_task_env_overrides(
|
||||
"tb2-task-fix-git", {"docker_image": "tb2:fix-git", "cwd": "/app"}
|
||||
)
|
||||
try:
|
||||
assert (
|
||||
terminal_tool._resolve_container_task_id("tb2-task-fix-git")
|
||||
== "tb2-task-fix-git"
|
||||
)
|
||||
finally:
|
||||
terminal_tool.clear_task_env_overrides("tb2-task-fix-git")
|
||||
|
||||
|
||||
def test_cleared_override_collapses_again():
|
||||
terminal_tool.register_task_env_overrides("tb2-x", {"docker_image": "x:y"})
|
||||
assert terminal_tool._resolve_container_task_id("tb2-x") == "tb2-x"
|
||||
terminal_tool.clear_task_env_overrides("tb2-x")
|
||||
assert terminal_tool._resolve_container_task_id("tb2-x") == "default"
|
||||
|
||||
|
||||
def test_get_active_env_reads_shared_container_from_subagent_id():
|
||||
"""``get_active_env`` must see the shared ``"default"`` sandbox when
|
||||
called with a subagent's task_id, so the agent loop's turn-budget
|
||||
enforcement reads the real env (not None) during delegation."""
|
||||
sentinel = object()
|
||||
terminal_tool._active_environments["default"] = sentinel
|
||||
try:
|
||||
assert terminal_tool.get_active_env("subagent-7-cafe") is sentinel
|
||||
assert terminal_tool.get_active_env(None) is sentinel
|
||||
assert terminal_tool.get_active_env("default") is sentinel
|
||||
finally:
|
||||
terminal_tool._active_environments.pop("default", None)
|
||||
|
||||
|
||||
def test_get_active_env_honours_rl_override():
|
||||
rl_env = object()
|
||||
default_env = object()
|
||||
terminal_tool._active_environments["default"] = default_env
|
||||
terminal_tool._active_environments["rl-42"] = rl_env
|
||||
terminal_tool.register_task_env_overrides("rl-42", {"docker_image": "x"})
|
||||
try:
|
||||
# With an override registered, lookup returns the task's own env,
|
||||
# not the shared "default" one.
|
||||
assert terminal_tool.get_active_env("rl-42") is rl_env
|
||||
finally:
|
||||
terminal_tool.clear_task_env_overrides("rl-42")
|
||||
terminal_tool._active_environments.pop("default", None)
|
||||
terminal_tool._active_environments.pop("rl-42", None)
|
||||
@ -22,6 +22,7 @@ from tools.tool_backend_helpers import (
|
||||
managed_nous_tools_enabled,
|
||||
normalize_browser_cloud_provider,
|
||||
normalize_modal_mode,
|
||||
prefers_gateway,
|
||||
resolve_modal_backend_state,
|
||||
resolve_openai_audio_api_key,
|
||||
)
|
||||
@ -189,6 +190,27 @@ class TestHasDirectModalCredentials:
|
||||
assert has_direct_modal_credentials() is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# prefers_gateway
|
||||
# ---------------------------------------------------------------------------
|
||||
class TestPrefersGateway:
|
||||
"""Honor bool-ish config values for tool gateway routing."""
|
||||
|
||||
def test_returns_false_for_quoted_false(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.config.load_config",
|
||||
lambda: {"web": {"use_gateway": "false"}},
|
||||
)
|
||||
assert prefers_gateway("web") is False
|
||||
|
||||
def test_returns_true_for_quoted_true(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.config.load_config",
|
||||
lambda: {"web": {"use_gateway": "true"}},
|
||||
)
|
||||
assert prefers_gateway("web") is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_modal_backend_state
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@ -259,6 +259,20 @@ class TestGlobalAllowPrivateUrls:
|
||||
with patch("hermes_cli.config.read_raw_config", return_value=cfg):
|
||||
assert _global_allow_private_urls() is True
|
||||
|
||||
def test_config_security_string_false_stays_disabled(self, monkeypatch):
|
||||
"""Quoted false must not opt out of SSRF protection."""
|
||||
monkeypatch.delenv("HERMES_ALLOW_PRIVATE_URLS", raising=False)
|
||||
cfg = {"security": {"allow_private_urls": "false"}}
|
||||
with patch("hermes_cli.config.read_raw_config", return_value=cfg):
|
||||
assert _global_allow_private_urls() is False
|
||||
|
||||
def test_config_browser_string_false_stays_disabled(self, monkeypatch):
|
||||
"""Legacy browser.allow_private_urls also normalises quoted false."""
|
||||
monkeypatch.delenv("HERMES_ALLOW_PRIVATE_URLS", raising=False)
|
||||
cfg = {"browser": {"allow_private_urls": "false"}}
|
||||
with patch("hermes_cli.config.read_raw_config", return_value=cfg):
|
||||
assert _global_allow_private_urls() is False
|
||||
|
||||
def test_config_security_takes_precedence_over_browser(self, monkeypatch):
|
||||
"""security section is checked before browser section."""
|
||||
monkeypatch.delenv("HERMES_ALLOW_PRIVATE_URLS", raising=False)
|
||||
|
||||
@ -67,6 +67,7 @@ from typing import Dict, Any, Optional, List, Tuple
|
||||
from pathlib import Path
|
||||
from agent.auxiliary_client import call_llm
|
||||
from hermes_constants import get_hermes_home
|
||||
from utils import is_truthy_value
|
||||
|
||||
try:
|
||||
from tools.website_policy import check_website_access
|
||||
@ -639,7 +640,11 @@ def _allow_private_urls() -> bool:
|
||||
try:
|
||||
from hermes_cli.config import read_raw_config
|
||||
cfg = read_raw_config()
|
||||
_cached_allow_private_urls = bool(cfg.get("browser", {}).get("allow_private_urls"))
|
||||
browser_cfg = cfg.get("browser", {})
|
||||
if isinstance(browser_cfg, dict):
|
||||
_cached_allow_private_urls = is_truthy_value(
|
||||
browser_cfg.get("allow_private_urls"), default=False
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Could not read allow_private_urls from config: %s", e)
|
||||
return _cached_allow_private_urls
|
||||
|
||||
@ -651,3 +651,204 @@ def format_checkpoint_list(checkpoints: List[Dict], directory: str) -> str:
|
||||
lines.append(" /rollback diff <N> preview changes since checkpoint N")
|
||||
lines.append(" /rollback <N> <file> restore a single file from checkpoint N")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auto-maintenance (issue #3015 follow-up)
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Every working directory the agent has ever touched gets its own shadow
|
||||
# repo under CHECKPOINT_BASE. Per-repo ``_prune`` is a no-op (see comment
|
||||
# in CheckpointManager._prune), so abandoned repos (deleted projects,
|
||||
# one-off tmp dirs, long-stale work trees) accumulate forever. Field
|
||||
# reports put the typical offender at 1000+ repos / ~12 GB on active
|
||||
# contributor machines.
|
||||
#
|
||||
# ``prune_checkpoints`` sweeps CHECKPOINT_BASE at startup, deleting shadow
|
||||
# repos that match either criterion:
|
||||
# * orphan: the ``HERMES_WORKDIR`` path no longer exists on disk
|
||||
# * stale: the repo's newest mtime is older than ``retention_days``
|
||||
#
|
||||
# ``maybe_auto_prune_checkpoints`` wraps it with an idempotency marker
|
||||
# (``CHECKPOINT_BASE/.last_prune``) so calling it on every CLI/gateway
|
||||
# startup is free after the first run of the day. Opt-in via
|
||||
# ``checkpoints.auto_prune`` in config.yaml — default off so users who
|
||||
# rely on ``/rollback`` against long-ago sessions never lose data
|
||||
# silently.
|
||||
|
||||
_PRUNE_MARKER_NAME = ".last_prune"
|
||||
|
||||
|
||||
def _read_workdir_marker(shadow_repo: Path) -> Optional[str]:
|
||||
"""Read ``HERMES_WORKDIR`` from a shadow repo, or None if missing/unreadable."""
|
||||
try:
|
||||
return (shadow_repo / "HERMES_WORKDIR").read_text(encoding="utf-8").strip()
|
||||
except (OSError, UnicodeDecodeError):
|
||||
return None
|
||||
|
||||
|
||||
def _shadow_repo_newest_mtime(shadow_repo: Path) -> float:
|
||||
"""Return newest mtime across the shadow repo (walks objects/refs/HEAD).
|
||||
|
||||
We walk instead of trusting the directory mtime because git's pack
|
||||
operations can leave the top-level dir untouched while refs/objects
|
||||
inside get updated. Best-effort — returns 0.0 on any error.
|
||||
"""
|
||||
newest = 0.0
|
||||
try:
|
||||
for p in shadow_repo.rglob("*"):
|
||||
try:
|
||||
m = p.stat().st_mtime
|
||||
if m > newest:
|
||||
newest = m
|
||||
except OSError:
|
||||
continue
|
||||
except OSError:
|
||||
pass
|
||||
return newest
|
||||
|
||||
|
||||
def prune_checkpoints(
|
||||
retention_days: int = 7,
|
||||
delete_orphans: bool = True,
|
||||
checkpoint_base: Optional[Path] = None,
|
||||
) -> Dict[str, int]:
|
||||
"""Delete stale/orphan shadow repos under ``checkpoint_base``.
|
||||
|
||||
A shadow repo is deleted when either:
|
||||
|
||||
* ``delete_orphans=True`` and its ``HERMES_WORKDIR`` path no longer
|
||||
exists on disk (the original project was deleted / moved); OR
|
||||
* its newest in-repo mtime is older than ``retention_days`` days.
|
||||
|
||||
Returns a dict with counts ``{"scanned", "deleted_orphan",
|
||||
"deleted_stale", "errors", "bytes_freed"}``.
|
||||
|
||||
Never raises — maintenance must never block interactive startup.
|
||||
"""
|
||||
base = checkpoint_base or CHECKPOINT_BASE
|
||||
result = {
|
||||
"scanned": 0,
|
||||
"deleted_orphan": 0,
|
||||
"deleted_stale": 0,
|
||||
"errors": 0,
|
||||
"bytes_freed": 0,
|
||||
}
|
||||
if not base.exists():
|
||||
return result
|
||||
|
||||
cutoff = 0.0
|
||||
if retention_days > 0:
|
||||
import time as _time
|
||||
cutoff = _time.time() - retention_days * 86400
|
||||
|
||||
for child in base.iterdir():
|
||||
if not child.is_dir():
|
||||
continue
|
||||
# Protect the marker file and anything that isn't a real shadow
|
||||
# repo (no HEAD = not initialised, leave alone).
|
||||
if not (child / "HEAD").exists():
|
||||
continue
|
||||
result["scanned"] += 1
|
||||
|
||||
reason: Optional[str] = None
|
||||
if delete_orphans:
|
||||
workdir = _read_workdir_marker(child)
|
||||
if workdir is None or not Path(workdir).exists():
|
||||
reason = "orphan"
|
||||
|
||||
if reason is None and retention_days > 0:
|
||||
newest = _shadow_repo_newest_mtime(child)
|
||||
if newest > 0 and newest < cutoff:
|
||||
reason = "stale"
|
||||
|
||||
if reason is None:
|
||||
continue
|
||||
|
||||
# Measure size before delete (best-effort)
|
||||
try:
|
||||
size = sum(p.stat().st_size for p in child.rglob("*") if p.is_file())
|
||||
except OSError:
|
||||
size = 0
|
||||
try:
|
||||
shutil.rmtree(child)
|
||||
result["bytes_freed"] += size
|
||||
if reason == "orphan":
|
||||
result["deleted_orphan"] += 1
|
||||
else:
|
||||
result["deleted_stale"] += 1
|
||||
logger.debug("Pruned %s checkpoint repo: %s (%d bytes)", reason, child.name, size)
|
||||
except OSError as exc:
|
||||
result["errors"] += 1
|
||||
logger.warning("Failed to prune checkpoint repo %s: %s", child.name, exc)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def maybe_auto_prune_checkpoints(
|
||||
retention_days: int = 7,
|
||||
min_interval_hours: int = 24,
|
||||
delete_orphans: bool = True,
|
||||
checkpoint_base: Optional[Path] = None,
|
||||
) -> Dict[str, object]:
|
||||
"""Idempotent wrapper around ``prune_checkpoints`` for startup hooks.
|
||||
|
||||
Writes ``CHECKPOINT_BASE/.last_prune`` on completion so subsequent
|
||||
calls within ``min_interval_hours`` short-circuit. Designed to be
|
||||
called once per CLI/gateway process startup; the marker keeps costs
|
||||
bounded regardless of how many times hermes is invoked per day.
|
||||
|
||||
Returns ``{"skipped": bool, "result": prune_checkpoints-dict,
|
||||
"error": optional str}``.
|
||||
"""
|
||||
import time as _time
|
||||
base = checkpoint_base or CHECKPOINT_BASE
|
||||
out: Dict[str, object] = {"skipped": False}
|
||||
|
||||
try:
|
||||
if not base.exists():
|
||||
out["result"] = {
|
||||
"scanned": 0, "deleted_orphan": 0, "deleted_stale": 0,
|
||||
"errors": 0, "bytes_freed": 0,
|
||||
}
|
||||
return out
|
||||
|
||||
marker = base / _PRUNE_MARKER_NAME
|
||||
now = _time.time()
|
||||
if marker.exists():
|
||||
try:
|
||||
last_ts = float(marker.read_text(encoding="utf-8").strip())
|
||||
if now - last_ts < min_interval_hours * 3600:
|
||||
out["skipped"] = True
|
||||
return out
|
||||
except (OSError, ValueError):
|
||||
pass # corrupt marker — treat as no prior run
|
||||
|
||||
result = prune_checkpoints(
|
||||
retention_days=retention_days,
|
||||
delete_orphans=delete_orphans,
|
||||
checkpoint_base=base,
|
||||
)
|
||||
out["result"] = result
|
||||
|
||||
try:
|
||||
marker.write_text(str(now), encoding="utf-8")
|
||||
except OSError as exc:
|
||||
logger.debug("Could not write checkpoint prune marker: %s", exc)
|
||||
|
||||
total = result["deleted_orphan"] + result["deleted_stale"]
|
||||
if total > 0:
|
||||
logger.info(
|
||||
"checkpoint auto-maintenance: pruned %d repo(s) "
|
||||
"(%d orphan, %d stale), reclaimed %.1f MB",
|
||||
total,
|
||||
result["deleted_orphan"],
|
||||
result["deleted_stale"],
|
||||
result["bytes_freed"] / (1024 * 1024),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("checkpoint auto-maintenance failed: %s", exc)
|
||||
out["error"] = str(exc)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@ -440,9 +440,10 @@ def _get_or_create_env(task_id: str):
|
||||
_active_environments, _env_lock, _create_environment,
|
||||
_get_env_config, _last_activity, _start_cleanup_thread,
|
||||
_creation_locks, _creation_locks_lock, _task_env_overrides,
|
||||
_resolve_container_task_id,
|
||||
)
|
||||
|
||||
effective_task_id = task_id or "default"
|
||||
effective_task_id = _resolve_container_task_id(task_id)
|
||||
|
||||
# Fast path: environment already exists
|
||||
with _env_lock:
|
||||
|
||||
@ -88,8 +88,14 @@ def _resolve_path(filepath: str, task_id: str = "default") -> Path:
|
||||
|
||||
def _get_live_tracking_cwd(task_id: str = "default") -> str | None:
|
||||
"""Return the task's live terminal cwd for bookkeeping when available."""
|
||||
try:
|
||||
from tools.terminal_tool import _resolve_container_task_id
|
||||
container_key = _resolve_container_task_id(task_id)
|
||||
except Exception:
|
||||
container_key = task_id
|
||||
|
||||
with _file_ops_lock:
|
||||
cached = _file_ops_cache.get(task_id)
|
||||
cached = _file_ops_cache.get(container_key) or _file_ops_cache.get(task_id)
|
||||
if cached is not None:
|
||||
live_cwd = getattr(getattr(cached, "env", None), "cwd", None) or getattr(
|
||||
cached, "cwd", None
|
||||
@ -101,7 +107,7 @@ def _get_live_tracking_cwd(task_id: str = "default") -> str | None:
|
||||
from tools.terminal_tool import _active_environments, _env_lock
|
||||
|
||||
with _env_lock:
|
||||
env = _active_environments.get(task_id)
|
||||
env = _active_environments.get(container_key) or _active_environments.get(task_id)
|
||||
live_cwd = getattr(env, "cwd", None) if env is not None else None
|
||||
if live_cwd:
|
||||
return live_cwd
|
||||
@ -208,6 +214,11 @@ _read_tracker: dict = {}
|
||||
_READ_HISTORY_CAP = 500 # set; used only by get_read_files_summary
|
||||
_DEDUP_CAP = 1000 # dict; skip-identical-reread guard
|
||||
_READ_TIMESTAMPS_CAP = 1000 # dict; external-edit detection for write/patch
|
||||
_READ_DEDUP_STATUS_MESSAGE = (
|
||||
"File unchanged since last read. The content from "
|
||||
"the earlier read_file result in this conversation is "
|
||||
"still current — refer to that instead of re-reading."
|
||||
)
|
||||
|
||||
|
||||
def _cap_read_tracker_data(task_data: dict) -> None:
|
||||
@ -252,6 +263,37 @@ def _cap_read_tracker_data(task_data: dict) -> None:
|
||||
break
|
||||
|
||||
|
||||
def _is_internal_file_status_text(content: str) -> bool:
|
||||
"""Return True when content looks like an internal file-tool status, not real file bytes.
|
||||
|
||||
The read_file dedup status message must never be persisted as file
|
||||
content. The obvious shape is the model echoing the message verbatim,
|
||||
but in practice it also wraps it with small framing text (a leading
|
||||
"Note:", a trailing newline + short comment, etc.) before calling
|
||||
write_file. We treat any short-ish write whose body is dominated by
|
||||
the status message as the same class of corruption.
|
||||
|
||||
Heuristic:
|
||||
* Strict equality (after strip) — the verbatim shape.
|
||||
* OR the stripped content contains the full status message AND is
|
||||
short enough that the status dominates it (<=2x the message length).
|
||||
Short, status-dominated writes can't plausibly be real files —
|
||||
legitimate docs/notes that happen to quote this internal message
|
||||
are always dramatically longer.
|
||||
"""
|
||||
if not isinstance(content, str):
|
||||
return False
|
||||
stripped = content.strip()
|
||||
if not stripped:
|
||||
return False
|
||||
if stripped == _READ_DEDUP_STATUS_MESSAGE:
|
||||
return True
|
||||
if _READ_DEDUP_STATUS_MESSAGE in stripped and \
|
||||
len(stripped) <= 2 * len(_READ_DEDUP_STATUS_MESSAGE):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _get_file_ops(task_id: str = "default") -> ShellFileOperations:
|
||||
"""Get or create ShellFileOperations for a terminal environment.
|
||||
|
||||
@ -261,15 +303,23 @@ def _get_file_ops(task_id: str = "default") -> ShellFileOperations:
|
||||
|
||||
Thread-safe: uses the same per-task creation locks as terminal_tool to
|
||||
prevent duplicate sandbox creation from concurrent tool calls.
|
||||
|
||||
Note: subagent task_ids are collapsed to "default" via
|
||||
``_resolve_container_task_id`` so delegate_task children share the
|
||||
parent's container and its cached file_ops. RL/benchmark task_ids with
|
||||
a registered env override keep their isolation.
|
||||
"""
|
||||
from tools.terminal_tool import (
|
||||
_active_environments, _env_lock, _create_environment,
|
||||
_get_env_config, _last_activity, _start_cleanup_thread,
|
||||
_creation_locks,
|
||||
_creation_locks_lock,
|
||||
_resolve_container_task_id,
|
||||
)
|
||||
import time
|
||||
|
||||
task_id = _resolve_container_task_id(task_id)
|
||||
|
||||
# Fast path: check cache -- but also verify the underlying environment
|
||||
# is still alive (it may have been killed by the cleanup thread).
|
||||
with _file_ops_lock:
|
||||
@ -437,13 +487,11 @@ def read_file_tool(path: str, offset: int = 1, limit: int = 500, task_id: str =
|
||||
current_mtime = os.path.getmtime(resolved_str)
|
||||
if current_mtime == cached_mtime:
|
||||
return json.dumps({
|
||||
"content": (
|
||||
"File unchanged since last read. The content from "
|
||||
"the earlier read_file result in this conversation is "
|
||||
"still current — refer to that instead of re-reading."
|
||||
),
|
||||
"status": "unchanged",
|
||||
"message": _READ_DEDUP_STATUS_MESSAGE,
|
||||
"path": path,
|
||||
"dedup": True,
|
||||
"content_returned": False,
|
||||
}, ensure_ascii=False)
|
||||
except OSError:
|
||||
pass # stat failed — fall through to full read
|
||||
@ -598,13 +646,48 @@ def notify_other_tool_call(task_id: str = "default"):
|
||||
task_data["consecutive"] = 0
|
||||
|
||||
|
||||
def _invalidate_dedup_for_path(filepath: str, task_id: str) -> None:
|
||||
"""Remove all dedup cache entries whose resolved path matches *filepath*.
|
||||
|
||||
Called after write_file and patch so that a subsequent read_file on
|
||||
the same path always returns fresh content instead of a stale
|
||||
"File unchanged" stub. The dedup cache keys are tuples of
|
||||
``(resolved_path, offset, limit)``; we must evict **all** offset/limit
|
||||
combinations for the written path because any cached range could now
|
||||
be stale.
|
||||
|
||||
Must be called with ``_read_tracker_lock`` **not** held — acquires it
|
||||
internally.
|
||||
"""
|
||||
try:
|
||||
resolved = str(_resolve_path(filepath))
|
||||
except (OSError, ValueError):
|
||||
return
|
||||
with _read_tracker_lock:
|
||||
task_data = _read_tracker.get(task_id)
|
||||
if task_data is None:
|
||||
return
|
||||
dedup = task_data.get("dedup")
|
||||
if not dedup:
|
||||
return
|
||||
# Collect keys to remove (can't mutate dict during iteration).
|
||||
stale_keys = [k for k in dedup if k[0] == resolved]
|
||||
for k in stale_keys:
|
||||
del dedup[k]
|
||||
|
||||
|
||||
def _update_read_timestamp(filepath: str, task_id: str) -> None:
|
||||
"""Record the file's current modification time after a successful write.
|
||||
|
||||
Called after write_file and patch so that consecutive edits by the
|
||||
same task don't trigger false staleness warnings — each write
|
||||
refreshes the stored timestamp to match the file's new state.
|
||||
|
||||
Also invalidates the dedup cache for the written path so that
|
||||
subsequent reads return fresh content (fixes #13144).
|
||||
"""
|
||||
# Invalidate dedup first (before acquiring lock for timestamp update).
|
||||
_invalidate_dedup_for_path(filepath, task_id)
|
||||
try:
|
||||
resolved = str(_resolve_path_for_task(filepath, task_id))
|
||||
current_mtime = os.path.getmtime(resolved)
|
||||
@ -653,6 +736,11 @@ def write_file_tool(path: str, content: str, task_id: str = "default") -> str:
|
||||
sensitive_err = _check_sensitive_path(path, task_id)
|
||||
if sensitive_err:
|
||||
return tool_error(sensitive_err)
|
||||
if _is_internal_file_status_text(content):
|
||||
return tool_error(
|
||||
"Refusing to write internal read_file status text as file content. "
|
||||
"Re-read the file or reconstruct the intended file contents before writing."
|
||||
)
|
||||
try:
|
||||
# Resolve once for the registry lock + stale check. Failures here
|
||||
# fall back to the legacy path — write proceeds, per-task staleness
|
||||
|
||||
@ -1044,33 +1044,51 @@ class MCPServerTask:
|
||||
|
||||
# Snapshot child PIDs before spawning so we can track the new one.
|
||||
pids_before = _snapshot_child_pids()
|
||||
new_pids: set = set()
|
||||
# Redirect subprocess stderr into a shared log file so MCP servers
|
||||
# (FastMCP banners, slack-mcp startup JSON, etc.) don't dump onto
|
||||
# the user's TTY and corrupt the TUI. Preserves debuggability via
|
||||
# ~/.hermes/logs/mcp-stderr.log.
|
||||
_write_stderr_log_header(self.name)
|
||||
_errlog = _get_mcp_stderr_log()
|
||||
async with stdio_client(server_params, errlog=_errlog) as (read_stream, write_stream):
|
||||
# Capture the newly spawned subprocess PID for force-kill cleanup.
|
||||
new_pids = _snapshot_child_pids() - pids_before
|
||||
try:
|
||||
async with stdio_client(server_params, errlog=_errlog) as (
|
||||
read_stream,
|
||||
write_stream,
|
||||
):
|
||||
# Capture the newly spawned subprocess PID for force-kill cleanup.
|
||||
new_pids = _snapshot_child_pids() - pids_before
|
||||
if new_pids:
|
||||
with _lock:
|
||||
for _pid in new_pids:
|
||||
_stdio_pids[_pid] = self.name
|
||||
async with ClientSession(
|
||||
read_stream, write_stream, **sampling_kwargs
|
||||
) as session:
|
||||
await session.initialize()
|
||||
self.session = session
|
||||
await self._discover_tools()
|
||||
self._ready.set()
|
||||
# stdio transport does not use OAuth, but we still honor
|
||||
# _reconnect_event (e.g. future manual /mcp refresh) for
|
||||
# consistency with _run_http.
|
||||
await self._wait_for_lifecycle_event()
|
||||
finally:
|
||||
# Runs on clean exit, exceptions, AND asyncio cancellation.
|
||||
# If any of the spawned PIDs are still alive, the SDK's
|
||||
# teardown failed (common when the task is cancelled mid-way
|
||||
# on Linux, where setsid() children escape the parent cgroup).
|
||||
# Mark them as orphans so the next cleanup sweep can reap them.
|
||||
if new_pids:
|
||||
with _lock:
|
||||
for _pid in new_pids:
|
||||
_stdio_pids[_pid] = self.name
|
||||
async with ClientSession(read_stream, write_stream, **sampling_kwargs) as session:
|
||||
await session.initialize()
|
||||
self.session = session
|
||||
await self._discover_tools()
|
||||
self._ready.set()
|
||||
# stdio transport does not use OAuth, but we still honor
|
||||
# _reconnect_event (e.g. future manual /mcp refresh) for
|
||||
# consistency with _run_http.
|
||||
await self._wait_for_lifecycle_event()
|
||||
# Context exited cleanly — subprocess was terminated by the SDK.
|
||||
if new_pids:
|
||||
with _lock:
|
||||
for _pid in new_pids:
|
||||
_stdio_pids.pop(_pid, None)
|
||||
_stdio_pids.pop(_pid, None)
|
||||
for pid in new_pids:
|
||||
try:
|
||||
os.kill(pid, 0) # signal 0: probe liveness only
|
||||
except (ProcessLookupError, PermissionError, OSError):
|
||||
continue # process already exited — nothing to do
|
||||
_orphan_stdio_pids.add(pid)
|
||||
|
||||
async def _run_http(self, config: dict):
|
||||
"""Run the server using HTTP/StreamableHTTP transport."""
|
||||
@ -1718,6 +1736,13 @@ _lock = threading.Lock()
|
||||
# normal server shutdown.
|
||||
_stdio_pids: Dict[int, str] = {} # pid -> server_name
|
||||
|
||||
# PIDs that survived their session context exit (SDK teardown failed to
|
||||
# terminate them). These are detected in _run_stdio's finally block and
|
||||
# can be cleaned up asynchronously by _kill_orphaned_mcp_children().
|
||||
# Separate from _stdio_pids so cleanup sweeps never race with active
|
||||
# sessions (e.g. concurrent cron jobs or live user chats).
|
||||
_orphan_stdio_pids: set = set()
|
||||
|
||||
|
||||
def _snapshot_child_pids() -> set:
|
||||
"""Return a set of current child process PIDs.
|
||||
@ -2959,21 +2984,34 @@ def shutdown_mcp_servers():
|
||||
_stop_mcp_loop()
|
||||
|
||||
|
||||
def _kill_orphaned_mcp_children() -> None:
|
||||
"""Graceful shutdown of MCP stdio subprocesses that survived loop cleanup.
|
||||
def _kill_orphaned_mcp_children(include_active: bool = False) -> None:
|
||||
"""Best-effort graceful shutdown of stdio MCP subprocesses to reap orphans.
|
||||
|
||||
Sends SIGTERM first, waits 2 seconds, then escalates to SIGKILL.
|
||||
This prevents shared-resource collisions when multiple hermes processes
|
||||
run on the same host (each has its own _stdio_pids dict).
|
||||
Orphans are PIDs that survived their session context exit (SDK teardown
|
||||
did not terminate the process — common on Linux when stdio children escape
|
||||
the parent cgroup on cancellation). By default only entries in
|
||||
``_orphan_stdio_pids`` are reaped so concurrent cron jobs and live user
|
||||
sessions are not disrupted.
|
||||
|
||||
Only kills PIDs tracked in ``_stdio_pids`` — never arbitrary children.
|
||||
Sends SIGTERM, waits 2 seconds, then escalates to SIGKILL for any
|
||||
survivors, avoiding shared-resource collisions when multiple hermes
|
||||
processes run on the same host (each has its own ``_stdio_pids`` dict).
|
||||
|
||||
With ``include_active=True`` also kills every PID in ``_stdio_pids`` —
|
||||
used only at final shutdown, after the MCP event loop has stopped and no
|
||||
sessions can still be in flight.
|
||||
"""
|
||||
import signal as _signal
|
||||
import time as _time
|
||||
|
||||
with _lock:
|
||||
pids = dict(_stdio_pids)
|
||||
_stdio_pids.clear()
|
||||
pids: Dict[int, str] = {}
|
||||
for opid in _orphan_stdio_pids:
|
||||
pids[opid] = "orphan"
|
||||
_orphan_stdio_pids.clear()
|
||||
if include_active:
|
||||
pids.update(dict(_stdio_pids))
|
||||
_stdio_pids.clear()
|
||||
|
||||
# Fast path: no tracked stdio PIDs to reap. Skip the SIGTERM/sleep/SIGKILL
|
||||
# dance entirely — otherwise every MCP-free shutdown pays a 2s sleep tax.
|
||||
@ -3022,5 +3060,6 @@ def _stop_mcp_loop():
|
||||
except Exception:
|
||||
pass
|
||||
# After closing the loop, any stdio subprocesses that survived the
|
||||
# graceful shutdown are now orphaned. Force-kill them.
|
||||
_kill_orphaned_mcp_children()
|
||||
# graceful shutdown are now orphaned — include active PIDs too
|
||||
# since the loop is gone and no session can still be in flight.
|
||||
_kill_orphaned_mcp_children(include_active=True)
|
||||
|
||||
@ -20,7 +20,15 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
_TELEGRAM_TOPIC_TARGET_RE = re.compile(r"^\s*(-?\d+)(?::(\d+))?\s*$")
|
||||
_FEISHU_TARGET_RE = re.compile(r"^\s*((?:oc|ou|on|chat|open)_[-A-Za-z0-9]+)(?::([-A-Za-z0-9_]+))?\s*$")
|
||||
# Slack conversation IDs: C (public channel), G (private/group channel), D (DM).
|
||||
# Must be uppercase alphanumeric, 9+ chars. User IDs (U...) and workspace IDs
|
||||
# (W...) are NOT valid chat.postMessage channel values — posting to them fails
|
||||
# because the API requires a conversation ID. To DM a user you must first call
|
||||
# conversations.open to obtain a D... ID. Without this gate, Slack IDs fall
|
||||
# through to channel-name resolution, which only matches by name and fails.
|
||||
_SLACK_TARGET_RE = re.compile(r"^\s*([CGD][A-Z0-9]{8,})\s*$")
|
||||
_WEIXIN_TARGET_RE = re.compile(r"^\s*((?:wxid|gh|v\d+|wm|wb)_[A-Za-z0-9_-]+|[A-Za-z0-9._-]+@chatroom|filehelper)\s*$")
|
||||
_YUANBAO_TARGET_RE = re.compile(r"^\s*((?:group|direct):[^:]+)\s*$")
|
||||
# Discord snowflake IDs are numeric, same regex pattern as Telegram topic targets.
|
||||
_NUMERIC_TOPIC_RE = _TELEGRAM_TOPIC_TARGET_RE
|
||||
# Platforms that address recipients by phone number and accept E.164 format
|
||||
@ -120,11 +128,11 @@ SEND_MESSAGE_SCHEMA = {
|
||||
},
|
||||
"target": {
|
||||
"type": "string",
|
||||
"description": "Delivery target. Format: 'platform' (uses home channel), 'platform:#channel-name', 'platform:chat_id', or 'platform:chat_id:thread_id' for Telegram topics and Discord threads. Examples: 'telegram', 'telegram:-1001234567890:17585', 'discord:999888777:555444333', 'discord:#bot-home', 'slack:#engineering', 'signal:+155****4567', 'matrix:!roomid:server.org', 'matrix:@user:server.org'"
|
||||
"description": "Delivery target. Format: 'platform' (uses home channel), 'platform:#channel-name', 'platform:chat_id', or 'platform:chat_id:thread_id' for Telegram topics and Discord threads. Examples: 'telegram', 'telegram:-1001234567890:17585', 'discord:999888777:555444333', 'discord:#bot-home', 'slack:#engineering', 'signal:+155****4567', 'matrix:!roomid:server.org', 'matrix:@user:server.org', 'yuanbao:direct:<account_id>' (DM), 'yuanbao:group:<group_code>' (group chat)"
|
||||
},
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "The message text to send"
|
||||
"description": "The message text to send. To send an image or file, include MEDIA:<local_path> (e.g. 'MEDIA:/tmp/hermes/cache/img_xxx.jpg') in the message — the platform will deliver it as a native media attachment."
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
@ -215,6 +223,7 @@ def _handle_send(args):
|
||||
"weixin": Platform.WEIXIN,
|
||||
"email": Platform.EMAIL,
|
||||
"sms": Platform.SMS,
|
||||
"yuanbao": Platform.YUANBAO,
|
||||
}
|
||||
platform = platform_map.get(platform_name)
|
||||
if not platform:
|
||||
@ -292,7 +301,15 @@ def _handle_send(args):
|
||||
from gateway.mirror import mirror_to_session
|
||||
from gateway.session_context import get_session_env
|
||||
source_label = get_session_env("HERMES_SESSION_PLATFORM", "cli")
|
||||
if mirror_to_session(platform_name, chat_id, mirror_text, source_label=source_label, thread_id=thread_id):
|
||||
user_id = get_session_env("HERMES_SESSION_USER_ID", "") or None
|
||||
if mirror_to_session(
|
||||
platform_name,
|
||||
chat_id,
|
||||
mirror_text,
|
||||
source_label=source_label,
|
||||
thread_id=thread_id,
|
||||
user_id=user_id,
|
||||
):
|
||||
result["mirrored"] = True
|
||||
except Exception:
|
||||
pass
|
||||
@ -318,10 +335,21 @@ def _parse_target_ref(platform_name: str, target_ref: str):
|
||||
match = _NUMERIC_TOPIC_RE.fullmatch(target_ref)
|
||||
if match:
|
||||
return match.group(1), match.group(2), True
|
||||
if platform_name == "slack":
|
||||
match = _SLACK_TARGET_RE.fullmatch(target_ref)
|
||||
if match:
|
||||
return match.group(1), None, True
|
||||
if platform_name == "weixin":
|
||||
match = _WEIXIN_TARGET_RE.fullmatch(target_ref)
|
||||
if match:
|
||||
return match.group(1), None, True
|
||||
if platform_name == "yuanbao":
|
||||
match = _YUANBAO_TARGET_RE.fullmatch(target_ref)
|
||||
if match:
|
||||
return match.group(1), None, True
|
||||
if target_ref.strip().isdigit():
|
||||
return f"group:{target_ref.strip()}", None, True
|
||||
return None, None, False
|
||||
if platform_name in _PHONE_PLATFORMS:
|
||||
match = _E164_TARGET_RE.fullmatch(target_ref)
|
||||
if match:
|
||||
@ -532,7 +560,7 @@ async def _send_to_platform(platform, pconfig, chat_id, message, thread_id=None,
|
||||
if media_files and not message.strip():
|
||||
return {
|
||||
"error": (
|
||||
f"send_message MEDIA delivery is currently only supported for telegram, discord, matrix, weixin, and signal; "
|
||||
f"send_message MEDIA delivery is currently only supported for telegram, discord, matrix, weixin, signal and yuanbao; "
|
||||
f"target {platform.value} had only media attachments"
|
||||
)
|
||||
}
|
||||
@ -540,7 +568,7 @@ async def _send_to_platform(platform, pconfig, chat_id, message, thread_id=None,
|
||||
if media_files:
|
||||
warning = (
|
||||
f"MEDIA attachments were omitted for {platform.value}; "
|
||||
"native send_message media delivery is currently only supported for telegram, discord, matrix, weixin, and signal"
|
||||
"native send_message media delivery is currently only supported for telegram, discord, matrix, weixin, signal and yuanbao"
|
||||
)
|
||||
|
||||
last_result = None
|
||||
@ -1510,6 +1538,35 @@ async def _send_qqbot(pconfig, chat_id, message):
|
||||
return _error(f"QQBot send failed: {e}")
|
||||
|
||||
|
||||
async def _send_yuanbao(chat_id, message, media_files=None):
|
||||
"""Send via Yuanbao using the running gateway adapter's WebSocket connection.
|
||||
|
||||
Yuanbao uses a persistent WebSocket — unlike HTTP-based platforms, we
|
||||
cannot create a throwaway client. We obtain the running singleton from
|
||||
the adapter module itself (``get_active_adapter``).
|
||||
|
||||
chat_id format:
|
||||
- Group: "group:<group_code>"
|
||||
- DM: "direct:<account_id>" or just "<account_id>"
|
||||
"""
|
||||
try:
|
||||
from gateway.platforms.yuanbao import get_active_adapter, send_yuanbao_direct
|
||||
except ImportError:
|
||||
return _error("Yuanbao adapter module not available.")
|
||||
|
||||
adapter = get_active_adapter()
|
||||
if adapter is None:
|
||||
return _error(
|
||||
"Yuanbao adapter is not running. "
|
||||
"Start the gateway with yuanbao platform enabled first."
|
||||
)
|
||||
|
||||
try:
|
||||
return await send_yuanbao_direct(adapter, chat_id, message, media_files=media_files)
|
||||
except Exception as e:
|
||||
return _error(f"Yuanbao send failed: {e}")
|
||||
|
||||
|
||||
# --- Registry ---
|
||||
from tools.registry import registry, tool_error
|
||||
|
||||
|
||||
@ -274,12 +274,13 @@ def _list_recent_sessions(db, limit: int, current_session_id: str = None) -> str
|
||||
try:
|
||||
sid = current_session_id
|
||||
visited = set()
|
||||
current_root = current_session_id
|
||||
while sid and sid not in visited:
|
||||
visited.add(sid)
|
||||
current_root = sid
|
||||
s = db.get_session(sid)
|
||||
parent = s.get("parent_session_id") if s else None
|
||||
sid = parent if parent else None
|
||||
current_root = max(visited, key=len) if visited else current_session_id
|
||||
except Exception:
|
||||
current_root = current_session_id
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user