forked from molecule-ai/molecule-core
- PLATFORM_URL: replace unreachable http://platform:8080 mesh-only default with Docker-aware detection (host.docker.internal in containers, localhost for local dev) across all workspace Python modules and the git-token-helper shell script. - WORKSPACE_ID: add fail-fast validation in main.py (SystemExit if empty) consistent with coordinator.py / a2a_cli.py patterns already in place. - INCIDENT_LOG.md: replace all 3 F1088 credential types with ***REDACTED*** (sk-cp- 2x, github_pat_ 2x, ADMIN_TOKEN base64 3x). Fixes #1124, #1333. Co-authored-by: Molecule AI Dev Lead <dev-lead@agents.moleculesai.app>
562 lines
20 KiB
Python
562 lines
20 KiB
Python
"""Human-In-The-Loop (HITL) workflow primitives.
|
|
|
|
Generalizes the approval tool into reusable HITL building blocks that work
|
|
across all Molecule AI adapters.
|
|
|
|
Features
|
|
--------
|
|
@requires_approval
|
|
Decorator that gates *any* async callable (tool, method, standalone fn)
|
|
behind a human approval request. The decorated function only runs if
|
|
the request is granted. Roles in ``hitl.bypass_roles`` skip the gate.
|
|
|
|
pause_task / resume_task
|
|
LangChain tools for explicit pause/resume of in-flight tasks. An agent
|
|
calls ``pause_task(task_id, reason)`` to suspend itself; an external
|
|
signal (webhook, dashboard click, another agent) calls ``resume_task``
|
|
with the same task_id to wake it up.
|
|
|
|
Notification channels
|
|
---------------------
|
|
Configured under ``hitl:`` in ``config.yaml``:
|
|
|
|
hitl:
|
|
channels:
|
|
- type: dashboard # always active; uses platform approval API
|
|
- type: slack
|
|
webhook_url: https://hooks.slack.com/services/…
|
|
- type: email
|
|
smtp_host: smtp.example.com
|
|
smtp_port: 587
|
|
from: alerts@example.com
|
|
to: ops@example.com
|
|
username: alerts@example.com # optional; password from SMTP_PASSWORD env
|
|
default_timeout: 300 # seconds before an unanswered request times out
|
|
bypass_roles: [admin] # roles that skip the approval gate entirely
|
|
|
|
Environment variables
|
|
---------------------
|
|
SMTP_PASSWORD Password for SMTP authentication (preferred over config file)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import functools
|
|
import logging
|
|
import os
|
|
import smtplib
|
|
from dataclasses import dataclass, field
|
|
from email.mime.text import MIMEText
|
|
from typing import Any, Callable
|
|
|
|
import httpx
|
|
from langchain_core.tools import tool
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Config
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@dataclass
|
|
class HITLConfig:
|
|
"""HITL settings loaded from the ``hitl:`` block in config.yaml."""
|
|
channels: list[dict] = field(default_factory=lambda: [{"type": "dashboard"}])
|
|
default_timeout: float = 300.0
|
|
bypass_roles: list[str] = field(default_factory=list)
|
|
|
|
|
|
def _load_hitl_config() -> HITLConfig:
|
|
"""Load HITL config from workspace config; fall back to safe defaults."""
|
|
try:
|
|
from config import load_config
|
|
cfg = load_config()
|
|
raw = getattr(cfg, "hitl", None)
|
|
if raw is None:
|
|
return HITLConfig()
|
|
return HITLConfig(
|
|
channels=raw.channels if hasattr(raw, "channels") else [{"type": "dashboard"}],
|
|
default_timeout=float(raw.default_timeout if hasattr(raw, "default_timeout") else 300),
|
|
bypass_roles=list(raw.bypass_roles if hasattr(raw, "bypass_roles") else []),
|
|
)
|
|
except Exception:
|
|
return HITLConfig()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Pause / Resume registry
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class _TaskPauseRegistry:
|
|
"""In-process registry mapping task_id → asyncio.Event + optional result.
|
|
|
|
Multiple coroutines awaiting the same task_id are all unblocked when
|
|
``resume()`` is called. Results survive until the awaiting coroutine
|
|
calls ``pop_result()``.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self._events: dict[str, asyncio.Event] = {}
|
|
self._results: dict[str, dict] = {}
|
|
# #265: owner map — workspace_id that created each task.
|
|
# Empty string means "no owner / legacy" (bypasses ownership check).
|
|
self._owners: dict[str, str] = {}
|
|
|
|
def register(self, task_id: str, owner: str = "") -> asyncio.Event:
|
|
"""Create and store an Event for *task_id*. Returns the event.
|
|
|
|
Args:
|
|
task_id: Unique task identifier.
|
|
owner: Workspace ID that owns this task. When set, ``resume``
|
|
will reject callers from a different workspace.
|
|
"""
|
|
ev = asyncio.Event()
|
|
self._events[task_id] = ev
|
|
self._owners[task_id] = owner
|
|
return ev
|
|
|
|
def resume(self, task_id: str, result: dict | None = None, owner: str = "") -> bool:
|
|
"""Signal the Event for *task_id*. Returns False if not registered.
|
|
|
|
Args:
|
|
task_id: The identifier used in ``register``.
|
|
result: Optional result payload forwarded to the waiting coroutine.
|
|
owner: Caller's workspace ID. When both the stored owner and
|
|
*owner* are non-empty and they differ, the call is rejected
|
|
(returns False) — prevents cross-workspace prompt injection
|
|
(#265). Passing ``owner=""`` bypasses the check (used in
|
|
direct registry calls from tests and platform code).
|
|
"""
|
|
# #265 ownership check
|
|
stored_owner = self._owners.get(task_id, "")
|
|
if owner and stored_owner and owner != stored_owner:
|
|
logger.warning(
|
|
"HITL: resume rejected for task %s — caller workspace %r != owner %r",
|
|
task_id, owner, stored_owner,
|
|
)
|
|
return False
|
|
ev = self._events.get(task_id)
|
|
if ev is None:
|
|
return False
|
|
self._results[task_id] = result or {}
|
|
ev.set()
|
|
return True
|
|
|
|
def pop_result(self, task_id: str) -> dict:
|
|
"""Return and remove the stored result for *task_id*."""
|
|
return self._results.pop(task_id, {})
|
|
|
|
def cleanup(self, task_id: str) -> None:
|
|
"""Remove *task_id* from all dicts."""
|
|
self._events.pop(task_id, None)
|
|
self._results.pop(task_id, None)
|
|
self._owners.pop(task_id, None)
|
|
|
|
def list_paused(self) -> list[str]:
|
|
"""Return IDs of tasks whose events have not yet been set."""
|
|
return [tid for tid, ev in self._events.items() if not ev.is_set()]
|
|
|
|
|
|
# Global singleton — safe within one asyncio event loop / process
|
|
pause_registry = _TaskPauseRegistry()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Notification channels
|
|
# ---------------------------------------------------------------------------
|
|
|
|
async def _notify_channels(
|
|
action: str,
|
|
reason: str,
|
|
approval_id: str,
|
|
cfg: HITLConfig,
|
|
) -> None:
|
|
"""Fire-and-forget notifications to all configured channels.
|
|
|
|
Errors in individual channels are logged but never re-raised so that a
|
|
misconfigured Slack webhook cannot block the approval flow.
|
|
"""
|
|
platform_url = os.environ.get("PLATFORM_URL", "http://host.docker.internal:8080")
|
|
workspace_id = os.environ.get("WORKSPACE_ID", "")
|
|
|
|
for channel in cfg.channels:
|
|
ch_type = channel.get("type", "dashboard")
|
|
try:
|
|
if ch_type == "slack":
|
|
await _notify_slack(channel, action, reason, approval_id,
|
|
platform_url, workspace_id)
|
|
elif ch_type == "email":
|
|
await _notify_email(channel, action, reason, approval_id,
|
|
platform_url, workspace_id)
|
|
# "dashboard" is handled by the platform via the approval POST
|
|
except Exception as exc:
|
|
logger.warning("HITL: channel '%s' notification failed: %s", ch_type, exc)
|
|
|
|
|
|
async def _notify_slack(
|
|
cfg: dict,
|
|
action: str,
|
|
reason: str,
|
|
approval_id: str,
|
|
platform_url: str,
|
|
workspace_id: str,
|
|
) -> None:
|
|
webhook_url = cfg.get("webhook_url", "")
|
|
if not webhook_url:
|
|
return
|
|
|
|
approve_url = f"{platform_url}/workspaces/{workspace_id}/approvals/{approval_id}/approve"
|
|
deny_url = f"{platform_url}/workspaces/{workspace_id}/approvals/{approval_id}/deny"
|
|
|
|
payload = {
|
|
"text": f":warning: Approval required from workspace `{workspace_id}`",
|
|
"blocks": [
|
|
{
|
|
"type": "section",
|
|
"text": {
|
|
"type": "mrkdwn",
|
|
"text": (
|
|
f"*Action:* {action}\n"
|
|
f"*Reason:* {reason}\n"
|
|
f"*Approval ID:* `{approval_id}`"
|
|
),
|
|
},
|
|
},
|
|
{
|
|
"type": "actions",
|
|
"elements": [
|
|
{
|
|
"type": "button",
|
|
"text": {"type": "plain_text", "text": "Approve"},
|
|
"style": "primary",
|
|
"url": approve_url,
|
|
},
|
|
{
|
|
"type": "button",
|
|
"text": {"type": "plain_text", "text": "Deny"},
|
|
"style": "danger",
|
|
"url": deny_url,
|
|
},
|
|
],
|
|
},
|
|
],
|
|
}
|
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
|
await client.post(webhook_url, json=payload)
|
|
logger.info("HITL: Slack notification sent for approval %s", approval_id)
|
|
|
|
|
|
async def _notify_email(
|
|
cfg: dict,
|
|
action: str,
|
|
reason: str,
|
|
approval_id: str,
|
|
platform_url: str,
|
|
workspace_id: str,
|
|
) -> None:
|
|
smtp_host = cfg.get("smtp_host", "")
|
|
smtp_port = int(cfg.get("smtp_port", 587))
|
|
from_addr = cfg.get("from", "")
|
|
to_addr = cfg.get("to", "")
|
|
|
|
if not all([smtp_host, from_addr, to_addr]):
|
|
logger.warning("HITL: email channel missing smtp_host/from/to — skipping")
|
|
return
|
|
|
|
approve_url = f"{platform_url}/workspaces/{workspace_id}/approvals/{approval_id}/approve"
|
|
deny_url = f"{platform_url}/workspaces/{workspace_id}/approvals/{approval_id}/deny"
|
|
|
|
body = (
|
|
f"Approval required from workspace {workspace_id}\n\n"
|
|
f"Action : {action}\n"
|
|
f"Reason : {reason}\n"
|
|
f"ID : {approval_id}\n\n"
|
|
f"Approve: {approve_url}\n"
|
|
f"Deny : {deny_url}\n"
|
|
)
|
|
|
|
msg = MIMEText(body, "plain", "utf-8")
|
|
msg["Subject"] = f"[Molecule AI] Approval required: {action}"
|
|
msg["From"] = from_addr
|
|
msg["To"] = to_addr
|
|
|
|
username = cfg.get("username", "")
|
|
password = cfg.get("password", os.environ.get("SMTP_PASSWORD", ""))
|
|
|
|
def _send() -> None:
|
|
with smtplib.SMTP(smtp_host, smtp_port) as srv:
|
|
srv.ehlo()
|
|
srv.starttls()
|
|
if username and password:
|
|
srv.login(username, password)
|
|
srv.send_message(msg)
|
|
|
|
await asyncio.to_thread(_send)
|
|
logger.info("HITL: email notification sent for approval %s", approval_id)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# @requires_approval decorator
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def requires_approval(
|
|
action_description: str = "",
|
|
reason_template: str = "",
|
|
bypass_roles: list[str] | None = None,
|
|
) -> Callable[[Callable], Callable]:
|
|
"""Decorator that gates an async callable behind a human approval request.
|
|
|
|
The wrapped function executes only when a human approves. Use this on
|
|
any tool or async helper that performs destructive or high-impact work.
|
|
|
|
Args:
|
|
action_description: Short label for the action shown to the approver.
|
|
Defaults to the function's ``name`` attribute or
|
|
``__name__``.
|
|
reason_template: f-string template for the reason line. Keyword
|
|
arguments of the decorated function are available,
|
|
e.g. ``"Delete table {table_name}"``).
|
|
bypass_roles: Roles that skip the gate entirely. Overrides
|
|
``hitl.bypass_roles`` in config.yaml when given.
|
|
|
|
Returns:
|
|
A decorator; applying it to a function returns an async wrapper.
|
|
|
|
Usage::
|
|
|
|
@tool
|
|
@requires_approval("Wipe production DB", bypass_roles=["admin"])
|
|
async def drop_table(table_name: str) -> dict:
|
|
...
|
|
|
|
# Works with plain async functions too:
|
|
@requires_approval("Send customer email")
|
|
async def send_email(to: str, body: str) -> dict:
|
|
...
|
|
"""
|
|
def decorator(fn: Callable) -> Callable:
|
|
action = action_description or getattr(fn, "name", None) or fn.__name__
|
|
|
|
@functools.wraps(fn)
|
|
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
hitl_cfg = _load_hitl_config()
|
|
|
|
# --- Check bypass roles -----------------------------------------
|
|
active_bypass = bypass_roles if bypass_roles is not None else hitl_cfg.bypass_roles
|
|
if active_bypass:
|
|
try:
|
|
from builtin_tools.audit import get_workspace_roles
|
|
roles, _ = get_workspace_roles()
|
|
if any(r in active_bypass for r in roles):
|
|
logger.info(
|
|
"@requires_approval bypassed (role %s) for '%s'", roles, action
|
|
)
|
|
return await fn(*args, **kwargs)
|
|
except Exception:
|
|
pass # If RBAC check fails, proceed to approval gate
|
|
|
|
# --- Build reason string -----------------------------------------
|
|
if reason_template:
|
|
try:
|
|
reason = reason_template.format(**kwargs)
|
|
except (KeyError, IndexError):
|
|
reason = reason_template
|
|
else:
|
|
arg_parts = [f"{k}={str(v)[:60]}" for k, v in list(kwargs.items())[:3]]
|
|
reason = f"Args: {', '.join(arg_parts)}" if arg_parts else "Automated action"
|
|
|
|
# --- Fire non-dashboard notifications (async, non-blocking) ------
|
|
asyncio.create_task(
|
|
_notify_channels(action, reason, "pending", hitl_cfg)
|
|
)
|
|
|
|
# --- Request approval via approval tool --------------------------
|
|
try:
|
|
from builtin_tools.approval import request_approval
|
|
approval_result = await request_approval.ainvoke(
|
|
{"action": action, "reason": reason}
|
|
)
|
|
except Exception as exc:
|
|
logger.error("@requires_approval: approval call failed: %s", exc)
|
|
return {
|
|
"success": False,
|
|
"error": f"Approval gate error: {exc}",
|
|
}
|
|
|
|
if not approval_result.get("approved"):
|
|
# Art. 14 audit: log the denial outcome so the activity log
|
|
# contains evidence that the human oversight gate was exercised.
|
|
try:
|
|
from builtin_tools.audit import log_event
|
|
log_event(
|
|
event_type="hitl",
|
|
action="approve",
|
|
resource=action,
|
|
outcome="denied",
|
|
actor=approval_result.get("decided_by"),
|
|
approval_id=approval_result.get("approval_id"),
|
|
reason=reason,
|
|
)
|
|
except Exception:
|
|
pass
|
|
return {
|
|
"success": False,
|
|
"error": (
|
|
f"Action '{action}' not approved: "
|
|
f"{approval_result.get('message', approval_result.get('error', 'denied'))}"
|
|
),
|
|
"approval_id": approval_result.get("approval_id"),
|
|
}
|
|
|
|
# Art. 14 audit: log the approval grant before running the function.
|
|
try:
|
|
from builtin_tools.audit import log_event
|
|
log_event(
|
|
event_type="hitl",
|
|
action="approve",
|
|
resource=action,
|
|
outcome="granted",
|
|
actor=approval_result.get("decided_by"),
|
|
approval_id=approval_result.get("approval_id"),
|
|
reason=reason,
|
|
)
|
|
except Exception:
|
|
pass
|
|
|
|
# --- Approved — run the original function ------------------------
|
|
return await fn(*args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Pause / Resume LangChain tools
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@tool
|
|
async def pause_task(task_id: str, reason: str = "") -> dict:
|
|
"""Suspend the current task and wait for a resume signal.
|
|
|
|
The agent calls this to pause itself at a decision point. Execution
|
|
resumes when ``resume_task`` is called with the same task_id, or after
|
|
the configured ``hitl.default_timeout`` seconds.
|
|
|
|
Args:
|
|
task_id: Unique identifier for this pause point (use the A2A task ID
|
|
or any stable string that the caller can reference later).
|
|
reason: Human-readable description of why the task is pausing.
|
|
"""
|
|
# #265: record workspace ownership on registration so resume_task can
|
|
# reject callers from a different workspace (cross-workspace prompt-injection
|
|
# prevention). External task_id is unchanged — only internal ownership
|
|
# metadata is added, so no tests or callers need to update their task IDs.
|
|
_ws = os.environ.get("WORKSPACE_ID", "")
|
|
|
|
try:
|
|
from builtin_tools.audit import log_event
|
|
log_event(
|
|
event_type="hitl",
|
|
action="pause",
|
|
resource=task_id,
|
|
outcome="paused",
|
|
trace_id=task_id,
|
|
reason=reason,
|
|
)
|
|
except Exception:
|
|
pass
|
|
|
|
event = pause_registry.register(task_id, owner=_ws)
|
|
timeout = _load_hitl_config().default_timeout
|
|
logger.info("HITL: task %s paused — %s", task_id, reason or "(no reason given)")
|
|
|
|
try:
|
|
await asyncio.wait_for(event.wait(), timeout=timeout)
|
|
result = pause_registry.pop_result(task_id)
|
|
logger.info("HITL: task %s resumed", task_id)
|
|
try:
|
|
from builtin_tools.audit import log_event
|
|
log_event(
|
|
event_type="hitl",
|
|
action="resume",
|
|
resource=task_id,
|
|
outcome="resumed",
|
|
trace_id=task_id,
|
|
)
|
|
except Exception:
|
|
pass
|
|
return {"resumed": True, "task_id": task_id, **result}
|
|
|
|
except asyncio.TimeoutError:
|
|
logger.warning("HITL: task %s timed out after %.0fs", task_id, timeout)
|
|
try:
|
|
from builtin_tools.audit import log_event
|
|
log_event(
|
|
event_type="hitl",
|
|
action="pause",
|
|
resource=task_id,
|
|
outcome="timeout",
|
|
trace_id=task_id,
|
|
timeout_seconds=timeout,
|
|
)
|
|
except Exception:
|
|
pass
|
|
return {
|
|
"resumed": False,
|
|
"task_id": task_id,
|
|
"error": f"Timed out after {timeout:.0f}s waiting for resume signal",
|
|
}
|
|
finally:
|
|
pause_registry.cleanup(task_id)
|
|
|
|
|
|
@tool
|
|
async def resume_task(task_id: str, message: str = "") -> dict:
|
|
"""Resume a previously paused task.
|
|
|
|
Signals the ``pause_task`` coroutine waiting on *task_id* to continue.
|
|
Safe to call even if the task has already resumed or timed out (returns
|
|
success=False in that case).
|
|
|
|
Args:
|
|
task_id: The identifier passed to ``pause_task``.
|
|
message: Optional message forwarded to the resumed task.
|
|
"""
|
|
# #265: pass caller's workspace ID so the registry can reject a resume
|
|
# from a different workspace (ownership check in _TaskPauseRegistry.resume).
|
|
_ws = os.environ.get("WORKSPACE_ID", "")
|
|
|
|
result_payload = {"message": message} if message else {}
|
|
success = pause_registry.resume(task_id, result_payload, owner=_ws)
|
|
|
|
if success:
|
|
logger.info("HITL: resume signal sent for task %s", task_id)
|
|
try:
|
|
from builtin_tools.audit import log_event
|
|
log_event(
|
|
event_type="hitl",
|
|
action="resume",
|
|
resource=task_id,
|
|
outcome="success",
|
|
trace_id=task_id,
|
|
message=message,
|
|
)
|
|
except Exception:
|
|
pass
|
|
return {"success": True, "task_id": task_id}
|
|
|
|
return {
|
|
"success": False,
|
|
"task_id": task_id,
|
|
"error": "Task not found or already resumed",
|
|
}
|
|
|
|
|
|
@tool
|
|
async def list_paused_tasks() -> dict:
|
|
"""List all tasks currently suspended and waiting for a resume signal."""
|
|
paused = pause_registry.list_paused()
|
|
return {"paused_tasks": paused, "count": len(paused)}
|