commit 851a6d7bfd60f3df932a2930e7fd0f95d6715ef9 Author: Hongming Wang Date: Thu Apr 16 04:26:06 2026 -0700 feat: initial release of molecule-ai-workspace-runtime 0.1.0 Extracts shared workspace runtime from molecule-monorepo/workspace-template into a publishable PyPI package. - molecule_runtime/ package with all shared infrastructure modules - Adapter discovery via ADAPTER_MODULE env var (standalone repos) + built-in scan - molecule-runtime console script entry point (main_sync) - CI workflow to publish on version tags - Published to PyPI as molecule-ai-workspace-runtime==0.1.0 Co-Authored-By: Claude Sonnet 4.6 diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000..0bdeae0 --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,29 @@ +name: Publish to PyPI + +on: + push: + tags: + - "v*" + workflow_dispatch: + +jobs: + build-and-publish: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install build tools + run: pip install build twine + + - name: Build package + run: python -m build + + - name: Publish to PyPI + env: + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} + run: python -m twine upload dist/* diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..196906c --- /dev/null +++ b/.gitignore @@ -0,0 +1,11 @@ +__pycache__/ +*.py[cod] +*.pyo +*.egg-info/ +dist/ +build/ +.eggs/ +*.egg +.venv/ +venv/ +*.pyc diff --git a/README.md b/README.md new file mode 100644 index 0000000..5a9e468 --- /dev/null +++ b/README.md @@ -0,0 +1,65 @@ +# molecule-ai-workspace-runtime + +Shared Python runtime infrastructure for all Molecule AI agent adapters. + +This package provides the core machinery that every Molecule AI workspace container needs: + +- **A2A server** — Registers with the platform, heartbeats, serves A2A JSON-RPC +- **Adapter interface** — `BaseAdapter` / `AdapterConfig` / `SetupResult` +- **Built-in tools** — delegation, memory, approvals, sandbox, telemetry +- **Skill loader** — loads and hot-reloads skill modules from `/configs/skills/` +- **Plugin system** — per-workspace + shared plugin discovery and install +- **Config / preflight** — YAML config loading with validation + +## Installation + +```bash +pip install molecule-ai-workspace-runtime +``` + +## Adapter Discovery + +The runtime discovers adapters in two ways: + +1. **`ADAPTER_MODULE` env var** (standalone adapter repos): + ```bash + ADAPTER_MODULE=my_adapter molecule-runtime + ``` + The module must export an `Adapter` class extending `BaseAdapter`. + +2. **Built-in subdirectory scan** (monorepo local dev): + Scans `molecule_runtime/adapters/` subdirectories for `Adapter` classes. + +## Writing an Adapter + +```python +from molecule_runtime.adapters.base import BaseAdapter, AdapterConfig +from a2a.server.agent_execution import AgentExecutor + +class Adapter(BaseAdapter): + @staticmethod + def name() -> str: + return "my-runtime" + + @staticmethod + def display_name() -> str: + return "My Runtime" + + @staticmethod + def description() -> str: + return "My custom agent runtime" + + async def setup(self, config: AdapterConfig) -> None: + result = await self._common_setup(config) + # Store result attributes for create_executor + + async def create_executor(self, config: AdapterConfig) -> AgentExecutor: + # Return an AgentExecutor instance + ... +``` + +Set `ADAPTER_MODULE=my_package.adapter` and run `molecule-runtime`. + +## License + +BSL-1.1 — see LICENSE for details. diff --git a/molecule_runtime/__init__.py b/molecule_runtime/__init__.py new file mode 100644 index 0000000..108e972 --- /dev/null +++ b/molecule_runtime/__init__.py @@ -0,0 +1,6 @@ +"""Molecule AI workspace runtime — shared infrastructure for all agent adapters.""" + +from molecule_runtime.adapters.base import BaseAdapter, AdapterConfig, SetupResult + +__version__ = "0.1.0" +__all__ = ["BaseAdapter", "AdapterConfig", "SetupResult"] diff --git a/molecule_runtime/a2a_cli.py b/molecule_runtime/a2a_cli.py new file mode 100644 index 0000000..00af26f --- /dev/null +++ b/molecule_runtime/a2a_cli.py @@ -0,0 +1,245 @@ +#!/usr/bin/env python3 +"""A2A CLI — command-line tools for inter-workspace communication. + +Supports both synchronous and asynchronous delegation: + a2a delegate — Send task, wait for response (sync) + a2a delegate --async — Send task, return task ID immediately + a2a status — Check task status / get result + a2a peers — List available peers + a2a info — Show this workspace's info + +Environment variables: + WORKSPACE_ID — this workspace's ID + PLATFORM_URL — platform API base URL +""" + +import asyncio +import json +import os +import sys +import uuid + +import httpx + +WORKSPACE_ID = os.environ.get("WORKSPACE_ID", "") +PLATFORM_URL = os.environ.get("PLATFORM_URL", "http://platform:8080") + + +async def discover(target_id: str) -> dict | None: + """Discover a peer workspace's URL.""" + async with httpx.AsyncClient(timeout=30.0) as client: + resp = await client.get( + f"{PLATFORM_URL}/registry/discover/{target_id}", + headers={"X-Workspace-ID": WORKSPACE_ID}, + ) + if resp.status_code == 200: + return resp.json() + return None + + +async def delegate(target_id: str, task: str, async_mode: bool = False): + """Delegate a task to another workspace.""" + peer = await discover(target_id) + if not peer: + print(f"Error: cannot reach workspace {target_id} (access denied or offline)", file=sys.stderr) + sys.exit(1) + + target_url = peer.get("url", "") + if not target_url: + print(f"Error: workspace {target_id} has no URL", file=sys.stderr) + sys.exit(1) + + task_id = str(uuid.uuid4()) + + if async_mode: + # Async: send and return immediately, don't wait for response + # Use a background task that fires and forgets + async with httpx.AsyncClient(timeout=10.0) as client: + try: + # Send with a short timeout — just confirm receipt + resp = await client.post( + target_url, + json={ + "jsonrpc": "2.0", + "id": task_id, + "method": "message/send", + "params": { + "message": { + "role": "user", + "messageId": str(uuid.uuid4()), + "parts": [{"kind": "text", "text": task}], + } + }, + }, + ) + # Even if we timeout, the task is queued on the target + print(json.dumps({ + "task_id": task_id, + "target": target_id, + "status": "submitted", + "target_url": target_url, + })) + except httpx.TimeoutException: + # Request was sent but we didn't get confirmation — task may or may not have been received + print(json.dumps({ + "task_id": task_id, + "target": target_id, + "status": "uncertain", + "note": "Request sent but response timed out — delivery unconfirmed. Use 'a2a status' to check.", + }), file=sys.stderr) + return + + # Sync: wait for full response with retry on rate limit + max_retries = 3 + for attempt in range(max_retries): + async with httpx.AsyncClient(timeout=300.0) as client: + try: + resp = await client.post( + target_url, + json={ + "jsonrpc": "2.0", + "id": task_id, + "method": "message/send", + "params": { + "message": { + "role": "user", + "messageId": str(uuid.uuid4()), + "parts": [{"kind": "text", "text": task}], + } + }, + }, + ) + try: + data = resp.json() + except Exception: + print(f"Error: invalid JSON response (status {resp.status_code})", file=sys.stderr) + sys.exit(1) + if "result" in data: + parts = data["result"].get("parts", []) + text = parts[0].get("text", "") if parts else "" + if text and text != "(no response generated)": + print(text) + return + # Empty or no-response — might be rate limited, retry + if attempt < max_retries - 1: + delay = 5 * (2 ** attempt) + print(f"(empty response, retrying in {delay}s...)", file=sys.stderr) + await asyncio.sleep(delay) + continue + print(text or "(no response after retries)") + elif "error" in data: + error_msg = data['error'].get('message', 'unknown') + if ("rate" in error_msg.lower() or "overloaded" in error_msg.lower()) and attempt < max_retries - 1: + delay = 5 * (2 ** attempt) + print(f"(rate limited, retrying in {delay}s...)", file=sys.stderr) + await asyncio.sleep(delay) + continue + print(f"Error: {error_msg}", file=sys.stderr) + sys.exit(1) + return + except httpx.TimeoutException: + if attempt < max_retries - 1: + delay = 5 * (2 ** attempt) + print(f"(timeout, retrying in {delay}s...)", file=sys.stderr) + await asyncio.sleep(delay) + continue + print("Error: request timed out after retries", file=sys.stderr) + sys.exit(1) + + +async def check_status(target_id: str, task_id: str): + """Check the status of an async task.""" + peer = await discover(target_id) + if not peer: + print(f"Error: cannot reach workspace {target_id}", file=sys.stderr) + sys.exit(1) + + target_url = peer.get("url", "") + async with httpx.AsyncClient(timeout=30.0) as client: + resp = await client.post( + target_url, + json={ + "jsonrpc": "2.0", + "id": str(uuid.uuid4()), + "method": "tasks/get", + "params": {"id": task_id}, + }, + ) + data = resp.json() + if "result" in data: + task = data["result"] + status = task.get("status", {}).get("state", "unknown") + print(f"Status: {status}") + if status == "completed": + artifacts = task.get("artifacts", []) + for a in artifacts: + for p in a.get("parts", []): + if p.get("text"): + print(p["text"]) + elif "error" in data: + print(f"Error: {data['error'].get('message', 'unknown')}") + + +async def peers(): + """List available peers.""" + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.get(f"{PLATFORM_URL}/registry/{WORKSPACE_ID}/peers") + if resp.status_code != 200: + print("Error: could not fetch peers", file=sys.stderr) + sys.exit(1) + for p in resp.json(): + status = p.get("status", "?") + role = p.get("role", "") + print(f"{p['id']} {p['name']:30s} {status:10s} {role}") + + +async def info(): + """Get this workspace's info.""" + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.get(f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}") + if resp.status_code == 200: + d = resp.json() + print(f"ID: {d['id']}") + print(f"Name: {d['name']}") + print(f"Role: {d.get('role', '')}") + print(f"Tier: {d['tier']}") + print(f"Status: {d['status']}") + print(f"Parent: {d.get('parent_id', '(root)')}") + + +def main(): + if len(sys.argv) < 2: + print("Usage: a2a [args]") + print("Commands:") + print(" delegate — Send task, wait for response") + print(" delegate --async — Send task, return immediately") + print(" status — Check async task status") + print(" peers — List available peers") + print(" info — Show workspace info") + sys.exit(1) + + cmd = sys.argv[1] + + if cmd == "delegate": + async_mode = "--async" in sys.argv + args = [a for a in sys.argv[2:] if a != "--async"] + if len(args) < 2: + print("Usage: a2a delegate [--async] ", file=sys.stderr) + sys.exit(1) + asyncio.run(delegate(args[0], " ".join(args[1:]), async_mode)) + elif cmd == "status": + if len(sys.argv) < 4: + print("Usage: a2a status ", file=sys.stderr) + sys.exit(1) + asyncio.run(check_status(sys.argv[2], sys.argv[3])) + elif cmd == "peers": + asyncio.run(peers()) + elif cmd == "info": + asyncio.run(info()) + else: + print(f"Unknown command: {cmd}", file=sys.stderr) + sys.exit(1) + + +if __name__ == "__main__": # pragma: no cover + main() diff --git a/molecule_runtime/a2a_client.py b/molecule_runtime/a2a_client.py new file mode 100644 index 0000000..0ae6dd2 --- /dev/null +++ b/molecule_runtime/a2a_client.py @@ -0,0 +1,111 @@ +"""A2A protocol client — peer discovery, messaging, and workspace info. + +Shared constants (WORKSPACE_ID, PLATFORM_URL) live here so that +a2a_tools and a2a_mcp_server can import them from a single place. +""" + +import logging +import os +import uuid + +import httpx + +from platform_auth import auth_headers + +logger = logging.getLogger(__name__) + +WORKSPACE_ID = os.environ.get("WORKSPACE_ID", "") +PLATFORM_URL = os.environ.get("PLATFORM_URL", "http://platform:8080") + +# Cache workspace ID → name mappings (populated by list_peers calls) +_peer_names: dict[str, str] = {} + +# Sentinel prefix for errors originating from send_a2a_message / child agents. +# Used by delegate_task to distinguish real errors from normal response text. +_A2A_ERROR_PREFIX = "[A2A_ERROR] " + + +async def discover_peer(target_id: str) -> dict | None: + """Discover a peer workspace's URL via the platform registry.""" + async with httpx.AsyncClient(timeout=10.0) as client: + try: + resp = await client.get( + f"{PLATFORM_URL}/registry/discover/{target_id}", + headers={"X-Workspace-ID": WORKSPACE_ID, **auth_headers()}, + ) + if resp.status_code == 200: + return resp.json() + return None + except Exception as e: + logger.error(f"Discovery failed for {target_id}: {e}") + return None + + +async def send_a2a_message(target_url: str, message: str) -> str: + """Send an A2A message/send to a target workspace.""" + # Fix F (Cycle 5 / H2 — flagged 5 consecutive audits): timeout=None allowed + # a hung upstream to block the agent indefinitely. Use a generous but bounded + # timeout: 30s connect + 300s read (long enough for slow LLM responses). + async with httpx.AsyncClient( + timeout=httpx.Timeout(connect=30.0, read=300.0, write=30.0, pool=30.0) + ) as client: + try: + resp = await client.post( + target_url, + headers=auth_headers(), + json={ + "jsonrpc": "2.0", + "id": str(uuid.uuid4()), + "method": "message/send", + "params": { + "message": { + "role": "user", + "messageId": str(uuid.uuid4()), + "parts": [{"kind": "text", "text": message}], + } + }, + }, + ) + data = resp.json() + if "result" in data: + parts = data["result"].get("parts", []) + text = parts[0].get("text", "") if parts else "(no response)" + # Tag child-reported errors so the caller can detect them reliably + if text.startswith("Agent error:"): + return f"{_A2A_ERROR_PREFIX}{text}" + return text + elif "error" in data: + return f"{_A2A_ERROR_PREFIX}{data['error'].get('message', 'unknown')}" + return str(data) + except Exception as e: + return f"{_A2A_ERROR_PREFIX}{e}" + + +async def get_peers() -> list[dict]: + """Get this workspace's peers from the platform registry.""" + async with httpx.AsyncClient(timeout=10.0) as client: + try: + resp = await client.get( + f"{PLATFORM_URL}/registry/{WORKSPACE_ID}/peers", + headers={"X-Workspace-ID": WORKSPACE_ID, **auth_headers()}, + ) + if resp.status_code == 200: + return resp.json() + return [] + except Exception: + return [] + + +async def get_workspace_info() -> dict: + """Get this workspace's info from the platform.""" + async with httpx.AsyncClient(timeout=10.0) as client: + try: + resp = await client.get( + f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}", + headers=auth_headers(), + ) + if resp.status_code == 200: + return resp.json() + return {"error": "not found"} + except Exception as e: + return {"error": str(e)} diff --git a/molecule_runtime/a2a_executor.py b/molecule_runtime/a2a_executor.py new file mode 100644 index 0000000..ebe4008 --- /dev/null +++ b/molecule_runtime/a2a_executor.py @@ -0,0 +1,419 @@ +"""Bridge between LangGraph agent and A2A protocol, with SSE streaming support. + +SSE streaming architecture +-------------------------- +The A2A SDK (``DefaultRequestHandler`` + ``EventQueue``) owns the SSE transport +layer. This executor's job is to push the right event types into the queue as +work progresses: + + 1. ``TaskStatusUpdateEvent(state=working)`` — immediately signals start + 2. ``TaskArtifactUpdateEvent(chunk, append=…)`` — one per LLM text token + 3. ``Message(final_text)`` — terminal event + +Client compatibility +-------------------- +*Non-streaming* (``message/send``): + ``ResultAggregator.consume_all()`` processes status/artifact events + (updating the task in the store) and returns the final ``Message`` + immediately — backward-compatible with ``a2a_client.py`` which reads + ``data["result"]["parts"][0]["text"]``. + +*Streaming* (``message/stream``): + ``consume_and_emit()`` yields every event above as SSE, letting the client + render tokens in real time. + +LangGraph integration +--------------------- +Uses ``agent.astream_events(version="v2")`` to receive ``on_chat_model_stream`` +events with ``AIMessageChunk`` payloads. Text is extracted from both plain +strings (OpenAI / Groq) and Anthropic-style content-block lists. Non-text +content (tool_use, etc.) is silently skipped. A fresh ``artifact_id`` is +generated for each new LLM ``run_id`` so tool-call cycles are grouped cleanly. +""" + +import functools +import logging +import os +import uuid + +from a2a.server.agent_execution import AgentExecutor, RequestContext +from a2a.server.events import EventQueue +from a2a.server.tasks import TaskUpdater +from a2a.types import Part, TextPart +from a2a.utils import new_agent_text_message +from adapters.shared_runtime import ( + extract_history as _extract_history, + extract_message_text, + brief_task, + set_current_task, +) +from builtin_tools.telemetry import ( + A2A_TASK_ID, + GEN_AI_OPERATION_NAME, + GEN_AI_REQUEST_MODEL, + GEN_AI_SYSTEM, + WORKSPACE_ID_ATTR, + _incoming_trace_context, + gen_ai_system_from_model, + get_tracer, + record_llm_token_usage, +) + +logger = logging.getLogger(__name__) + +_WORKSPACE_ID = os.environ.get("WORKSPACE_ID", "unknown") + +# LangGraph ReAct cycle budget per turn. Library default is 25; 500 covers +# PM fan-outs (plan → 6 delegations → 6 awaits → 6 results → synthesize ≈ +# 30+ steps even before retries). Overridable via LANGGRAPH_RECURSION_LIMIT. +DEFAULT_RECURSION_LIMIT = 500 + + +def _parse_recursion_limit() -> int: + """Read LANGGRAPH_RECURSION_LIMIT; fall back to DEFAULT_RECURSION_LIMIT + with a WARNING log on any unparseable or non-positive value.""" + raw = os.environ.get("LANGGRAPH_RECURSION_LIMIT", "") + if not raw: + return DEFAULT_RECURSION_LIMIT + try: + n = int(raw) + except ValueError: + logger.warning( + "LANGGRAPH_RECURSION_LIMIT=%r is not an integer; using default %d", + raw, DEFAULT_RECURSION_LIMIT, + ) + return DEFAULT_RECURSION_LIMIT + if n <= 0: + logger.warning( + "LANGGRAPH_RECURSION_LIMIT=%d is not positive; using default %d", + n, DEFAULT_RECURSION_LIMIT, + ) + return DEFAULT_RECURSION_LIMIT + return n + +# --------------------------------------------------------------------------- +# Compliance (OWASP Top 10 for Agentic Apps) — optional, lazy-loaded +# --------------------------------------------------------------------------- + +try: + from builtin_tools.compliance import ( + AgencyTracker, + ExcessiveAgencyError, + PromptInjectionError, + redact_pii as _redact_pii, + sanitize_input as _sanitize_input, + ) + _COMPLIANCE_AVAILABLE = True +except ImportError: # pragma: no cover + _COMPLIANCE_AVAILABLE = False + + +@functools.lru_cache(maxsize=1) +def _get_compliance_cfg(): + """Return ComplianceConfig or None (cached for process lifetime).""" + try: + from config import load_config + return load_config().compliance + except Exception: + return None + + +def _extract_chunk_text(content) -> list[str]: + """Extract text strings from an LLM streaming chunk's content field. + + Handles both provider content styles: + - OpenAI / Groq: ``content`` is a plain ``str`` (empty for tool-call chunks). + - Anthropic: ``content`` is a list of typed blocks, e.g. + ``[{"type": "text", "text": "Hello"}, {"type": "tool_use", ...}]`` + + Only ``"text"`` blocks are returned; ``tool_use``, ``tool_result``, and + other non-text blocks are filtered out so raw tool JSON never appears in + the SSE stream. + + Args: + content: ``chunk.content`` value from an ``on_chat_model_stream`` event. + + Returns: + List of non-empty text strings. + """ + if isinstance(content, str): + return [content] if content else [] + if isinstance(content, list): + texts: list[str] = [] + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + text = block.get("text", "") + if text: + texts.append(text) + elif isinstance(block, str) and block: + texts.append(block) + return texts + return [] + + +class LangGraphA2AExecutor(AgentExecutor): + """Bridges LangGraph agent to A2A event model with SSE streaming support. + + Always uses ``agent.astream_events()`` so that: + - Streaming clients (``message/stream``) receive token-level SSE events. + - Non-streaming clients (``message/send``) receive the final ``Message`` + collected from the same stream — no duplicate LLM call, full compat. + """ + + def __init__(self, agent, heartbeat=None, model: str = "unknown"): + self.agent = agent # Compiled LangGraph graph (create_react_agent output) + self._heartbeat = heartbeat + self._model = model # e.g. "anthropic:claude-sonnet-4-6" + + async def execute(self, context: RequestContext, event_queue: EventQueue) -> None: + """Execute a task from an A2A request with SSE streaming. + + Routes through the Temporal durable workflow when a global + ``TemporalWorkflowWrapper`` is initialised and connected to Temporal; + otherwise falls back to ``_core_execute()`` (direct path). + + Event emission sequence: + 1. TaskStatusUpdateEvent(working) — immediate start signal + 2. TaskArtifactUpdateEvent chunks — token-by-token via astream_events + 3. Message(final_text) — terminal; non-streaming clients + return on this; streaming clients + also receive it as the last SSE event. + """ + # ── Optional Temporal durable execution wrapper ────────────────────── + # When a TemporalWorkflowWrapper is active this routes execution through + # a MoleculeAIAgentWorkflow (task_receive → llm_call → task_complete). + # Falls back silently to _core_execute() on any error or if Temporal + # is unavailable, so the client always receives a response. + try: + from builtin_tools.temporal_workflow import get_wrapper as _get_temporal_wrapper + + _tw = _get_temporal_wrapper() + if _tw is not None and _tw.is_available(): + return await _tw.run(self, context, event_queue) + except Exception: + pass # Never let the wrapper path crash the executor + + await self._core_execute(context, event_queue) + + async def _core_execute(self, context: RequestContext, event_queue: EventQueue) -> str: + """Core execution pipeline — called directly or from a Temporal activity. + + This is the original ``execute()`` body, extracted so that the Temporal + ``llm_call`` activity can invoke it without re-entering the wrapper + check and causing infinite recursion. + + Returns the final response text (empty string on empty input or error). + + Event emission sequence: + 1. TaskStatusUpdateEvent(working) — immediate start signal + 2. TaskArtifactUpdateEvent chunks — token-by-token via astream_events + 3. Message(final_text) — terminal event + """ + user_input = extract_message_text(context) + if not user_input: + parts = getattr(getattr(context, "message", None), "parts", None) + logger.warning("A2A execute: no text content in message parts: %s", parts) + await event_queue.enqueue_event( + new_agent_text_message("Error: message contained no text content.") + ) + return "" + + # ── OA-01: Prompt injection check (OWASP Agentic Top 10) ──────────── + _compliance_cfg = _get_compliance_cfg() if _COMPLIANCE_AVAILABLE else None + if _COMPLIANCE_AVAILABLE and _compliance_cfg and _compliance_cfg.mode == "owasp_agentic": + try: + user_input = _sanitize_input( + user_input, + prompt_injection_mode=_compliance_cfg.prompt_injection, + context_id=context.context_id or "", + ) + except PromptInjectionError as exc: + await event_queue.enqueue_event( + new_agent_text_message(f"Request blocked: {exc}") + ) + return "" + + logger.info("A2A execute: user_input=%s", user_input[:200]) + + # ── OTEL: task_receive span ────────────────────────────────────────── + parent_ctx = _incoming_trace_context.get() + tracer = get_tracer() + + _result: str = "" # captured inside the span for return after it closes + + with tracer.start_as_current_span("task_receive", context=parent_ctx) as task_span: + task_span.set_attribute(WORKSPACE_ID_ATTR, _WORKSPACE_ID) + task_span.set_attribute(A2A_TASK_ID, context.context_id or "") + task_span.set_attribute("a2a.input_preview", user_input[:256]) + + await set_current_task(self._heartbeat, brief_task(user_input)) + + # Resolve IDs — the RequestContextBuilder always sets them, but + # we generate fallbacks for safety (e.g. in unit tests). + task_id = context.task_id or str(uuid.uuid4()) + context_id = context.context_id or str(uuid.uuid4()) + + updater = TaskUpdater(event_queue, task_id, context_id) + + try: + messages = _extract_history(context) + if messages: + logger.info("A2A execute: injecting %d history messages", len(messages)) + messages.append(("human", user_input)) + + # Recursion limit: see DEFAULT_RECURSION_LIMIT and + # _parse_recursion_limit() at module top. Re-read on every + # call so the env var can be hot-changed between requests. + recursion_limit = _parse_recursion_limit() + run_config = { + "configurable": {"thread_id": context_id}, + "run_name": f"a2a-{context_id[:8]}", + "recursion_limit": recursion_limit, + } + + # ── OTEL: llm_call span ────────────────────────────────────── + with tracer.start_as_current_span("llm_call") as llm_span: + llm_span.set_attribute(GEN_AI_OPERATION_NAME, "chat") + llm_span.set_attribute(GEN_AI_SYSTEM, gen_ai_system_from_model(self._model)) + llm_span.set_attribute(GEN_AI_REQUEST_MODEL, self._model) + llm_span.set_attribute(WORKSPACE_ID_ATTR, _WORKSPACE_ID) + + # ── Step 1: signal "working" to streaming clients ───────── + await updater.start_work() + + # ── Step 2: stream tokens via LangGraph astream_events ──── + # Each "on_chat_model_stream" event carries an AIMessageChunk. + # We emit one TaskArtifactUpdateEvent per text chunk so SSE + # clients can render tokens in real time. + # artifact_id resets on each new LLM run_id so agent→tool→agent + # cycles each get their own artifact slot. + + artifact_id = str(uuid.uuid4()) + has_streamed = False # True after first chunk for current artifact + current_run_id = None # Detects new LLM call in a ReAct cycle + accumulated: list[str] = [] # All text for the final Message + last_ai_message = None # Saved for token-usage telemetry + + # ── OA-03: Excessive agency tracker ────────────────────── + _agency = ( + AgencyTracker( + max_tool_calls=_compliance_cfg.max_tool_calls_per_task, + max_duration_seconds=float(_compliance_cfg.max_task_duration_seconds), + ) + if _COMPLIANCE_AVAILABLE and _compliance_cfg and _compliance_cfg.mode == "owasp_agentic" + else None + ) + + async for event in self.agent.astream_events( + {"messages": messages}, + config=run_config, + version="v2", + ): + kind = event.get("event", "") + + if kind == "on_chat_model_stream": + run_id = event.get("run_id", "") + if run_id and run_id != current_run_id: + # New LLM run started — fresh artifact slot + current_run_id = run_id + artifact_id = str(uuid.uuid4()) + has_streamed = False + + chunk = event.get("data", {}).get("chunk") + if chunk is not None: + texts = _extract_chunk_text(chunk.content) + for text in texts: + await updater.add_artifact( + parts=[Part(root=TextPart(text=text))], + artifact_id=artifact_id, + append=has_streamed, # False=first, True=append + last_chunk=False, + ) + has_streamed = True + accumulated.append(text) + + elif kind == "on_tool_start": + tool_name = event.get("name", "?") + logger.debug("SSE: tool start — %s", tool_name) + if _agency is not None: + _agency.on_tool_call( + tool_name=tool_name, + context_id=context_id, + ) + + elif kind == "on_tool_end": + logger.debug("SSE: tool end — %s", event.get("name", "?")) + + elif kind == "on_chat_model_end": + # Capture the last completed AIMessage for token telemetry + output = event.get("data", {}).get("output") + if output is not None: + last_ai_message = output + + # Record token usage from the last completed LLM call + if last_ai_message is not None: + record_llm_token_usage(llm_span, {"messages": [last_ai_message]}) + + # Build final text from all accumulated streaming tokens + final_text = "".join(accumulated).strip() or "(no response generated)" + logger.info("A2A execute: response length=%d chars", len(final_text)) + + # ── OA-02 / OA-06: Output PII redaction ────────────────────── + if _COMPLIANCE_AVAILABLE and _compliance_cfg and _compliance_cfg.mode == "owasp_agentic": + final_text, _pii_types = _redact_pii(final_text) + if _pii_types: + from builtin_tools.audit import log_event as _audit_log + _audit_log( + event_type="compliance", + action="pii.redact", + resource="task_output", + outcome="redacted", + pii_types=_pii_types, + context_id=context_id, + ) + + # ── OTEL: task_complete span ───────────────────────────────── + with tracer.start_as_current_span("task_complete") as done_span: + done_span.set_attribute(WORKSPACE_ID_ATTR, _WORKSPACE_ID) + done_span.set_attribute(A2A_TASK_ID, context_id) + done_span.set_attribute("task.has_response", bool(accumulated)) + done_span.set_attribute("task.response_length", len(final_text)) + + # ── Step 3: emit final Message ──────────────────────────────── + # Non-streaming: ResultAggregator.consume_all() returns this + # immediately as the response (a2a_client.py reads .parts[0].text). + # Streaming: yielded as the last SSE event in the stream. + await event_queue.enqueue_event( + new_agent_text_message(final_text, task_id=task_id, context_id=context_id) + ) + _result = final_text + + except Exception as e: + logger.error("A2A execute error: %s", e, exc_info=True) + try: + task_span.record_exception(e) + from opentelemetry.trace import StatusCode + task_span.set_status(StatusCode.ERROR, str(e)) + except Exception: + pass + # Emit a Message so both streaming and non-streaming clients + # receive an error response rather than hanging. + await event_queue.enqueue_event( + new_agent_text_message( + f"Agent error: {e}", task_id=task_id, context_id=context_id + ) + ) + finally: + await set_current_task(self._heartbeat, "") + + return _result + + async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None: + """Cancel a running task — emits canceled state to comply with A2A protocol.""" + from a2a.types import TaskStatus, TaskState, TaskStatusUpdateEvent + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + status=TaskStatus(state=TaskState.canceled), + final=True, + ) + ) diff --git a/molecule_runtime/a2a_mcp_server.py b/molecule_runtime/a2a_mcp_server.py new file mode 100644 index 0000000..29ca254 --- /dev/null +++ b/molecule_runtime/a2a_mcp_server.py @@ -0,0 +1,293 @@ +#!/usr/bin/env python3 +"""A2A MCP Server — runs inside each workspace container. + +Exposes A2A delegation, peer discovery, and workspace info as MCP tools +so CLI-based runtimes (Claude Code, Codex) can communicate with other workspaces. + +Launched automatically by main.py for CLI runtimes. Runs on stdio transport +and is configured as a local MCP server for the claude --print invocation. + +Environment variables (set by the workspace container): + WORKSPACE_ID — this workspace's ID + PLATFORM_URL — platform API base URL (e.g. http://platform:8080) +""" + +import asyncio +import json +import logging +import sys + +from a2a_tools import ( + tool_check_task_status, + tool_commit_memory, + tool_delegate_task, + tool_delegate_task_async, + tool_get_workspace_info, + tool_list_peers, + tool_recall_memory, + tool_send_message_to_user, +) + +logger = logging.getLogger(__name__) + +# Re-export constants and client functions so existing imports +# (e.g. tests that do `import a2a_mcp_server`) still work. +from a2a_client import ( # noqa: F401, E402 + PLATFORM_URL, + WORKSPACE_ID, + _A2A_ERROR_PREFIX, + _peer_names, + discover_peer, + get_peers, + get_workspace_info, + send_a2a_message, +) +from a2a_tools import report_activity # noqa: F401, E402 + +# --- Tool definitions (schemas) --- + +TOOLS = [ + { + "name": "delegate_task", + "description": "Delegate a task to another workspace via A2A protocol and WAIT for the response. Use for quick tasks. The target must be a peer (sibling or parent/child). Use list_peers to find available targets.", + "inputSchema": { + "type": "object", + "properties": { + "workspace_id": { + "type": "string", + "description": "Target workspace ID (from list_peers)", + }, + "task": { + "type": "string", + "description": "The task description to send to the target workspace", + }, + }, + "required": ["workspace_id", "task"], + }, + }, + { + "name": "delegate_task_async", + "description": "Send a task to another workspace with a short timeout (fire-and-forget). Returns immediately — the target continues processing. Best when you don't need the result right away. Note: check_task_status may not work with all workspace implementations.", + "inputSchema": { + "type": "object", + "properties": { + "workspace_id": { + "type": "string", + "description": "Target workspace ID (from list_peers)", + }, + "task": { + "type": "string", + "description": "The task description to send to the target workspace", + }, + }, + "required": ["workspace_id", "task"], + }, + }, + { + "name": "check_task_status", + "description": "Check the status of a previously submitted async task via tasks/get. Note: only works if the target workspace's A2A implementation supports task persistence. May return 'not found' for completed tasks.", + "inputSchema": { + "type": "object", + "properties": { + "workspace_id": { + "type": "string", + "description": "The workspace ID the task was sent to", + }, + "task_id": { + "type": "string", + "description": "The task_id returned by delegate_task_async", + }, + }, + "required": ["workspace_id", "task_id"], + }, + }, + { + "name": "list_peers", + "description": "List all workspaces this agent can communicate with (siblings and parent/children). Returns name, ID, status, and role for each peer.", + "inputSchema": {"type": "object", "properties": {}}, + }, + { + "name": "get_workspace_info", + "description": "Get this workspace's own info — ID, name, role, tier, parent, status.", + "inputSchema": {"type": "object", "properties": {}}, + }, + { + "name": "send_message_to_user", + "description": "Send a message directly to the user's canvas chat — pushed instantly via WebSocket. Use this to: (1) acknowledge a task immediately ('Got it, I'll start working on this'), (2) send interim progress updates while doing long work, (3) deliver follow-up results after delegation completes. The message appears in the user's chat as if you're proactively reaching out.", + "inputSchema": { + "type": "object", + "properties": { + "message": { + "type": "string", + "description": "The message to send to the user", + }, + }, + "required": ["message"], + }, + }, + { + "name": "commit_memory", + "description": "Save important information to persistent memory. Use this to remember decisions, conversation context, task results, and anything that should survive a restart. Scope: LOCAL (this workspace only), TEAM (parent + siblings), GLOBAL (entire org).", + "inputSchema": { + "type": "object", + "properties": { + "content": { + "type": "string", + "description": "The information to remember — be detailed and specific", + }, + "scope": { + "type": "string", + "enum": ["LOCAL", "TEAM", "GLOBAL"], + "description": "Memory scope (default: LOCAL)", + }, + }, + "required": ["content"], + }, + }, + { + "name": "recall_memory", + "description": "Search persistent memory for previously saved information. Returns all matching memories. Use this at the start of conversations to recall prior context.", + "inputSchema": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search query (empty returns all memories)", + }, + "scope": { + "type": "string", + "enum": ["LOCAL", "TEAM", "GLOBAL", ""], + "description": "Filter by scope (empty returns all accessible)", + }, + }, + }, + }, +] + + +# --- Tool dispatch --- + +async def handle_tool_call(name: str, arguments: dict) -> str: + """Handle a tool call and return the result as text.""" + if name == "delegate_task": + return await tool_delegate_task( + arguments.get("workspace_id", ""), + arguments.get("task", ""), + ) + elif name == "delegate_task_async": + return await tool_delegate_task_async( + arguments.get("workspace_id", ""), + arguments.get("task", ""), + ) + elif name == "check_task_status": + return await tool_check_task_status( + arguments.get("workspace_id", ""), + arguments.get("task_id", ""), + ) + elif name == "send_message_to_user": + return await tool_send_message_to_user(arguments.get("message", "")) + elif name == "list_peers": + return await tool_list_peers() + elif name == "get_workspace_info": + return await tool_get_workspace_info() + elif name == "commit_memory": + return await tool_commit_memory( + arguments.get("content", ""), + arguments.get("scope", "LOCAL"), + ) + elif name == "recall_memory": + return await tool_recall_memory( + arguments.get("query", ""), + arguments.get("scope", ""), + ) + return f"Unknown tool: {name}" + + +# --- MCP Server (JSON-RPC over stdio) --- + +async def main(): # pragma: no cover + """Run MCP server on stdio — reads JSON-RPC requests, writes responses.""" + reader = asyncio.StreamReader() + protocol = asyncio.StreamReaderProtocol(reader) + await asyncio.get_event_loop().connect_read_pipe(lambda: protocol, sys.stdin) + + writer_transport, writer_protocol = await asyncio.get_event_loop().connect_write_pipe( + asyncio.streams.FlowControlMixin, sys.stdout + ) + writer = asyncio.StreamWriter(writer_transport, writer_protocol, None, asyncio.get_event_loop()) + + async def write_response(response: dict): + data = json.dumps(response) + "\n" + writer.write(data.encode()) + await writer.drain() + + buffer = "" + while True: + try: + chunk = await reader.read(65536) + if not chunk: + break + buffer += chunk.decode(errors="replace") + + while "\n" in buffer: + line, buffer = buffer.split("\n", 1) + line = line.strip() + if not line: + continue + + try: + request = json.loads(line) + except json.JSONDecodeError: + continue + + req_id = request.get("id") + method = request.get("method", "") + + if method == "initialize": + await write_response({ + "jsonrpc": "2.0", + "id": req_id, + "result": { + "protocolVersion": "2024-11-05", + "capabilities": {"tools": {"listChanged": False}}, + "serverInfo": {"name": "a2a-delegation", "version": "1.0.0"}, + }, + }) + + elif method == "notifications/initialized": + pass # No response needed + + elif method == "tools/list": + await write_response({ + "jsonrpc": "2.0", + "id": req_id, + "result": {"tools": TOOLS}, + }) + + elif method == "tools/call": + params = request.get("params", {}) + tool_name = params.get("name", "") + tool_args = params.get("arguments", {}) + result_text = await handle_tool_call(tool_name, tool_args) + await write_response({ + "jsonrpc": "2.0", + "id": req_id, + "result": { + "content": [{"type": "text", "text": result_text}], + }, + }) + + else: + await write_response({ + "jsonrpc": "2.0", + "id": req_id, + "error": {"code": -32601, "message": f"Method not found: {method}"}, + }) + + except Exception as e: + logger.error(f"MCP server error: {e}") + break + + +if __name__ == "__main__": # pragma: no cover + asyncio.run(main()) diff --git a/molecule_runtime/a2a_tools.py b/molecule_runtime/a2a_tools.py new file mode 100644 index 0000000..6ba37a0 --- /dev/null +++ b/molecule_runtime/a2a_tools.py @@ -0,0 +1,269 @@ +"""A2A MCP tool implementations — the body of each tool handler. + +Imports shared client functions and constants from a2a_client. +""" + +import json +import uuid + +import httpx + +from a2a_client import ( + PLATFORM_URL, + WORKSPACE_ID, + _A2A_ERROR_PREFIX, + _peer_names, + discover_peer, + get_peers, + get_workspace_info, + send_a2a_message, +) + + +def _auth_headers_for_heartbeat() -> dict[str, str]: + """Return Phase 30.1 auth headers; tolerate platform_auth being absent + in older installs (e.g. during rolling upgrade).""" + try: + from platform_auth import auth_headers + return auth_headers() + except Exception: + return {} + + +async def report_activity( + activity_type: str, target_id: str = "", summary: str = "", status: str = "ok", + task_text: str = "", response_text: str = "", +): + """Report activity to the platform for live progress tracking.""" + try: + async with httpx.AsyncClient(timeout=5.0) as client: + payload: dict = { + "activity_type": activity_type, + "source_id": WORKSPACE_ID, + "target_id": target_id, + "method": "message/send", + "summary": summary, + "status": status, + } + if task_text: + payload["request_body"] = {"task": task_text} + if response_text: + payload["response_body"] = {"result": response_text} + await client.post( + f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/activity", + json=payload, + headers=_auth_headers_for_heartbeat(), + ) + # Also push current_task via heartbeat for canvas card display + if summary: + await client.post( + f"{PLATFORM_URL}/registry/heartbeat", + json={ + "workspace_id": WORKSPACE_ID, + "current_task": summary, + "active_tasks": 1, + "error_rate": 0, + "sample_error": "", + "uptime_seconds": 0, + }, + headers=_auth_headers_for_heartbeat(), + ) + except Exception: + pass # Best-effort — don't block delegation on activity reporting + + +async def tool_delegate_task(workspace_id: str, task: str) -> str: + """Delegate a task to another workspace via A2A (synchronous — waits for response).""" + if not workspace_id or not task: + return "Error: workspace_id and task are required" + + # Discover the target + peer = await discover_peer(workspace_id) + if not peer: + return f"Error: workspace {workspace_id} not found or not accessible (check access control)" + + target_url = peer.get("url", "") + if not target_url: + return f"Error: workspace {workspace_id} has no URL (may be offline)" + + # Report delegation start — include the task text for traceability + peer_name = peer.get("name") or _peer_names.get(workspace_id) or workspace_id[:8] + _peer_names[workspace_id] = peer_name # cache for future use + # Brief summary for canvas display — just the delegation target + await report_activity("a2a_send", workspace_id, f"Delegating to {peer_name}", task_text=task) + + # Send A2A message and log the full round-trip + result = await send_a2a_message(target_url, task) + + # Detect delegation failures — wrap them clearly so the calling agent + # can decide to retry, use another peer, or handle the task itself. + is_error = result.startswith(_A2A_ERROR_PREFIX) + await report_activity( + "a2a_receive", workspace_id, + f"{peer_name} responded ({len(result)} chars)" if not is_error else f"{peer_name} failed", + task_text=task, response_text=result, + status="error" if is_error else "ok", + ) + if is_error: + return ( + f"DELEGATION FAILED to {peer_name}: {result}\n" + f"You should either: (1) try a different peer, (2) handle this task yourself, " + f"or (3) inform the user that {peer_name} is unavailable and provide your best answer." + ) + return result + + +async def tool_delegate_task_async(workspace_id: str, task: str) -> str: + """Delegate a task via the platform's async delegation API (fire-and-forget). + + Uses POST /workspaces/:id/delegate which runs the A2A request in the background. + Results are tracked in the platform DB and broadcast via WebSocket. + Use check_task_status to poll for results. + """ + if not workspace_id or not task: + return "Error: workspace_id and task are required" + + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.post( + f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/delegate", + json={"target_id": workspace_id, "task": task}, + headers=_auth_headers_for_heartbeat(), + ) + if resp.status_code == 202: + data = resp.json() + return json.dumps({ + "delegation_id": data.get("delegation_id", ""), + "workspace_id": workspace_id, + "status": "delegated", + "note": "Task delegated. The platform runs it in the background. Use check_task_status to poll for results.", + }) + else: + return f"Error: delegation failed with status {resp.status_code}: {resp.text[:200]}" + except Exception as e: + return f"Error: delegation failed — {e}" + + +async def tool_check_task_status(workspace_id: str, task_id: str) -> str: + """Check delegations for this workspace via the platform API. + + Args: + workspace_id: Ignored (kept for backward compat). Checks this workspace's delegations. + task_id: Optional delegation_id to filter. If empty, returns all recent delegations. + """ + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.get( + f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/delegations", + headers=_auth_headers_for_heartbeat(), + ) + if resp.status_code != 200: + return f"Error: failed to check delegations ({resp.status_code})" + delegations = resp.json() + if task_id: + # Filter by delegation_id + matching = [d for d in delegations if d.get("delegation_id") == task_id] + if matching: + return json.dumps(matching[0]) + return json.dumps({"status": "not_found", "delegation_id": task_id}) + # Return all recent delegations + summary = [] + for d in delegations[:10]: + summary.append({ + "delegation_id": d.get("delegation_id", ""), + "target_id": d.get("target_id", ""), + "status": d.get("status", ""), + "summary": d.get("summary", ""), + "response_preview": d.get("response_preview", ""), + }) + return json.dumps({"delegations": summary, "count": len(delegations)}) + except Exception as e: + return f"Error checking delegations: {e}" + + +async def tool_send_message_to_user(message: str) -> str: + """Send a message directly to the user's canvas chat via WebSocket.""" + if not message: + return "Error: message is required" + try: + async with httpx.AsyncClient(timeout=5.0) as client: + resp = await client.post( + f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/notify", + json={"message": message}, + headers=_auth_headers_for_heartbeat(), + ) + if resp.status_code == 200: + return "Message sent to user" + return f"Error: platform returned {resp.status_code}" + except Exception as e: + return f"Error sending message: {e}" + + +async def tool_list_peers() -> str: + """List all workspaces this agent can communicate with.""" + peers = await get_peers() + if not peers: + return "No peers available (this workspace may be isolated)" + lines = [] + for p in peers: + status = p.get("status", "unknown") + role = p.get("role", "") + # Cache name for use in delegate_task + _peer_names[p["id"]] = p["name"] + lines.append(f"- {p['name']} (ID: {p['id']}, status: {status}, role: {role})") + return "\n".join(lines) + + +async def tool_get_workspace_info() -> str: + """Get this workspace's own info.""" + info = await get_workspace_info() + return json.dumps(info, indent=2) + + +async def tool_commit_memory(content: str, scope: str = "LOCAL") -> str: + """Save important information to persistent memory.""" + if not content: + return "Error: content is required" + scope = scope.upper() + if scope not in ("LOCAL", "TEAM", "GLOBAL"): + scope = "LOCAL" + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.post( + f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/memories", + json={"content": content, "scope": scope}, + headers=_auth_headers_for_heartbeat(), + ) + data = resp.json() + if resp.status_code in (200, 201): + return json.dumps({"success": True, "id": data.get("id"), "scope": scope}) + return f"Error: {data.get('error', resp.text)}" + except Exception as e: + return f"Error saving memory: {e}" + + +async def tool_recall_memory(query: str = "", scope: str = "") -> str: + """Search persistent memory for previously saved information.""" + params = {} + if query: + params["q"] = query + if scope: + params["scope"] = scope.upper() + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.get( + f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/memories", + params=params, + headers=_auth_headers_for_heartbeat(), + ) + data = resp.json() + if isinstance(data, list): + if not data: + return "No memories found." + lines = [] + for m in data: + lines.append(f"[{m.get('scope', '?')}] {m.get('content', '')}") + return "\n".join(lines) + return json.dumps(data) + except Exception as e: + return f"Error recalling memory: {e}" diff --git a/molecule_runtime/adapters/__init__.py b/molecule_runtime/adapters/__init__.py new file mode 100644 index 0000000..e0ea570 --- /dev/null +++ b/molecule_runtime/adapters/__init__.py @@ -0,0 +1,86 @@ +"""Adapter registry — discovers and loads agent infrastructure adapters.""" + +import importlib +import logging +import os +from .base import BaseAdapter, AdapterConfig, SetupResult + +logger = logging.getLogger(__name__) + +_ADAPTER_CACHE: dict[str, type[BaseAdapter]] = {} + + +def discover_adapters() -> dict[str, type[BaseAdapter]]: + """Scan subdirectories for adapter modules. Each must export an Adapter class. + + This is used for local development inside the monorepo where adapters + live as subdirectories. In standalone adapter repos, use ADAPTER_MODULE + env var instead. + """ + if _ADAPTER_CACHE: + return _ADAPTER_CACHE + + from pathlib import Path + adapters_dir = Path(__file__).parent + for entry in sorted(adapters_dir.iterdir()): + if not entry.is_dir() or entry.name.startswith("_"): + continue + try: + mod = importlib.import_module(f"molecule_runtime.adapters.{entry.name}") + adapter_cls = getattr(mod, "Adapter", None) + if adapter_cls and issubclass(adapter_cls, BaseAdapter): + _ADAPTER_CACHE[adapter_cls.name()] = adapter_cls + logger.debug(f"Loaded adapter: {adapter_cls.name()} ({adapter_cls.display_name()})") + except Exception as e: + # Log but don't crash — adapter may have uninstalled deps + logger.debug(f"Skipped adapter {entry.name}: {e}") + + return _ADAPTER_CACHE + + +def get_adapter(runtime: str) -> type[BaseAdapter]: + """Get adapter class by runtime name. + + Resolution order: + 1. ADAPTER_MODULE env var — used by standalone adapter repos to register + their adapter without modifying the runtime package. + 2. Built-in discovery — scans subdirectories (for local monorepo dev). + + Raises KeyError if the adapter cannot be found. + """ + # First check env override (standalone adapter repos set this) + adapter_module = os.environ.get("ADAPTER_MODULE") + if adapter_module: + try: + mod = importlib.import_module(adapter_module) + cls = getattr(mod, "Adapter") + if cls and issubclass(cls, BaseAdapter): + return cls + except Exception as e: + raise KeyError( + f"ADAPTER_MODULE={adapter_module!r} could not be loaded: {e}" + ) from e + + # Fall back to built-in discovery (for local dev / monorepo) + adapters = discover_adapters() + if runtime not in adapters: + available = ", ".join(sorted(adapters.keys())) + raise KeyError(f"Unknown runtime '{runtime}'. Available: {available}") + return adapters[runtime] + + +def list_adapters() -> list[dict]: + """Return metadata for all discovered adapters (for API/UI).""" + adapters = discover_adapters() + return [ + { + "name": cls.name(), + "display_name": cls.display_name(), + "description": cls.description(), + "config_schema": cls.get_config_schema(), + } + for cls in adapters.values() + ] + + +__all__ = ["BaseAdapter", "AdapterConfig", "SetupResult", "get_adapter", "list_adapters", "discover_adapters"] diff --git a/molecule_runtime/adapters/base.py b/molecule_runtime/adapters/base.py new file mode 100644 index 0000000..a1820e7 --- /dev/null +++ b/molecule_runtime/adapters/base.py @@ -0,0 +1,309 @@ +"""Base adapter interface for agent infrastructure providers.""" + +import logging +import os +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any + +from a2a.server.agent_execution import AgentExecutor + +logger = logging.getLogger(__name__) + + +@dataclass +class SetupResult: + """Result from the shared _common_setup() pipeline.""" + system_prompt: str + loaded_skills: list # LoadedSkill instances + langchain_tools: list # LangChain BaseTool instances + is_coordinator: bool + children: list # child workspace dicts + + +@dataclass +class AdapterConfig: + """Standardized config passed to every adapter.""" + model: str # e.g. "anthropic:claude-sonnet-4-6" or "openrouter:google/gemini-2.5-flash" + system_prompt: str | None = None # Assembled system prompt text + tools: list[str] = field(default_factory=list) # Tool names from config.yaml + runtime_config: dict[str, Any] = field(default_factory=dict) # Raw runtime_config block + config_path: str = "/configs" # Path to configs directory + workspace_id: str = "" # Workspace identifier + prompt_files: list[str] = field(default_factory=list) # Ordered prompt file names + a2a_port: int = 8000 # Port for A2A server + heartbeat: Any = None # HeartbeatLoop instance + + +class BaseAdapter(ABC): + """Interface every agent infrastructure adapter must implement. + + To add a new agent infra: + 1. Create workspace-template/adapters// + 2. Implement adapter.py with a class extending BaseAdapter + 3. Add requirements.txt with your infra's dependencies + 4. Export as Adapter in __init__.py + 5. Submit a PR + """ + + @staticmethod + @abstractmethod + def name() -> str: # pragma: no cover + """Return the runtime identifier (e.g. 'langgraph', 'crewai'). + This must match the 'runtime' field in config.yaml.""" + ... + + @staticmethod + @abstractmethod + def display_name() -> str: # pragma: no cover + """Human-readable name for UI display.""" + ... + + @staticmethod + @abstractmethod + def description() -> str: # pragma: no cover + """Short description of what this adapter provides.""" + ... + + @staticmethod + def get_config_schema() -> dict: + """Return JSON Schema for runtime_config fields this adapter supports. + Used by the Config tab UI to render the right form fields. + Override in subclasses for adapter-specific settings.""" + return {} + + # ------------------------------------------------------------------ + # Plugin install hooks + # ------------------------------------------------------------------ + # New pipeline: each plugin ships per-runtime adaptors resolved via + # `plugins_registry.resolve()`. Adapters expose hooks below that + # adaptors call to wire plugin content into the runtime. + # + # Default implementations are filesystem-only (write to /configs, + # append to CLAUDE.md). Runtimes with a dynamic tool registry + # (e.g. DeepAgents sub-agents) override the hooks to also register + # in-process state. + + def memory_filename(self) -> str: + """File under /configs that the runtime treats as long-lived memory. + + Both Claude Code and DeepAgents read CLAUDE.md natively, so this is + the sensible default. Override only if a runtime expects a different + filename. + """ + return "CLAUDE.md" + + def register_tool_hook(self, name: str, fn) -> None: + """Default no-op. Override on runtimes with a dynamic tool registry. + + Runtimes that pick tools up at startup via filesystem scan (Claude + Code reads /configs/skills, LangGraph globs **/*.py) don't need to + do anything here — the adaptor's file-write step is enough. + """ + return None + + async def transcript_lines(self, since: int = 0, limit: int = 100) -> dict: + """Return live transcript entries for the most-recent agent session. + + Default implementation returns ``supported: False`` for runtimes + that don't expose a per-session log on disk. Override in subclasses + that DO (Claude Code reads ``~/.claude/projects//.jsonl``). + + This is the "look over the agent's shoulder" feature — lets canvas / + operators see live tool calls + AI thinking instead of waiting for + the high-level activity log to flush. + + Args: + since: line offset to skip — caller's last cursor (0 = from start) + limit: max lines to return (caller-side cap, default 100, max 1000) + + Returns: + ``{runtime, supported, lines, cursor, more, source}`` where + ``cursor`` is the new offset to pass on the next poll, ``more`` + is True if additional lines remain past ``limit``, and ``source`` + is the file path lines were read from (useful for debugging). + """ + return { + "runtime": self.name(), + "supported": False, + "lines": [], + "cursor": since, + "more": False, + "source": None, + } + + def register_subagent_hook(self, name: str, spec: dict) -> None: + """Default no-op. DeepAgents overrides to register a sub-agent.""" + return None + + def append_to_memory_hook(self, config: AdapterConfig, filename: str, content: str) -> None: + """Append text to /configs/ if the marker isn't already present. + + Idempotent: looks for the first line of `content` as a marker so a + re-install doesn't duplicate the block. Adaptors should pass content + beginning with a unique header (e.g. ``# Plugin: molecule-dev-conventions``). + """ + import os + target = os.path.join(config.config_path, filename) + marker = content.splitlines()[0].strip() if content else "" + existing = "" + if os.path.exists(target): + with open(target) as f: + existing = f.read() + if marker and marker in existing: + logger.info("append_to_memory: %s already contains %r — skipping", filename, marker) + return + os.makedirs(os.path.dirname(target) or ".", exist_ok=True) + with open(target, "a") as f: + if existing and not existing.endswith("\n"): + f.write("\n") + f.write(content if content.endswith("\n") else content + "\n") + logger.info("append_to_memory: appended %d chars to %s", len(content), filename) + + async def install_plugins_via_registry( + self, + config: AdapterConfig, + plugins, + ) -> list: + """Drive the new per-runtime adaptor pipeline for every loaded plugin. + + For each plugin in `plugins.plugins`, resolve the adaptor for this + runtime (via :func:`plugins_registry.resolve`) and invoke + ``install(ctx)``. Returns the list of :class:`InstallResult` so + callers can surface warnings (e.g. raw-drop fallback hits). + + Adapters whose runtime supports the new pipeline call this from + ``setup()`` instead of the legacy ``inject_plugins()``. + """ + from pathlib import Path + from plugins_registry import InstallContext, resolve + + results = [] + runtime = self.name().replace("-", "_") # e.g. "claude-code" -> "claude_code" + + for plugin in plugins.plugins: + adaptor, source = resolve(plugin.name, runtime, Path(plugin.path)) + ctx = InstallContext( + configs_dir=Path(config.config_path), + workspace_id=config.workspace_id, + runtime=runtime, + plugin_root=Path(plugin.path), + memory_filename=self.memory_filename(), + register_tool=self.register_tool_hook, + register_subagent=self.register_subagent_hook, + append_to_memory=lambda fn, c, _cfg=config: self.append_to_memory_hook(_cfg, fn, c), + ) + try: + result = await adaptor.install(ctx) + results.append(result) + logger.info( + "Plugin %s installed via %s adaptor (warnings: %d)", + plugin.name, source, len(result.warnings), + ) + except Exception as exc: + logger.exception("Plugin %s install via %s failed: %s", plugin.name, source, exc) + + return results + + async def inject_plugins(self, config: AdapterConfig, plugins) -> None: + """Legacy hook — kept for backwards compatibility during migration. + + Default: drive the new per-runtime adaptor pipeline. Adapters not yet + migrated may still override this with their own logic. + """ + await self.install_plugins_via_registry(config, plugins) + + async def _common_setup(self, config: AdapterConfig) -> SetupResult: + """Shared setup pipeline — loads plugins, skills, tools, coordinator, and builds system prompt. + + All adapters can call this to get the full platform feature set. + Returns a SetupResult with LangChain BaseTool instances that adapters + convert to their native format if needed. + """ + from plugins import load_plugins + from skill_loader.loader import load_skills + from coordinator import get_children, get_parent_context, build_children_description + from prompt import build_system_prompt, get_peer_capabilities + from builtin_tools.approval import request_approval + from builtin_tools.delegation import delegate_to_workspace, check_delegation_status + from builtin_tools.memory import commit_memory, search_memory + from builtin_tools.sandbox import run_code + + platform_url = os.environ.get("PLATFORM_URL", "http://platform:8080") + + # Load plugins from per-workspace dir first, then shared fallback + workspace_plugins_dir = os.path.join(config.config_path, "plugins") + plugins = load_plugins( + workspace_plugins_dir=workspace_plugins_dir, + shared_plugins_dir=os.environ.get("PLUGINS_DIR", "/plugins"), + ) + await self.inject_plugins(config, plugins) + if plugins.plugin_names: + logger.info(f"Plugins: {', '.join(plugins.plugin_names)}") + + # Load skills (workspace + plugin skills, deduped) + loaded_skills = load_skills(config.config_path, config.tools) + seen_skill_ids = {s.metadata.id for s in loaded_skills} + for plugin_skills_dir in plugins.skill_dirs: + plugin_skill_names = [ + d for d in os.listdir(plugin_skills_dir) + if os.path.isdir(os.path.join(plugin_skills_dir, d)) + ] + for skill in load_skills(plugin_skills_dir, plugin_skill_names): + if skill.metadata.id not in seen_skill_ids: + loaded_skills.append(skill) + seen_skill_ids.add(skill.metadata.id) + logger.info(f"Loaded {len(loaded_skills)} skills: {[s.metadata.id for s in loaded_skills]}") + + # Assemble tools: 6 core + skill tools + all_tools = [delegate_to_workspace, check_delegation_status, request_approval, commit_memory, search_memory, run_code] + for skill in loaded_skills: + all_tools.extend(skill.tools) + + # Coordinator mode: detect children and add routing tool + children = await get_children() + is_coordinator = len(children) > 0 + if is_coordinator: + from coordinator import route_task_to_team + logger.info(f"Coordinator mode: {len(children)} children") + all_tools.append(route_task_to_team) + + # Parent context (if this is a child workspace) + parent_context = await get_parent_context() + + # Build system prompt with all context + peers = await get_peer_capabilities(platform_url, config.workspace_id) + coordinator_prompt = build_children_description(children) if is_coordinator else "" + extra_prompts = list(plugins.prompt_fragments) + if coordinator_prompt: + extra_prompts.append(coordinator_prompt) + + system_prompt = build_system_prompt( + config.config_path, config.workspace_id, loaded_skills, peers, + prompt_files=config.prompt_files, + plugin_rules=plugins.rules, + plugin_prompts=extra_prompts, + parent_context=parent_context, + ) + + return SetupResult( + system_prompt=system_prompt, + loaded_skills=loaded_skills, + langchain_tools=all_tools, + is_coordinator=is_coordinator, + children=children, + ) + + @abstractmethod + async def setup(self, config: AdapterConfig) -> None: + """One-time setup: validate config, prepare internal state. + Called after deps are installed but before create_executor(). + Raise RuntimeError if setup fails (missing deps, bad config, etc.).""" + ... # pragma: no cover + + @abstractmethod + async def create_executor(self, config: AdapterConfig) -> AgentExecutor: + """Create and return an AgentExecutor ready for A2A integration. + The returned executor's execute() method will be called by the + A2A server's DefaultRequestHandler.""" + ... # pragma: no cover diff --git a/molecule_runtime/adapters/shared_runtime.py b/molecule_runtime/adapters/shared_runtime.py new file mode 100644 index 0000000..a383866 --- /dev/null +++ b/molecule_runtime/adapters/shared_runtime.py @@ -0,0 +1,190 @@ +"""Shared runtime helpers for A2A-backed workspace executors.""" + +from __future__ import annotations + +from typing import Any + +from a2a.server.agent_execution import RequestContext + + +def _extract_part_text(part) -> str: + """Extract text from a message part, handling dicts and A2A objects.""" + if isinstance(part, dict): + text = part.get("text", "") + if text: + return text + root = part.get("root") + if isinstance(root, dict): + return root.get("text", "") + return "" + if hasattr(part, "text") and part.text: + return part.text + if hasattr(part, "root") and hasattr(part.root, "text") and part.root.text: + return part.root.text + return "" + + +def extract_message_text(context_or_parts) -> str: + """Extract concatenated plain text from A2A message parts.""" + parts = getattr(getattr(context_or_parts, "message", None), "parts", None) + if parts is None: + parts = context_or_parts + return " ".join( + text for part in (parts or []) if (text := _extract_part_text(part)) + ).strip() + + +def extract_history(context: RequestContext) -> list[tuple[str, str]]: + """Extract conversation history from A2A request metadata.""" + messages: list[tuple[str, str]] = [] + request = getattr(context, "request", None) + metadata = getattr(request, "metadata", None) if request else None + if not isinstance(metadata, dict): + metadata = getattr(context, "metadata", None) or {} + history = metadata.get("history", []) if isinstance(metadata, dict) else [] + if not isinstance(history, list): + return messages + + for entry in history: + if not isinstance(entry, dict): + continue + role = entry.get("role", "user") + parts = entry.get("parts", []) + text = " ".join( + text for part in (parts or []) if (text := _extract_part_text(part)) + ).strip() + if text: + mapped_role = "human" if role == "user" else "ai" + messages.append((mapped_role, text)) + return messages + + +def format_conversation_history(history: list[tuple[str, str]]) -> str: + """Render `(role, text)` history into a stable human-readable transcript.""" + return "\n".join( + f"{'User' if role == 'human' else 'Agent'}: {text}" for role, text in history + ) + + +def build_task_text(user_message: str, history: list[tuple[str, str]]) -> str: + """Build a single task/request string with optional prepended conversation history.""" + if not history: + return user_message + transcript = format_conversation_history(history) + return f"Conversation so far:\n{transcript}\n\nCurrent request: {user_message}" + + +def append_peer_guidance( + base_text: str | None, + peers_info: str, + *, + default_text: str, + tool_name: str, +) -> str: + """Append peer guidance text when peers are available.""" + text = (base_text or default_text).strip() + if peers_info: + text += f"\n\n## Peers\n{peers_info}\nUse {tool_name} to communicate with them." + return text + + +def summarize_peer_cards(peers: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Return compact peer metadata for prompt rendering.""" + summaries: list[dict[str, Any]] = [] + for peer in peers: + agent_card = peer.get("agent_card") + if not agent_card: + continue + if isinstance(agent_card, str): + try: + import json + + agent_card = json.loads(agent_card) + except Exception: + continue + if not isinstance(agent_card, dict): + continue + + skills = agent_card.get("skills", []) + summaries.append( + { + "id": peer.get("id", "unknown"), + "name": agent_card.get("name", peer.get("name", "Unknown")), + "status": peer.get("status", "unknown"), + "skills": [ + s.get("name", s.get("id", "")) + for s in skills + if isinstance(s, dict) + ], + } + ) + return summaries + + +def build_peer_section( + peers: list[dict[str, Any]], + *, + heading: str = "## Your Peers (workspaces you can delegate to)", + instruction: str = ( + "Use the `delegate_to_workspace` tool to send tasks to peers. " + "Only delegate to peers listed above." + ), +) -> str: + """Render a stable peer section for system prompts.""" + summaries = summarize_peer_cards(peers) + if not summaries: + return "" + + parts = [heading, ""] + for peer in summaries: + parts.append(f"- **{peer['name']}** (id: `{peer['id']}`, status: {peer['status']})") + if peer["skills"]: + parts.append(f" Skills: {', '.join(peer['skills'])}") + parts.append("") + parts.append(instruction) + return "\n".join(parts) + + +def brief_task(text: str, limit: int = 60) -> str: + """Create a short human-readable task label for the heartbeat banner.""" + return text[:limit] + ("..." if len(text) > limit else "") + + +async def set_current_task(heartbeat: Any, task: str) -> None: + """Update current task on heartbeat and push immediately to platform. + + The heartbeat loop only fires every 30s, so quick tasks would finish + before the canvas ever sees them. Setting a task pushes immediately. + Clearing a task only updates the heartbeat object — the next heartbeat + cycle will broadcast the clear, keeping the task visible longer. + """ + if heartbeat: + heartbeat.current_task = task + heartbeat.active_tasks = 1 if task else 0 + + # Only push immediately when SETTING a task (not clearing) + # Clearing is handled by the next heartbeat cycle, which keeps + # the task visible on the canvas for quick A2A responses + if not task: + return + + import os + workspace_id = os.environ.get("WORKSPACE_ID", "") + platform_url = os.environ.get("PLATFORM_URL", "") + if workspace_id and platform_url: + try: + import httpx + async with httpx.AsyncClient(timeout=3.0) as client: + await client.post( + f"{platform_url}/registry/heartbeat", + json={ + "workspace_id": workspace_id, + "current_task": task, + "active_tasks": 1, + "error_rate": 0, + "sample_error": "", + "uptime_seconds": 0, + }, + ) + except Exception: + pass # Best-effort diff --git a/molecule_runtime/agent.py b/molecule_runtime/agent.py new file mode 100644 index 0000000..d50403e --- /dev/null +++ b/molecule_runtime/agent.py @@ -0,0 +1,133 @@ +"""Create the Deep Agent with model + skills + tools.""" + +import os +import logging + +from langgraph.prebuilt import create_react_agent + +logger = logging.getLogger(__name__) + + +def create_agent(model_str: str, tools: list, system_prompt: str): + """Create a LangGraph ReAct agent. + + Args: + model_str: LangChain-compatible model string (e.g., 'anthropic:claude-sonnet-4-6') + tools: List of tool functions + system_prompt: The system prompt for the agent + """ + # Parse provider:model format + if ":" in model_str: + provider, model_name = model_str.split(":", 1) + else: + provider = "anthropic" + model_name = model_str + + # Import the provider package + try: + if provider in ("anthropic",): + from langchain_anthropic import ChatAnthropic as LLMClass + elif provider in ("openai", "openrouter", "groq", "cerebras", "qianfan"): + from langchain_openai import ChatOpenAI as LLMClass + elif provider == "google_genai": + from langchain_google_genai import ChatGoogleGenerativeAI as LLMClass + elif provider == "ollama": + from langchain_ollama import ChatOllama as LLMClass + else: + raise ValueError(f"Unsupported model provider: {provider}") + except ImportError as e: + pkg = "langchain-openai" if provider == "openrouter" else f"langchain-{provider}" + raise ImportError(f"Provider '{provider}' requires package '{pkg}'. Install: pip install {pkg}") from e + + # Instantiate the LLM + if provider == "anthropic": + llm_kwargs = {"model": model_name} + anthropic_base_url = os.environ.get("ANTHROPIC_BASE_URL", "") + if anthropic_base_url: + llm_kwargs["anthropic_api_url"] = anthropic_base_url + llm = LLMClass(**llm_kwargs) + elif provider == "openrouter": + api_key = os.environ.get("OPENROUTER_API_KEY", os.environ.get("OPENAI_API_KEY", "")) + max_tokens = int(os.environ.get("MAX_TOKENS", "2048")) + llm = LLMClass( + model=model_name, + openai_api_key=api_key, + openai_api_base="https://openrouter.ai/api/v1", + max_tokens=max_tokens, + ) + elif provider == "groq": + api_key = os.environ.get("GROQ_API_KEY", "") + llm = LLMClass( + model=model_name, + openai_api_key=api_key, + openai_api_base="https://api.groq.com/openai/v1", + ) + elif provider == "cerebras": + api_key = os.environ.get("CEREBRAS_API_KEY", "") + llm = LLMClass( + model=model_name, + openai_api_key=api_key, + openai_api_base="https://api.cerebras.ai/v1", + ) + elif provider == "qianfan": + api_key = os.environ.get("QIANFAN_API_KEY", os.environ.get("AISTUDIO_API_KEY", "")) + llm = LLMClass( + model=model_name, + openai_api_key=api_key, + openai_api_base="https://qianfan.baidubce.com/v2", + ) + elif provider == "openai": + llm_kwargs = {"model": model_name} + openai_base_url = os.environ.get("OPENAI_BASE_URL", "") + if openai_base_url: + llm_kwargs["openai_api_base"] = openai_base_url + llm = LLMClass(**llm_kwargs) + else: + llm = LLMClass(model=model_name) + + # Auto-inject Langfuse tracing if env vars are present + callbacks = _setup_langfuse() + if callbacks: + llm.callbacks = callbacks + + agent = create_react_agent( + model=llm, + tools=tools, + prompt=system_prompt, + ) + + return agent + + +def _setup_langfuse(): + """Set up Langfuse tracing if LANGFUSE_* env vars are present. + + Returns list of callbacks to pass to agent invocations, or empty list. + """ + langfuse_host = os.environ.get("LANGFUSE_HOST") + langfuse_public = os.environ.get("LANGFUSE_PUBLIC_KEY") + langfuse_secret = os.environ.get("LANGFUSE_SECRET_KEY") + + if not (langfuse_host and langfuse_public and langfuse_secret): + return [] + + try: + from langfuse.callback import CallbackHandler + + handler = CallbackHandler( + host=langfuse_host, + public_key=langfuse_public, + secret_key=langfuse_secret, + ) + logger.info("Langfuse tracing enabled: %s", langfuse_host) + + # Also set LANGSMITH_TRACING for LangGraph native integration + os.environ.setdefault("LANGSMITH_TRACING", "true") + + return [handler] + except ImportError: + logger.warning("Langfuse env vars set but langfuse package not installed") + return [] + except Exception as e: + logger.warning("Langfuse setup failed: %s", e) + return [] diff --git a/molecule_runtime/builtin_tools/__init__.py b/molecule_runtime/builtin_tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/molecule_runtime/builtin_tools/a2a_tools.py b/molecule_runtime/builtin_tools/a2a_tools.py new file mode 100644 index 0000000..ac05290 --- /dev/null +++ b/molecule_runtime/builtin_tools/a2a_tools.py @@ -0,0 +1,85 @@ +"""A2A communication tools — framework-agnostic delegation and peer discovery. + +These are plain async functions that any adapter can wrap in its native tool format. +The LangChain @tool versions are in tools/delegation.py. +""" + +import os +import uuid + +import httpx + +PLATFORM_URL = os.environ.get("PLATFORM_URL", "http://platform:8080") +WORKSPACE_ID = os.environ.get("WORKSPACE_ID", "") + + +async def list_peers() -> list[dict]: + """Get this workspace's peers from the platform registry.""" + async with httpx.AsyncClient(timeout=10.0) as client: + try: + resp = await client.get(f"{PLATFORM_URL}/registry/{WORKSPACE_ID}/peers") + if resp.status_code == 200: + return resp.json() + return [] + except Exception: + return [] + + +async def delegate_task(workspace_id: str, task: str) -> str: + """Send a task to a peer workspace via A2A and return the response text.""" + async with httpx.AsyncClient(timeout=120.0) as client: + # Discover target URL + try: + resp = await client.get( + f"{PLATFORM_URL}/registry/discover/{workspace_id}", + headers={"X-Workspace-ID": WORKSPACE_ID}, + ) + if resp.status_code != 200: + return f"Error: cannot reach workspace {workspace_id} (status {resp.status_code})" + target_url = resp.json().get("url", "") + if not target_url: + return f"Error: workspace {workspace_id} has no URL" + except Exception as e: + return f"Error discovering workspace: {e}" + + # Send A2A message + try: + a2a_resp = await client.post( + target_url, + json={ + "jsonrpc": "2.0", + "id": str(uuid.uuid4()), + "method": "message/send", + "params": { + "message": { + "role": "user", + "messageId": str(uuid.uuid4()), + "parts": [{"kind": "text", "text": task}], + }, + }, + }, + ) + data = a2a_resp.json() + if "result" in data: + parts = data["result"].get("parts", []) + return parts[0].get("text", "(no text)") if parts else str(data["result"]) + elif "error" in data: + return f"Error: {data['error'].get('message', str(data['error']))}" + return str(data) + except Exception as e: + return f"Error sending A2A message: {e}" + + +async def get_peers_summary() -> str: + """Return a formatted string of available peers for system prompts.""" + peers = await list_peers() + if not peers: + return "No peers available." + lines = [] + for p in peers: + name = p.get("name", "Unknown") + pid = p.get("id", "") + role = p.get("role", "") + status = p.get("status", "") + lines.append(f"- {name} (ID: {pid}) — {role} [{status}]") + return "Available peers:\n" + "\n".join(lines) diff --git a/molecule_runtime/builtin_tools/approval.py b/molecule_runtime/builtin_tools/approval.py new file mode 100644 index 0000000..fb2465b --- /dev/null +++ b/molecule_runtime/builtin_tools/approval.py @@ -0,0 +1,320 @@ +"""Approval tool for human-in-the-loop workflows. + +When an agent encounters a destructive, expensive, or unauthorized action, +it calls request_approval() which creates a request and waits for a decision. + +## Notification strategy + +By default this module uses a **WebSocket subscription** (APPROVAL_USE_WEBSOCKET=true +or when the ``websockets`` package is installed). The platform pushes an +``APPROVAL_DECIDED`` event to the workspace WebSocket as soon as a human +clicks Approve / Deny on the canvas — no polling required, instant delivery. + +If WebSocket is unavailable (env var opt-out or import error) the module +falls back to a **polling loop** so existing deployments without WebSocket +support continue to work without any config change. + +RBAC enforcement +---------------- +The calling workspace must hold a role that grants the ``"approve"`` action. +Roles are read from ``config.yaml`` under ``rbac.roles`` (default: operator). + +Audit trail +----------- +Every approval lifecycle emits structured JSON Lines records: + + 1. ``approval / approve / requested`` — request submitted to platform + 2. ``approval / approve / granted`` — human approved (actor = decided_by) + 3. ``approval / approve / denied`` — human denied (actor = decided_by) + 4. ``approval / approve / timeout`` — no decision within APPROVAL_TIMEOUT + +RBAC denials emit an ``rbac / rbac.deny / denied`` event instead. + +Environment variables +--------------------- +PLATFORM_URL Platform base URL (default: http://platform:8080) +WORKSPACE_ID This workspace's ID (default: "") +APPROVAL_TIMEOUT Max wait in seconds (default: 300) +APPROVAL_POLL_INTERVAL Polling interval in seconds (default: 5, polling path only) +APPROVAL_USE_WEBSOCKET "true" to force WS, "false" + to force polling (default: auto-detect) +AUDIT_LOG_PATH Path for JSON Lines audit log (default: /var/log/molecule/audit.jsonl) +""" + +import asyncio +import json +import logging +import os +import uuid + +import httpx +from langchain_core.tools import tool + +from builtin_tools.audit import check_permission, get_workspace_roles, log_event + +logger = logging.getLogger(__name__) + +PLATFORM_URL = os.environ.get("PLATFORM_URL", "http://platform:8080") +WORKSPACE_ID = os.environ.get("WORKSPACE_ID", "") +APPROVAL_POLL_INTERVAL = float(os.environ.get("APPROVAL_POLL_INTERVAL", "5")) +APPROVAL_TIMEOUT = float(os.environ.get("APPROVAL_TIMEOUT", "300")) + +# Auto-detect WebSocket support; can be overridden with env var +_ws_env = os.environ.get("APPROVAL_USE_WEBSOCKET", "").lower() +if _ws_env == "false": + _USE_WEBSOCKET_DEFAULT = False +elif _ws_env == "true": + _USE_WEBSOCKET_DEFAULT = True +else: + try: + import websockets as _ws_probe # noqa: F401 + _USE_WEBSOCKET_DEFAULT = True + except ImportError: + _USE_WEBSOCKET_DEFAULT = False + +# Module-level reference so tests can monkeypatch it +try: + import websockets +except ImportError: + websockets = None # type: ignore[assignment] + +# Expose for test introspection +APPROVAL_USE_WEBSOCKET = _USE_WEBSOCKET_DEFAULT + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + +async def _create_approval_request(action: str, reason: str) -> dict: + """POST to the platform to create an approval request. + + Returns {"approval_id": str} on success or {"error": str} on failure. + """ + async with httpx.AsyncClient(timeout=10.0) as client: + try: + resp = await client.post( + f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/approvals", + json={"action": action, "reason": reason}, + ) + if resp.status_code != 201: + return {"error": f"Failed to create request: {resp.status_code}"} + try: + approval_id = resp.json().get("approval_id") + except (ValueError, Exception): + return {"error": f"Platform returned invalid JSON (status {resp.status_code})"} + logger.info("Approval requested: %s (id=%s)", action, approval_id) + return {"approval_id": approval_id} + except Exception as e: + return {"error": f"Failed to request approval: {e}"} + + +async def _wait_websocket(approval_id: str, timeout: float) -> dict: + """Subscribe to the platform WebSocket and wait for APPROVAL_DECIDED event. + + Returns the decision dict or raises asyncio.TimeoutError on expiry. + """ + ws_url = ( + PLATFORM_URL.replace("http://", "ws://").replace("https://", "wss://") + + "/ws" + ) + headers = {"X-Workspace-ID": WORKSPACE_ID} + + logger.debug("Approval %s: waiting via WebSocket %s", approval_id, ws_url) + + async with websockets.connect(ws_url, additional_headers=headers) as ws: + async for raw_message in ws: + try: + event = json.loads(raw_message) + except json.JSONDecodeError: + continue + + if event.get("event") != "APPROVAL_DECIDED": + continue + if event.get("approval_id") != approval_id: + continue + + status = event.get("status") + decided_by = event.get("decided_by", "") + logger.info("Approval %s decided via WebSocket: %s by %s", + approval_id, status, decided_by) + + if status == "approved": + return { + "approved": True, + "approval_id": approval_id, + "decided_by": decided_by, + } + else: + return { + "approved": False, + "approval_id": approval_id, + "decided_by": decided_by, + "message": "Denied by human", + } + + +async def _wait_polling(approval_id: str, timeout: float) -> dict: + """Legacy polling loop — checks platform REST endpoint every APPROVAL_POLL_INTERVAL seconds.""" + elapsed = 0.0 + async with httpx.AsyncClient(timeout=10.0) as client: + while elapsed < timeout: + await asyncio.sleep(APPROVAL_POLL_INTERVAL) + elapsed += APPROVAL_POLL_INTERVAL + try: + resp = await client.get( + f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/approvals", + ) + if resp.status_code == 200: + for a in resp.json(): + if a.get("id") == approval_id: + status = a.get("status") + if status == "approved": + logger.info("Approval granted (poll): %s", approval_id) + return { + "approved": True, + "approval_id": approval_id, + "decided_by": a.get("decided_by"), + } + elif status == "denied": + logger.info("Approval denied (poll): %s", approval_id) + return { + "approved": False, + "approval_id": approval_id, + "decided_by": a.get("decided_by"), + "message": "Denied by human", + } + except Exception: + pass # transient error — keep retrying + + raise asyncio.TimeoutError() + + +# --------------------------------------------------------------------------- +# Public tool +# --------------------------------------------------------------------------- + +@tool +async def request_approval( + action: str, + reason: str, +) -> dict: + """Request human approval before proceeding with a sensitive action. + + Use this when you're about to do something destructive, expensive, + or outside your normal authority. The request is sent to the canvas + where a human can approve or deny it. + + Args: + action: Short description of what you want to do + reason: Why this action is necessary + """ + # One trace_id links every audit event for this approval lifecycle. + trace_id = str(uuid.uuid4()) + + # --- RBAC check ----------------------------------------------------------- + roles, custom_perms = get_workspace_roles() + if not check_permission("approve", roles, custom_perms): + log_event( + event_type="rbac", + action="rbac.deny", + resource=action, + outcome="denied", + trace_id=trace_id, + attempted_action="approve", + roles=roles, + ) + return { + "approved": False, + "error": ( + "RBAC: this workspace does not have the 'approve' permission. " + f"Current roles: {roles}" + ), + } + + # Step 1: Create the approval request + creation = await _create_approval_request(action, reason) + if "error" in creation: + log_event( + event_type="approval", + action="approve", + resource=action, + outcome="failure", + trace_id=trace_id, + reason="submit_failed", + error=creation["error"], + ) + return {"approved": False, "error": creation["error"]} + + approval_id = creation["approval_id"] + log_event( + event_type="approval", + action="approve", + resource=action, + outcome="requested", + trace_id=trace_id, + approval_id=approval_id, + reason_text=reason, + ) + + timeout = float(os.environ.get("APPROVAL_TIMEOUT", str(APPROVAL_TIMEOUT))) + + # Step 2: Wait for decision — WebSocket preferred, polling as fallback + use_ws = APPROVAL_USE_WEBSOCKET and websockets is not None + + try: + if use_ws: + try: + result = await asyncio.wait_for( + _wait_websocket(approval_id, timeout), + timeout=timeout, + ) + except Exception as ws_err: + # WebSocket failed (connection error, etc.) — fall through to polling + logger.warning( + "WebSocket approval wait failed (%s), falling back to polling", + ws_err, + ) + result = await asyncio.wait_for( + _wait_polling(approval_id, timeout), + timeout=timeout + APPROVAL_POLL_INTERVAL, + ) + else: + # Polling path (primary when WS disabled) + result = await asyncio.wait_for( + _wait_polling(approval_id, timeout), + timeout=timeout + APPROVAL_POLL_INTERVAL, # slight grace period + ) + + # Log the human decision + decided_by = result.get("decided_by") + outcome = "granted" if result.get("approved") else "denied" + log_event( + event_type="approval", + action="approve", + resource=action, + outcome=outcome, + # Record the human identity as actor when available + actor=decided_by or WORKSPACE_ID, + trace_id=trace_id, + approval_id=approval_id, + decided_by=decided_by, + ) + return result + + except asyncio.TimeoutError: + logger.warning("Approval timed out after %.0fs: %s", timeout, approval_id) + log_event( + event_type="approval", + action="approve", + resource=action, + outcome="timeout", + trace_id=trace_id, + approval_id=approval_id, + timeout_seconds=timeout, + ) + return { + "approved": False, + "approval_id": approval_id, + "error": f"Timed out after {timeout}s waiting for human decision", + } diff --git a/molecule_runtime/builtin_tools/audit.py b/molecule_runtime/builtin_tools/audit.py new file mode 100644 index 0000000..7806cf2 --- /dev/null +++ b/molecule_runtime/builtin_tools/audit.py @@ -0,0 +1,274 @@ +"""Immutable append-only audit log for EU AI Act compliance. + +Fulfils Article 12 (record-keeping), Article 13 (transparency), and +Article 17 (quality-management system) requirements for high-risk AI systems. + +Log format: JSON Lines (one UTF-8 JSON object per line), suitable for direct +ingestion by any SIEM (Splunk, Elastic, Datadog, etc.). + +Required event fields +--------------------- +timestamp ISO 8601 UTC datetime with timezone offset +event_type Coarse category: "delegation", "approval", "memory", "rbac" +workspace_id Workspace that generated this event +actor Entity that triggered the action; defaults to workspace_id for + automated events, or the human identity for approval decisions +action Verb describing what was attempted: + delegate | approve | memory.read | memory.write | rbac.deny +resource Object of the action: target workspace ID, memory scope, + approval action string, etc. +outcome One of: allowed | denied | success | failure | timeout | + requested | granted +trace_id UUID v4 correlating related events across workspaces + +The log file is opened in append mode ("a") on every write — it is NEVER +truncated, rewritten, or deleted by this module. Rotate externally using +logrotate (with ``copytruncate`` disabled) or ship to a SIEM before rotating. + +Configuration +------------- +AUDIT_LOG_PATH env var — full path to the JSONL file + default: /var/log/molecule/audit.jsonl +""" + +from __future__ import annotations + +import functools +import json +import logging +import os +import threading +import uuid +from datetime import datetime, timezone +from pathlib import Path +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + pass # avoid circular import at runtime + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + +AUDIT_LOG_PATH: str = os.environ.get( + "AUDIT_LOG_PATH", "/var/log/molecule/audit.jsonl" +) +WORKSPACE_ID: str = os.environ.get("WORKSPACE_ID", "") + +# Protects the open() + write() sequence; prevents interleaved JSON lines +# when multiple async tasks run in the same event-loop thread. +_write_lock = threading.Lock() + + +# --------------------------------------------------------------------------- +# Built-in role → permitted-action mappings +# --------------------------------------------------------------------------- + +#: Maps each built-in role name to the set of actions it grants. +#: Custom roles can be added in config.yaml under ``rbac.allowed_actions``. +ROLE_PERMISSIONS: dict[str, set[str]] = { + # Full access — shortcircuits all other checks + "admin": {"delegate", "approve", "memory.read", "memory.write"}, + # Standard agent role + "operator": {"delegate", "approve", "memory.read", "memory.write"}, + # Read-only observer — no writes, no delegation, no approvals + "read-only": {"memory.read"}, + # Can approve and write memory, but cannot delegate + "no-delegation": {"approve", "memory.read", "memory.write"}, + # Can delegate and write memory, but cannot invoke approval gate + "no-approval": {"delegate", "memory.read", "memory.write"}, + # Memory reads only (useful for analytic sidecars) + "memory-readonly": {"memory.read"}, +} + + +# --------------------------------------------------------------------------- +# Config loader (lazy, cached per process) +# --------------------------------------------------------------------------- + +@functools.lru_cache(maxsize=1) +def _load_workspace_config(): + """Return the WorkspaceConfig or None if it cannot be loaded.""" + try: + from config import load_config # local import avoids circular deps + return load_config() + except Exception as exc: + logger.warning("audit: could not load workspace config for RBAC: %s", exc) + return None + + +def get_workspace_roles() -> tuple[list[str], dict[str, list[str]]]: + """Return ``(roles, custom_permissions)`` from the workspace config. + + Falls back to ``["operator"]`` / ``{}`` when the config is unavailable so + that agents remain functional in degraded environments. + """ + cfg = _load_workspace_config() + if cfg is None: + return ["operator"], {} + return list(cfg.rbac.roles), dict(cfg.rbac.allowed_actions) + + +# --------------------------------------------------------------------------- +# RBAC helpers +# --------------------------------------------------------------------------- + +def check_permission( + action: str, + roles: list[str], + custom_permissions: dict[str, list[str]] | None = None, +) -> bool: + """Return True if *any* of ``roles`` grants ``action``. + + Evaluation order + ~~~~~~~~~~~~~~~~ + 1. ``"admin"`` shortcircuits — always grants everything. + 2. Custom role definitions (from ``rbac.allowed_actions`` in config.yaml). + 3. Built-in :data:`ROLE_PERMISSIONS` table. + + When a role appears in *custom_permissions* its built-in definition is + **ignored** — the custom list is the complete permission set for that role. + + Args: + action: Action to authorise, e.g. ``"delegate"``. + roles: Roles assigned to the calling workspace. + custom_permissions: Optional ``{role: [action, ...]}`` mapping loaded + from ``WorkspaceConfig.rbac.allowed_actions``. + + Returns: + ``True`` if the action is permitted, ``False`` otherwise. + + Examples:: + + >>> check_permission("delegate", ["operator"]) + True + >>> check_permission("delegate", ["read-only"]) + False + >>> check_permission("deploy", ["developer"], {"developer": ["deploy"]}) + True + """ + for role in roles: + if role == "admin": + return True + if custom_permissions and role in custom_permissions: + # Custom entry is definitive for this role + if action in custom_permissions[role]: + return True + continue # Don't fall through to built-ins for custom roles + if role in ROLE_PERMISSIONS and action in ROLE_PERMISSIONS[role]: + return True + return False + + +# --------------------------------------------------------------------------- +# Public audit API +# --------------------------------------------------------------------------- + +def log_event( + event_type: str, + action: str, + resource: str, + outcome: str, + actor: str | None = None, + trace_id: str | None = None, + **extra: Any, +) -> str: + """Append one audit event to the immutable JSON Lines log. + + Args: + event_type: Coarse category — ``"delegation"``, ``"approval"``, + ``"memory"``, or ``"rbac"``. + action: Verb — ``"delegate"``, ``"approve"``, ``"memory.write"``, + ``"memory.read"``, ``"rbac.deny"``. + resource: Object of the action — target workspace ID, memory scope, + approval action string, etc. + outcome: Terminal state — one of ``"allowed"``, ``"denied"``, + ``"success"``, ``"failure"``, ``"timeout"``, + ``"requested"``, ``"granted"``. + actor: Identity that triggered the event. Defaults to + ``WORKSPACE_ID`` (the running workspace) for automated + events. Pass ``decided_by`` for human approval decisions. + trace_id: Caller-supplied UUID v4 for cross-event correlation. + A fresh UUID is generated when omitted. + **extra: Additional key-value pairs appended verbatim to the JSON + object (e.g. ``target_workspace_id``, ``memory_scope``, + ``attempt``). Built-in keys cannot be overridden. + + Returns: + The ``trace_id`` used for this event, enabling callers to chain + related events under a single correlation identifier. + + Example:: + + trace = log_event( + event_type="delegation", + action="delegate", + resource="billing-agent", + outcome="success", + target_workspace_id="billing-agent", + attempt=1, + ) + """ + if trace_id is None: + trace_id = str(uuid.uuid4()) + + event: dict[str, Any] = { + "timestamp": datetime.now(timezone.utc).isoformat(), + "event_type": event_type, + "workspace_id": WORKSPACE_ID, + "actor": actor if actor is not None else WORKSPACE_ID, + "action": action, + "resource": resource, + "outcome": outcome, + "trace_id": trace_id, + } + + # Merge extra fields — built-in keys are not overridable + for key, value in extra.items(): + if key not in event: + event[key] = value + + _write_event(event) + return trace_id + + +# --------------------------------------------------------------------------- +# Internal writer +# --------------------------------------------------------------------------- + +def _ensure_log_dir(path: str) -> None: + """Create the parent directory for *path* if it does not already exist.""" + Path(path).parent.mkdir(parents=True, exist_ok=True) + + +def _write_event(event: dict[str, Any]) -> None: + """Serialise *event* as a JSON line and fsync-append it to the log file. + + The write is atomic with respect to other threads in this process: the + lock ensures that no two JSON objects are interleaved on the same line. + + Failures are emitted to the standard Python logger at WARNING level but + are **never** re-raised — the application must not crash because audit + logging is temporarily unavailable (e.g. disk full, permission error). + In production, consider wiring an alert on WARNING messages from this + module so that missing audit records are detected quickly. + """ + try: + log_path = AUDIT_LOG_PATH + _ensure_log_dir(log_path) + line = json.dumps(event, default=str, ensure_ascii=False) + "\n" + with _write_lock: + with open(log_path, "a", encoding="utf-8") as fh: + fh.write(line) + fh.flush() + os.fsync(fh.fileno()) + except Exception as exc: # pylint: disable=broad-except + logger.warning( + "Audit log write failed — event NOT persisted " + "(trace_id=%s, action=%s): %s", + event.get("trace_id", "?"), + event.get("action", "?"), + exc, + ) diff --git a/molecule_runtime/builtin_tools/awareness_client.py b/molecule_runtime/builtin_tools/awareness_client.py new file mode 100644 index 0000000..696ce05 --- /dev/null +++ b/molecule_runtime/builtin_tools/awareness_client.py @@ -0,0 +1,122 @@ +"""Workspace-scoped awareness backend wrapper. + +The agent-facing memory tools keep their existing signatures and delegate +to this helper when workspace awareness is configured. +""" + +from __future__ import annotations + +import os +import sys +from types import SimpleNamespace +from typing import Any + +from policies.namespaces import resolve_awareness_namespace + +try: # pragma: no cover - optional runtime dependency in lightweight test envs + import httpx # type: ignore +except ImportError: # pragma: no cover + httpx = SimpleNamespace(AsyncClient=None) + + +DEFAULT_AWARENESS_TIMEOUT = 10.0 + + +def get_awareness_config() -> dict[str, str] | None: + """Return awareness connection settings if the workspace is configured.""" + base_url = os.environ.get("AWARENESS_URL", "").rstrip("/") + workspace_id = os.environ.get("WORKSPACE_ID", "") + configured_namespace = os.environ.get("AWARENESS_NAMESPACE", "") + if not base_url: + return None + if not workspace_id and not configured_namespace: + return None + namespace = resolve_awareness_namespace(workspace_id, configured_namespace) + return { + "base_url": base_url, + "namespace": namespace, + } + + +class AwarenessClient: + """Small HTTP client for workspace-scoped awareness memory operations.""" + + def __init__(self, base_url: str, namespace: str, timeout: float = DEFAULT_AWARENESS_TIMEOUT): + self.base_url = base_url.rstrip("/") + self.namespace = namespace + self.timeout = timeout + + def _memories_url(self) -> str: + # Keep the awareness path isolated in one helper so the contract can + # be adjusted later without touching the agent-facing tools. + return f"{self.base_url}/api/v1/namespaces/{self.namespace}/memories" + + async def commit(self, content: str, scope: str) -> dict[str, Any]: + client_cls = _resolve_async_client() + async with client_cls(timeout=self.timeout) as client: + resp = await client.post( + self._memories_url(), + json={"content": content, "scope": scope}, + ) + return _parse_commit_response(resp, scope) + + async def search(self, query: str = "", scope: str = "") -> dict[str, Any]: + params: dict[str, str] = {} + if query: + params["q"] = query + if scope: + params["scope"] = scope + + client_cls = _resolve_async_client() + async with client_cls(timeout=self.timeout) as client: + resp = await client.get(self._memories_url(), params=params) + return _parse_search_response(resp) + + +def build_awareness_client() -> AwarenessClient | None: + """Create an awareness client from the current workspace environment.""" + config = get_awareness_config() + if not config: + return None + return AwarenessClient(config["base_url"], config["namespace"]) + + +def _parse_commit_response(resp: httpx.Response, scope: str) -> dict[str, Any]: + data = _safe_json(resp) + if resp.status_code in (200, 201): + return {"success": True, "id": data.get("id"), "scope": scope} + return {"success": False, "error": data.get("error", resp.text)} + + +def _parse_search_response(resp: httpx.Response) -> dict[str, Any]: + data = _safe_json(resp) + if resp.status_code == 200: + memories = data if isinstance(data, list) else data.get("memories", []) + return { + "success": True, + "count": len(memories), + "memories": memories, + } + return {"success": False, "error": data.get("error", resp.text)} + + +def _safe_json(resp: httpx.Response) -> dict[str, Any] | list[Any]: + try: + return resp.json() + except ValueError: + return {"error": resp.text} + + +def _resolve_async_client(): + client_cls = getattr(httpx, "AsyncClient", None) + if client_cls is not None: + return client_cls + + memory_module = sys.modules.get("builtin_tools.memory") + if memory_module is not None: + memory_httpx = getattr(memory_module, "httpx", None) + client_cls = getattr(memory_httpx, "AsyncClient", None) + if client_cls is not None: + return client_cls + + raise RuntimeError("httpx.AsyncClient is unavailable") diff --git a/molecule_runtime/builtin_tools/compliance.py b/molecule_runtime/builtin_tools/compliance.py new file mode 100644 index 0000000..1c4e45e --- /dev/null +++ b/molecule_runtime/builtin_tools/compliance.py @@ -0,0 +1,359 @@ +"""OWASP Top 10 for Agentic Applications compliance enforcement (Dec 2025). + +Enable via config.yaml:: + + compliance: + mode: owasp_agentic + prompt_injection: detect # detect | block + max_tool_calls_per_task: 50 + max_task_duration_seconds: 300 + +When ``mode`` is absent or empty, this module is a no-op — no overhead, no +behaviour change. This makes it safe to import unconditionally. + +Coverage +-------- + +OA-01 Prompt Injection (``sanitize_input``) + Scans user-supplied text for instruction-override patterns, role-hijacking + attempts, system-prompt delimiter injection, and known jailbreak keywords. + + - ``detect`` (default): log an audit event, return the original text so + the agent still processes the input. Operators are alerted without + breaking legitimate use-cases that happen to contain trigger words. + + - ``block``: raise ``PromptInjectionError`` before the agent sees the text. + +OA-03 Excessive Agency (``check_agency_limits``) + Tracks the number of tool calls and wall-clock time elapsed per task. + When a limit is exceeded, ``ExcessiveAgencyError`` is raised. The caller + (``a2a_executor.py``) catches it and terminates the task gracefully. + +OA-02 / OA-06 Insecure Output / Sensitive Data Exposure (``redact_pii``) + Scans agent output for credit-card numbers, SSNs, API keys, AWS access + keys, and e-mail addresses. Detected values are replaced with + ``[REDACTED:]`` tokens before the response reaches the caller. + An audit event records the PII types found (not the values themselves). + + Note on streaming: ``redact_pii`` is applied to the *final accumulated + text* before the terminal ``Message`` event is emitted. Token-by-token + SSE artifacts that have already been sent to streaming clients are not + retroactively redacted. For full streaming redaction, integrate + ``redact_pii`` at the ``TaskArtifactUpdateEvent`` level. + +Compliance posture report (``get_compliance_posture``) + Returns the current effective compliance configuration as a plain ``dict`` + suitable for a health or audit endpoint, letting operators verify that the + correct settings are active without reading config files. +""" + +from __future__ import annotations + +import logging +import re +import time +import uuid +from dataclasses import dataclass, field +from typing import Any + +from builtin_tools.audit import log_event + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Public exceptions +# --------------------------------------------------------------------------- + + +class PromptInjectionError(ValueError): + """Raised when prompt injection is detected and ``prompt_injection=block``.""" + + +class ExcessiveAgencyError(RuntimeError): + """Raised when the tool-call count or task-duration limit is exceeded.""" + + +# --------------------------------------------------------------------------- +# OA-01 — Prompt Injection detection +# --------------------------------------------------------------------------- + +#: Compiled patterns matched against normalised (lowercased + collapsed) input. +#: Add workspace-specific patterns in config if needed. +_INJECTION_PATTERNS: list[tuple[re.Pattern[str], str]] = [ + # Instruction override + (re.compile(r"ignore\s+(all\s+)?previous\s+instructions?", re.I), "instruction_override"), + (re.compile(r"disregard\s+(all\s+)?previous", re.I), "instruction_override"), + (re.compile(r"forget\s+(all\s+)?previous", re.I), "instruction_override"), + (re.compile(r"override\s+(your\s+)?(instructions?|guidelines?|rules?)", re.I), "instruction_override"), + # Role hijacking + (re.compile(r"you\s+are\s+now\s+\w", re.I), "role_hijack"), + (re.compile(r"act\s+as\s+(a\s+)?(new\s+|different\s+|unrestricted\s+)", re.I), "role_hijack"), + (re.compile(r"roleplay\s+as", re.I), "role_hijack"), + (re.compile(r"pretend\s+(you\s+are|to\s+be)\b", re.I), "role_hijack"), + (re.compile(r"from\s+now\s+on\s+(you\s+are|act\s+as)", re.I), "role_hijack"), + # System-prompt delimiter injection (LLM-specific tokens) + (re.compile(r"<\|?\s*(system|im_start|im_end|endoftext)\s*\|?>", re.I), "delimiter_injection"), + (re.compile(r"\[INST\]|\[/INST\]|\[\[SYS\]\]|\[\[/SYS\]\]", re.I), "delimiter_injection"), + (re.compile(r"<>|<>", re.I), "delimiter_injection"), + # DAN / jailbreak keywords + (re.compile(r"\bDAN\b.{0,30}(mode|now|enabled|activated)", re.I), "jailbreak"), + (re.compile(r"do\s+anything\s+now", re.I), "jailbreak"), + (re.compile(r"\bjailbreak\b", re.I), "jailbreak"), + (re.compile(r"developer\s+mode\s+(enabled|on)", re.I), "jailbreak"), + # Prompt exfiltration + (re.compile(r"(repeat|print|output|show|reveal|display)\s+(your\s+)?(system\s+prompt|initial\s+instructions?)", re.I), "prompt_exfiltration"), + (re.compile(r"what\s+(are\s+)?your\s+(instructions?|system\s+prompt)", re.I), "prompt_exfiltration"), +] + + +def detect_prompt_injection(text: str) -> list[tuple[str, str]]: + """Return a list of ``(pattern_description, category)`` for each match. + + Args: + text: Raw user input to scan. + + Returns: + List of ``(matched_pattern, category)`` tuples; empty means clean. + """ + matches: list[tuple[str, str]] = [] + for pattern, category in _INJECTION_PATTERNS: + m = pattern.search(text) + if m: + matches.append((m.group(0)[:80], category)) + return matches + + +def sanitize_input( + text: str, + *, + prompt_injection_mode: str = "detect", + context_id: str = "", +) -> str: + """Check *text* for prompt injection and enforce the configured response. + + Args: + text: User-supplied input to the agent. + prompt_injection_mode: ``"detect"`` or ``"block"``. + context_id: Task/context identifier for audit correlation. + + Returns: + The original *text* unchanged (``detect`` mode always returns input). + + Raises: + :class:`PromptInjectionError`: only when ``prompt_injection_mode="block"`` + and at least one injection pattern is matched. + """ + matches = detect_prompt_injection(text) + if not matches: + return text + + categories = list({cat for _, cat in matches}) + trace_id = str(uuid.uuid4()) + + log_event( + event_type="compliance", + action="prompt_injection.detect", + resource="user_input", + outcome="detected" if prompt_injection_mode == "detect" else "blocked", + trace_id=trace_id, + context_id=context_id, + categories=categories, + match_count=len(matches), + # Log category + truncated match, never the full raw text (OA-06) + matches=[{"category": cat, "snippet": snippet} for snippet, cat in matches[:5]], + ) + + if prompt_injection_mode == "block": + raise PromptInjectionError( + f"Prompt injection detected ({', '.join(categories)}). " + "Request blocked by compliance policy." + ) + + # detect mode — log and continue + logger.warning( + "Prompt injection patterns detected (context_id=%s, categories=%s) — " + "passing to agent in detect mode", + context_id, + categories, + ) + return text + + +# --------------------------------------------------------------------------- +# OA-03 — Excessive Agency +# --------------------------------------------------------------------------- + + +@dataclass +class AgencyTracker: + """Per-task mutable state for excessive-agency enforcement. + + Instantiate once per ``execute()`` call and pass to + :func:`check_agency_limits` at each tool-start event. + """ + + max_tool_calls: int = 50 + max_duration_seconds: float = 300.0 + tool_call_count: int = field(default=0, init=False) + start_time: float = field(default_factory=time.monotonic, init=False) + + def on_tool_call(self, tool_name: str = "", context_id: str = "") -> None: + """Increment counter and enforce limits. + + Raises: + :class:`ExcessiveAgencyError`: if either limit is exceeded. + """ + self.tool_call_count += 1 + elapsed = time.monotonic() - self.start_time + + if self.tool_call_count > self.max_tool_calls: + log_event( + event_type="compliance", + action="excessive_agency.tool_limit", + resource=tool_name or "unknown_tool", + outcome="blocked", + context_id=context_id, + tool_call_count=self.tool_call_count, + limit=self.max_tool_calls, + elapsed_seconds=round(elapsed, 2), + ) + raise ExcessiveAgencyError( + f"Tool call limit exceeded: {self.tool_call_count} calls > " + f"max {self.max_tool_calls} per task" + ) + + if elapsed > self.max_duration_seconds: + log_event( + event_type="compliance", + action="excessive_agency.duration_limit", + resource=tool_name or "unknown_tool", + outcome="blocked", + context_id=context_id, + tool_call_count=self.tool_call_count, + elapsed_seconds=round(elapsed, 2), + limit_seconds=self.max_duration_seconds, + ) + raise ExcessiveAgencyError( + f"Task duration limit exceeded: {elapsed:.0f}s > " + f"max {self.max_duration_seconds:.0f}s per task" + ) + + +# --------------------------------------------------------------------------- +# OA-02 / OA-06 — PII redaction +# --------------------------------------------------------------------------- + +#: ``(compiled_pattern, replacement_token)`` pairs applied in order. +#: The replacement tokens are SIEM-friendly: ``[REDACTED:type]``. +_PII_PATTERNS: list[tuple[re.Pattern[str], str]] = [ + # Formatted credit cards: XXXX-XXXX-XXXX-XXXX or XXXX XXXX XXXX XXXX + (re.compile(r"\b\d{4}[\s\-]\d{4}[\s\-]\d{4}[\s\-]\d{4}\b"), "[REDACTED:credit_card]"), + # US Social Security Numbers: XXX-XX-XXXX + (re.compile(r"\b\d{3}-\d{2}-\d{4}\b"), "[REDACTED:ssn]"), + # OpenAI-style keys: sk-... (≥ 32 chars after prefix) + (re.compile(r"\bsk-[A-Za-z0-9_\-]{32,}\b"), "[REDACTED:api_key]"), + # Generic API/secret keys with common prefixes + (re.compile(r"\b(?:sk|pk|api|secret|token|auth)[-_][A-Za-z0-9_\-]{20,}\b", re.I), "[REDACTED:api_key]"), + # AWS Access Key IDs + (re.compile(r"\bAKIA[0-9A-Z]{16}\b"), "[REDACTED:aws_key]"), + # GitHub personal access tokens — classic format (36-char alphanumeric suffix) + (re.compile(r"\bghp_[A-Za-z0-9]{36}\b"), "[REDACTED:github_token]"), + # GitHub personal access tokens — fine-grained format (82-char alphanumeric+underscore suffix) + (re.compile(r"\bgithub_pat_[A-Za-z0-9_]{82}\b"), "[REDACTED:github_token]"), + # Email addresses + (re.compile(r"\b[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}\b"), "[REDACTED:email]"), +] + + +def redact_pii(text: str) -> tuple[str, list[str]]: + """Redact PII from *text* and return ``(redacted_text, pii_types_found)``. + + Each unique PII type is reported at most once in ``pii_types_found``. + The replacement tokens (``[REDACTED:type]``) are SIEM-indexable and + preserve the structural context of the output while hiding sensitive data. + + Args: + text: Agent output text to scan. + + Returns: + Tuple of ``(redacted_text, list_of_pii_type_strings)``. The list is + empty when no PII is detected (the common case). + + Examples:: + + >>> redacted, types = redact_pii("Call me at test@example.com sk-abc123...") + >>> "email" in types + True + >>> "[REDACTED:email]" in redacted + True + """ + found: list[str] = [] + result = text + for pattern, replacement in _PII_PATTERNS: + new_result = pattern.sub(replacement, result) + if new_result != result: + # Extract type from "[REDACTED:type]" + pii_type = replacement[len("[REDACTED:"):-1] + if pii_type not in found: + found.append(pii_type) + result = new_result + return result, found + + +# --------------------------------------------------------------------------- +# Compliance posture report +# --------------------------------------------------------------------------- + + +def get_compliance_posture() -> dict[str, Any]: + """Return the current compliance configuration as a serialisable dict. + + Loads ``WorkspaceConfig`` lazily (cached) and returns a snapshot of the + active compliance settings. Safe to call from a health endpoint. + + Returns a dict with these keys:: + + { + "compliance_mode": "owasp_agentic" | "", + "enabled": true | false, + "prompt_injection": "detect" | "block", + "max_tool_calls_per_task": 50, + "max_task_duration_seconds": 300, + "pii_redaction_enabled": true, + "security_scan_mode": "warn" | "block" | "off", + "rbac_roles": ["operator"], + } + """ + try: + from builtin_tools.audit import _load_workspace_config + cfg = _load_workspace_config() + except Exception: + cfg = None + + if cfg is None: + return { + "compliance_mode": "", + "enabled": False, + "prompt_injection": "detect", + "max_tool_calls_per_task": 50, + "max_task_duration_seconds": 300, + "pii_redaction_enabled": False, + "security_scan_mode": "warn", + "rbac_roles": [], + "note": "config unavailable", + } + + c = cfg.compliance + enabled = c.mode == "owasp_agentic" + return { + "compliance_mode": c.mode, + "enabled": enabled, + "prompt_injection": c.prompt_injection, + "max_tool_calls_per_task": c.max_tool_calls_per_task, + "max_task_duration_seconds": c.max_task_duration_seconds, + # PII redaction is active whenever compliance mode is on + "pii_redaction_enabled": enabled, + "security_scan_mode": cfg.security_scan.mode, + "rbac_roles": list(cfg.rbac.roles), + } diff --git a/molecule_runtime/builtin_tools/delegation.py b/molecule_runtime/builtin_tools/delegation.py new file mode 100644 index 0000000..b9c4296 --- /dev/null +++ b/molecule_runtime/builtin_tools/delegation.py @@ -0,0 +1,366 @@ +"""Async delegation tool for sending tasks to peer workspaces via A2A. + +Delegations are non-blocking: the tool fires the A2A request in the background +and returns immediately with a task_id. The agent can check status anytime via +check_delegation_status, or just continue working and check later. + +When the delegate responds, the result is stored and the agent is notified +via a status update. +""" + +import asyncio +import os +import uuid +from dataclasses import dataclass, field +from enum import Enum +from typing import Optional + +import httpx +from langchain_core.tools import tool + +from builtin_tools.audit import check_permission, get_workspace_roles, log_event +from builtin_tools.telemetry import ( + A2A_SOURCE_WORKSPACE, + A2A_TARGET_WORKSPACE, + A2A_TASK_ID, + WORKSPACE_ID_ATTR, + get_current_traceparent, + get_tracer, + inject_trace_headers, +) + +PLATFORM_URL = os.environ.get("PLATFORM_URL", "http://platform:8080") +WORKSPACE_ID = os.environ.get("WORKSPACE_ID", "") +DELEGATION_RETRY_ATTEMPTS = int(os.environ.get("DELEGATION_RETRY_ATTEMPTS", "3")) +DELEGATION_RETRY_DELAY = float(os.environ.get("DELEGATION_RETRY_DELAY", "5.0")) +DELEGATION_TIMEOUT = float(os.environ.get("DELEGATION_TIMEOUT", "300.0")) + + +class DelegationStatus(str, Enum): + PENDING = "pending" + IN_PROGRESS = "in_progress" + COMPLETED = "completed" + FAILED = "failed" + + +@dataclass +class DelegationTask: + task_id: str + workspace_id: str + task_description: str + status: DelegationStatus = DelegationStatus.PENDING + result: Optional[str] = None + error: Optional[str] = None + + +# In-memory store of delegation tasks for this workspace +_delegations: dict[str, DelegationTask] = {} +_background_tasks: set[asyncio.Task] = set() +MAX_DELEGATION_HISTORY = 100 +logger = __import__("logging").getLogger(__name__) + + +def _evict_old_delegations(): + """Remove completed/failed delegations when store exceeds MAX_DELEGATION_HISTORY.""" + if len(_delegations) <= MAX_DELEGATION_HISTORY: + return + # Evict oldest completed/failed first + removable = [ + tid for tid, d in _delegations.items() + if d.status in (DelegationStatus.COMPLETED, DelegationStatus.FAILED) + ] + for tid in removable[:len(_delegations) - MAX_DELEGATION_HISTORY]: + del _delegations[tid] + + +def _on_task_done(task: asyncio.Task): + """Callback for background tasks — log unhandled exceptions.""" + _background_tasks.discard(task) + if not task.cancelled() and task.exception(): + logger.error("Delegation background task failed: %s", task.exception()) + + +async def _notify_completion(task_id: str, target_workspace_id: str, status: str): + """Push notification to platform when delegation completes/fails.""" + try: + async with httpx.AsyncClient(timeout=10) as client: + await client.post( + f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/notify", + json={ + "type": "delegation_complete", + "task_id": task_id, + "target_workspace_id": target_workspace_id, + "status": status, + }, + ) + except Exception as e: + logger.debug("Delegation notify failed (best-effort): %s", e) + + +async def _record_delegation_on_platform(task_id: str, target_workspace_id: str, task: str): + """Register the delegation in the platform's activity_logs (#64 fix). + + Best-effort POST to /workspaces//delegations/record. The agent still + fires A2A directly for speed + OTEL propagation, but the platform's + GET /delegations endpoint now mirrors the same set an agent's local + check_delegation_status sees. + """ + try: + async with httpx.AsyncClient(timeout=10) as client: + await client.post( + f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/delegations/record", + json={ + "target_id": target_workspace_id, + "task": task, + "delegation_id": task_id, + }, + ) + except Exception as e: + logger.debug("Delegation record failed (best-effort): %s", e) + + +async def _update_delegation_on_platform(task_id: str, status: str, error: str = "", response_preview: str = ""): + """Mirror status changes to the platform's activity_logs (#64 fix). + + Paired with _record_delegation_on_platform — fires on completion/failure + so the platform view stays in sync with the agent's local dict. + """ + try: + async with httpx.AsyncClient(timeout=10) as client: + await client.post( + f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/delegations/{task_id}/update", + json={ + "status": status, + "error": error, + "response_preview": response_preview[:500], + }, + ) + except Exception as e: + logger.debug("Delegation update failed (best-effort): %s", e) + + +async def _execute_delegation(task_id: str, workspace_id: str, task: str): + """Background coroutine that sends the A2A request and stores the result.""" + delegation = _delegations[task_id] + delegation.status = DelegationStatus.IN_PROGRESS + + # #64: register on the platform so GET /workspaces//delegations + # sees the same set as check_delegation_status. Best-effort — platform + # unreachability must not block the actual A2A delegation. + await _record_delegation_on_platform(task_id, workspace_id, task) + + tracer = get_tracer() + with tracer.start_as_current_span("task_delegate") as delegate_span: + delegate_span.set_attribute(WORKSPACE_ID_ATTR, WORKSPACE_ID) + delegate_span.set_attribute(A2A_SOURCE_WORKSPACE, WORKSPACE_ID) + delegate_span.set_attribute(A2A_TARGET_WORKSPACE, workspace_id) + delegate_span.set_attribute(A2A_TASK_ID, task_id) + + async with httpx.AsyncClient(timeout=DELEGATION_TIMEOUT) as client: + # Discover target URL + try: + discover_resp = await client.get( + f"{PLATFORM_URL}/registry/discover/{workspace_id}", + headers={"X-Workspace-ID": WORKSPACE_ID}, + ) + if discover_resp.status_code != 200: + delegation.status = DelegationStatus.FAILED + delegation.error = f"Discovery failed: HTTP {discover_resp.status_code}" + log_event(event_type="delegation", action="delegate", resource=workspace_id, + outcome="failure", trace_id=task_id, reason="discovery_error") + return + + target_url = discover_resp.json().get("url") + if not target_url: + delegation.status = DelegationStatus.FAILED + delegation.error = "No URL for workspace" + return + except Exception as e: + delegation.status = DelegationStatus.FAILED + delegation.error = f"Discovery error: {e}" + return + + # Send A2A with retry + outgoing_headers = inject_trace_headers({ + "Content-Type": "application/json", + "X-Workspace-ID": WORKSPACE_ID, + }) + traceparent = get_current_traceparent() + + last_error = None + for attempt in range(DELEGATION_RETRY_ATTEMPTS): + try: + a2a_resp = await client.post( + target_url, + headers=outgoing_headers, + json={ + "jsonrpc": "2.0", + "method": "message/send", + "id": f"delegation-{task_id}-{attempt}", + "params": { + "message": { + "role": "user", + "parts": [{"kind": "text", "text": task}], + "messageId": f"msg-{task_id}-{attempt}", + }, + "metadata": { + "parent_task_id": task_id, + "source_workspace_id": WORKSPACE_ID, + "traceparent": traceparent, + }, + }, + }, + ) + + if a2a_resp.status_code == 200: + try: + result = a2a_resp.json() + except Exception: + delegation.status = DelegationStatus.FAILED + delegation.error = "Invalid JSON response" + return + + if "result" in result: + task_result = result["result"] + artifacts = task_result.get("artifacts", []) + texts = [] + for artifact in artifacts: + for part in artifact.get("parts", []): + if part.get("kind") == "text": + texts.append(part["text"]) + # Also check top-level parts + for part in task_result.get("parts", []): + if part.get("kind") == "text": + texts.append(part["text"]) + + delegation.status = DelegationStatus.COMPLETED + delegation.result = "\n".join(texts) if texts else str(task_result) + log_event(event_type="delegation", action="delegate", resource=workspace_id, + outcome="success", trace_id=task_id, attempt=attempt + 1) + await _notify_completion(task_id, workspace_id, "completed") + # #64: mirror to platform activity_logs so + # GET /delegations shows the completion state. + await _update_delegation_on_platform( + task_id, "completed", "", + delegation.result or "", + ) + return + + if "error" in result: + last_error = result["error"].get("message", str(result["error"])) + break + + except (httpx.ConnectError, httpx.TimeoutException) as e: + last_error = str(e) + if attempt < DELEGATION_RETRY_ATTEMPTS - 1: + await asyncio.sleep(DELEGATION_RETRY_DELAY * (attempt + 1)) + continue + + delegation.status = DelegationStatus.FAILED + delegation.error = str(last_error) + log_event(event_type="delegation", action="delegate", resource=workspace_id, + outcome="failure", trace_id=task_id, last_error=str(last_error)) + await _notify_completion(task_id, workspace_id, "failed") + # #64: mirror failure to platform activity_logs. + await _update_delegation_on_platform( + task_id, "failed", str(last_error), "", + ) + + +@tool +async def delegate_to_workspace( + workspace_id: str, + task: str, +) -> dict: + """Delegate a task to a peer workspace via A2A protocol (non-blocking). + + Sends the task in the background and returns immediately with a task_id. + Use check_delegation_status to poll for the result, or continue working + and check later. The delegate works independently. + + Args: + workspace_id: The ID of the target workspace to delegate to. + task: The task description to send to the peer. + + Returns: + A dict with task_id and status="delegated". Use check_delegation_status(task_id) to get results. + """ + task_id = str(uuid.uuid4()) + + # RBAC check + roles, custom_perms = get_workspace_roles() + if not check_permission("delegate", roles, custom_perms): + log_event(event_type="rbac", action="rbac.deny", resource=workspace_id, + outcome="denied", trace_id=task_id, attempted_action="delegate", roles=roles) + return {"success": False, "error": f"RBAC: no 'delegate' permission. Roles: {roles}"} + + log_event(event_type="delegation", action="delegate", resource=workspace_id, + outcome="dispatched", trace_id=task_id, task_preview=task[:200]) + + # Store the delegation and launch background task + delegation = DelegationTask( + task_id=task_id, + workspace_id=workspace_id, + task_description=task[:200], + ) + _delegations[task_id] = delegation + _evict_old_delegations() + + bg_task = asyncio.create_task(_execute_delegation(task_id, workspace_id, task)) + _background_tasks.add(bg_task) + bg_task.add_done_callback(_on_task_done) + + return { + "success": True, + "task_id": task_id, + "status": "delegated", + "message": f"Task delegated to {workspace_id}. Use check_delegation_status('{task_id}') to get the result when ready.", + } + + +@tool +async def check_delegation_status( + task_id: str = "", +) -> dict: + """Check the status of a delegated task, or list all active delegations. + + Args: + task_id: The task_id returned by delegate_to_workspace. If empty, lists all delegations. + + Returns: + Status and result (if completed) of the delegation. + """ + if not task_id: + # List all delegations + summary = [] + for tid, d in _delegations.items(): + entry = { + "task_id": tid, + "workspace_id": d.workspace_id, + "status": d.status.value, + "task": d.task_description, + } + if d.status == DelegationStatus.COMPLETED: + entry["result_preview"] = (d.result or "")[:200] + if d.status == DelegationStatus.FAILED: + entry["error"] = d.error + summary.append(entry) + return {"delegations": summary, "count": len(summary)} + + delegation = _delegations.get(task_id) + if not delegation: + return {"error": f"No delegation found with task_id {task_id}"} + + result = { + "task_id": task_id, + "workspace_id": delegation.workspace_id, + "status": delegation.status.value, + "task": delegation.task_description, + } + + if delegation.status == DelegationStatus.COMPLETED: + result["result"] = delegation.result + elif delegation.status == DelegationStatus.FAILED: + result["error"] = delegation.error + + return result diff --git a/molecule_runtime/builtin_tools/governance.py b/molecule_runtime/builtin_tools/governance.py new file mode 100644 index 0000000..3399f44 --- /dev/null +++ b/molecule_runtime/builtin_tools/governance.py @@ -0,0 +1,403 @@ +"""Bridge between Molecule AI's RBAC + audit subsystem and the Microsoft Agent +Governance Toolkit (agent-os-kernel, released April 2, 2026). + +Integration points +------------------ +* ``check_permission`` → ``PolicyEvaluator.evaluate()`` + Molecule AI's RBAC gate runs first; if RBAC allows the action the toolkit + evaluator is consulted according to ``policy_mode``. + +* ``log_event`` → governance audit sink + Every permission decision (allow or deny) is written via + ``tools.audit.log_event`` with extra governance metadata so the full + decision trail lands in Molecule AI's existing audit stream. + +* OTEL traceparent flows through + ``tools.telemetry.get_current_traceparent()`` is called inside ``emit()`` + and the W3C traceparent string is attached to every audit record, giving + end-to-end distributed tracing across agent boundaries. + +Graceful degradation +-------------------- +If ``agent-os-kernel`` is not installed the module falls back to Molecule AI +RBAC alone. No exception propagates to the agent — governance is a +best-effort overlay, never a hard dependency. + +Install:: + + pip install agent-os-kernel + +Minimal config.yaml snippet:: + + governance: + enabled: true + toolkit: microsoft + policy_mode: strict # strict | permissive | audit + policy_endpoint: https://your-tenant.governance.azure.com + policy_file: policies/workspace.rego + blocked_patterns: + - ".*\\.exec$" + - "shell\\." + max_tool_calls_per_task: 50 + +NOTE: The agent-os-kernel package was released April 2, 2026 and is in +community preview. The API bindings in this module target v3.0.x of the +package (agent_os.policies.PolicyEvaluator). If the package API changes, +update _init_evaluator() accordingly. +""" + +import logging +import os +from typing import Any, Optional + +logger = logging.getLogger(__name__) +WORKSPACE_ID: str = os.environ.get("WORKSPACE_ID", "") + +# Module-level singleton — set by initialize_governance() at startup +_adapter: Optional["GovernanceAdapter"] = None + + +class GovernanceAdapter: + """Bridges Molecule AI RBAC + audit trail to the Microsoft Agent Governance Toolkit.""" + + def __init__(self, config: Any) -> None: + self._config = config + self._evaluator = None + self._toolkit_available: bool = False + + async def initialize(self) -> None: + """Async entry point: initialise evaluator and log outcome.""" + self._init_evaluator() + if self._toolkit_available: + logger.info( + "GovernanceAdapter initialised — toolkit=%s mode=%s", + self._config.toolkit, + self._config.policy_mode, + ) + else: + logger.warning( + "GovernanceAdapter initialised in RBAC-only mode " + "(agent-os-kernel not available or failed to load)." + ) + + def _init_evaluator(self) -> None: + """Lazy-import and configure the PolicyEvaluator from agent-os-kernel. + + All failures are caught and logged; the adapter simply runs without + the toolkit rather than crashing the workspace. + """ + try: + try: + from agent_os.policies import PolicyEvaluator # type: ignore[import] + except ImportError: + logger.warning( + "agent-os-kernel is not installed — graceful degradation active. " + "Governance will use Molecule AI RBAC only. " + "To enable the Microsoft Agent Governance Toolkit run: " + "pip install agent-os-kernel" + ) + return + + kwargs: dict[str, Any] = { + "policy_mode": self._config.policy_mode, + "max_tool_calls_per_task": self._config.max_tool_calls_per_task, + "blocked_patterns": self._config.blocked_patterns, + } + if self._config.policy_endpoint: + kwargs["endpoint"] = self._config.policy_endpoint + + self._evaluator = PolicyEvaluator(**kwargs) + + # Load a policy file if one is configured and exists on disk. + if self._config.policy_file: + policy_file = self._config.policy_file + if os.path.exists(policy_file): + ext = os.path.splitext(policy_file)[1].lower() + if ext == ".rego": + self._evaluator.load_rego(path=policy_file) + logger.info("Loaded Rego policy file: %s", policy_file) + elif ext in (".yaml", ".yml"): + self._evaluator.load_yaml(path=policy_file) + logger.info("Loaded YAML policy file: %s", policy_file) + elif ext == ".cedar": + self._evaluator.load_cedar(path=policy_file) + logger.info("Loaded Cedar policy file: %s", policy_file) + else: + logger.warning( + "Unrecognised policy file extension '%s' — skipping load.", + ext, + ) + else: + logger.warning( + "policy_file '%s' does not exist — skipping load.", + policy_file, + ) + + self._toolkit_available = True + logger.info( + "agent-os-kernel PolicyEvaluator ready — policy_mode=%s", + self._config.policy_mode, + ) + + except Exception as exc: # noqa: BLE001 + logger.warning( + "Failed to initialise agent-os-kernel PolicyEvaluator: %s — " + "graceful degradation active (RBAC only).", + exc, + ) + + def check_permission( + self, + action: str, + roles: list[str], + custom_permissions: dict | None = None, + context: dict | None = None, + ) -> tuple[bool, str]: + """Evaluate an action against Molecule AI RBAC and (optionally) the toolkit. + + Returns + ------- + tuple[bool, str] + ``(allowed, reason)`` — reason is a short human-readable string + explaining the decision. + """ + from builtin_tools import audit # inline import to avoid circular dependencies + + context = context or {} + + # --- Step 1: Molecule AI RBAC gate (always runs) --- + rbac_allowed: bool = audit.check_permission(action, roles, custom_permissions) + + if not rbac_allowed: + self.emit( + event_type="permission_check", + action=action, + resource=context.get("resource", ""), + outcome="denied", + actor=context.get("actor"), + policy_decision="rbac_deny", + roles=roles, + ) + return False, f"RBAC denied action '{action}' for roles {roles}" + + # --- Step 2: If toolkit unavailable or audit-only mode, return RBAC result --- + if not self._toolkit_available or self._config.policy_mode == "audit": + self.emit( + event_type="permission_check", + action=action, + resource=context.get("resource", ""), + outcome="allowed", + actor=context.get("actor"), + policy_decision="rbac_allowed", + roles=roles, + toolkit_mode=self._config.policy_mode, + ) + return rbac_allowed, "rbac_allowed" + + # --- Step 3: Toolkit evaluation --- + eval_context: dict[str, Any] = { + "action": action, + "resource": context.get("resource", ""), + "roles": roles, + "workspace_id": WORKSPACE_ID, + } + # Merge any extra context keys the caller supplied. + for key, value in context.items(): + if key not in eval_context: + eval_context[key] = value + + toolkit_allowed: bool = True + reason: str = "" + evaluator_name: str = "agent-os-kernel" + + try: + decision = self._evaluator.evaluate(eval_context) + toolkit_allowed = getattr(decision, "allowed", True) + reason = getattr(decision, "reason", "") + evaluator_name = getattr(decision, "evaluator_name", "agent-os-kernel") + except Exception as exc: # noqa: BLE001 + logger.warning( + "agent-os-kernel evaluation raised an exception: %s — " + "falling back to RBAC result to avoid blocking the agent.", + exc, + ) + self.emit( + event_type="permission_check", + action=action, + resource=context.get("resource", ""), + outcome="allowed", + actor=context.get("actor"), + policy_decision="toolkit_evaluation_error", + toolkit_mode=self._config.policy_mode, + roles=roles, + ) + return rbac_allowed, "toolkit_evaluation_error" + + # --- Step 4: Combine results according to policy_mode --- + if self._config.policy_mode == "permissive": + # Toolkit denial is advisory only in permissive mode. + if not toolkit_allowed: + logger.warning( + "Governance toolkit denied action '%s' (reason=%s) but policy_mode " + "is 'permissive' — allowing and logging advisory denial.", + action, + reason, + ) + final_allowed = rbac_allowed + else: + # strict: both gates must allow. + final_allowed = rbac_allowed and toolkit_allowed + + outcome = "allowed" if final_allowed else "denied" + self.emit( + event_type="permission_check", + action=action, + resource=context.get("resource", ""), + outcome=outcome, + actor=context.get("actor"), + policy_decision=reason or outcome, + evaluator=evaluator_name, + toolkit_mode=self._config.policy_mode, + roles=roles, + ) + return final_allowed, reason or "allowed" + + def emit( + self, + event_type: str, + action: str, + resource: str, + outcome: str, + actor: str | None = None, + trace_id: str | None = None, + **extra: Any, + ) -> str: + """Write a governance-annotated audit event. + + Pulls the current W3C traceparent from the active OTEL span so that + governance decisions are traceable across service boundaries. + + Returns + ------- + str + The ``trace_id`` produced by ``audit.log_event``. + """ + from builtin_tools import audit # inline import to avoid circular dependencies + from builtin_tools.telemetry import get_current_traceparent # inline import + + traceparent: str | None = get_current_traceparent() + + recorded_trace_id: str = audit.log_event( + event_type, + action, + resource, + outcome, + actor=actor, + trace_id=trace_id, + governance_toolkit=( + self._config.toolkit if self._toolkit_available else "disabled" + ), + traceparent=traceparent or "", + **extra, + ) + return recorded_trace_id + + +# --------------------------------------------------------------------------- +# Module-level functions +# --------------------------------------------------------------------------- + + +async def initialize_governance(config: Any) -> Optional[GovernanceAdapter]: + """Initialize the module-level GovernanceAdapter singleton. + + Called once at startup by main.py when governance.enabled is True. + Returns the adapter, or None if initialization fails. + """ + global _adapter + + try: + adapter = GovernanceAdapter(config) + await adapter.initialize() + _adapter = adapter + logger.info( + "Governance singleton initialised — toolkit=%s mode=%s", + config.toolkit, + config.policy_mode, + ) + return adapter + except Exception as exc: # noqa: BLE001 + logger.warning( + "initialize_governance() failed: %s — governance disabled for this session.", + exc, + ) + return None + + +def get_governance_adapter() -> Optional[GovernanceAdapter]: + """Return the module-level GovernanceAdapter singleton (may be None).""" + return _adapter + + +def check_permission_with_governance( + action: str, + roles: list[str], + custom_permissions: dict | None = None, + context: dict | None = None, +) -> tuple[bool, str]: + """Convenience wrapper: use GovernanceAdapter when available, else RBAC only. + + Parameters + ---------- + action: + The action name to evaluate (e.g. ``"memory.write"``). + roles: + The list of role names held by the requesting actor. + custom_permissions: + Optional custom role→action mapping to overlay on built-in roles. + context: + Optional extra context forwarded to the PolicyEvaluator. + + Returns + ------- + tuple[bool, str] + ``(allowed, reason)`` + """ + if _adapter is None: + from builtin_tools import audit # inline import to avoid circular dependencies + + result: bool = audit.check_permission(action, roles, custom_permissions) + return result, "rbac_only" + + return _adapter.check_permission(action, roles, custom_permissions, context) + + +# --------------------------------------------------------------------------- +# Private helper +# --------------------------------------------------------------------------- + + +def _emit_governance_event( + event_type: str, + action: str, + resource: str, + outcome: str, + actor: str | None = None, + trace_id: str | None = None, + **extra: Any, +) -> Optional[str]: + """Emit a governance audit event via the singleton adapter if one is set. + + Returns the trace_id produced by log_event, or None if no adapter is set. + """ + if _adapter is None: + return None + return _adapter.emit( + event_type, + action, + resource, + outcome, + actor=actor, + trace_id=trace_id, + **extra, + ) diff --git a/molecule_runtime/builtin_tools/hitl.py b/molecule_runtime/builtin_tools/hitl.py new file mode 100644 index 0000000..d7bccc2 --- /dev/null +++ b/molecule_runtime/builtin_tools/hitl.py @@ -0,0 +1,531 @@ +"""Human-In-The-Loop (HITL) workflow primitives. + +Generalizes the approval tool into reusable HITL building blocks that work +across all Molecule AI adapters. + +Features +-------- +@requires_approval + Decorator that gates *any* async callable (tool, method, standalone fn) + behind a human approval request. The decorated function only runs if + the request is granted. Roles in ``hitl.bypass_roles`` skip the gate. + +pause_task / resume_task + LangChain tools for explicit pause/resume of in-flight tasks. An agent + calls ``pause_task(task_id, reason)`` to suspend itself; an external + signal (webhook, dashboard click, another agent) calls ``resume_task`` + with the same task_id to wake it up. + +Notification channels +--------------------- +Configured under ``hitl:`` in ``config.yaml``: + + hitl: + channels: + - type: dashboard # always active; uses platform approval API + - type: slack + webhook_url: https://hooks.slack.com/services/… + - type: email + smtp_host: smtp.example.com + smtp_port: 587 + from: alerts@example.com + to: ops@example.com + username: alerts@example.com # optional; password from SMTP_PASSWORD env + default_timeout: 300 # seconds before an unanswered request times out + bypass_roles: [admin] # roles that skip the approval gate entirely + +Environment variables +--------------------- +SMTP_PASSWORD Password for SMTP authentication (preferred over config file) +""" + +from __future__ import annotations + +import asyncio +import functools +import logging +import os +import smtplib +from dataclasses import dataclass, field +from email.mime.text import MIMEText +from typing import Any, Callable + +import httpx +from langchain_core.tools import tool + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + +@dataclass +class HITLConfig: + """HITL settings loaded from the ``hitl:`` block in config.yaml.""" + channels: list[dict] = field(default_factory=lambda: [{"type": "dashboard"}]) + default_timeout: float = 300.0 + bypass_roles: list[str] = field(default_factory=list) + + +def _load_hitl_config() -> HITLConfig: + """Load HITL config from workspace config; fall back to safe defaults.""" + try: + from config import load_config + cfg = load_config() + raw = getattr(cfg, "hitl", None) + if raw is None: + return HITLConfig() + return HITLConfig( + channels=raw.channels if hasattr(raw, "channels") else [{"type": "dashboard"}], + default_timeout=float(raw.default_timeout if hasattr(raw, "default_timeout") else 300), + bypass_roles=list(raw.bypass_roles if hasattr(raw, "bypass_roles") else []), + ) + except Exception: + return HITLConfig() + + +# --------------------------------------------------------------------------- +# Pause / Resume registry +# --------------------------------------------------------------------------- + +class _TaskPauseRegistry: + """In-process registry mapping task_id → asyncio.Event + optional result. + + Multiple coroutines awaiting the same task_id are all unblocked when + ``resume()`` is called. Results survive until the awaiting coroutine + calls ``pop_result()``. + """ + + def __init__(self) -> None: + self._events: dict[str, asyncio.Event] = {} + self._results: dict[str, dict] = {} + # #265: owner map — workspace_id that created each task. + # Empty string means "no owner / legacy" (bypasses ownership check). + self._owners: dict[str, str] = {} + + def register(self, task_id: str, owner: str = "") -> asyncio.Event: + """Create and store an Event for *task_id*. Returns the event. + + Args: + task_id: Unique task identifier. + owner: Workspace ID that owns this task. When set, ``resume`` + will reject callers from a different workspace. + """ + ev = asyncio.Event() + self._events[task_id] = ev + self._owners[task_id] = owner + return ev + + def resume(self, task_id: str, result: dict | None = None, owner: str = "") -> bool: + """Signal the Event for *task_id*. Returns False if not registered. + + Args: + task_id: The identifier used in ``register``. + result: Optional result payload forwarded to the waiting coroutine. + owner: Caller's workspace ID. When both the stored owner and + *owner* are non-empty and they differ, the call is rejected + (returns False) — prevents cross-workspace prompt injection + (#265). Passing ``owner=""`` bypasses the check (used in + direct registry calls from tests and platform code). + """ + # #265 ownership check + stored_owner = self._owners.get(task_id, "") + if owner and stored_owner and owner != stored_owner: + logger.warning( + "HITL: resume rejected for task %s — caller workspace %r != owner %r", + task_id, owner, stored_owner, + ) + return False + ev = self._events.get(task_id) + if ev is None: + return False + self._results[task_id] = result or {} + ev.set() + return True + + def pop_result(self, task_id: str) -> dict: + """Return and remove the stored result for *task_id*.""" + return self._results.pop(task_id, {}) + + def cleanup(self, task_id: str) -> None: + """Remove *task_id* from all dicts.""" + self._events.pop(task_id, None) + self._results.pop(task_id, None) + self._owners.pop(task_id, None) + + def list_paused(self) -> list[str]: + """Return IDs of tasks whose events have not yet been set.""" + return [tid for tid, ev in self._events.items() if not ev.is_set()] + + +# Global singleton — safe within one asyncio event loop / process +pause_registry = _TaskPauseRegistry() + + +# --------------------------------------------------------------------------- +# Notification channels +# --------------------------------------------------------------------------- + +async def _notify_channels( + action: str, + reason: str, + approval_id: str, + cfg: HITLConfig, +) -> None: + """Fire-and-forget notifications to all configured channels. + + Errors in individual channels are logged but never re-raised so that a + misconfigured Slack webhook cannot block the approval flow. + """ + platform_url = os.environ.get("PLATFORM_URL", "http://platform:8080") + workspace_id = os.environ.get("WORKSPACE_ID", "") + + for channel in cfg.channels: + ch_type = channel.get("type", "dashboard") + try: + if ch_type == "slack": + await _notify_slack(channel, action, reason, approval_id, + platform_url, workspace_id) + elif ch_type == "email": + await _notify_email(channel, action, reason, approval_id, + platform_url, workspace_id) + # "dashboard" is handled by the platform via the approval POST + except Exception as exc: + logger.warning("HITL: channel '%s' notification failed: %s", ch_type, exc) + + +async def _notify_slack( + cfg: dict, + action: str, + reason: str, + approval_id: str, + platform_url: str, + workspace_id: str, +) -> None: + webhook_url = cfg.get("webhook_url", "") + if not webhook_url: + return + + approve_url = f"{platform_url}/workspaces/{workspace_id}/approvals/{approval_id}/approve" + deny_url = f"{platform_url}/workspaces/{workspace_id}/approvals/{approval_id}/deny" + + payload = { + "text": f":warning: Approval required from workspace `{workspace_id}`", + "blocks": [ + { + "type": "section", + "text": { + "type": "mrkdwn", + "text": ( + f"*Action:* {action}\n" + f"*Reason:* {reason}\n" + f"*Approval ID:* `{approval_id}`" + ), + }, + }, + { + "type": "actions", + "elements": [ + { + "type": "button", + "text": {"type": "plain_text", "text": "Approve"}, + "style": "primary", + "url": approve_url, + }, + { + "type": "button", + "text": {"type": "plain_text", "text": "Deny"}, + "style": "danger", + "url": deny_url, + }, + ], + }, + ], + } + async with httpx.AsyncClient(timeout=10.0) as client: + await client.post(webhook_url, json=payload) + logger.info("HITL: Slack notification sent for approval %s", approval_id) + + +async def _notify_email( + cfg: dict, + action: str, + reason: str, + approval_id: str, + platform_url: str, + workspace_id: str, +) -> None: + smtp_host = cfg.get("smtp_host", "") + smtp_port = int(cfg.get("smtp_port", 587)) + from_addr = cfg.get("from", "") + to_addr = cfg.get("to", "") + + if not all([smtp_host, from_addr, to_addr]): + logger.warning("HITL: email channel missing smtp_host/from/to — skipping") + return + + approve_url = f"{platform_url}/workspaces/{workspace_id}/approvals/{approval_id}/approve" + deny_url = f"{platform_url}/workspaces/{workspace_id}/approvals/{approval_id}/deny" + + body = ( + f"Approval required from workspace {workspace_id}\n\n" + f"Action : {action}\n" + f"Reason : {reason}\n" + f"ID : {approval_id}\n\n" + f"Approve: {approve_url}\n" + f"Deny : {deny_url}\n" + ) + + msg = MIMEText(body, "plain", "utf-8") + msg["Subject"] = f"[Molecule AI] Approval required: {action}" + msg["From"] = from_addr + msg["To"] = to_addr + + username = cfg.get("username", "") + password = cfg.get("password", os.environ.get("SMTP_PASSWORD", "")) + + def _send() -> None: + with smtplib.SMTP(smtp_host, smtp_port) as srv: + srv.ehlo() + srv.starttls() + if username and password: + srv.login(username, password) + srv.send_message(msg) + + await asyncio.to_thread(_send) + logger.info("HITL: email notification sent for approval %s", approval_id) + + +# --------------------------------------------------------------------------- +# @requires_approval decorator +# --------------------------------------------------------------------------- + +def requires_approval( + action_description: str = "", + reason_template: str = "", + bypass_roles: list[str] | None = None, +) -> Callable[[Callable], Callable]: + """Decorator that gates an async callable behind a human approval request. + + The wrapped function executes only when a human approves. Use this on + any tool or async helper that performs destructive or high-impact work. + + Args: + action_description: Short label for the action shown to the approver. + Defaults to the function's ``name`` attribute or + ``__name__``. + reason_template: f-string template for the reason line. Keyword + arguments of the decorated function are available, + e.g. ``"Delete table {table_name}"``). + bypass_roles: Roles that skip the gate entirely. Overrides + ``hitl.bypass_roles`` in config.yaml when given. + + Returns: + A decorator; applying it to a function returns an async wrapper. + + Usage:: + + @tool + @requires_approval("Wipe production DB", bypass_roles=["admin"]) + async def drop_table(table_name: str) -> dict: + ... + + # Works with plain async functions too: + @requires_approval("Send customer email") + async def send_email(to: str, body: str) -> dict: + ... + """ + def decorator(fn: Callable) -> Callable: + action = action_description or getattr(fn, "name", None) or fn.__name__ + + @functools.wraps(fn) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + hitl_cfg = _load_hitl_config() + + # --- Check bypass roles ----------------------------------------- + active_bypass = bypass_roles if bypass_roles is not None else hitl_cfg.bypass_roles + if active_bypass: + try: + from builtin_tools.audit import get_workspace_roles + roles, _ = get_workspace_roles() + if any(r in active_bypass for r in roles): + logger.info( + "@requires_approval bypassed (role %s) for '%s'", roles, action + ) + return await fn(*args, **kwargs) + except Exception: + pass # If RBAC check fails, proceed to approval gate + + # --- Build reason string ----------------------------------------- + if reason_template: + try: + reason = reason_template.format(**kwargs) + except (KeyError, IndexError): + reason = reason_template + else: + arg_parts = [f"{k}={str(v)[:60]}" for k, v in list(kwargs.items())[:3]] + reason = f"Args: {', '.join(arg_parts)}" if arg_parts else "Automated action" + + # --- Fire non-dashboard notifications (async, non-blocking) ------ + asyncio.create_task( + _notify_channels(action, reason, "pending", hitl_cfg) + ) + + # --- Request approval via approval tool -------------------------- + try: + from builtin_tools.approval import request_approval + approval_result = await request_approval.ainvoke( + {"action": action, "reason": reason} + ) + except Exception as exc: + logger.error("@requires_approval: approval call failed: %s", exc) + return { + "success": False, + "error": f"Approval gate error: {exc}", + } + + if not approval_result.get("approved"): + return { + "success": False, + "error": ( + f"Action '{action}' not approved: " + f"{approval_result.get('message', approval_result.get('error', 'denied'))}" + ), + "approval_id": approval_result.get("approval_id"), + } + + # --- Approved — run the original function ------------------------ + return await fn(*args, **kwargs) + + return wrapper + + return decorator + + +# --------------------------------------------------------------------------- +# Pause / Resume LangChain tools +# --------------------------------------------------------------------------- + +@tool +async def pause_task(task_id: str, reason: str = "") -> dict: + """Suspend the current task and wait for a resume signal. + + The agent calls this to pause itself at a decision point. Execution + resumes when ``resume_task`` is called with the same task_id, or after + the configured ``hitl.default_timeout`` seconds. + + Args: + task_id: Unique identifier for this pause point (use the A2A task ID + or any stable string that the caller can reference later). + reason: Human-readable description of why the task is pausing. + """ + # #265: record workspace ownership on registration so resume_task can + # reject callers from a different workspace (cross-workspace prompt-injection + # prevention). External task_id is unchanged — only internal ownership + # metadata is added, so no tests or callers need to update their task IDs. + _ws = os.environ.get("WORKSPACE_ID", "") + + try: + from builtin_tools.audit import log_event + log_event( + event_type="hitl", + action="pause", + resource=task_id, + outcome="paused", + trace_id=task_id, + reason=reason, + ) + except Exception: + pass + + event = pause_registry.register(task_id, owner=_ws) + timeout = _load_hitl_config().default_timeout + logger.info("HITL: task %s paused — %s", task_id, reason or "(no reason given)") + + try: + await asyncio.wait_for(event.wait(), timeout=timeout) + result = pause_registry.pop_result(task_id) + logger.info("HITL: task %s resumed", task_id) + try: + from builtin_tools.audit import log_event + log_event( + event_type="hitl", + action="resume", + resource=task_id, + outcome="resumed", + trace_id=task_id, + ) + except Exception: + pass + return {"resumed": True, "task_id": task_id, **result} + + except asyncio.TimeoutError: + logger.warning("HITL: task %s timed out after %.0fs", task_id, timeout) + try: + from builtin_tools.audit import log_event + log_event( + event_type="hitl", + action="pause", + resource=task_id, + outcome="timeout", + trace_id=task_id, + timeout_seconds=timeout, + ) + except Exception: + pass + return { + "resumed": False, + "task_id": task_id, + "error": f"Timed out after {timeout:.0f}s waiting for resume signal", + } + finally: + pause_registry.cleanup(task_id) + + +@tool +async def resume_task(task_id: str, message: str = "") -> dict: + """Resume a previously paused task. + + Signals the ``pause_task`` coroutine waiting on *task_id* to continue. + Safe to call even if the task has already resumed or timed out (returns + success=False in that case). + + Args: + task_id: The identifier passed to ``pause_task``. + message: Optional message forwarded to the resumed task. + """ + # #265: pass caller's workspace ID so the registry can reject a resume + # from a different workspace (ownership check in _TaskPauseRegistry.resume). + _ws = os.environ.get("WORKSPACE_ID", "") + + result_payload = {"message": message} if message else {} + success = pause_registry.resume(task_id, result_payload, owner=_ws) + + if success: + logger.info("HITL: resume signal sent for task %s", task_id) + try: + from builtin_tools.audit import log_event + log_event( + event_type="hitl", + action="resume", + resource=task_id, + outcome="success", + trace_id=task_id, + message=message, + ) + except Exception: + pass + return {"success": True, "task_id": task_id} + + return { + "success": False, + "task_id": task_id, + "error": "Task not found or already resumed", + } + + +@tool +async def list_paused_tasks() -> dict: + """List all tasks currently suspended and waiting for a resume signal.""" + paused = pause_registry.list_paused() + return {"paused_tasks": paused, "count": len(paused)} diff --git a/molecule_runtime/builtin_tools/medo.py b/molecule_runtime/builtin_tools/medo.py new file mode 100644 index 0000000..0c824f9 --- /dev/null +++ b/molecule_runtime/builtin_tools/medo.py @@ -0,0 +1,106 @@ +"""MeDo builtin tools — Baidu MeDo no-code AI platform integration. + +MeDo (摩搭, moda.baidu.com) is Baidu's no-code AI application builder used in +the Molecule AI hackathon integration (May 2026). Three core operations: + create_medo_app — scaffold a new application from a template + update_medo_app — push content / config changes to an existing app + publish_medo_app — publish a draft app to a target environment + +Authentication: set MEDO_API_KEY as a workspace secret. +Override base URL via MEDO_BASE_URL (default: https://api.moda.baidu.com/v1). + +Mock backend: when MEDO_API_KEY is absent the tools return a predictable stub +response — safe for unit tests and local development. +TODO: swap _mock_http_post for a real httpx.AsyncClient call once keys are live. +""" + +import logging +import os + +from langchain_core.tools import tool + +logger = logging.getLogger(__name__) + +MEDO_BASE_URL = os.environ.get("MEDO_BASE_URL", "https://api.moda.baidu.com/v1") +MEDO_API_KEY = os.environ.get("MEDO_API_KEY", "") + +_VALID_TEMPLATES = ("blank", "chatbot", "form", "dashboard") +_VALID_ENVS = ("production", "staging") + + +async def _mock_http_post(path: str, payload: dict) -> dict: + """Stub HTTP call. TODO: replace with real httpx.AsyncClient once MEDO_API_KEY is live.""" + return {"status": "ok", "mock": True, "path": path, "payload_keys": list(payload.keys())} + + +@tool +async def create_medo_app(name: str, template: str = "blank", description: str = "") -> dict: + """Create a new MeDo application. + + Args: + name: Application name (required). + template: Starting template — blank | chatbot | form | dashboard (default: blank). + description: Short description of the application. + + Returns: + dict with 'app_id' and 'status' on success, 'error' key on failure. + """ + if not name: + return {"error": "name is required"} + if template not in _VALID_TEMPLATES: + return {"error": f"template must be one of: {', '.join(_VALID_TEMPLATES)}"} + try: + result = await _mock_http_post("/apps", {"name": name, "template": template, "description": description}) + logger.info("MeDo create_app: name=%s template=%s → %s", name, template, result) + return result + except Exception as exc: + logger.exception("MeDo create_app failed") + return {"error": str(exc)} + + +@tool +async def update_medo_app(app_id: str, content: dict) -> dict: + """Push content or configuration changes to an existing MeDo application. + + Args: + app_id: The MeDo application ID returned by create_medo_app. + content: Dict of fields to update (e.g. {"title": "...", "nodes": [...]}). + + Returns: + dict with 'status' on success, 'error' key on failure. + """ + if not app_id: + return {"error": "app_id is required"} + if not content: + return {"error": "content must be a non-empty dict"} + try: + result = await _mock_http_post(f"/apps/{app_id}", content) + logger.info("MeDo update_app: app_id=%s keys=%s → %s", app_id, list(content.keys()), result) + return result + except Exception as exc: + logger.exception("MeDo update_app failed") + return {"error": str(exc)} + + +@tool +async def publish_medo_app(app_id: str, environment: str = "production") -> dict: + """Publish a MeDo application to a target environment. + + Args: + app_id: The MeDo application ID to publish. + environment: Target — production | staging (default: production). + + Returns: + dict with 'status' on success, 'error' key on failure. + """ + if not app_id: + return {"error": "app_id is required"} + if environment not in _VALID_ENVS: + return {"error": f"environment must be one of: {', '.join(_VALID_ENVS)}"} + try: + result = await _mock_http_post(f"/apps/{app_id}/publish", {"environment": environment}) + logger.info("MeDo publish_app: app_id=%s env=%s → %s", app_id, environment, result) + return result + except Exception as exc: + logger.exception("MeDo publish_app failed") + return {"error": str(exc)} diff --git a/molecule_runtime/builtin_tools/memory.py b/molecule_runtime/builtin_tools/memory.py new file mode 100644 index 0000000..0d36f97 --- /dev/null +++ b/molecule_runtime/builtin_tools/memory.py @@ -0,0 +1,468 @@ +"""HMA memory tools for agents. + +Hierarchical Memory Architecture: +- LOCAL: private to this workspace, invisible to others +- TEAM: shared with parent + siblings (same team) +- GLOBAL: readable by all, writable by root workspaces only + +RBAC enforcement +---------------- +``commit_memory`` requires the ``"memory.write"`` action. +``search_memory`` requires the ``"memory.read"`` action. +Roles are read from ``config.yaml`` under ``rbac.roles`` (default: operator). + +Audit trail +----------- +Every memory operation appends a JSON Lines record to the audit log: + + memory / memory.write / allowed — write permitted by RBAC + memory / memory.write / success — write committed successfully + memory / memory.write / failure — write failed (platform error) + memory / memory.read / allowed — read permitted by RBAC + memory / memory.read / success — search returned results + memory / memory.read / failure — search failed (platform error) + +RBAC denials emit ``rbac / rbac.deny / denied`` events instead. +""" + +import json +import os +import uuid +from types import SimpleNamespace +from typing import Any + +from langchain_core.tools import tool +from builtin_tools.awareness_client import build_awareness_client +from builtin_tools.audit import check_permission, get_workspace_roles, log_event +from builtin_tools.telemetry import MEMORY_QUERY, MEMORY_SCOPE, WORKSPACE_ID_ATTR, get_tracer + +try: # pragma: no cover - optional runtime dependency in lightweight test envs + import httpx # type: ignore +except ImportError: # pragma: no cover + httpx = SimpleNamespace(AsyncClient=None) + +PLATFORM_URL = os.environ.get("PLATFORM_URL", "http://platform:8080") +WORKSPACE_ID = os.environ.get("WORKSPACE_ID", "") + + +@tool +async def commit_memory(content: str, scope: str = "LOCAL") -> dict: + """Store a fact in memory with a specific scope. + + Args: + content: The fact or knowledge to remember. + scope: Memory scope — LOCAL (private), TEAM (shared with team), or GLOBAL (company-wide, root only). + """ + trace_id = str(uuid.uuid4()) + scope = scope.upper() + if scope not in ("LOCAL", "TEAM", "GLOBAL"): + return {"error": "scope must be LOCAL, TEAM, or GLOBAL"} + + # --- RBAC check ----------------------------------------------------------- + roles, custom_perms = get_workspace_roles() + if not check_permission("memory.write", roles, custom_perms): + log_event( + event_type="rbac", + action="rbac.deny", + resource=scope, + outcome="denied", + trace_id=trace_id, + attempted_action="memory.write", + roles=roles, + ) + return { + "success": False, + "error": ( + "RBAC: this workspace does not have the 'memory.write' permission. " + f"Current roles: {roles}" + ), + } + + log_event( + event_type="memory", + action="memory.write", + resource=scope, + outcome="allowed", + trace_id=trace_id, + memory_scope=scope, + content_length=len(content), + ) + + # ── OTEL: memory_write span ────────────────────────────────────────────── + tracer = get_tracer() + + with tracer.start_as_current_span("memory_write") as mem_span: + mem_span.set_attribute(WORKSPACE_ID_ATTR, WORKSPACE_ID) + mem_span.set_attribute(MEMORY_SCOPE, scope) + mem_span.set_attribute("memory.content_length", len(content)) + + awareness_client = build_awareness_client() + if awareness_client is not None: + try: + result = await awareness_client.commit(content, scope) + except Exception as e: + log_event( + event_type="memory", + action="memory.write", + resource=scope, + outcome="failure", + trace_id=trace_id, + memory_scope=scope, + error=str(e), + ) + try: + mem_span.record_exception(e) + except Exception: + pass + return {"success": False, "error": str(e)} + else: + # #215-class bug: platform now gates /workspaces/:id/memories behind + # workspace auth. Import auth_headers lazily (same pattern as the + # activity-log path below) so test environments that don't ship + # platform_auth still work. + try: + from platform_auth import auth_headers as _auth + _headers = _auth() + except Exception: + _headers = {} + async with httpx.AsyncClient(timeout=10.0) as client: + try: + resp = await client.post( + f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/memories", + json={"content": content, "scope": scope}, + headers=_headers, + ) + if resp.status_code == 201: + result = {"success": True, "id": resp.json().get("id"), "scope": scope} + else: + result = {"success": False, "error": resp.json().get("error", resp.text)} + except Exception as e: + log_event( + event_type="memory", + action="memory.write", + resource=scope, + outcome="failure", + trace_id=trace_id, + memory_scope=scope, + error=str(e), + ) + try: + mem_span.record_exception(e) + except Exception: + pass + return {"success": False, "error": str(e)} + + if result.get("success"): + mem_span.set_attribute("memory.id", result.get("id") or "") + mem_span.set_attribute("memory.success", True) + log_event( + event_type="memory", + action="memory.write", + resource=scope, + outcome="success", + trace_id=trace_id, + memory_scope=scope, + memory_id=result.get("id"), + ) + # #125: surface memory writes in /activity so the Canvas + # "Agent Comms" tab shows what an agent chose to remember. + # Fire-and-forget — failure here must not poison the tool + # response since the memory write itself already succeeded. + await _record_memory_activity(scope, content, result.get("id")) + await _maybe_log_skill_promotion(content, scope, result) + else: + mem_span.set_attribute("memory.success", False) + log_event( + event_type="memory", + action="memory.write", + resource=scope, + outcome="failure", + trace_id=trace_id, + memory_scope=scope, + error=result.get("error"), + ) + + return result + + +@tool +async def search_memory(query: str = "", scope: str = "") -> dict: + """Search stored memories. + + Args: + query: Text to search for (empty returns all). + scope: Filter by scope — LOCAL, TEAM, GLOBAL, or empty for all accessible. + """ + trace_id = str(uuid.uuid4()) + scope = scope.upper() + if scope and scope not in ("LOCAL", "TEAM", "GLOBAL"): + return {"error": "scope must be LOCAL, TEAM, GLOBAL, or empty"} + + # --- RBAC check ----------------------------------------------------------- + roles, custom_perms = get_workspace_roles() + if not check_permission("memory.read", roles, custom_perms): + log_event( + event_type="rbac", + action="rbac.deny", + resource=scope or "all", + outcome="denied", + trace_id=trace_id, + attempted_action="memory.read", + roles=roles, + ) + return { + "success": False, + "error": ( + "RBAC: this workspace does not have the 'memory.read' permission. " + f"Current roles: {roles}" + ), + } + + log_event( + event_type="memory", + action="memory.read", + resource=scope or "all", + outcome="allowed", + trace_id=trace_id, + memory_scope=scope or "all", + query_length=len(query), + ) + + # ── OTEL: memory_read span ─────────────────────────────────────────────── + tracer = get_tracer() + + with tracer.start_as_current_span("memory_read") as mem_span: + mem_span.set_attribute(WORKSPACE_ID_ATTR, WORKSPACE_ID) + mem_span.set_attribute(MEMORY_SCOPE, scope or "all") + mem_span.set_attribute(MEMORY_QUERY, query[:256] if query else "") + + awareness_client = build_awareness_client() + if awareness_client is not None: + try: + result = await awareness_client.search(query, scope) + mem_span.set_attribute("memory.result_count", result.get("count", 0)) + mem_span.set_attribute("memory.success", result.get("success", False)) + log_event( + event_type="memory", + action="memory.read", + resource=scope or "all", + outcome="success" if result.get("success") else "failure", + trace_id=trace_id, + memory_scope=scope or "all", + result_count=result.get("count", 0), + ) + return result + except Exception as e: + log_event( + event_type="memory", + action="memory.read", + resource=scope or "all", + outcome="failure", + trace_id=trace_id, + memory_scope=scope or "all", + error=str(e), + ) + try: + mem_span.record_exception(e) + except Exception: + pass + return {"success": False, "error": str(e)} + + params = {} + if query: + params["q"] = query + if scope: + params["scope"] = scope.upper() + + # #215-class bug (search path): same fix as commit_memory above — + # the platform gates GET /workspaces/:id/memories behind workspace + # auth, so without auth_headers() every search silently 401s and the + # agent thinks its backlog is empty (observed on Technical Researcher + # idle-loop pilot 2026-04-15). + try: + from platform_auth import auth_headers as _auth + _headers = _auth() + except Exception: + _headers = {} + + async with httpx.AsyncClient(timeout=10.0) as client: + try: + resp = await client.get( + f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/memories", + params=params, + headers=_headers, + ) + if resp.status_code == 200: + memories = resp.json() + mem_span.set_attribute("memory.result_count", len(memories)) + mem_span.set_attribute("memory.success", True) + log_event( + event_type="memory", + action="memory.read", + resource=scope or "all", + outcome="success", + trace_id=trace_id, + memory_scope=scope or "all", + result_count=len(memories), + ) + return { + "success": True, + "count": len(memories), + "memories": memories, + } + mem_span.set_attribute("memory.success", False) + log_event( + event_type="memory", + action="memory.read", + resource=scope or "all", + outcome="failure", + trace_id=trace_id, + memory_scope=scope or "all", + http_status=resp.status_code, + ) + return {"success": False, "error": resp.json().get("error", resp.text)} + except Exception as e: + log_event( + event_type="memory", + action="memory.read", + resource=scope or "all", + outcome="failure", + trace_id=trace_id, + memory_scope=scope or "all", + error=str(e), + ) + try: + mem_span.record_exception(e) + except Exception: + pass + return {"success": False, "error": str(e)} + + +def _parse_promotion_packet(content: str) -> dict[str, Any] | None: + """Return a structured memory packet when content looks like promotion metadata.""" + text = content.strip() + if not text.startswith("{"): + return None + + try: + payload = json.loads(text) + except json.JSONDecodeError: + return None + + if not isinstance(payload, dict): # pragma: no cover + return None + if not payload.get("promote_to_skill"): + return None + + return payload + + +async def _record_memory_activity(scope: str, content: str, memory_id: str | None) -> None: + """Surface a successful memory write as an activity row so the Canvas + "Agent Comms" tab can display what an agent chose to remember. + Fire-and-forget — never raises. #125. + + The summary is intentionally short (scope tag + first 80 chars of + content with a ``…`` ellipsis when truncated) so the activity table + stays readable; full content lives in ``agent_memories``. + """ + workspace_id = WORKSPACE_ID.strip() + platform_url = PLATFORM_URL.strip().rstrip("/") + if not workspace_id or not platform_url: + return + + preview = content.strip().replace("\n", " ") + if len(preview) > 80: + preview = preview[:80] + "…" + summary = f"[{scope}] {preview}" + + # NOTE: target_id is a UUID column scoped to workspace_id references — + # cannot hold awareness/memory IDs (which are arbitrary strings). + # We embed the memory_id in the summary instead so it's still searchable. + if memory_id: + summary = f"{summary} (id={memory_id[:24]})" + payload: dict[str, Any] = { + "workspace_id": workspace_id, + "activity_type": "memory_write", + "summary": summary, + "status": "ok", + } + + try: + try: + from platform_auth import auth_headers as _auth + _headers = _auth() + except Exception: + _headers = {} + async with httpx.AsyncClient(timeout=5.0) as client: + await client.post( + f"{platform_url}/workspaces/{workspace_id}/activity", + json=payload, + headers=_headers, + ) + except Exception: + # Activity logging is purely observability — never poison the + # tool response on a failure here. We don't even log_event the + # failure since the memory write itself succeeded and that's + # what matters to the caller. + pass + + +async def _maybe_log_skill_promotion(content: str, scope: str, memory_result: dict) -> None: + """Best-effort activity log for durable memory entries that should become skills.""" + packet = _parse_promotion_packet(content) + if packet is None: + return + + workspace_id = WORKSPACE_ID.strip() + platform_url = PLATFORM_URL.strip().rstrip("/") + if not workspace_id or not platform_url: + return + + repetition_signal = packet.get("repetition_signal") + summary = ( + packet.get("summary") + or packet.get("title") + or packet.get("what changed") + or "Repeatable workflow promoted to skill candidate" + ) + metadata: dict[str, Any] = { + "source": "memory-curation", + "scope": scope, + "memory_id": memory_result.get("id"), + "promote_to_skill": True, + "repetition_signal": repetition_signal, + "memory_packet": packet, + } + + payload = { + "activity_type": "skill_promotion", + "method": "memory/skill-promotion", + "summary": summary, + "status": "ok", + "source_id": workspace_id, + "request_body": packet, + "metadata": metadata, + } + + try: + async with httpx.AsyncClient(timeout=5.0) as client: + await client.post( + f"{platform_url}/workspaces/{workspace_id}/activity", + json=payload, + ) + await client.post( + f"{platform_url}/registry/heartbeat", + json={ + "workspace_id": workspace_id, + "error_rate": 0, + "sample_error": "", + "active_tasks": 1, + "uptime_seconds": 0, + "current_task": f"Skill promotion: {summary}", + }, + ) + except Exception: + # Best-effort observability only. Memory commits must never fail because + # the promotion log could not be written. + return diff --git a/molecule_runtime/builtin_tools/sandbox.py b/molecule_runtime/builtin_tools/sandbox.py new file mode 100644 index 0000000..dc1fd37 --- /dev/null +++ b/molecule_runtime/builtin_tools/sandbox.py @@ -0,0 +1,281 @@ +"""Code sandbox tool for safe code execution. + +Executes code in an isolated environment. Three backends are supported: + +subprocess (default) + Runs code locally via asyncio subprocess with a hard timeout. + Best for Tier 1/2 agents where run_code is lightly used and the + workspace container itself is the isolation boundary. + +docker + Throwaway Docker-in-Docker container: network disabled, memory capped, + read-only filesystem. Requires Docker socket access inside the container. + Best for Tier 3 on-prem deployments. + +e2b + Cloud-hosted microVM sandbox via E2B (https://e2b.dev). + No local Docker required — code runs in E2B's isolated cloud VMs. + Supports Python and JavaScript. + Requires: + - e2b-code-interpreter Python package (pinned in requirements.txt) + - E2B_API_KEY workspace secret (set via canvas Secrets panel or API) + Best for hosted/cloud Molecule AI deployments. + +Backend is selected via the SANDBOX_BACKEND env var, which the provisioner +sets from config.yaml → sandbox.backend. Default: "subprocess". +""" + +import asyncio +import logging +import os +import tempfile + +from langchain_core.tools import tool + +logger = logging.getLogger(__name__) + +SANDBOX_BACKEND = os.environ.get("SANDBOX_BACKEND", "subprocess") +SANDBOX_TIMEOUT = int(os.environ.get("SANDBOX_TIMEOUT", "30")) +SANDBOX_MEMORY_LIMIT = os.environ.get("SANDBOX_MEMORY_LIMIT", "256m") +MAX_OUTPUT = 10_000 + +# E2B kernel names differ from internal language names. +_E2B_KERNEL_MAP = { + "python": "python3", + "javascript": "js", + "js": "js", +} + + +@tool +async def run_code(code: str, language: str = "python") -> dict: + """Execute code in an isolated sandbox and return the output. + + Args: + code: The code to execute. + language: Programming language — python, javascript, or shell. + The e2b backend supports python and javascript only. + """ + if SANDBOX_BACKEND == "docker": + return await _run_docker(code, language) + elif SANDBOX_BACKEND == "e2b": + return await _run_e2b(code, language) + else: + return await _run_subprocess(code, language) + + +async def _run_subprocess(code: str, language: str) -> dict: + """Fallback: run code in a subprocess with timeout.""" + cmd_map = { + "python": ["python3", "-c"], + "javascript": ["node", "-e"], + "shell": ["sh", "-c"], + "bash": ["bash", "-c"], + } + + cmd_prefix = cmd_map.get(language) + if not cmd_prefix: + return {"error": f"Unsupported language: {language}", "exit_code": -1} + + try: + proc = await asyncio.create_subprocess_exec( + *cmd_prefix, code, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + + stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=SANDBOX_TIMEOUT) + + return { + "exit_code": proc.returncode, + "stdout": stdout.decode("utf-8", errors="replace")[:MAX_OUTPUT], + "stderr": stderr.decode("utf-8", errors="replace")[:MAX_OUTPUT], + "language": language, + "backend": "subprocess", + } + except asyncio.TimeoutError: + try: + proc.kill() + await proc.wait() + except ProcessLookupError: + pass + return {"error": f"Timeout after {SANDBOX_TIMEOUT}s", "exit_code": -1} + except Exception as e: + return {"error": str(e), "exit_code": -1} + + +async def _run_docker(code: str, language: str) -> dict: + """Run code in a throwaway Docker container via mounted temp file.""" + image_map = { + "python": ("python:3.11-slim", ["python3", "/sandbox/code.py"]), + "javascript": ("node:20-slim", ["node", "/sandbox/code.js"]), + "shell": ("alpine:3.18", ["sh", "/sandbox/code.sh"]), + "bash": ("alpine:3.18", ["sh", "/sandbox/code.sh"]), + } + + entry = image_map.get(language) + if not entry: + return {"error": f"Unsupported language: {language}", "exit_code": -1} + + image, run_cmd = entry + code_file = None + + try: + # Write code to temp file — avoids shell metacharacter injection + ext = {"python": ".py", "javascript": ".js", "shell": ".sh", "bash": ".sh"}.get(language, ".txt") + fd, code_file = tempfile.mkstemp(suffix=ext, prefix="sandbox_") + with os.fdopen(fd, "w") as f: + f.write(code) + + cmd = [ + "docker", "run", "--rm", + "--network", "none", + "--memory", SANDBOX_MEMORY_LIMIT, + "--cpus", "0.5", + "--read-only", + "--tmpfs", "/tmp:size=32m", + "-v", f"{code_file}:/sandbox/code{ext}:ro", + image, + ] + run_cmd + + proc = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + + stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=SANDBOX_TIMEOUT) + + return { + "exit_code": proc.returncode, + "stdout": stdout.decode("utf-8", errors="replace")[:MAX_OUTPUT], + "stderr": stderr.decode("utf-8", errors="replace")[:MAX_OUTPUT], + "language": language, + "backend": "docker", + "image": image, + } + except asyncio.TimeoutError: + return {"error": f"Timeout after {SANDBOX_TIMEOUT}s", "exit_code": -1} + except Exception as e: + return {"error": str(e), "exit_code": -1} + finally: + if code_file: + try: + os.unlink(code_file) + except OSError: + pass + + +async def _run_e2b(code: str, language: str) -> dict: + """Run code in an E2B cloud microVM sandbox. + + Requires the e2b-code-interpreter package and an E2B_API_KEY secret. + Each call creates a fresh sandbox, runs the code, and destroys the sandbox. + Sandbox lifetime is bounded by SANDBOX_TIMEOUT seconds. + + Supported languages: python, javascript. + """ + # Import lazily so the package is only required when the e2b backend is + # actually configured — other backends work without it installed. + try: + from e2b_code_interpreter import Sandbox + except ImportError: + return { + "error": ( + "e2b-code-interpreter is not installed. " + "Add it to requirements.txt or switch to the docker/subprocess backend." + ), + "exit_code": -1, + } + + api_key = os.environ.get("E2B_API_KEY") + if not api_key: + return { + "error": ( + "E2B_API_KEY is not set. " + "Add it as a workspace secret via the canvas Secrets panel or platform API." + ), + "exit_code": -1, + } + + kernel = _E2B_KERNEL_MAP.get(language) + if kernel is None: + return { + "error": ( + f"Language '{language}' is not supported by the e2b backend. " + "Supported: python, javascript." + ), + "exit_code": -1, + } + + sandbox = None + try: + # Create a fresh sandbox for this execution. + # timeout controls the sandbox lifetime in seconds. + sandbox = await asyncio.wait_for( + asyncio.get_running_loop().run_in_executor( + None, + lambda: Sandbox(api_key=api_key, timeout=SANDBOX_TIMEOUT), + ), + timeout=SANDBOX_TIMEOUT, + ) + + # Execute code and collect results. + execution = await asyncio.wait_for( + asyncio.get_running_loop().run_in_executor( + None, + lambda: sandbox.run_code(code, language=kernel), + ), + timeout=SANDBOX_TIMEOUT, + ) + + # E2B returns a list of Result objects; collect text/error output. + stdout_parts = [] + stderr_parts = [] + + for result in execution.results: + # result.text is the primary output (stdout equivalent) + if hasattr(result, "text") and result.text: + stdout_parts.append(str(result.text)) + # Some result types expose an error attribute + if hasattr(result, "error") and result.error: + stderr_parts.append(str(result.error)) + + # Logs are stored separately in execution.logs + if hasattr(execution, "logs"): + logs = execution.logs + if hasattr(logs, "stdout") and logs.stdout: + stdout_parts.extend(logs.stdout) + if hasattr(logs, "stderr") and logs.stderr: + stderr_parts.extend(logs.stderr) + + combined_stdout = "".join(stdout_parts)[:MAX_OUTPUT] + combined_stderr = "".join(stderr_parts)[:MAX_OUTPUT] + + # Treat any stderr output as a non-zero exit code (e2b doesn't expose + # a numeric exit code at the sandbox level). + exit_code = 1 if combined_stderr else 0 + + return { + "exit_code": exit_code, + "stdout": combined_stdout, + "stderr": combined_stderr, + "language": language, + "backend": "e2b", + } + + except asyncio.TimeoutError: + logger.warning("E2B sandbox timed out after %ds", SANDBOX_TIMEOUT) + return {"error": f"Timeout after {SANDBOX_TIMEOUT}s", "exit_code": -1} + except Exception as e: + logger.exception("E2B sandbox error: %s", e) + return {"error": str(e), "exit_code": -1} + finally: + # Always destroy the sandbox to avoid leaking E2B credits. + if sandbox is not None: + try: + await asyncio.get_running_loop().run_in_executor( + None, sandbox.kill + ) + except Exception: + pass # Best-effort cleanup diff --git a/molecule_runtime/builtin_tools/security_scan.py b/molecule_runtime/builtin_tools/security_scan.py new file mode 100644 index 0000000..214e5fb --- /dev/null +++ b/molecule_runtime/builtin_tools/security_scan.py @@ -0,0 +1,344 @@ +"""Skill dependency security scanner — supply-chain risk management. + +Scans a skill's ``requirements.txt`` for known CVEs before the skill is +loaded into the workspace. Two scanners are supported: + + Snyk CLI — ``snyk test --file=requirements.txt --json`` + Preferred; requires the ``snyk`` binary in PATH and + a SNYK_TOKEN env var for authenticated scans. + + pip-audit — ``pip-audit -r requirements.txt --json`` + Fallback; no authentication required. + +The scanner is auto-selected: Snyk if available, pip-audit otherwise. +If neither is present in PATH the scan is silently skipped with a log line. + +Scan mode (``security_scan.mode`` in config.yaml): + + block — raise ``SkillSecurityError`` when critical/high CVEs are found; + the skill is *not* loaded. + warn — log a WARNING + audit event; the skill is loaded anyway. + off — skip scanning entirely; useful in air-gapped CI. + +Audit trail +----------- +Every scan (pass or fail) is recorded via ``tools.audit.log_event`` with +``event_type="security_scan"``, enabling compliance reports to prove that +all loaded skills were checked before activation. +""" + +from __future__ import annotations + +import json +import logging +import shutil +import subprocess +from dataclasses import dataclass, field +from pathlib import Path +from typing import Optional + +from builtin_tools.audit import log_event + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Public exception +# --------------------------------------------------------------------------- + + +class SkillSecurityError(RuntimeError): + """Raised when a skill fails security scanning in ``block`` mode. + + The message contains the skill name, scanner used, and a summary of the + critical/high findings so operators can act on it immediately. + """ + + +# --------------------------------------------------------------------------- +# Data models +# --------------------------------------------------------------------------- + + +@dataclass +class CVEFinding: + """A single vulnerability finding from a security scanner.""" + + vuln_id: str + """CVE or advisory identifier, e.g. ``SNYK-PYTHON-REQUESTS-1234``.""" + package: str + """Affected package name.""" + version: str + """Installed version of the package.""" + severity: str + """One of: critical | high | medium | low | unknown.""" + description: str + """Short human-readable summary (≤ 200 chars).""" + + +@dataclass +class ScanResult: + """Aggregated result of a single skill dependency scan.""" + + skill_name: str + scanner: str + """Scanner used: ``"snyk"`` | ``"pip-audit"`` | ``"none"``.""" + requirements_file: Optional[str] + """Absolute path to the scanned requirements.txt, or ``None``.""" + findings: list[CVEFinding] = field(default_factory=list) + scan_error: Optional[str] = None + """Non-fatal scanner error (e.g. timeout); findings may be incomplete.""" + + @property + def critical_or_high(self) -> list[CVEFinding]: + return [f for f in self.findings if f.severity in ("critical", "high")] + + @property + def has_critical_or_high(self) -> bool: + return bool(self.critical_or_high) + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _find_requirements(skill_path: Path) -> Optional[Path]: + """Return the first ``requirements.txt`` found in the skill tree.""" + for candidate in ( + skill_path / "requirements.txt", + skill_path / "tools" / "requirements.txt", + ): + if candidate.exists(): + return candidate + return None + + +def _run_scanner(cmd: list[str], timeout: int = 120) -> tuple[str, Optional[str]]: + """Run a scanner subprocess and return ``(stdout, error_or_None)``.""" + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=timeout, + ) + # Both Snyk and pip-audit exit 1 when vulns are found — not an error. + # Exit 2 from Snyk means a genuine scan failure. + if result.returncode == 2 and not result.stdout.strip(): + return "", f"scanner exited 2: {result.stderr.strip()[:200]}" + return result.stdout, None + except subprocess.TimeoutExpired: + return "", f"scanner timed out after {timeout}s" + except FileNotFoundError as exc: + return "", str(exc) + except Exception as exc: # pylint: disable=broad-except + return "", str(exc) + + +def _parse_snyk(stdout: str) -> tuple[list[CVEFinding], Optional[str]]: + """Parse ``snyk test --json`` output.""" + if not stdout.strip(): + return [], "empty snyk output" + try: + data = json.loads(stdout) + except json.JSONDecodeError as exc: + return [], f"snyk JSON parse error: {exc}" + + vulns = data.get("vulnerabilities", []) + findings = [ + CVEFinding( + vuln_id=v.get("id", "UNKNOWN"), + package=v.get("packageName", "?"), + version=v.get("version", "?"), + severity=v.get("severity", "unknown").lower(), + description=(v.get("title", "") or "")[:200], + ) + for v in vulns + if isinstance(v, dict) + ] + return findings, None + + +def _parse_pip_audit(stdout: str) -> tuple[list[CVEFinding], Optional[str]]: + """Parse ``pip-audit --json`` output. + + pip-audit does not always provide a CVSS severity level. When absent we + conservatively classify the finding as ``"high"`` so it is not silently + ignored in ``warn`` mode. + """ + if not stdout.strip(): + return [], "empty pip-audit output" + try: + data = json.loads(stdout) + except json.JSONDecodeError as exc: + return [], f"pip-audit JSON parse error: {exc}" + + # pip-audit ≥ 2.x wraps results in {"dependencies": [...]} + if isinstance(data, dict): + deps = data.get("dependencies", []) + else: + deps = data # older versions return a bare list + + findings: list[CVEFinding] = [] + for dep in deps: + if not isinstance(dep, dict): + continue + for vuln in dep.get("vulns", []): + sev_raw = vuln.get("fix_versions") and "high" # pip-audit lacks severity + sev = (vuln.get("severity") or sev_raw or "high").lower() + findings.append( + CVEFinding( + vuln_id=vuln.get("id", "UNKNOWN"), + package=dep.get("name", "?"), + version=dep.get("version", "?"), + severity=sev, + description=(vuln.get("description", "") or "")[:200], + ) + ) + return findings, None + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def scan_skill_dependencies( + skill_name: str, + skill_path: Path, + mode: str, + fail_open_if_no_scanner: bool = True, +) -> ScanResult: + """Scan a skill's dependency file for known CVEs. + + Args: + skill_name: Name of the skill (used in log messages and audit events). + skill_path: Absolute path to the skill's root directory. + mode: ``"block"`` | ``"warn"`` | ``"off"`` + fail_open_if_no_scanner: + When *True* (default) silently skip scanning if neither snyk nor + pip-audit is in PATH. When *False* and ``mode="block"``, raise + :class:`SkillSecurityError` so operators know the gate is absent. + Corresponds to ``security_scan.fail_open_if_no_scanner`` in + config.yaml. Closes #268. + + Returns: + A :class:`ScanResult` describing what was found. + + Raises: + :class:`SkillSecurityError`: When ``mode="block"`` and one or more + critical/high severity CVEs are found — OR when + ``mode="block"`` and ``fail_open_if_no_scanner=False`` and no + scanner is available. + """ + if mode == "off": + return ScanResult(skill_name=skill_name, scanner="none", requirements_file=None) + + req_file = _find_requirements(skill_path) + if req_file is None: + # No requirements file — nothing to scan; not a problem. + return ScanResult(skill_name=skill_name, scanner="none", requirements_file=None) + + # ── Select scanner ──────────────────────────────────────────────────────── + scanner_name: str + findings: list[CVEFinding] + scan_error: Optional[str] + + if shutil.which("snyk"): + scanner_name = "snyk" + stdout, run_error = _run_scanner( + ["snyk", "test", f"--file={req_file}", "--json"] + ) + if run_error: + findings, scan_error = [], run_error + else: + findings, scan_error = _parse_snyk(stdout) + + elif shutil.which("pip-audit"): + scanner_name = "pip-audit" + stdout, run_error = _run_scanner( + ["pip-audit", "-r", str(req_file), "--json", "--progress-spinner=off"] + ) + if run_error: + findings, scan_error = [], run_error + else: + findings, scan_error = _parse_pip_audit(stdout) + + else: + logger.info( + "security_scan: no scanner (snyk, pip-audit) in PATH — skipping %s", + skill_name, + ) + log_event( + event_type="security_scan", + action="skill.security_scan", + resource=skill_name, + outcome="skipped", + reason="no_scanner_in_path", + requirements_file=str(req_file), + mode=mode, + ) + # #268: if fail_open_if_no_scanner=False and mode=block, the operator + # explicitly opted in to "fail closed" — raise so the missing scanner + # is visible rather than silently skipped. + if not fail_open_if_no_scanner and mode == "block": + raise SkillSecurityError( + f"Skill '{skill_name}' blocked: no scanner (snyk or pip-audit) " + f"found in PATH and fail_open_if_no_scanner=false" + ) + return ScanResult( + skill_name=skill_name, + scanner="none", + requirements_file=str(req_file), + scan_error="No scanner (snyk or pip-audit) found in PATH", + ) + + result = ScanResult( + skill_name=skill_name, + scanner=scanner_name, + requirements_file=str(req_file), + findings=findings, + scan_error=scan_error, + ) + + # ── Log scan outcome to audit trail ────────────────────────────────────── + audit_outcome = "clean" if not result.has_critical_or_high else "vulnerable" + log_event( + event_type="security_scan", + action="skill.security_scan", + resource=skill_name, + outcome=audit_outcome, + scanner=scanner_name, + requirements_file=str(req_file), + total_findings=len(findings), + critical_or_high_count=len(result.critical_or_high), + scan_error=scan_error, + ) + + if scan_error: + logger.warning( + "security_scan: scanner error for skill '%s': %s", skill_name, scan_error + ) + + # ── Enforce mode ───────────────────────────────────────────────────────── + if result.has_critical_or_high: + summary = ", ".join( + f"{f.vuln_id}({f.severity}) in {f.package}@{f.version}" + for f in result.critical_or_high[:5] + ) + if len(result.critical_or_high) > 5: + summary += f" … and {len(result.critical_or_high) - 5} more" + + msg = ( + f"Skill '{skill_name}' has {len(result.critical_or_high)} " + f"critical/high CVE(s) [{scanner_name}]: {summary}" + ) + + if mode == "block": + logger.error("Blocking skill load — %s", msg) + raise SkillSecurityError(msg) + + # warn mode — continue loading, but make noise + logger.warning("Security warning — %s", msg) + + return result diff --git a/molecule_runtime/builtin_tools/telemetry.py b/molecule_runtime/builtin_tools/telemetry.py new file mode 100644 index 0000000..7b2e3d0 --- /dev/null +++ b/molecule_runtime/builtin_tools/telemetry.py @@ -0,0 +1,418 @@ +"""OpenTelemetry (OTEL) instrumentation for the Molecule AI workspace runtime. + +Architecture +------------ +* One global ``TracerProvider`` is initialised at startup via ``setup_telemetry()``. +* Up to three exporters are wired in: + 1. **OTLP/HTTP** — activated when ``OTEL_EXPORTER_OTLP_ENDPOINT`` is set. + Point this at any compatible collector (Jaeger, Tempo, Grafana OTEL, …). + 2. **Langfuse OTLP bridge** — activated when the ``LANGFUSE_HOST``, + ``LANGFUSE_PUBLIC_KEY`` and ``LANGFUSE_SECRET_KEY`` env vars are all present. + Langfuse ≥4 accepts OTLP/HTTP at ``/api/public/otel``. + This is a *second* exporter alongside the existing Langfuse LangChain + callback handler in agent.py — both paths emit spans simultaneously. + 3. **Console** (debug) — activated when ``OTEL_DEBUG=1``. + +* **W3C TraceContext** propagation (``traceparent`` / ``tracestate``) is used for + cross-workspace context injection and extraction so A2A hops form a single + distributed trace. + +* ``make_trace_middleware()`` returns an ASGI middleware that extracts incoming + trace context from HTTP headers and stores it in a ``ContextVar`` so the + A2A executor can access it to parent its spans correctly. + +GenAI semantic conventions +-------------------------- +Attribute constants for ``gen_ai.*`` follow OpenTelemetry GenAI SemConv 1.26. + +Usage example +------------- + # main.py — call once at startup + from builtin_tools.telemetry import setup_telemetry, make_trace_middleware + setup_telemetry(service_name=workspace_id) + instrumented = make_trace_middleware(app.build()) + + # Any module + from builtin_tools.telemetry import get_tracer + tracer = get_tracer() + with tracer.start_as_current_span("my_span") as span: + span.set_attribute("key", "value") + + # Outgoing HTTP — inject W3C headers + from builtin_tools.telemetry import inject_trace_headers + headers = inject_trace_headers({"Content-Type": "application/json"}) + await client.post(url, headers=headers, ...) + + # Incoming HTTP — extract context (done automatically by middleware) + from builtin_tools.telemetry import extract_trace_context + ctx = extract_trace_context(dict(request.headers)) +""" + +from __future__ import annotations + +import base64 +import logging +import os +from contextvars import ContextVar +from typing import Any, Optional + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# GenAI Semantic Convention attribute keys (OTel SemConv 1.26) +# https://opentelemetry.io/docs/specs/semconv/gen-ai/ +# --------------------------------------------------------------------------- +GEN_AI_SYSTEM = "gen_ai.system" +GEN_AI_REQUEST_MODEL = "gen_ai.request.model" +GEN_AI_OPERATION_NAME = "gen_ai.operation.name" +GEN_AI_USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens" +GEN_AI_USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens" +GEN_AI_RESPONSE_FINISH_REASONS = "gen_ai.response.finish_reasons" + +# --------------------------------------------------------------------------- +# Workspace / A2A attribute keys +# --------------------------------------------------------------------------- +WORKSPACE_ID_ATTR = "workspace.id" +A2A_SOURCE_WORKSPACE = "a2a.source_workspace_id" +A2A_TARGET_WORKSPACE = "a2a.target_workspace_id" +A2A_TASK_ID = "a2a.task_id" +MEMORY_SCOPE = "memory.scope" +MEMORY_QUERY = "memory.query" + +# --------------------------------------------------------------------------- +# Module-level state +# --------------------------------------------------------------------------- +WORKSPACE_ID: str = os.environ.get("WORKSPACE_ID", "unknown") + +_initialized: bool = False +_tracer: Any = None # opentelemetry.trace.Tracer | _NoopTracer + +# ContextVar that carries incoming trace context from the ASGI middleware to +# the A2A executor. Using a ContextVar (rather than a global) is safe with +# asyncio because each task inherits a copy of the context at creation time. +_incoming_trace_context: ContextVar[Optional[Any]] = ContextVar( + "otel_incoming_trace_context", default=None +) + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +def setup_telemetry(service_name: Optional[str] = None) -> None: + """Initialise the global ``TracerProvider``. Safe to call multiple times. + + Reads configuration from environment variables: + + ``OTEL_EXPORTER_OTLP_ENDPOINT`` + Base URL of an OTLP-compatible collector (e.g. ``http://jaeger:4318``). + Spans are sent to ``/v1/traces``. + + ``LANGFUSE_HOST`` + ``LANGFUSE_PUBLIC_KEY`` + ``LANGFUSE_SECRET_KEY`` + When all three are set, a second OTLP exporter is wired to Langfuse's + ingest endpoint using HTTP Basic auth. + + ``OTEL_DEBUG`` + Set to ``1`` / ``true`` to also print spans to stdout. + """ + global _initialized, _tracer + + if _initialized: + return + + try: + from opentelemetry import propagate, trace + from opentelemetry.baggage.propagation import W3CBaggagePropagator + from opentelemetry.propagators.composite import CompositePropagator + from opentelemetry.sdk.resources import SERVICE_NAME as OTEL_SERVICE_NAME + from opentelemetry.sdk.resources import Resource + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter + from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator + except ImportError as exc: + logger.warning( + "OTEL: opentelemetry packages not installed — telemetry disabled. " + "Add opentelemetry-api, opentelemetry-sdk, " + "opentelemetry-exporter-otlp-proto-http to requirements.txt. " + "Error: %s", + exc, + ) + return + + svc = service_name or f"molecule-{WORKSPACE_ID}" + + resource = Resource.create( + { + OTEL_SERVICE_NAME: svc, + "service.version": "1.0.0", + WORKSPACE_ID_ATTR: WORKSPACE_ID, + } + ) + + provider = TracerProvider(resource=resource) + + # -- Exporter 1: Generic OTLP/HTTP ---------------------------------------- + otlp_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT", "").rstrip("/") + if otlp_endpoint: + try: + from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter + + exporter = OTLPSpanExporter(endpoint=f"{otlp_endpoint}/v1/traces") + provider.add_span_processor(BatchSpanProcessor(exporter)) + logger.info("OTEL: OTLP/HTTP exporter → %s", otlp_endpoint) + except ImportError: + logger.warning( + "OTEL: OTEL_EXPORTER_OTLP_ENDPOINT is set but " + "opentelemetry-exporter-otlp-proto-http is not installed" + ) + except Exception as exc: + logger.warning("OTEL: OTLP exporter init failed: %s", exc) + + # -- Exporter 2: Langfuse OTLP bridge ------------------------------------- + # Langfuse ≥4 accepts OTLP at /api/public/otel (Basic auth). + lf_host = os.environ.get("LANGFUSE_HOST", "").rstrip("/") + lf_public = os.environ.get("LANGFUSE_PUBLIC_KEY", "") + lf_secret = os.environ.get("LANGFUSE_SECRET_KEY", "") + + if lf_host and lf_public and lf_secret: + try: + from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter + + lf_endpoint = f"{lf_host}/api/public/otel/v1/traces" + token = base64.b64encode(f"{lf_public}:{lf_secret}".encode()).decode() + lf_exporter = OTLPSpanExporter( + endpoint=lf_endpoint, + headers={"Authorization": f"Basic {token}"}, + ) + provider.add_span_processor(BatchSpanProcessor(lf_exporter)) + logger.info("OTEL: Langfuse OTLP bridge → %s", lf_endpoint) + except ImportError: + logger.warning( + "OTEL: Langfuse env vars set but " + "opentelemetry-exporter-otlp-proto-http is not installed" + ) + except Exception as exc: + logger.warning("OTEL: Langfuse OTLP bridge init failed: %s", exc) + + # -- Exporter 3: Console (debug) ------------------------------------------ + if os.environ.get("OTEL_DEBUG", "").lower() in ("1", "true", "yes"): + provider.add_span_processor(BatchSpanProcessor(ConsoleSpanExporter())) + logger.info("OTEL: console debug exporter enabled") + + # -- Register global provider + W3C propagators --------------------------- + trace.set_tracer_provider(provider) + propagate.set_global_textmap( + CompositePropagator( + [ + TraceContextTextMapPropagator(), + W3CBaggagePropagator(), + ] + ) + ) + + _tracer = trace.get_tracer( + "molecule.workspace", + schema_url="https://opentelemetry.io/schemas/1.26.0", + ) + _initialized = True + logger.info("OTEL: telemetry initialised for service '%s'", svc) + + +def get_tracer() -> Any: + """Return the global ``Tracer``. Lazily calls ``setup_telemetry()`` if needed. + + Returns a no-op tracer when the opentelemetry packages are not installed so + that instrumented code never raises ``ImportError``. + """ + global _tracer + + if not _initialized: + setup_telemetry() + + if _tracer is None: + # Packages unavailable — hand back a no-op implementation + try: + from opentelemetry import trace + + return trace.get_tracer("molecule.noop") + except ImportError: + return _NoopTracer() + + return _tracer + + +def inject_trace_headers(headers: dict) -> dict: + """Inject W3C ``traceparent`` / ``tracestate`` into *headers* and return it. + + Mutates the dict in-place so it can be used directly:: + + headers = inject_trace_headers({"Content-Type": "application/json"}) + await client.post(url, headers=headers, ...) + """ + try: + from opentelemetry import propagate + + propagate.inject(headers) + except Exception: + pass # Never let telemetry break the caller + return headers + + +def extract_trace_context(carrier: dict) -> Any: + """Extract W3C trace context from a header mapping. + + Returns an OpenTelemetry ``Context`` object suitable for:: + + tracer.start_as_current_span("name", context=ctx) + + Returns ``None`` when packages are unavailable or no context is present. + """ + try: + from opentelemetry import propagate + + return propagate.extract(carrier) + except Exception: + return None + + +def get_current_traceparent() -> Optional[str]: + """Return the W3C ``traceparent`` string for the active span, or ``None``.""" + try: + from opentelemetry import trace + + span = trace.get_current_span() + ctx = span.get_span_context() + if not ctx.is_valid: + return None + trace_id = format(ctx.trace_id, "032x") + span_id = format(ctx.span_id, "016x") + flags = "01" if ctx.trace_flags else "00" + return f"00-{trace_id}-{span_id}-{flags}" + except Exception: + return None + + +def make_trace_middleware(asgi_app: Any) -> Any: + """Wrap an ASGI application with W3C trace-context extraction middleware. + + The middleware reads ``traceparent`` / ``tracestate`` from every incoming + HTTP request and stores the extracted ``Context`` in the + ``_incoming_trace_context`` ContextVar. The A2A executor reads that + ContextVar to parent its ``task_receive`` span correctly, forming an + unbroken distributed trace across workspace hops. + + Usage:: + + built = app.build() + instrumented = make_trace_middleware(built) + uvicorn.Config(instrumented, ...) + """ + + async def _middleware(scope: dict, receive: Any, send: Any) -> None: # type: ignore[override] + if scope.get("type") != "http": + await asgi_app(scope, receive, send) + return + + # Decode byte-headers from the ASGI scope (latin-1 per HTTP/1.1 spec) + raw_headers: list[tuple[bytes, bytes]] = scope.get("headers", []) + str_headers: dict[str, str] = { + k.decode("latin-1"): v.decode("latin-1") for k, v in raw_headers + } + + ctx = extract_trace_context(str_headers) + token = _incoming_trace_context.set(ctx) + try: + await asgi_app(scope, receive, send) + finally: + _incoming_trace_context.reset(token) + + return _middleware + + +# --------------------------------------------------------------------------- +# Helpers for GenAI attributes +# --------------------------------------------------------------------------- + +def gen_ai_system_from_model(model_str: str) -> str: + """Map a ``provider:model`` string to a ``gen_ai.system`` value.""" + if ":" not in model_str: + return "unknown" + provider = model_str.split(":", 1)[0].lower() + return { + "anthropic": "anthropic", + "openai": "openai", + "openrouter": "openrouter", + "groq": "groq", + "google_genai": "google", + "ollama": "ollama", + }.get(provider, provider) + + +def record_llm_token_usage(span: Any, result: dict) -> None: + """Extract token counts from a LangGraph ainvoke result and set span attrs. + + Handles both Anthropic (``usage``) and OpenAI (``token_usage``) metadata + shapes. Silently skips if metadata is absent. + """ + try: + messages = result.get("messages", []) + for msg in reversed(messages): + meta = getattr(msg, "response_metadata", {}) or {} + # Anthropic + usage = meta.get("usage", {}) + if usage: + inp = usage.get("input_tokens") or usage.get("prompt_tokens") + out = usage.get("output_tokens") or usage.get("completion_tokens") + if inp is not None: + span.set_attribute(GEN_AI_USAGE_INPUT_TOKENS, int(inp)) + if out is not None: + span.set_attribute(GEN_AI_USAGE_OUTPUT_TOKENS, int(out)) + return + # OpenAI + token_usage = meta.get("token_usage", {}) + if token_usage: + inp = token_usage.get("prompt_tokens") + out = token_usage.get("completion_tokens") + if inp is not None: + span.set_attribute(GEN_AI_USAGE_INPUT_TOKENS, int(inp)) + if out is not None: + span.set_attribute(GEN_AI_USAGE_OUTPUT_TOKENS, int(out)) + return + except Exception: + pass # Best-effort — never break the caller + + +# --------------------------------------------------------------------------- +# No-op fallbacks (used when opentelemetry packages are absent) +# --------------------------------------------------------------------------- + +class _NoopSpan: + """Transparent no-op span that satisfies the context-manager protocol.""" + + def set_attribute(self, key: str, value: Any) -> None: # noqa: ARG002 + pass + + def set_status(self, *args: Any, **kwargs: Any) -> None: + pass + + def record_exception(self, exc: BaseException, *args: Any, **kwargs: Any) -> None: + pass + + def add_event(self, name: str, *args: Any, **kwargs: Any) -> None: + pass + + def __enter__(self) -> "_NoopSpan": + return self + + def __exit__(self, *args: Any) -> None: + pass + + +class _NoopTracer: + """Transparent no-op tracer returned when the SDK is unavailable.""" + + def start_as_current_span(self, name: str, *args: Any, **kwargs: Any) -> _NoopSpan: # noqa: ARG002 + return _NoopSpan() + + def start_span(self, name: str, *args: Any, **kwargs: Any) -> _NoopSpan: # noqa: ARG002 + return _NoopSpan() diff --git a/molecule_runtime/builtin_tools/temporal_workflow.py b/molecule_runtime/builtin_tools/temporal_workflow.py new file mode 100644 index 0000000..bb5c049 --- /dev/null +++ b/molecule_runtime/builtin_tools/temporal_workflow.py @@ -0,0 +1,515 @@ +"""Temporal durable execution wrapper for Molecule AI A2A workspaces. + +Architecture +----------- +A co-located Temporal worker runs as an asyncio background task **inside the +same process** as the A2A server. This means worker activities share the same +memory space as the A2A handler, which lets us bridge non-serialisable objects +(LangGraph agent, EventQueue, RequestContext) through an in-process registry +without having to serialise them through Temporal's state store. + +Workflow stages (names mirror the OTEL span names in a2a_executor.py): + + task_receive → llm_call → task_complete + + task_receive — durable checkpoint: task acknowledged, queued + llm_call — durable checkpoint: LLM execution + SSE streaming (retryable) + task_complete — durable checkpoint: execution finished, telemetry recorded + +Crash-recovery behaviour +------------------------ +If the process crashes while ``llm_call`` is running, Temporal retries the +activity on the restarted process. The in-process registry is empty after a +restart, so the activity detects a registry miss, logs a warning, and returns +an error result. The SSE client connection is already gone at that point so +no response can be delivered — but the task is permanently recorded in +Temporal's history and will not silently disappear. + +Env vars +-------- +TEMPORAL_HOST Temporal gRPC endpoint (default: ``localhost:7233``) + Set this to enable durable execution. Leave unset (or point + at an unreachable host) to run in direct-execution mode. + +Dependencies (optional) +----------- + temporalio>=1.7.0 + +Add to requirements.txt to enable. The module loads and the wrapper class +works without the package installed — all Temporal paths return early with a +graceful fallback to direct execution. +""" + +from __future__ import annotations + +import asyncio +import dataclasses +import logging +import os +import uuid +from datetime import timedelta +from typing import Any, Optional + +logger = logging.getLogger(__name__) + +# ───────────────────────────────────────────────────────────────────────────── +# Constants +# ───────────────────────────────────────────────────────────────────────────── + +_TASK_QUEUE = "molecule-agent-tasks" +_WORKFLOW_EXECUTION_TIMEOUT = timedelta(minutes=30) +_ACTIVITY_START_TO_CLOSE_TIMEOUT = timedelta(minutes=10) + +# ───────────────────────────────────────────────────────────────────────────── +# Serialisable data models +# These are the only objects that cross the Temporal serialisation boundary. +# ───────────────────────────────────────────────────────────────────────────── + + +@dataclasses.dataclass +class AgentTaskInput: + """Serialisable snapshot of an incoming A2A task. + + All fields must be JSON-representable so that Temporal can persist them in + its workflow history (used for crash recovery and replay). + """ + + task_id: str + context_id: str + user_input: str + model: str + workspace_id: str + history: list # [[role, content], ...] — tuples converted to lists + + +@dataclasses.dataclass +class LLMResult: + """Serialisable execution result passed from ``llm_call`` to ``task_complete``.""" + + final_text: str + success: bool + error: str = "" + + +# ───────────────────────────────────────────────────────────────────────────── +# In-process registry +# +# Maps task_id → {executor, context, event_queue, final_text} +# Activities look up non-serialisable objects here. The registry is +# populated by TemporalWorkflowWrapper.run() before the workflow starts and +# cleaned up in the finally block when the workflow completes. +# ───────────────────────────────────────────────────────────────────────────── + +_task_registry: dict[str, dict[str, Any]] = {} + + +# ───────────────────────────────────────────────────────────────────────────── +# Temporal workflow + activities +# Loaded only when the temporalio package is installed. The surrounding +# try/except ensures the module imports cleanly without the package. +# ───────────────────────────────────────────────────────────────────────────── + +_TEMPORAL_AVAILABLE = False + +try: + from temporalio import activity, workflow + from temporalio.client import Client + from temporalio.worker import Worker + + _TEMPORAL_AVAILABLE = True + + # ── Activities ────────────────────────────────────────────────────────── # + + @activity.defn(name="task_receive") + async def task_receive_activity(inp: AgentTaskInput) -> dict: + """Durable checkpoint: task received and queued for LLM execution. + + Mirrors the *task_receive* OTEL span opened in + ``LangGraphA2AExecutor._core_execute()``. This activity is lightweight — + it validates that the in-process registry entry exists and logs receipt. + The actual A2A "working" signal (``updater.start_work()``) is emitted + inside ``_core_execute()`` so that SSE timing is preserved. + """ + logger.info( + "Temporal[task_receive] task_id=%s context_id=%s workspace=%s model=%s", + inp.task_id, + inp.context_id, + inp.workspace_id, + inp.model, + ) + if inp.task_id not in _task_registry: + logger.warning( + "Temporal[task_receive] task_id=%s not found in registry " + "(crash recovery path — no SSE client connection available)", + inp.task_id, + ) + return {"task_id": inp.task_id, "status": "registry_miss"} + + return {"task_id": inp.task_id, "status": "received"} + + @activity.defn(name="llm_call") + async def llm_call_activity(inp: AgentTaskInput) -> LLMResult: + """Durable checkpoint: LLM execution with streaming to the event_queue. + + Mirrors the *llm_call* OTEL span in ``LangGraphA2AExecutor._core_execute()``. + Calls ``executor._core_execute()`` which handles the full execution pipeline: + SSE streaming, OTEL sub-spans, final message emission, and heartbeat updates. + + On crash recovery (empty registry): logs a warning and returns an error + result. Temporal records the failure and will retry if configured to do so. + The original SSE client connection is gone after a crash, so no response + can be delivered, but the task is durably recorded in Temporal's history. + """ + logger.info("Temporal[llm_call] task_id=%s", inp.task_id) + + entry = _task_registry.get(inp.task_id) + if entry is None: + msg = ( + f"task_id={inp.task_id} not in registry — " + "process likely restarted; original SSE client connection is gone" + ) + logger.warning("Temporal[llm_call] registry miss: %s", msg) + return LLMResult(final_text="", success=False, error=msg) + + try: + executor = entry["executor"] + context = entry["context"] + event_queue = entry["event_queue"] + + # _core_execute() is the renamed body of the original execute(). + # It handles: OTEL spans, SSE streaming, final message, heartbeat. + final_text = await executor._core_execute(context, event_queue) + + # Cache for task_complete observability + entry["final_text"] = final_text or "" + return LLMResult(final_text=final_text or "", success=True) + + except Exception as exc: + logger.error( + "Temporal[llm_call] task_id=%s execution error: %s", + inp.task_id, + exc, + exc_info=True, + ) + return LLMResult(final_text="", success=False, error=str(exc)) + + @activity.defn(name="task_complete") + async def task_complete_activity(result: LLMResult) -> None: + """Durable checkpoint: task execution finished. + + Mirrors the *task_complete* OTEL span in ``LangGraphA2AExecutor._core_execute()``. + This activity records the outcome for Temporal observability. The actual + OTEL task_complete span fires inside ``_core_execute()``; this activity + provides a durable, queryable record in Temporal's workflow history. + """ + if result.success: + logger.info( + "Temporal[task_complete] success=True final_text_len=%d", + len(result.final_text), + ) + else: + logger.warning( + "Temporal[task_complete] success=False error=%r", + result.error, + ) + + # ── Workflow ──────────────────────────────────────────────────────────── # + + @workflow.defn + class MoleculeAIAgentWorkflow: + """Durable Temporal workflow for Molecule AI A2A agent task execution. + + Sequences three activities that mirror the OTEL span hierarchy in + ``LangGraphA2AExecutor._core_execute()``: + + task_receive → llm_call → task_complete + + Each activity is a durable checkpoint: if the process crashes between + activities, Temporal resumes from the last completed checkpoint on + restart. If an activity fails (exception or timeout), Temporal can + retry it according to the configured retry policy. + """ + + @workflow.run + async def run(self, inp: AgentTaskInput) -> LLMResult: + opts: dict[str, Any] = { + "start_to_close_timeout": _ACTIVITY_START_TO_CLOSE_TIMEOUT, + } + + # Stage 1 — acknowledge receipt (lightweight checkpoint) + await workflow.execute_activity(task_receive_activity, inp, **opts) + + # Stage 2 — LLM execution (main work; retryable on crash/timeout) + result: LLMResult = await workflow.execute_activity( + llm_call_activity, inp, **opts + ) + + # Stage 3 — record completion (lightweight checkpoint) + await workflow.execute_activity(task_complete_activity, result, **opts) + + return result + +except ImportError: + # temporalio not installed — the wrapper class below will gracefully fall + # back to direct execution for every call. + logger.debug( + "Temporal: temporalio package not installed — " + "durable execution disabled (add temporalio>=1.7.0 to requirements.txt)" + ) + + +# ───────────────────────────────────────────────────────────────────────────── +# TemporalWorkflowWrapper +# ───────────────────────────────────────────────────────────────────────────── + + +class TemporalWorkflowWrapper: + """Wraps ``LangGraphA2AExecutor.execute()`` with Temporal durable execution. + + The wrapper intercepts each ``execute()`` call and routes it through a + ``MoleculeAIAgentWorkflow`` Temporal workflow. If Temporal is unavailable + for any reason, execution falls back transparently to the direct path + (``executor._core_execute()``), so the A2A server never crashes due to + Temporal issues. + + Lifecycle + --------- + 1. ``create_wrapper()`` — instantiate and register the global singleton. + 2. ``await wrapper.start()`` — connect to Temporal, launch the background + worker. No-op (with a log warning) if Temporal is unreachable. + 3. Normal operation — ``wrapper.run()`` is called from ``execute()``. + 4. ``await wrapper.stop()`` — cancel the background worker task on shutdown. + + Co-located worker pattern + ------------------------- + The Temporal worker runs as an asyncio background task in the **same event + loop** as the A2A server. This means: + - No separate worker process to manage. + - Activities share the process's memory (registry access works). + - Worker and server share the same asyncio event loop. + + Env vars + -------- + ``TEMPORAL_HOST`` Temporal gRPC address, e.g. ``localhost:7233`` or + ``temporal.internal:7233``. Defaults to + ``localhost:7233``. If Temporal is not reachable at + this address, the wrapper falls back to direct execution. + """ + + def __init__(self) -> None: + self._host: str = os.environ.get("TEMPORAL_HOST", "localhost:7233") + self._client: Optional[Any] = None + self._worker: Optional[Any] = None + self._worker_task: Optional[asyncio.Task] = None # type: ignore[type-arg] + self._available: bool = False + + # ── Lifecycle ─────────────────────────────────────────────────────────── # + + async def start(self) -> None: + """Connect to Temporal and start the co-located background worker. + + Safe to call multiple times (idempotent after first success). + Never raises — logs a warning and returns on any failure. + """ + if not _TEMPORAL_AVAILABLE: + logger.info( + "Temporal: temporalio package not installed — " + "all tasks will use direct execution. " + "To enable durable execution: pip install temporalio>=1.7.0" + ) + return + + if self._available: + return # already started + + # Connect to the Temporal server + try: + self._client = await Client.connect(self._host) # type: ignore[name-defined] + logger.info("Temporal: connected to %s", self._host) + except Exception as exc: + logger.warning( + "Temporal: cannot connect to %s (%s) — " + "all tasks will use direct execution (no durable state)", + self._host, + exc, + ) + return + + # Start the worker as an asyncio background task + try: + self._worker = Worker( # type: ignore[name-defined] + self._client, + task_queue=_TASK_QUEUE, + workflows=[MoleculeAIAgentWorkflow], # type: ignore[name-defined] + activities=[ + task_receive_activity, # type: ignore[name-defined] + llm_call_activity, # type: ignore[name-defined] + task_complete_activity, # type: ignore[name-defined] + ], + ) + self._worker_task = asyncio.create_task( + self._worker.run(), + name="temporal-worker", + ) + self._available = True + logger.info( + "Temporal: co-located worker started on task queue '%s'", + _TASK_QUEUE, + ) + except Exception as exc: + logger.warning( + "Temporal: worker initialisation failed (%s) — " + "falling back to direct execution", + exc, + ) + + async def stop(self) -> None: + """Gracefully stop the Temporal worker background task.""" + self._available = False + if self._worker_task and not self._worker_task.done(): + self._worker_task.cancel() + try: + await self._worker_task + except (asyncio.CancelledError, Exception): + pass + logger.info("Temporal: worker stopped") + + # ── Public API ────────────────────────────────────────────────────────── # + + def is_available(self) -> bool: + """Return ``True`` if Temporal is connected and the worker is running.""" + return self._available + + async def run( + self, + executor: Any, + context: Any, + event_queue: Any, + ) -> None: + """Route one A2A task execution through a Temporal durable workflow. + + Steps + ----- + 1. Build a serialisable ``AgentTaskInput`` from the A2A request context. + 2. Store non-serialisable state (executor, context, event_queue) in + the in-process ``_task_registry`` keyed by task_id. + 3. Submit and await ``MoleculeAIAgentWorkflow`` on the Temporal server. + 4. Clean up the registry entry (always, via ``finally``). + + Falls back to ``executor._core_execute()`` if: + - Temporal is not available (``is_available()`` is False). + - Input extraction fails. + - The workflow raises any exception. + + This guarantees that the A2A client always receives a response even + when Temporal is misconfigured or temporarily unreachable. + """ + if not self._available or self._client is None: + # Temporal unavailable — silent direct fallback + await executor._core_execute(context, event_queue) + return + + task_id = getattr(context, "task_id", None) or str(uuid.uuid4()) + context_id = getattr(context, "context_id", None) or str(uuid.uuid4()) + + # Build serialisable AgentTaskInput + try: + from adapters.shared_runtime import ( + extract_history as _extract_history, + extract_message_text, + ) + + user_input = extract_message_text(context) or "" + raw_history = _extract_history(context) + # Convert (role, content) tuples → [role, content] lists (JSON-safe) + history: list = [list(pair) for pair in raw_history] + except Exception as exc: + logger.warning( + "Temporal: failed to extract serialisable task input (%s) — " + "falling back to direct execution", + exc, + ) + await executor._core_execute(context, event_queue) + return + + inp = AgentTaskInput( + task_id=task_id, + context_id=context_id, + user_input=user_input, + model=getattr(executor, "_model", "unknown"), + workspace_id=os.environ.get("WORKSPACE_ID", "unknown"), + history=history, + ) + + # Register non-serialisable in-process state for activities to access + _task_registry[task_id] = { + "executor": executor, + "context": context, + "event_queue": event_queue, + "final_text": "", + } + + try: + logger.info( + "Temporal: starting workflow molecule-%s on queue '%s'", + task_id, + _TASK_QUEUE, + ) + await self._client.execute_workflow( + MoleculeAIAgentWorkflow.run, # type: ignore[name-defined] + inp, + id=f"molecule-{task_id}", + task_queue=_TASK_QUEUE, + execution_timeout=_WORKFLOW_EXECUTION_TIMEOUT, + ) + except Exception as exc: + logger.error( + "Temporal: workflow molecule-%s failed (%s) — " + "falling back to direct execution so client receives a response", + task_id, + exc, + exc_info=True, + ) + # Direct fallback ensures the SSE client is never left hanging + await executor._core_execute(context, event_queue) + finally: + _task_registry.pop(task_id, None) + + +# ───────────────────────────────────────────────────────────────────────────── +# Module-level singleton helpers +# Used by a2a_executor.py and main.py +# ───────────────────────────────────────────────────────────────────────────── + +_global_wrapper: Optional[TemporalWorkflowWrapper] = None + + +def get_wrapper() -> Optional[TemporalWorkflowWrapper]: + """Return the global ``TemporalWorkflowWrapper``, or ``None`` if not set. + + Called from ``LangGraphA2AExecutor.execute()`` on every request. + Returns ``None`` before ``create_wrapper()`` is called (direct-execution mode). + """ + return _global_wrapper + + +def create_wrapper() -> TemporalWorkflowWrapper: + """Create (or return the existing) global ``TemporalWorkflowWrapper``. + + Idempotent — safe to call multiple times. Call ``await wrapper.start()`` + after this to connect to Temporal and launch the background worker. + + Example (in main.py):: + + from builtin_tools.temporal_workflow import create_wrapper as create_temporal_wrapper + temporal_wrapper = create_temporal_wrapper() + await temporal_wrapper.start() # connects + starts worker + try: + await server.serve() + finally: + await temporal_wrapper.stop() + """ + global _global_wrapper + if _global_wrapper is None: + _global_wrapper = TemporalWorkflowWrapper() + return _global_wrapper diff --git a/molecule_runtime/claude_sdk_executor.py b/molecule_runtime/claude_sdk_executor.py new file mode 100644 index 0000000..1389b0b --- /dev/null +++ b/molecule_runtime/claude_sdk_executor.py @@ -0,0 +1,449 @@ +"""SDK-based agent executor for Claude Code runtime. + +Uses the official `claude-agent-sdk` Python package to invoke the Claude Code +engine programmatically — no subprocess, no stdout parsing, no zombie reap. + +Replaces CLIAgentExecutor for the `claude-code` runtime only. Other CLI runtimes +(codex, ollama) keep using `cli_executor.py`. + +Benefits over CLI subprocess: +- No per-message ~500ms startup overhead +- No stdout buffering issues +- Native Python session management (no JSON parsing of stdout) +- Real message stream — can surface tool calls in future for live UX +- Cooperative cancel (closes the query async generator on cancel()) +- Same Claude Code engine, so plugins / skills / CLAUDE.md still apply + +Concurrency model +----------------- +Turns are serialized per-executor via an asyncio.Lock. The old CLI executor +serialized implicitly by spawning one subprocess per message and awaiting it; +the SDK removes that, so we re-introduce serialization explicitly. This keeps +session_id updates race-free and makes cancel() well-defined (there's at most +one active stream at any given moment). +""" + +from __future__ import annotations + +import asyncio +import logging +import os +import sys +from collections.abc import AsyncIterator +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +import claude_agent_sdk as sdk + +from a2a.server.agent_execution import AgentExecutor, RequestContext +from a2a.server.events import EventQueue +from a2a.utils import new_agent_text_message + +from executor_helpers import ( + CONFIG_MOUNT, + MEMORY_CONTENT_MAX_CHARS, + WORKSPACE_MOUNT, + brief_summary, + commit_memory, + extract_message_text, + get_a2a_instructions, + get_mcp_server_path, + get_system_prompt, + read_delegation_results, + recall_memories, + sanitize_agent_error, + set_current_task, +) + +if TYPE_CHECKING: + from heartbeat import HeartbeatLoop + +logger = logging.getLogger(__name__) + +_NO_TEXT_MSG = "Error: message contained no text content." +_NO_RESPONSE_MSG = "(no response generated)" +_MAX_RETRIES = 3 +_BASE_RETRY_DELAY_S = 5 +# Cap for stderr captured from the CLI subprocess in the executor log. Keeps +# log lines bounded while still surfacing enough context to diagnose crashes. +# Fixes #66 (previously the executor logged nothing beyond the generic +# "Check stderr output for details" message). +_PROCESS_ERROR_STDERR_MAX_CHARS = 4096 + +# Substrings in error messages that indicate a transient failure worth retrying. +_RETRYABLE_PATTERNS = ( + "rate", + "limit", + "429", + "overloaded", + "capacity", + "exit code 1", + "try again", +) + + +_SWALLOWED_STDERR_MARKER = "Check stderr output for details" + + +def _probe_claude_cli_error() -> str | None: + """Run ``claude --print`` directly and capture its stderr + stdout. + + Used as a fallback when the claude-agent-sdk raises a bare ``Exception`` + with the swallowed "Check stderr output for details" placeholder — that + happens when the SDK wraps a stream error from the CLI subprocess and + loses both the ``.stderr`` attribute and the exit code. At that point + the only way to see the real failure reason (rate limit, auth error, + network outage, missing token) is to run the CLI ourselves. + + Bounded by a 30s timeout so a hung CLI can't stall the error path. + Returns None if the probe itself failed (wrong invariant — don't + corrupt the main error message with probe noise). + """ + try: + import subprocess + # --print reads stdin, prints response, exits. Empty stdin gives the + # CLI something to work with without triggering an actual model call + # when it's going to fail anyway. + proc = subprocess.run( + ["claude", "--print"], + input="probe", + capture_output=True, + text=True, + timeout=30, + ) + if proc.returncode == 0: + # CLI succeeded — the original error was a transient state that + # resolved between the SDK failure and our probe. Signal that. + return "" + raw = (proc.stderr or "") + (proc.stdout or "") + raw = raw.strip() + if not raw: + return f"" + if len(raw) > _PROCESS_ERROR_STDERR_MAX_CHARS: + raw = raw[:_PROCESS_ERROR_STDERR_MAX_CHARS] + "... [truncated]" + return raw + except Exception as probe_exc: # pragma: no cover — best-effort diagnostic + return f"" + + +def _format_process_error(exc: BaseException) -> str: + """Render a Claude-SDK ProcessError (or any ClaudeSDKError) with its full + captured context — exit code, stderr, exception type. Plain strings for + non-SDK exceptions fall back to str(exc). + + Bounded at _PROCESS_ERROR_STDERR_MAX_CHARS so a runaway CLI can't spam + the log. Used by the executor's error path (fixes #66 — the SDK's + ProcessError carries `.stderr`/`.exit_code` attributes that the previous + code silently discarded, leaving every CLI crash with an identical + "Check stderr output for details" message in the workspace log). + + Fixes #160: when the SDK raises a bare ``Exception`` containing the + "Check stderr output for details" placeholder (which happens when the + CLI subprocess emits a stream error the SDK can't categorize — rate + limit, auth, network), there's no ``.stderr``/``.exit_code`` to read. + In that case we fall back to running the CLI ourselves via + ``_probe_claude_cli_error`` so the operator sees the real failure + reason (e.g. ``You've hit your limit · resets Apr 17``) instead of + chasing ghosts in the workspace logs. + """ + parts = [f"{type(exc).__name__}: {exc}"] + exit_code = getattr(exc, "exit_code", None) + if exit_code is not None: + parts.append(f"exit_code={exit_code}") + stderr = getattr(exc, "stderr", None) + if stderr: + trimmed = stderr[:_PROCESS_ERROR_STDERR_MAX_CHARS] + if len(stderr) > _PROCESS_ERROR_STDERR_MAX_CHARS: + trimmed += f"... [{len(stderr) - _PROCESS_ERROR_STDERR_MAX_CHARS} more chars truncated]" + parts.append(f"stderr={trimmed!r}") + elif exit_code is None and _SWALLOWED_STDERR_MARKER in str(exc): + # #160: generic exception with the swallowed-stderr placeholder. + # Probe the CLI directly — this is the only way to surface the real + # error when the SDK lost it in translation. + probed = _probe_claude_cli_error() + if probed: + parts.append(f"probed_cli_error={probed!r}") + return " | ".join(parts) + + +@dataclass +class QueryResult: + """Outcome of a single `query()` stream. + + `text` is the canonical final response; `session_id` is the id the SDK + reports in its ResultMessage (used for resume on the next turn). + """ + text: str + session_id: str | None + + +class ClaudeSDKExecutor(AgentExecutor): + """Executes agent tasks via the claude-agent-sdk programmatic API.""" + + def __init__( + self, + system_prompt: str | None, + config_path: str, + heartbeat: "HeartbeatLoop | None", + model: str = "sonnet", + ): + self.system_prompt = system_prompt + self.config_path = config_path + self.heartbeat = heartbeat + self.model = model + self._session_id: str | None = None + self._active_stream: AsyncIterator[Any] | None = None + # Serializes concurrent execute() calls on the same executor so + # session_id / _active_stream mutations stay race-free. + self._run_lock = asyncio.Lock() + + # ------------------------------------------------------------------ + # Prompt + options builders + # ------------------------------------------------------------------ + + def _resolve_cwd(self) -> str: + """Run in /workspace if it has been populated, otherwise /configs.""" + if os.path.isdir(WORKSPACE_MOUNT) and os.listdir(WORKSPACE_MOUNT): + return WORKSPACE_MOUNT + return CONFIG_MOUNT + + def _build_system_prompt(self) -> str | None: + """Compose system prompt from file + A2A delegation instructions.""" + base = get_system_prompt(self.config_path, fallback=self.system_prompt) + a2a = get_a2a_instructions(mcp=True) + if base and a2a: + return f"{base}\n\n{a2a}" + return base or a2a + + def _prepare_prompt(self, user_input: str) -> str: + """Prepend delegation results that arrived while idle.""" + delegation_context = read_delegation_results() + if delegation_context: + return ( + "[Delegation results received while you were idle]\n" + f"{delegation_context}\n\n[New message]\n{user_input}" + ) + return user_input + + async def _inject_memories_if_first_turn(self, prompt: str) -> str: + if self._session_id: + return prompt + memories = await recall_memories() + if not memories: + return prompt + return f"[Prior context from memory]\n{memories}\n\n{prompt}" + + def _build_options(self) -> Any: + """Build ClaudeAgentOptions. + + No allowed_tools allowlist — bypassPermissions grants full access, + matching the old CLI `--dangerously-skip-permissions` so Claude can + use every built-in tool (Task, TodoWrite, NotebookEdit, BashOutput/ + KillShell, ExitPlanMode, etc.) plus all MCP tools. + + The MCP server launcher uses `sys.executable` so tests and alternate + virtual-env layouts don't depend on a `python3` shim being on PATH. + """ + mcp_servers = { + "a2a": { + "command": sys.executable, + "args": [get_mcp_server_path()], + } + } + return sdk.ClaudeAgentOptions( + model=self.model, + permission_mode="bypassPermissions", + cwd=self._resolve_cwd(), + mcp_servers=mcp_servers, + system_prompt=self._build_system_prompt(), + resume=self._session_id, + ) + + # ------------------------------------------------------------------ + # Query streaming + # ------------------------------------------------------------------ + + async def _run_query(self, prompt: str, options: Any) -> QueryResult: + """Drive the SDK query stream and return a QueryResult. + + Prefers ResultMessage.result (the canonical final text — same field + the CLI's --output-format json used) and only falls back to the + concatenation of AssistantMessage TextBlocks when result is absent. + Otherwise pre-tool reasoning and post-tool summary get double-emitted. + + Pure: does not mutate executor state other than setting / clearing + `self._active_stream` so cancel() can reach in. The caller decides + whether to persist the returned session_id. + """ + assistant_chunks: list[str] = [] + result_text: str | None = None + session_id: str | None = None + self._active_stream = sdk.query(prompt=prompt, options=options) + try: + async for message in self._active_stream: + if isinstance(message, sdk.AssistantMessage): + for block in message.content: + if isinstance(block, sdk.TextBlock): + assistant_chunks.append(block.text) + elif isinstance(message, sdk.ResultMessage): + sid = getattr(message, "session_id", None) + if sid: + session_id = sid + result_text = getattr(message, "result", None) + finally: + self._active_stream = None + text = result_text if result_text is not None else "".join(assistant_chunks) + return QueryResult(text=text, session_id=session_id) + + # ------------------------------------------------------------------ + # AgentExecutor interface + # ------------------------------------------------------------------ + + async def execute(self, context: RequestContext, event_queue: EventQueue): + """Run a turn through the Claude Agent SDK and emit the response. + + Serialized via `self._run_lock` — concurrent A2A messages to the same + workspace queue rather than racing on `_session_id` / `_active_stream`. + """ + user_input = extract_message_text(context.message) + if not user_input: + await event_queue.enqueue_event(new_agent_text_message(_NO_TEXT_MSG)) + return + + async with self._run_lock: + response_text = await self._execute_locked(user_input) + + # Enqueue outside the lock so the next queued turn can start + # preparing its prompt while this turn's response ships. Event + # ordering is preserved per-queue by the A2A server, so no races. + await event_queue.enqueue_event(new_agent_text_message(response_text)) + + @staticmethod + def _is_retryable(exc: BaseException) -> bool: + """Check if an SDK exception looks like a transient rate-limit or + capacity error that's worth retrying with backoff.""" + msg = str(exc).lower() + return any(p in msg for p in _RETRYABLE_PATTERNS) + + def _reset_session_after_error(self, exc: BaseException) -> None: + """Clear `_session_id` if the exception looks like a subprocess + crash (#75). On the next `_build_options()` call `resume=None` is + passed to the SDK, so the CLI boots a brand-new session instead of + trying to resume one the previous subprocess left in an + unrecoverable state. + + Kept in its own method so the policy can evolve (e.g. also clear + on MessageParseError) without touching the retry loop. Logs at + INFO when a session was actually cleared; silent when there was + nothing to reset. + """ + exc_name = type(exc).__name__ + # Conservative: reset only on subprocess-level failures. Pure + # rate-limit / capacity errors don't leave the session in a bad + # state — keep the session_id so the resumed turn preserves + # conversational continuity. + is_subprocess_error = ( + exc_name in ("ProcessError", "CLIConnectionError") + or getattr(exc, "exit_code", None) is not None + or "exit code" in str(exc).lower() + ) + if not is_subprocess_error: + return + if self._session_id is None: + return + logger.info( + "SDK session reset after %s: clearing session_id so the next " + "attempt starts fresh (fixes #75 session contamination)", + exc_name, + ) + self._session_id = None + + async def _execute_locked(self, user_input: str) -> str: + """Body of execute() that runs under the run lock. + + Retries transient errors (rate limits, capacity, exit-code-1) up to + _MAX_RETRIES times with exponential backoff (5s, 10s, 20s). + """ + # Keep a clean copy of the user's actual message for the memory record, + # BEFORE any delegation or memory injection. + original_input = user_input + await set_current_task(self.heartbeat, brief_summary(user_input)) + logger.debug("SDK execute [claude-code]: %s", user_input[:200]) + + prompt = self._prepare_prompt(user_input) + prompt = await self._inject_memories_if_first_turn(prompt) + + response_text: str = "" + try: + for attempt in range(_MAX_RETRIES): + options = self._build_options() + try: + result = await self._run_query(prompt=prompt, options=options) + if result.session_id: + self._session_id = result.session_id + response_text = result.text + break # success + except Exception as exc: + formatted = _format_process_error(exc) + # #75: CLI subprocess crashes leave our _session_id + # referencing a session the next subprocess can't + # resume. Without this reset the next attempt would + # crash identically even when the underlying cause + # was transient, cascading into "crashed once → + # crashes forever until container restart." Clear + # the session_id so the next attempt (retry or + # next user turn) starts fresh. + self._reset_session_after_error(exc) + if attempt < _MAX_RETRIES - 1 and self._is_retryable(exc): + delay = _BASE_RETRY_DELAY_S * (2 ** attempt) + logger.warning( + "SDK agent [claude-code] transient error (attempt %d/%d), " + "retrying in %ds: %s", + attempt + 1, _MAX_RETRIES, delay, formatted, + ) + await asyncio.sleep(delay) + continue + # Non-retryable or exhausted retries. Log exit_code + + # stderr explicitly (fixes #66) so operators don't have + # to reproduce the crash manually to find out why the + # subprocess died. + logger.error("SDK agent error [claude-code]: %s", formatted) + logger.exception("SDK agent error [claude-code] — full traceback follows") + response_text = sanitize_agent_error(exc) + break + finally: + await set_current_task(self.heartbeat, "") + await commit_memory( + f"Conversation: {original_input[:MEMORY_CONTENT_MAX_CHARS]}" + ) + + return response_text or _NO_RESPONSE_MSG + + async def cancel(self, context: RequestContext, event_queue: EventQueue): + """Cooperatively cancel the currently running turn. + + cancel() targets whatever turn is in flight *right now*, not the + specific turn the caller may have been looking at when they sent + the cancel request. If turn A has finished and turn B is already + running under the run lock by the time cancel arrives, turn B is + the one that gets aborted. This matches how a "stop" button in a + chat UI typically behaves (stop whatever is running) and is a + conscious trade-off against per-turn bookkeeping. + + Implementation: the SDK's query() is an async generator; calling + aclose() raises GeneratorExit inside the running turn and unwinds + cleanly. We read `self._active_stream` into a local BEFORE calling + aclose so the reference can't be reassigned by another turn + mid-cancel. Best-effort — if no stream is active (cancel arrived + between turns, or the stream has no aclose), this is a no-op. + """ + stream = self._active_stream + if stream is None: + return + aclose = getattr(stream, "aclose", None) + if aclose is None: + return + try: + await aclose() + except Exception: + logger.exception("SDK cancel: aclose() raised") diff --git a/molecule_runtime/cli_executor.py b/molecule_runtime/cli_executor.py new file mode 100644 index 0000000..2f2802e --- /dev/null +++ b/molecule_runtime/cli_executor.py @@ -0,0 +1,456 @@ +"""CLI-based agent executor for A2A protocol. + +Supports CLI agents that accept a prompt and output a response: +- OpenAI Codex: codex --print -p "..." +- Ollama: ollama run "..." +- Custom: any command that reads stdin or accepts -p + +NOTE: the `claude-code` runtime no longer routes here. It uses +ClaudeSDKExecutor (see claude_sdk_executor.py) which wraps the +claude-agent-sdk Python package. This executor is reserved for CLI-only +runtimes that don't yet have a programmatic SDK integration. + +The runtime is selected via config.yaml: + runtime: codex | ollama | custom + runtime_config: + command: "codex" # for custom + args: ["--extra-flag"] # additional CLI args + auth_token_env: "OPENAI_API_KEY" + auth_token_file: ".auth-token" + timeout: 300 + model: "sonnet" +""" + +import asyncio +import atexit +import json +import logging +import os +import shlex +import shutil +import sys +import tempfile +from pathlib import Path + +from a2a.server.agent_execution import AgentExecutor, RequestContext +from a2a.server.events import EventQueue +from a2a.utils import new_agent_text_message + +from config import RuntimeConfig +from executor_helpers import ( + CONFIG_MOUNT, + MEMORY_CONTENT_MAX_CHARS, + WORKSPACE_MOUNT, + brief_summary, + classify_subprocess_error, + commit_memory, + extract_message_text, + get_a2a_instructions, + get_mcp_server_path, + get_system_prompt, + read_delegation_results, + recall_memories, + sanitize_agent_error, + set_current_task, +) + +logger = logging.getLogger(__name__) + + +# Built-in runtime presets. +# The `claude-code` runtime uses ClaudeSDKExecutor (claude_sdk_executor.py) +# and intentionally has no entry here. +RUNTIME_PRESETS: dict[str, dict] = { + "codex": { + "command": "codex", + "base_args": ["--print", "--dangerously-skip-permissions"], + "prompt_flag": "-p", + "model_flag": "--model", + "system_prompt_flag": "--system-prompt", + "auth_pattern": "env", # uses OPENAI_API_KEY env var + "default_auth_env": "OPENAI_API_KEY", + "default_auth_file": "", + }, + "ollama": { + "command": "ollama", + "base_args": ["run"], + "prompt_flag": None, # prompt is positional + "model_flag": None, # model is positional after "run" + "system_prompt_flag": "--system", + "auth_pattern": None, # no auth needed + "default_auth_env": "", + "default_auth_file": "", + }, + # Gemini CLI (github.com/google-gemini/gemini-cli, Apache 2.0). + # Auth via GEMINI_API_KEY env var; MCP is wired via ~/.gemini/settings.json + # (not --mcp-config) — the adapter's setup() handles that step. + # System prompt is seeded into GEMINI.md (equivalent of CLAUDE.md). + "gemini-cli": { + "command": "gemini", + "base_args": ["--yolo"], # auto-approve all tool calls (non-interactive) + "prompt_flag": "-p", + "model_flag": "--model", + "system_prompt_flag": None, # GEMINI.md used instead; seeded by adapter.setup() + "auth_pattern": "env", # GEMINI_API_KEY; also enables A2A MCP instructions + "default_auth_env": "GEMINI_API_KEY", + "default_auth_file": "", + "mcp_via_settings": True, # MCP injected into ~/.gemini/settings.json, not --mcp-config + }, +} + + +class CLIAgentExecutor(AgentExecutor): + """Executes agent tasks by invoking a CLI tool. + + Works with any CLI agent that accepts a prompt and outputs text. + """ + + def __init__( + self, + runtime: str, + runtime_config: RuntimeConfig, + system_prompt: str | None = None, + config_path: str = "/configs", + heartbeat: "HeartbeatLoop | None" = None, + ): + if runtime == "claude-code": + # Defensive — the adapter should never construct a CLI executor + # for claude-code. Fail loud rather than silently falling back. + raise ValueError( + "claude-code runtime is served by ClaudeSDKExecutor, not " + "CLIAgentExecutor. Check adapters/claude_code/adapter.py." + ) + self.runtime = runtime + self.config = runtime_config + self.system_prompt = system_prompt + self.config_path = config_path + self._heartbeat = heartbeat + + # Resolve preset or use custom + if runtime in RUNTIME_PRESETS: + self.preset = RUNTIME_PRESETS[runtime] + elif runtime == "custom": + self.preset = { + "command": runtime_config.command, + "base_args": [], # args go in config.args, appended at end + "prompt_flag": "-p", + "model_flag": None, + "system_prompt_flag": None, + "auth_pattern": None, + "default_auth_env": "", + "default_auth_file": "", + } + else: + raise ValueError(f"Unknown runtime: {runtime}. Use: {', '.join(RUNTIME_PRESETS.keys())}, custom") + + # Resolve auth token + self._auth_token = self._resolve_auth_token() + self._auth_helper_path: str | None = None + self._temp_files: list[str] = [] # Track temp files for cleanup + + if self._auth_token and self.preset.get("auth_pattern") == "apiKeyHelper": + self._auth_helper_path = self._create_auth_helper(self._auth_token) + + # Create MCP config once (reuse across invocations) + self._mcp_config_path: str | None = None + if self.preset.get("auth_pattern") in ("apiKeyHelper", "env"): + mcp_config = json.dumps({ + "mcpServers": { + "a2a": {"command": sys.executable, "args": [get_mcp_server_path()]} + } + }) + fd, self._mcp_config_path = tempfile.mkstemp(suffix=".json", prefix="a2a-mcp-") + self._temp_files.append(self._mcp_config_path) # Track immediately + os.close(fd) + with open(self._mcp_config_path, "w") as f: + f.write(mcp_config) + + # Register cleanup for reliable temp file removal (atexit is more reliable than __del__) + atexit.register(self._cleanup_temp_files) + + # Verify command exists + cmd = self.config.command or self.preset["command"] + if not shutil.which(cmd): + logger.warning(f"CLI command '{cmd}' not found in PATH") + + def _resolve_auth_token(self) -> str | None: + """Resolve auth token from env var or file. + + Resolution order: + 1. required_env — first entry that exists in the environment + 2. auth_token_env (deprecated) — explicit env var name + 3. Preset default_auth_env — adapter-declared fallback + 4. auth_token_file (deprecated) — file on disk + 5. Preset default_auth_file — adapter-declared file fallback + """ + # 1. New path: required_env (first match wins) + for env_name in (self.config.required_env or []): + token = os.environ.get(env_name) + if token: + return token + + # 2. Legacy: explicit env var from config + env_name = self.config.auth_token_env or self.preset.get("default_auth_env", "") + if env_name: + token = os.environ.get(env_name) + if token: + return token + + # 3. Legacy: token file from config + file_name = self.config.auth_token_file or self.preset.get("default_auth_file", "") + if file_name: + token_path = Path(self.config_path) / file_name + if token_path.exists(): + return token_path.read_text().strip() + + return None + + def _create_auth_helper(self, token: str) -> str: + """Create a shell script that outputs the auth token (for apiKeyHelper pattern).""" + fd, helper_path = tempfile.mkstemp(suffix=".sh", prefix="agent-auth-") + self._temp_files.append(helper_path) # Track immediately before any exception can leak + os.close(fd) + with open(helper_path, "w") as f: + f.write(f"#!/bin/sh\necho {shlex.quote(token)}\n") + os.chmod(helper_path, 0o700) + return helper_path + + def _build_command(self, message: str) -> list[str]: + """Build the full CLI command from preset + config + message.""" + cmd = self.config.command or self.preset["command"] + args = list(self.preset.get("base_args", [])) + + # Model + model = self.config.model or None + model_flag = self.preset.get("model_flag") + if model and model_flag: + args.extend([model_flag, model]) + elif model and self.runtime == "ollama": + # Ollama: model is positional after "run" + args.append(model) + + # System prompt (+ A2A instructions). The remaining CLI runtimes don't + # support session resume, so we inject the system prompt on every call. + system_prompt = get_system_prompt(self.config_path, fallback=self.system_prompt) or "" + mcp_capable = self.preset.get("auth_pattern") in ("apiKeyHelper", "env") + a2a_instructions = get_a2a_instructions(mcp=mcp_capable) + if a2a_instructions: + system_prompt = ( + f"{system_prompt}\n\n{a2a_instructions}" if system_prompt else a2a_instructions + ) + system_flag = self.preset.get("system_prompt_flag") + if system_prompt and system_flag: + args.extend([system_flag, system_prompt]) + + # Auth (apiKeyHelper pattern — reserved for future CLI runtimes) + if self._auth_helper_path and self.preset.get("auth_pattern") == "apiKeyHelper": + settings = json.dumps({"apiKeyHelper": self._auth_helper_path}) + args.extend(["--settings", settings]) + + # A2A MCP server — inject for MCP-compatible runtimes (created once in __init__). + # Runtimes that declare `mcp_via_settings: True` (e.g. gemini-cli) wire MCP + # through their own settings file (adapter.setup()) instead of --mcp-config. + if self._mcp_config_path and not self.preset.get("mcp_via_settings"): + args.extend(["--mcp-config", self._mcp_config_path]) + + # Extra args from config (before prompt so flags are parsed correctly) + args.extend(self.config.args) + + # Prompt (must be last — some CLIs treat final arg as the prompt) + prompt_flag = self.preset.get("prompt_flag") + if prompt_flag: + args.extend([prompt_flag, message]) + else: + # Positional prompt (ollama) + args.append(message) + + return [cmd] + args + + async def execute(self, context: RequestContext, event_queue: EventQueue): + """Execute a task by invoking the CLI agent.""" + user_input = extract_message_text(context.message) + if not user_input: + await event_queue.enqueue_event( + new_agent_text_message("Error: message contained no text content.") + ) + return + + # Keep a clean copy of the user's actual message for memory BEFORE any + # delegation or memory injection happens. + original_input = user_input + + # Show current task on canvas — extract a brief one-line summary + await set_current_task(self._heartbeat, brief_summary(user_input)) + + logger.debug("CLI execute [%s]: %s", self.runtime, user_input[:200]) + + # Inject delegation results that arrived since last message + delegation_context = read_delegation_results() + if delegation_context: + user_input = f"[Delegation results received while you were idle]\n{delegation_context}\n\n[New message]\n{user_input}" + + # Auto-recall: inject prior memories into every prompt. (The CLI + # runtimes don't keep a session, so there's no "first turn" concept.) + memories = await recall_memories() + if memories: + user_input = f"[Prior context from memory]\n{memories}\n\n{user_input}" + + try: + await self._run_cli(user_input, event_queue) + finally: + await set_current_task(self._heartbeat, "") + # Auto-commit: save the original user request (not the memory-injected version) + await commit_memory( + f"Conversation: {original_input[:MEMORY_CONTENT_MAX_CHARS]}" + ) + + async def _run_cli(self, user_input: str, event_queue: EventQueue): + """Run the CLI subprocess and enqueue the result.""" + cmd = self._build_command(user_input) + timeout = self.config.timeout or None # None = no timeout (wait until agent finishes) + max_retries = 3 + base_delay = 5 # seconds + + # Build env — pass through auth env var if using env pattern + env = dict(os.environ) + if self._auth_token and self.preset.get("auth_pattern") == "env": + # Use first required_env entry, or fall back to legacy auth_token_env + auth_env = (self.config.required_env or [None])[0] if self.config.required_env else None + auth_env = auth_env or self.config.auth_token_env or self.preset.get("default_auth_env", "") + if auth_env: + env[auth_env] = self._auth_token + + for attempt in range(max_retries): + proc = None + try: + # Run in /workspace if it exists and has content (cloned repo), + # otherwise /configs (agent config files) + cwd = ( + WORKSPACE_MOUNT + if os.path.isdir(WORKSPACE_MOUNT) and os.listdir(WORKSPACE_MOUNT) + else CONFIG_MOUNT + ) + + proc = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=env, + cwd=cwd, + ) + if timeout: + stdout, stderr = await asyncio.wait_for( + proc.communicate(), timeout=timeout + ) + else: + stdout, stderr = await proc.communicate() + + stdout_text = stdout.decode().strip() + stderr_text = stderr.decode().strip() + + if proc.returncode != 0: + logger.error("CLI agent [%s] exit=%d stdout=%s stderr=%s", + self.runtime, proc.returncode, + stdout_text[:200] if stdout_text else "(empty)", + stderr_text[:500] if stderr_text else "(empty)") + + if proc.returncode == 0 or stdout_text: + # Success, or non-zero exit but produced output (some CLIs exit 1 with valid output) + result = stdout_text + if result: + await event_queue.enqueue_event( + new_agent_text_message(result) + ) + return + else: + # Empty response — likely rate limited, retry with backoff + if attempt < max_retries - 1: + delay = base_delay * (2 ** attempt) + logger.warning("CLI agent [%s] returned empty (attempt %d/%d), retrying in %ds", + self.runtime, attempt + 1, max_retries, delay) + await asyncio.sleep(delay) + continue + await event_queue.enqueue_event( + new_agent_text_message("(no response generated after retries)") + ) + return + else: + error_msg = stderr_text or f"Exit code {proc.returncode}" + # Classify once — used both for retry policy and the + # sanitized user-facing error message. + category = classify_subprocess_error(error_msg, proc.returncode) + if category in ("rate_limited", "session_error", "auth_failed") \ + and attempt < max_retries - 1: + delay = base_delay * (2 ** attempt) + logger.warning( + "CLI agent [%s] %s (attempt %d/%d), retrying in %ds", + self.runtime, category, attempt + 1, max_retries, delay, + ) + await asyncio.sleep(delay) + continue + + # Log the full stderr (may contain paths/tokens); surface + # only the sanitized category to the user. + logger.error("CLI agent error [%s]: %s", self.runtime, error_msg[:500]) + await event_queue.enqueue_event( + new_agent_text_message(sanitize_agent_error(category=category)) + ) + return + + except asyncio.TimeoutError: + logger.error("CLI agent timeout [%s] after %ds", self.runtime, timeout) + if proc: + # Kill and reap the process to prevent zombies + try: + proc.kill() + except ProcessLookupError: + pass # already exited + except Exception as kill_err: + logger.warning("CLI kill error: %s", kill_err) + # Always await wait() to reap zombie, even if kill failed + try: + await asyncio.wait_for(proc.wait(), timeout=5) + except asyncio.TimeoutError: + logger.error("CLI agent: proc.wait() also timed out — possible zombie") + except Exception as wait_err: + logger.warning("CLI wait error: %s", wait_err) + await event_queue.enqueue_event( + new_agent_text_message(sanitize_agent_error(category="timeout")) + ) + return + except Exception as exc: + logger.exception("CLI agent exception [%s]", self.runtime) + await event_queue.enqueue_event( + new_agent_text_message(sanitize_agent_error(exc)) + ) + return + + def _cleanup_temp_files(self): # pragma: no cover + """Clean up temp files. Called via atexit for reliable cleanup.""" + for f in self._temp_files: + try: + os.unlink(f) + except OSError: + pass + if self._auth_helper_path: + try: + os.unlink(self._auth_helper_path) + except OSError: + pass + + def __del__(self): # pragma: no cover + """Clean up temp files (fallback — prefer atexit-registered _cleanup_temp_files).""" + for f in getattr(self, "_temp_files", []): + try: + os.unlink(f) + except OSError: + pass + if getattr(self, "_auth_helper_path", None): + try: + os.unlink(self._auth_helper_path) + except OSError: + pass + + async def cancel(self, context: RequestContext, event_queue: EventQueue): # pragma: no cover + """Cancel a running task.""" + pass diff --git a/molecule_runtime/config.py b/molecule_runtime/config.py new file mode 100644 index 0000000..6f7dbc5 --- /dev/null +++ b/molecule_runtime/config.py @@ -0,0 +1,349 @@ +"""Load workspace configuration from config.yaml.""" + +import os +from dataclasses import dataclass, field +from pathlib import Path +from typing import Optional + +import yaml + + +@dataclass +class RBACConfig: + """Role-based access control settings for this workspace. + + ``roles`` declares what this workspace is *allowed* to do. Each role + name maps to a set of permitted actions. Built-in roles are defined in + ``tools/audit.ROLE_PERMISSIONS``; custom roles can be added via + ``allowed_actions``. + + Built-in roles + -------------- + admin All actions (delegate, approve, memory.read, memory.write) + operator Same as admin — standard agent role (default) + read-only memory.read only + no-delegation approve + memory.read + memory.write + no-approval delegate + memory.read + memory.write + memory-readonly memory.read only + + Example config.yaml snippet:: + + rbac: + roles: + - operator + allowed_actions: + analyst: + - memory.read + - memory.write + """ + + roles: list[str] = field(default_factory=lambda: ["operator"]) + """List of role names granted to this workspace.""" + + allowed_actions: dict[str, list[str]] = field(default_factory=dict) + """Custom role → [action, ...] overrides. Takes precedence over built-ins.""" + + +@dataclass +class HITLConfig: + """Human-In-The-Loop settings loaded from the ``hitl:`` block in config.yaml. + + Example config.yaml snippet:: + + hitl: + channels: + - type: dashboard # always active + - type: slack + webhook_url: https://hooks.slack.com/services/… + - type: email + smtp_host: smtp.example.com + from: alerts@example.com + to: ops@example.com + default_timeout: 300 # seconds + bypass_roles: [admin] + """ + channels: list[dict] = field(default_factory=lambda: [{"type": "dashboard"}]) + default_timeout: float = 300.0 + bypass_roles: list[str] = field(default_factory=list) + + +@dataclass +class DelegationConfig: + retry_attempts: int = 3 + retry_delay: float = 5.0 + timeout: float = 120.0 + escalate: bool = True + + +@dataclass +class A2AConfig: + port: int = 8000 + streaming: bool = True + push_notifications: bool = True + + +@dataclass +class SandboxConfig: + backend: str = "subprocess" # subprocess | docker + memory_limit: str = "256m" + timeout: int = 30 + +@dataclass +class RuntimeConfig: + """Configuration for CLI-based agent runtimes (claude-code, codex, ollama, custom).""" + command: str = "" # e.g. "claude", "codex", "ollama" (model goes in model field) + args: list[str] = field(default_factory=list) # additional CLI args + required_env: list[str] = field(default_factory=list) # env vars required to run (e.g. ["CLAUDE_CODE_OAUTH_TOKEN"]) + timeout: int = 0 # seconds (0 = no timeout — agents wait until done) + model: str = "" # model override for the CLI + # Deprecated — use required_env + secrets API instead. Kept for backward compat. + auth_token_env: str = "" + auth_token_file: str = "" + + +@dataclass +class GovernanceConfig: + """Microsoft Agent Governance Toolkit integration settings. + + When ``enabled`` is True, Molecule AI's RBAC and audit trail are bridged + to the Agent Governance Toolkit (agent-os-kernel) for policy evaluation. + + ``toolkit`` is reserved for future extensibility — only ``"microsoft"`` + is supported today. + + ``policy_mode`` controls enforcement: + strict RBAC *and* toolkit policy must both allow — strictest mode + permissive RBAC must allow; toolkit denials are logged but not enforced + audit RBAC only; toolkit evaluated and logged but never blocks + + ``policy_file`` path to a Rego (.rego), YAML (.yaml/.yml), or Cedar + (.cedar) policy file, loaded into the PolicyEvaluator at startup. + + ``blocked_patterns`` is a list of regex patterns that the toolkit will + always deny regardless of roles or policy. + """ + + enabled: bool = False + toolkit: str = "microsoft" + policy_endpoint: str = "" + policy_mode: str = "audit" # strict | permissive | audit + policy_file: str = "" + blocked_patterns: list[str] = field(default_factory=list) + max_tool_calls_per_task: int = 50 + + +@dataclass +class SecurityScanConfig: + """Skill dependency security scanning settings. + + ``mode`` controls what happens when critical/high CVEs are found: + + block — raise ``SkillSecurityError``; the skill is NOT loaded. + warn — emit a WARNING + audit event; the skill is loaded anyway (default). + off — skip scanning entirely (air-gapped or CI environments). + + Scanners tried in order: Snyk CLI (requires ``SNYK_TOKEN``), then + pip-audit. If neither is available the scan is silently skipped. + + Example config.yaml snippet:: + + security_scan: warn # shorthand string form + # or verbose form: + security_scan: + mode: block + """ + + mode: str = "warn" + """One of: block | warn | off.""" + + fail_open_if_no_scanner: bool = True + """When True (default), silently skip scanning if no scanner (snyk/pip-audit) + is in PATH. When False and mode='block', raise SkillSecurityError so that + operators who require a CVE gate know the gate is absent. Closes #268.""" + + +@dataclass +class ComplianceConfig: + """OWASP Top 10 for Agentic Applications compliance settings. + + Set ``mode: owasp_agentic`` to enable all checks. When ``mode`` is + empty or absent the compliance layer is a complete no-op. + + Example config.yaml snippet:: + + compliance: + mode: owasp_agentic + prompt_injection: block # detect | block (default: detect) + max_tool_calls_per_task: 30 + max_task_duration_seconds: 180 + """ + + mode: str = "" + """Enable compliance mode. Set to ``owasp_agentic`` to activate.""" + + prompt_injection: str = "detect" + """``detect`` logs injection attempts; ``block`` raises PromptInjectionError.""" + + max_tool_calls_per_task: int = 50 + """Maximum number of tool invocations per task before ExcessiveAgencyError.""" + + max_task_duration_seconds: int = 300 + """Maximum wall-clock seconds per task before ExcessiveAgencyError.""" + + +@dataclass +class WorkspaceConfig: + name: str = "Workspace" + description: str = "" + version: str = "1.0.0" + tier: int = 1 + model: str = "anthropic:claude-sonnet-4-6" + runtime: str = "langgraph" # langgraph | claude-code | codex | ollama | custom + runtime_config: RuntimeConfig = field(default_factory=RuntimeConfig) + initial_prompt: str = "" + """Auto-sent as the first A2A message after startup. Default empty = no auto-message. + Can be an inline string or a file reference (initial_prompt_file in yaml).""" + idle_prompt: str = "" + """Auto-sent every `idle_interval_seconds` while the workspace has no active + task (heartbeat.active_tasks == 0). Default empty = no idle loop. This is + the reflection-on-completion / backlog-pull pattern from the Hermes/Letta + playbook: the workspace self-wakes when idle, runs a lightweight reflection + prompt, and either picks up queued work or stops. Cost scales with useful + activity (the prompt returns quickly if there's nothing to do). Can be + inline or a file reference via `idle_prompt_file`.""" + idle_interval_seconds: int = 600 + """How often the idle loop checks in (seconds). Default 600 (10 min). + Ignored when idle_prompt is empty.""" + skills: list[str] = field(default_factory=list) + plugins: list[str] = field(default_factory=list) # installed plugin names + tools: list[str] = field(default_factory=list) + prompt_files: list[str] = field(default_factory=list) + shared_context: list[str] = field(default_factory=list) + a2a: A2AConfig = field(default_factory=A2AConfig) + delegation: DelegationConfig = field(default_factory=DelegationConfig) + sandbox: SandboxConfig = field(default_factory=SandboxConfig) + rbac: RBACConfig = field(default_factory=RBACConfig) + hitl: HITLConfig = field(default_factory=HITLConfig) + governance: GovernanceConfig = field(default_factory=GovernanceConfig) + security_scan: SecurityScanConfig = field(default_factory=SecurityScanConfig) + compliance: ComplianceConfig = field(default_factory=ComplianceConfig) + sub_workspaces: list[dict] = field(default_factory=list) + + +def load_config(config_path: Optional[str] = None) -> WorkspaceConfig: + """Load config from WORKSPACE_CONFIG_PATH or the given path.""" + if config_path is None: + config_path = os.environ.get("WORKSPACE_CONFIG_PATH", "/configs") + + config_file = Path(config_path) / "config.yaml" + if not config_file.exists(): + raise FileNotFoundError(f"Config file not found: {config_file}") + + with open(config_file) as f: + raw = yaml.safe_load(f) or {} + + # Override model from env if provided + model = os.environ.get("MODEL_PROVIDER", raw.get("model", "anthropic:claude-sonnet-4-6")) + + runtime = raw.get("runtime", "langgraph") + runtime_raw = raw.get("runtime_config", {}) + + a2a_raw = raw.get("a2a", {}) + delegation_raw = raw.get("delegation", {}) + sandbox_raw = raw.get("sandbox", {}) + rbac_raw = raw.get("rbac", {}) + hitl_raw = raw.get("hitl", {}) + governance_raw = raw.get("governance", {}) + # security_scan accepts both shorthand string ("warn") and dict ({"mode": "warn"}) + _ss_raw = raw.get("security_scan", {}) + security_scan_raw = _ss_raw if isinstance(_ss_raw, dict) else {"mode": str(_ss_raw)} + compliance_raw = raw.get("compliance", {}) + + # Resolve initial_prompt: inline string or file reference + initial_prompt = raw.get("initial_prompt", "") + initial_prompt_file = raw.get("initial_prompt_file", "") + if not initial_prompt and initial_prompt_file: + prompt_path = Path(config_path) / initial_prompt_file + if prompt_path.exists(): + initial_prompt = prompt_path.read_text().strip() + + # Resolve idle_prompt: same pattern as initial_prompt + idle_prompt = raw.get("idle_prompt", "") + idle_prompt_file = raw.get("idle_prompt_file", "") + if not idle_prompt and idle_prompt_file: + idle_path = Path(config_path) / idle_prompt_file + if idle_path.exists(): + idle_prompt = idle_path.read_text().strip() + idle_interval_seconds = int(raw.get("idle_interval_seconds", 600)) + + return WorkspaceConfig( + name=raw.get("name", "Workspace"), + description=raw.get("description", ""), + version=raw.get("version", "1.0.0"), + tier=int(raw.get("tier", 1)) if str(raw.get("tier", 1)).isdigit() else 1, + model=model, + runtime=runtime, + initial_prompt=initial_prompt, + idle_prompt=idle_prompt, + idle_interval_seconds=idle_interval_seconds, + runtime_config=RuntimeConfig( + command=runtime_raw.get("command", ""), + args=runtime_raw.get("args", []), + required_env=runtime_raw.get("required_env", []), + timeout=runtime_raw.get("timeout", 0), + model=runtime_raw.get("model", ""), + # Deprecated fields — kept for backward compat + auth_token_env=runtime_raw.get("auth_token_env", ""), + auth_token_file=runtime_raw.get("auth_token_file", ""), + ), + skills=raw.get("skills", []), + plugins=raw.get("plugins", []), + tools=raw.get("tools", []), + prompt_files=raw.get("prompt_files", []), + shared_context=raw.get("shared_context", []), + a2a=A2AConfig( + port=a2a_raw.get("port", 8000), + streaming=a2a_raw.get("streaming", True), + push_notifications=a2a_raw.get("push_notifications", True), + ), + delegation=DelegationConfig( + retry_attempts=delegation_raw.get("retry_attempts", 3), + retry_delay=delegation_raw.get("retry_delay", 5.0), + timeout=delegation_raw.get("timeout", 120.0), + escalate=delegation_raw.get("escalate", True), + ), + sandbox=SandboxConfig( + backend=sandbox_raw.get("backend", "subprocess"), + memory_limit=sandbox_raw.get("memory_limit", "256m"), + timeout=sandbox_raw.get("timeout", 30), + ), + rbac=RBACConfig( + roles=rbac_raw.get("roles", ["operator"]), + allowed_actions=rbac_raw.get("allowed_actions", {}), + ), + hitl=HITLConfig( + channels=hitl_raw.get("channels", [{"type": "dashboard"}]), + default_timeout=float(hitl_raw.get("default_timeout", 300)), + bypass_roles=hitl_raw.get("bypass_roles", []), + ), + governance=GovernanceConfig( + enabled=governance_raw.get("enabled", False), + toolkit=governance_raw.get("toolkit", "microsoft"), + policy_endpoint=governance_raw.get("policy_endpoint", ""), + policy_mode=governance_raw.get("policy_mode", "audit"), + policy_file=governance_raw.get("policy_file", ""), + blocked_patterns=governance_raw.get("blocked_patterns", []), + max_tool_calls_per_task=governance_raw.get("max_tool_calls_per_task", 50), + ), + security_scan=SecurityScanConfig( + mode=security_scan_raw.get("mode", "warn"), + fail_open_if_no_scanner=security_scan_raw.get("fail_open_if_no_scanner", True), + ), + compliance=ComplianceConfig( + mode=compliance_raw.get("mode", ""), + prompt_injection=compliance_raw.get("prompt_injection", "detect"), + max_tool_calls_per_task=int(compliance_raw.get("max_tool_calls_per_task", 50)), + max_task_duration_seconds=int(compliance_raw.get("max_task_duration_seconds", 300)), + ), + sub_workspaces=raw.get("sub_workspaces", []), + ) diff --git a/molecule_runtime/consolidation.py b/molecule_runtime/consolidation.py new file mode 100644 index 0000000..38e4b58 --- /dev/null +++ b/molecule_runtime/consolidation.py @@ -0,0 +1,131 @@ +"""Memory consolidation loop. + +When an agent is idle (no active tasks for a configurable period), +the consolidation loop wakes up and summarizes noisy local memory +entries into dense, high-value knowledge facts. + +Similar to human sleep consolidation — raw scratchpad entries get +compressed into reusable knowledge. +""" + +import asyncio +import logging +import os + +import httpx + +from platform_auth import auth_headers + +logger = logging.getLogger(__name__) + +PLATFORM_URL = os.environ.get("PLATFORM_URL", "http://platform:8080") +WORKSPACE_ID = os.environ.get("WORKSPACE_ID", "") +CONSOLIDATION_INTERVAL = float(os.environ.get("CONSOLIDATION_INTERVAL", "300")) # 5 min +CONSOLIDATION_THRESHOLD = int(os.environ.get("CONSOLIDATION_THRESHOLD", "10")) # min memories before consolidating + + +class ConsolidationLoop: + """Background loop that consolidates local memories when idle.""" + + def __init__(self, agent=None): + self.agent = agent + self._running = False + + async def start(self): + """Start the consolidation loop.""" + self._running = True + logger.info("Memory consolidation loop started (interval=%ss, threshold=%d)", + CONSOLIDATION_INTERVAL, CONSOLIDATION_THRESHOLD) + + while self._running: + await asyncio.sleep(CONSOLIDATION_INTERVAL) + + if not self._running: + break + + try: + await self._consolidate() + except Exception as e: + logger.warning("Consolidation error: %s", e) + + async def _consolidate(self): + """Check if consolidation is needed and run it.""" + async with httpx.AsyncClient(timeout=10.0) as client: + # Fetch local memories + resp = await client.get( + f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/memories", + params={"scope": "LOCAL"}, + headers=auth_headers(), + ) + if resp.status_code != 200: + return + + memories = resp.json() + if len(memories) < CONSOLIDATION_THRESHOLD: + return + + logger.info("Consolidating %d local memories", len(memories)) + + # Build a summary of all local memories + contents = [m["content"] for m in memories] + summary_prompt = ( + "Summarize the following workspace memories into 3-5 key facts. " + "Each fact should be a single, clear sentence capturing the most " + "important and reusable knowledge:\n\n" + + "\n".join(f"- {c}" for c in contents) + ) + + # Use the agent to generate the summary if available + summary = "" + if self.agent: + try: + result = await self.agent.ainvoke( + {"messages": [("user", summary_prompt)]}, + config={"configurable": {"thread_id": "consolidation"}}, + ) + messages = result.get("messages", []) + summary = "" + for msg in reversed(messages): + content = getattr(msg, "content", "") + if isinstance(content, str) and content.strip(): + msg_type = getattr(msg, "type", "") + if msg_type != "human": + summary = content + break + + if summary: + # Store consolidated summary as a TEAM memory — only delete originals if POST succeeds + resp = await client.post( + f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/memories", + json={"content": f"[Consolidated] {summary}", "scope": "TEAM"}, + headers=auth_headers(), + ) + if resp.status_code in (200, 201): + # Safe to delete originals — consolidated version is saved + for m in memories: + await client.delete( + f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/memories/{m['id']}", + headers=auth_headers(), + ) + logger.info("Consolidated %d memories into team knowledge", len(memories)) + else: + logger.warning("Consolidation POST failed (status %d) — keeping originals", resp.status_code) + except Exception as e: + logger.error( + "CONSOLIDATION: Agent summarization failed (rate limit? model error?): %s. " + "Falling back to simple concatenation.", e + ) + # Fall through to concatenation below + + # Fallback: concatenate without agent summarization + if not (self.agent and summary): + combined = " | ".join(contents[:20]) + await client.post( + f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/memories", + json={"content": f"[Consolidated] {combined}", "scope": "TEAM"}, + headers=auth_headers(), + ) + logger.info("Consolidated %d memories via concatenation fallback", len(memories)) + + def stop(self): + self._running = False diff --git a/molecule_runtime/coordinator.py b/molecule_runtime/coordinator.py new file mode 100644 index 0000000..99e9adb --- /dev/null +++ b/molecule_runtime/coordinator.py @@ -0,0 +1,136 @@ +"""Coordinator pattern for team workspaces. + +When a workspace is expanded into a team, the parent agent becomes a +coordinator that routes incoming tasks to the appropriate child workspace +based on the task content and children's capabilities. + +The coordinator: +1. Fetches its children's Agent Cards (skills, capabilities) +2. Analyzes each incoming task to determine which child is best suited +3. Delegates to the chosen child via the delegation tool +4. Aggregates responses if a task requires multiple children +5. Falls back to handling the task itself if no child is appropriate +""" + +import logging +import os + +import httpx +from langchain_core.tools import tool +from adapters.shared_runtime import build_peer_section +from policies.routing import build_team_routing_payload + +logger = logging.getLogger(__name__) + +PLATFORM_URL = os.environ.get("PLATFORM_URL", "http://platform:8080") +WORKSPACE_ID = os.environ.get("WORKSPACE_ID", "") + + +async def get_parent_context() -> list[dict]: + """Fetch shared context files from this workspace's parent. + + Returns a list of {"path": str, "content": str} dicts. + Returns empty list if no parent, parent unreachable, or no shared context. + """ + parent_id = os.environ.get("PARENT_ID", "") + if not parent_id: + return [] + + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.get( + f"{PLATFORM_URL}/workspaces/{parent_id}/shared-context", + headers={"X-Workspace-ID": WORKSPACE_ID}, + ) + if resp.status_code == 200: + return resp.json() + except Exception as e: + logger.warning("Failed to fetch parent context: %s", e) + return [] + + +async def get_children() -> list[dict]: + """Fetch this workspace's children from the platform.""" + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.get( + f"{PLATFORM_URL}/registry/{WORKSPACE_ID}/peers", + headers={"X-Workspace-ID": WORKSPACE_ID}, + ) + if resp.status_code == 200: + peers = resp.json() + # Filter to only children (parent_id == our ID) + return [p for p in peers if p.get("parent_id") == WORKSPACE_ID] + except Exception as e: + logger.warning("Failed to fetch children: %s", e) + return [] + + +def build_children_description(children: list[dict]) -> str: + """Build a description of children's capabilities for the coordinator prompt.""" + if not children: + return "" + + team_section = build_peer_section( + children, + heading="## Your Team (sub-workspaces you coordinate)", + instruction=( + "Use the `delegate_to_workspace` tool to send tasks to the chosen member. " + "Only delegate to members listed above." + ), + ) + + return "\n".join( + [ + team_section, + "", + "### Coordination Rules — MANDATORY", + "1. You are a COORDINATOR. Your ONLY job is to delegate and synthesize. NEVER do the work yourself.", + "2. For EVERY task, use `delegate_to_workspace` to send it to the appropriate team member(s). " + "Do this BEFORE writing any analysis, code, or research yourself.", + "3. If a task spans multiple members, delegate to ALL of them in parallel and aggregate results.", + "4. If ALL members are offline/paused, tell the caller which members are unavailable. " + "Do NOT attempt the work yourself — you lack the specialist context.", + "5. If a delegation FAILS (error, timeout): try another member first. " + "Only provide your own brief summary if NO member can respond. Never forward raw errors.", + "6. Your response should be a SYNTHESIS of your team's work, not your own analysis.", + "7. Always respond in the same language the caller uses.", + ] + ) + + +@tool +async def route_task_to_team( + task: str, + preferred_member_id: str = "", +) -> dict: + """Route a task to the most appropriate team member. + + As the team coordinator, analyze the task and delegate to the best-suited + child workspace. If preferred_member_id is provided, delegate directly to + that member. + + Args: + task: The task description to route. + preferred_member_id: Optional — directly delegate to this member. + """ + from builtin_tools.delegation import delegate_to_workspace as delegate + + children = await get_children() + decision = build_team_routing_payload( + children, + task=task, + preferred_member_id=preferred_member_id, + ) + + if decision.get("action") == "delegate_to_preferred_member": + # Async delegation — returns immediately with task_id + result = await delegate.ainvoke( + { + "workspace_id": decision["preferred_member_id"], + "task": task, + } + ) + return result + + return decision diff --git a/molecule_runtime/events.py b/molecule_runtime/events.py new file mode 100644 index 0000000..a682dca --- /dev/null +++ b/molecule_runtime/events.py @@ -0,0 +1,96 @@ +"""WebSocket subscriber for platform events. + +Subscribes to the platform WebSocket with X-Workspace-ID header +so the workspace only receives events about reachable peers. +Triggers system prompt rebuild on relevant peer changes. +""" + +import asyncio +import json +import logging + +import httpx + +logger = logging.getLogger(__name__) + +# Events that should trigger a system prompt rebuild +REBUILD_EVENTS = { + "WORKSPACE_ONLINE", + "WORKSPACE_OFFLINE", + "WORKSPACE_EXPANDED", + "WORKSPACE_COLLAPSED", + "WORKSPACE_REMOVED", + "AGENT_CARD_UPDATED", +} + + +class PlatformEventSubscriber: + """Subscribes to platform WebSocket for peer events.""" + + def __init__( + self, + platform_url: str, + workspace_id: str, + on_peer_change=None, + ): + self.ws_url = platform_url.replace("http://", "ws://").replace("https://", "wss://") + "/ws" + self.workspace_id = workspace_id + self.on_peer_change = on_peer_change + self._running = False + self._reconnect_delay = 1.0 + + async def start(self): + """Connect to platform WebSocket with exponential backoff reconnect.""" + self._running = True + + while self._running: + try: + await self._connect() + except Exception as e: + if not self._running: + break + logger.warning("WebSocket disconnected: %s. Reconnecting in %.0fs...", e, self._reconnect_delay) + await asyncio.sleep(self._reconnect_delay) + self._reconnect_delay = min(self._reconnect_delay * 2, 30.0) + + async def _connect(self): + """Establish WebSocket connection and process events.""" + try: + import websockets + except ImportError: + logger.warning("websockets package not installed, skipping event subscription") + self._running = False + return + + # Fix D (Cycle 5): include bearer token in WebSocket upgrade so the + # server's new auth check can validate this agent connection. + # Graceful fallback for workspaces that have no token yet. + headers = {"X-Workspace-ID": self.workspace_id} + try: + from platform_auth import auth_headers as _auth_headers + headers.update(_auth_headers()) + except Exception: + pass # No token available — connect unauthenticated (grandfathered) + logger.info("Connecting to platform WebSocket: %s", self.ws_url) + + async with websockets.connect(self.ws_url, additional_headers=headers) as ws: + self._reconnect_delay = 1.0 # Reset on successful connect + logger.info("Platform WebSocket connected") + + async for message in ws: + try: + event = json.loads(message) + event_type = event.get("event", "") + + if event_type in REBUILD_EVENTS: + logger.info("Peer event: %s for workspace %s", + event_type, event.get("workspace_id", "")) + if self.on_peer_change: + await self.on_peer_change(event) + except json.JSONDecodeError: + continue + except Exception as e: + logger.warning("Error processing event: %s", e) + + def stop(self): + self._running = False diff --git a/molecule_runtime/executor_helpers.py b/molecule_runtime/executor_helpers.py new file mode 100644 index 0000000..c435f84 --- /dev/null +++ b/molecule_runtime/executor_helpers.py @@ -0,0 +1,389 @@ +"""Shared helpers for AgentExecutor implementations. + +Used by both CLIAgentExecutor (codex, ollama) and ClaudeSDKExecutor (claude-code). +Provides: +- Memory recall/commit (HTTP to platform /memories endpoints) +- Delegation results consumption (atomic file rename) +- Current task heartbeat updates +- System prompt loading from /configs +- A2A instructions text for system prompt injection (MCP and CLI variants) +- Brief task summary extraction (markdown-aware) +- Error message sanitization (exception classes and subprocess categories) +- Shared workspace path constants and the MCP server path resolver +""" + +from __future__ import annotations + +import json +import logging +import os +import re +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import httpx + +if TYPE_CHECKING: + from heartbeat import HeartbeatLoop + + +logger = logging.getLogger(__name__) + + +# ======================================================================== +# Constants — workspace container layout +# ======================================================================== + +WORKSPACE_MOUNT = "/workspace" +CONFIG_MOUNT = "/configs" +DEFAULT_MCP_SERVER_PATH = "/app/a2a_mcp_server.py" +DEFAULT_DELEGATION_RESULTS_FILE = "/tmp/delegation_results.jsonl" +PLATFORM_HTTP_TIMEOUT_S = 5.0 +MEMORY_RECALL_LIMIT = 10 +MEMORY_CONTENT_MAX_CHARS = 200 +BRIEF_SUMMARY_MAX_LEN = 80 + + +def get_mcp_server_path() -> str: + """Return the path to the stdio MCP server script. + + Overridable via A2A_MCP_SERVER_PATH for tests and non-default layouts. + """ + return os.environ.get("A2A_MCP_SERVER_PATH", DEFAULT_MCP_SERVER_PATH) + + +# ======================================================================== +# HTTP client (shared, lazily initialised) +# ======================================================================== + +_http_client: httpx.AsyncClient | None = None + + +def get_http_client() -> httpx.AsyncClient: + """Lazy-init a shared httpx client for platform API calls.""" + global _http_client + if _http_client is None or _http_client.is_closed: + _http_client = httpx.AsyncClient(timeout=PLATFORM_HTTP_TIMEOUT_S) + return _http_client + + +def reset_http_client_for_tests() -> None: + """Test helper — drop the shared client so the next call rebuilds it. + + Not for production use. Exposed so tests can guarantee a clean slate + between cases without touching module internals. + """ + global _http_client + _http_client = None + + +# ======================================================================== +# Memory recall + commit +# ======================================================================== + +async def recall_memories() -> str: + """Recall recent memories from the platform API. + + Returns a newline-joined bullet list of up to MEMORY_RECALL_LIMIT most recent + memories, or empty string when the platform is unreachable / not configured + / returns a non-200 / returns an unexpected payload shape. + """ + workspace_id = os.environ.get("WORKSPACE_ID", "") + platform_url = os.environ.get("PLATFORM_URL", "") + if not workspace_id or not platform_url: + return "" + # Fix E (Cycle 5): send auth headers so the WorkspaceAuth middleware + # (Fix A) allows access once the workspace has a live token on file. + try: + from platform_auth import auth_headers as _platform_auth + _auth = _platform_auth() + except Exception: + _auth = {} + try: + resp = await get_http_client().get( + f"{platform_url}/workspaces/{workspace_id}/memories", + headers=_auth, + ) + if not 200 <= resp.status_code < 300: + logger.debug( + "recall_memories: non-2xx response %s from platform", + resp.status_code, + ) + return "" + data = resp.json() + except Exception as exc: + logger.debug("recall_memories: request failed: %s", exc) + return "" + if not isinstance(data, list) or not data: + return "" + lines = [ + f"- [{m.get('scope', '?')}] {m.get('content', '')}" + for m in data[-MEMORY_RECALL_LIMIT:] + ] + return "\n".join(lines) + + +async def commit_memory(content: str) -> None: + """Save a memory to the platform API. Best-effort, no error propagation.""" + workspace_id = os.environ.get("WORKSPACE_ID", "") + platform_url = os.environ.get("PLATFORM_URL", "") + if not workspace_id or not platform_url or not content: + return + # Fix E (Cycle 5): include auth header so WorkspaceAuth middleware allows access. + try: + from platform_auth import auth_headers as _platform_auth + _auth = _platform_auth() + except Exception: + _auth = {} + try: + await get_http_client().post( + f"{platform_url}/workspaces/{workspace_id}/memories", + json={"content": content, "scope": "LOCAL"}, + headers=_auth, + ) + except Exception as exc: + logger.debug("commit_memory: request failed: %s", exc) + + +# ======================================================================== +# Delegation results — written by heartbeat loop, consumed atomically +# ======================================================================== + +def read_delegation_results() -> str: + """Read and consume delegation results written by the heartbeat loop. + + Uses atomic rename to prevent races with the heartbeat writer. + Returns formatted text suitable for prompt injection, or empty string. + """ + results_file = Path( + os.environ.get("DELEGATION_RESULTS_FILE", DEFAULT_DELEGATION_RESULTS_FILE) + ) + if not results_file.exists(): + return "" + consumed = results_file.with_suffix(".consumed") + try: + results_file.rename(consumed) + except OSError: + return "" # File disappeared between exists() and rename() + try: + raw = consumed.read_text(encoding="utf-8", errors="replace") + except OSError: + return "" + finally: + consumed.unlink(missing_ok=True) + + parts: list[str] = [] + for line in raw.strip().split("\n"): + if not line.strip(): + continue + try: + record = json.loads(line) + except json.JSONDecodeError: + continue + status = record.get("status", "?") + summary = record.get("summary", "") + preview = record.get("response_preview", "") + parts.append(f"- [{status}] {summary}") + if preview: + parts.append(f" Response: {preview[:200]}") + return "\n".join(parts) + + +# ======================================================================== +# Current task heartbeat update +# ======================================================================== + +async def set_current_task(heartbeat: "HeartbeatLoop | None", task: str) -> None: + """Update current task on heartbeat and push immediately via platform API.""" + if heartbeat is not None: + heartbeat.current_task = task + heartbeat.active_tasks = 1 if task else 0 + workspace_id = os.environ.get("WORKSPACE_ID", "") + platform_url = os.environ.get("PLATFORM_URL", "") + if not (workspace_id and platform_url): + return + try: + try: + from platform_auth import auth_headers as _auth + _headers = _auth() + except Exception: + _headers = {} + await get_http_client().post( + f"{platform_url}/registry/heartbeat", + json={ + "workspace_id": workspace_id, + "current_task": task, + "active_tasks": 1 if task else 0, + "error_rate": 0, + "sample_error": "", + "uptime_seconds": 0, + }, + headers=_headers, + ) + except Exception as exc: + logger.debug("set_current_task: heartbeat push failed: %s", exc) + + +# ======================================================================== +# System prompt loading +# ======================================================================== + +def get_system_prompt(config_path: str, fallback: str | None = None) -> str | None: + """Read system-prompt.md from the config dir each call (supports hot-reload). + + Falls back to the provided string if the file doesn't exist. + """ + prompt_file = Path(config_path) / "system-prompt.md" + if prompt_file.exists(): + return prompt_file.read_text(encoding="utf-8", errors="replace").strip() + return fallback + + +_A2A_INSTRUCTIONS_MCP = """## Inter-Agent Communication +You have MCP tools for communicating with other workspaces: +- list_peers: discover available peer workspaces (name, ID, status, role) +- delegate_task: send a task and WAIT for the response (for quick tasks) +- delegate_task_async: send a task and return immediately with a task_id (for long tasks) +- check_task_status: poll an async task's status and get results when done +- get_workspace_info: get your own workspace info + +For quick questions, use delegate_task (synchronous). +For long-running work (building pages, running audits), use delegate_task_async + check_task_status. +Always use list_peers first to discover available workspace IDs. +Access control is enforced — you can only reach siblings and parent/children. + +PROACTIVE MESSAGING: Use send_message_to_user to push messages to the user's chat at ANY time: +- Acknowledge tasks immediately: "Got it, delegating to the team now..." +- Send progress updates during long work: "Research Lead finished, waiting on Dev Lead..." +- Deliver follow-up results: "All teams reported back. Here's the synthesis: ..." +This lets you respond quickly ("I'll work on this") and come back later with results. + +If delegate_task returns a DELEGATION FAILED message, do NOT forward the raw error to the user. +Instead: (1) try delegating to a different peer, (2) handle the task yourself, or +(3) tell the user which peer is unavailable and provide your own best answer.""" + + +_A2A_INSTRUCTIONS_CLI = """## Inter-Agent Communication +You can delegate tasks to other workspaces using the a2a command: + python3 /app/a2a_cli.py peers # List available peers + python3 /app/a2a_cli.py delegate # Sync: wait for response + python3 /app/a2a_cli.py delegate --async # Async: return task_id + python3 /app/a2a_cli.py status # Check async task + python3 /app/a2a_cli.py info # Your workspace info + +For quick questions, use sync delegate. For long tasks, use --async + status. +Only delegate to peers listed by the peers command (access control enforced).""" + + +def get_a2a_instructions(mcp: bool = True) -> str: + """Return inter-agent communication instructions for system-prompt injection. + + Pass `mcp=True` (default) for MCP-capable runtimes (Claude Code via SDK, + Codex). Pass `mcp=False` for CLI-only runtimes (Ollama, custom) that have + to call a2a_cli.py as a subprocess. + """ + return _A2A_INSTRUCTIONS_MCP if mcp else _A2A_INSTRUCTIONS_CLI + + +# ======================================================================== +# Misc text helpers +# ======================================================================== + +_MARKDOWN_FENCE = "```" +_MARKDOWN_HR = "---" + + +_BRIEF_SUMMARY_MIN_LEN = 4 # 1 char + 3-char ellipsis + + +def brief_summary(text: str, max_len: int = BRIEF_SUMMARY_MAX_LEN) -> str: + """Extract a one-line task summary for the canvas card display. + + Strips markdown headers (#, ##, ###), bold/italic markers (**, __), + and skips code fences and horizontal rules. Returns the first meaningful + line, truncated with an ellipsis when it exceeds `max_len`. + + `max_len` is clamped to at least 4 (one real character plus a 3-char + ellipsis) so degenerate callers can't produce negative slice indices. + """ + max_len = max(max_len, _BRIEF_SUMMARY_MIN_LEN) + for raw_line in text.split("\n"): + line = raw_line.strip() + while line.startswith("#"): + line = line[1:] + line = line.strip() + if not line or line.startswith(_MARKDOWN_FENCE) or line == _MARKDOWN_HR: + continue + line = line.replace("**", "").replace("__", "") + if len(line) > max_len: + return line[: max_len - 3] + "..." + return line + return text[:max_len] + + +def extract_message_text(message: Any) -> str: + """Extract text from an A2A message (handles both .text and .root.text patterns).""" + parts = getattr(message, "parts", None) or [] + text_parts: list[str] = [] + for part in parts: + text = getattr(part, "text", None) + if text: + text_parts.append(text) + continue + root = getattr(part, "root", None) + if root is not None: + root_text = getattr(root, "text", None) + if root_text: + text_parts.append(root_text) + return " ".join(text_parts).strip() + + +# Word-boundary patterns for subprocess stderr classification. Using word +# boundaries avoids false positives like "author" matching "auth" or +# "generate" matching "rate". +_RATE_LIMIT_RE = re.compile(r"\brate\b|\b429\b|\boverloaded\b", re.IGNORECASE) +_AUTH_RE = re.compile(r"\bauth(?:entication|orization)?\b|\bapi[_-]?key\b", re.IGNORECASE) +_SESSION_RE = re.compile(r"\bsession\b|\bno conversation found\b", re.IGNORECASE) + + +def classify_subprocess_error(stderr_text: str, exit_code: int | None) -> str: + """Map a subprocess stderr blob to a short, user-safe category tag. + + The full stderr goes to the workspace logs via `logger.error`; only the + category is surfaced to the user to avoid leaking tokens, internal paths, + or stack traces in the chat UI. Used with `sanitize_agent_error` to + produce a user-facing message for subprocess failures. + """ + if _RATE_LIMIT_RE.search(stderr_text): + return "rate_limited" + if _AUTH_RE.search(stderr_text): + return "auth_failed" + if _SESSION_RE.search(stderr_text): + return "session_error" + if exit_code is not None and exit_code != 0: + return f"exit_{exit_code}" + return "subprocess_error" + + +def sanitize_agent_error( + exc: BaseException | None = None, + category: str | None = None, +) -> str: + """Render an agent-side failure into a user-safe error message. + + Either pass an exception (class name is used as the tag) or an explicit + category string (e.g. from `classify_subprocess_error`). If both are + given, `category` wins. If neither, the tag defaults to "unknown". + + The message body is deliberately dropped — exception messages and + subprocess stderr frequently leak stack traces, paths, tokens, and + API keys. Full detail is available in the workspace logs via + `logger.exception()` / `logger.error()`. + """ + if category: + tag = category + elif exc is not None: + tag = type(exc).__name__ + else: + tag = "unknown" + return f"Agent error ({tag}) — see workspace logs for details." diff --git a/molecule_runtime/heartbeat.py b/molecule_runtime/heartbeat.py new file mode 100644 index 0000000..a67bec7 --- /dev/null +++ b/molecule_runtime/heartbeat.py @@ -0,0 +1,291 @@ +"""Heartbeat loop — alive signal + delegation status checker. + +Every 30 seconds: +1. Send heartbeat to platform (alive signal with current_task, error_rate) +2. Check pending delegations — any results back? +3. Store completed delegation results for the agent to pick up + +Resilient: recreates HTTP client on failure, auto-restarts on crash. +""" + +import asyncio +import json +import logging +import os +import time +from pathlib import Path + +import httpx + +from platform_auth import auth_headers + +logger = logging.getLogger(__name__) + +HEARTBEAT_INTERVAL = 30 # seconds +MAX_CONSECUTIVE_FAILURES = 10 +MAX_SEEN_DELEGATION_IDS = 200 +SELF_MESSAGE_COOLDOWN = 60 # seconds — minimum between self-messages to prevent loops +# Shared path — also used by cli_executor._read_delegation_results() +DELEGATION_RESULTS_FILE = os.environ.get("DELEGATION_RESULTS_FILE", "/tmp/delegation_results.jsonl") + + +class HeartbeatLoop: + def __init__(self, platform_url: str, workspace_id: str): + self.platform_url = platform_url + self.workspace_id = workspace_id + self.start_time = time.time() + self.error_count = 0 + self.request_count = 0 + self.active_tasks = 0 + self.current_task = "" + self.sample_error = "" + self._task = None + self._consecutive_failures = 0 + self._seen_delegation_ids: set[str] = set() + self._last_self_message_time = 0.0 + self._parent_name: str | None = None # Cached after first lookup + + @property + def error_rate(self) -> float: + if self.request_count == 0: + return 0.0 + return self.error_count / self.request_count + + def record_error(self, error: str): + self.error_count += 1 + self.request_count += 1 + self.sample_error = error + + def record_success(self): + self.request_count += 1 + + def start(self): + self._task = asyncio.create_task(self._loop()) + self._task.add_done_callback(self._on_done) + + def _on_done(self, task): + if not task.cancelled() and task.exception(): + logger.error("Heartbeat loop died: %s — restarting", task.exception()) + self._task = asyncio.create_task(self._loop()) + self._task.add_done_callback(self._on_done) + + async def stop(self): + if self._task: + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + + async def _loop(self): + while True: + client = None + try: + client = httpx.AsyncClient(timeout=10.0) + while True: + # 1. Send heartbeat (Phase 30.1: include auth header if token known) + try: + await client.post( + f"{self.platform_url}/registry/heartbeat", + json={ + "workspace_id": self.workspace_id, + "error_rate": self.error_rate, + "sample_error": self.sample_error, + "active_tasks": self.active_tasks, + "current_task": self.current_task, + "uptime_seconds": int(time.time() - self.start_time), + }, + headers=auth_headers(), + ) + self.error_count = 0 + self.request_count = 0 + self._consecutive_failures = 0 + except Exception as e: + self._consecutive_failures += 1 + if self._consecutive_failures <= 3 or self._consecutive_failures % MAX_CONSECUTIVE_FAILURES == 0: + logger.warning("Heartbeat failed (%d consecutive): %s", self._consecutive_failures, e) + if self._consecutive_failures >= MAX_CONSECUTIVE_FAILURES: + logger.info("Heartbeat: recreating HTTP client after %d failures", self._consecutive_failures) + try: + await client.aclose() + except Exception: + pass + break + + # 2. Check delegation status + try: + await self._check_delegations(client) + except Exception as e: + logger.debug("Delegation check failed: %s", e) + + await asyncio.sleep(HEARTBEAT_INTERVAL) + + except asyncio.CancelledError: + raise + except Exception as e: + logger.error("Heartbeat loop error: %s — retrying in 30s", e) + await asyncio.sleep(HEARTBEAT_INTERVAL) + finally: + if client: + try: + await client.aclose() + except Exception: + pass + + async def _check_delegations(self, client: httpx.AsyncClient): + """Check for completed delegations and store results for the agent.""" + try: + resp = await client.get( + f"{self.platform_url}/workspaces/{self.workspace_id}/delegations", + headers=auth_headers(), + ) + if resp.status_code != 200: + return + + delegations = resp.json() + if not isinstance(delegations, list): + return + + new_results = [] + for d in delegations: + did = d.get("delegation_id", "") + status = d.get("status", "") + + if not did or did in self._seen_delegation_ids: + continue + + if status in ("completed", "failed"): + # Fix B (Cycle 5): validate source_id before accepting delegation + # results. Only process delegations that THIS workspace created + # (source_id == self.workspace_id). Attacker-crafted delegation + # records with a foreign source_id cannot inject instructions. + source_id = d.get("source_id", "") + if source_id != self.workspace_id: + logger.warning( + "Heartbeat: skipping delegation %s — source_id %r does not " + "match this workspace %r; possible injection attempt", + did, source_id, self.workspace_id, + ) + self._seen_delegation_ids.add(did) # mark seen so we don't warn again + continue + + self._seen_delegation_ids.add(did) + new_results.append({ + "delegation_id": did, + "target_id": d.get("target_id", ""), + "source_id": source_id, + "status": status, + "summary": d.get("summary", ""), + "response_preview": d.get("response_preview", ""), + "error": d.get("error", ""), + "timestamp": time.time(), + }) + + # Evict old seen IDs if over limit + if len(self._seen_delegation_ids) > MAX_SEEN_DELEGATION_IDS: + # Keep most recent half + self._seen_delegation_ids = set(list(self._seen_delegation_ids)[MAX_SEEN_DELEGATION_IDS // 2:]) + + if new_results: + # Append to results file for context injection on next message + with open(DELEGATION_RESULTS_FILE, "a") as f: + for r in new_results: + f.write(json.dumps(r) + "\n") + logger.info("Heartbeat: %d new delegation results — triggering self-message", len(new_results)) + + # Build a summary message for the agent. + # Fix B (Cycle 5): do NOT embed raw response_preview text in + # user-role A2A messages — that is the prompt-injection vector. + # Instead reference only the delegation ID and status; the agent + # reads full content from DELEGATION_RESULTS_FILE which was + # written above from trusted platform data. + summary_lines = [] + for r in new_results: + line = f"- [{r['status']}] Delegation {r['delegation_id'][:8]}: {r['summary'][:80]}" + if r.get("error"): + line += f"\n Error: {r['error'][:100]}" + summary_lines.append(line) + + # Look up parent workspace (cached after first call) + if self._parent_name is None: + try: + parent_resp = await client.get( + f"{self.platform_url}/workspaces/{self.workspace_id}", + headers=auth_headers(), + ) + if parent_resp.status_code == 200: + parent_id = parent_resp.json().get("parent_id", "") + if parent_id: + parent_info = await client.get( + f"{self.platform_url}/workspaces/{parent_id}", + headers=auth_headers(), + ) + if parent_info.status_code == 200: + self._parent_name = parent_info.json().get("name", "") + if self._parent_name is None: + self._parent_name = "" # No parent — cache empty + except Exception: + pass # Will retry next cycle + parent_name = self._parent_name or "" + + report_instruction = "" + if parent_name: + report_instruction = ( + f"\n\nIMPORTANT: Report these results back to your parent '{parent_name}' " + f"by delegating a summary to them. Use delegate_task or delegate_task_async " + f"with a concise status report. Also use send_message_to_user to notify the user." + ) + else: + report_instruction = ( + "\n\nReport results using send_message_to_user to notify the user." + ) + + trigger_msg = ( + "Delegation results are ready. Review them and take appropriate action:\n" + + "\n".join(summary_lines) + + report_instruction + ) + + # Send A2A self-message to wake the agent. + # Minimum 60s between self-messages to avoid spam, but always send + # when there are genuinely NEW results to process. + now = time.time() + if now - self._last_self_message_time < SELF_MESSAGE_COOLDOWN: + logger.debug("Heartbeat: self-message cooldown (60s), will retry next cycle") + else: + self._last_self_message_time = now + try: + await client.post( + f"{self.platform_url}/workspaces/{self.workspace_id}/a2a", + json={ + "method": "message/send", + "params": { + "message": { + "role": "user", + "parts": [{"type": "text", "text": trigger_msg}], + }, + }, + }, + headers=auth_headers(), + timeout=120.0, + ) + logger.info("Heartbeat: self-message sent to process delegation results") + except Exception as e: + logger.warning("Heartbeat: failed to send self-message: %s", e) + + # Also push notification to user via canvas + for r in new_results: + try: + msg = f"Delegation {r['status']}: {r['summary'][:100]}" + if r.get("response_preview"): + msg += f"\nResult: {r['response_preview'][:200]}" + await client.post( + f"{self.platform_url}/workspaces/{self.workspace_id}/notify", + json={"message": msg, "type": "delegation_result"}, + headers=auth_headers(), + ) + except Exception: + pass + + except Exception as e: + logger.debug("Delegation check error: %s", e) diff --git a/molecule_runtime/initial_prompt.py b/molecule_runtime/initial_prompt.py new file mode 100644 index 0000000..e5ba69b --- /dev/null +++ b/molecule_runtime/initial_prompt.py @@ -0,0 +1,51 @@ +"""Helpers for the workspace's one-shot initial_prompt. + +Kept as a standalone module (no heavy imports like uvicorn) so the marker +logic is unit-testable without standing up the full workspace runtime. + +Background: the workspace runtime supports an `initial_prompt` that runs once +on first boot (clone the repo, set git hooks, read CLAUDE.md, commit_memory). +A marker file `.initial_prompt_done` prevents the prompt from re-running on +subsequent boots. + +Prior behaviour wrote the marker AFTER the prompt completed successfully. If +the prompt crashed mid-execution (e.g. ProcessError from a stale Claude +session), the marker was never written; every subsequent container boot +replayed the same failing prompt, cascading into "every message crashes until +an operator intervenes." See GitHub issue #71. + +Fix (2026-04-12): write the marker BEFORE firing the prompt. If the prompt +fails, operators re-send it manually via chat — cheap and available — instead +of trapping the workspace in a crash loop. +""" +from __future__ import annotations + +import os + + +def resolve_initial_prompt_marker(config_path: str) -> str: + """Return the path where the `.initial_prompt_done` marker should live. + + Prefers ``/.initial_prompt_done`` when the directory is + writable; falls back to ``/workspace/.initial_prompt_done`` for containers + where ``/configs`` is read-only. + """ + if os.access(config_path, os.W_OK): + return os.path.join(config_path, ".initial_prompt_done") + return "/workspace/.initial_prompt_done" + + +def mark_initial_prompt_attempted(marker_path: str) -> bool: + """Write the marker best-effort. Return True on success, False on I/O error. + + Called BEFORE the initial-prompt self-message is sent. If the attempt + later fails, the marker is still present — so the next container boot + does NOT replay the same failing prompt. Operators retry manually via + the chat interface instead of relying on auto-replay. + """ + try: + with open(marker_path, "w") as f: + f.write("attempted") + return True + except OSError: + return False diff --git a/molecule_runtime/main.py b/molecule_runtime/main.py new file mode 100644 index 0000000..4ff54e7 --- /dev/null +++ b/molecule_runtime/main.py @@ -0,0 +1,556 @@ +"""Workspace runtime entry point. + +Loads config -> discovers adapter -> setup -> create executor -> wrap in A2A -> register -> heartbeat. +""" + +import asyncio +import json +import os +import socket +import sys + +# When running as the installed `molecule-runtime` console script, the flat +# module names (config, heartbeat, adapters, etc.) need to resolve to this +# package's submodules. We inject the package directory onto sys.path so that +# `from config import ...` resolves to `molecule_runtime/config.py`. +_PKG_DIR = os.path.dirname(__file__) +if _PKG_DIR not in sys.path: + sys.path.insert(0, _PKG_DIR) + +import httpx +import uvicorn +from a2a.server.apps import A2AStarletteApplication +from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.tasks import InMemoryTaskStore +from a2a.types import AgentCard, AgentCapabilities, AgentSkill + +from adapters import get_adapter, AdapterConfig +from config import load_config +from heartbeat import HeartbeatLoop +from preflight import run_preflight, render_preflight_report +from builtin_tools.awareness_client import get_awareness_config +import uuid as _uuid + +from builtin_tools.telemetry import setup_telemetry, make_trace_middleware +from policies.namespaces import resolve_awareness_namespace + + +from initial_prompt import ( + mark_initial_prompt_attempted, + resolve_initial_prompt_marker, +) +from platform_auth import auth_headers + + +def get_machine_ip() -> str: # pragma: no cover + """Get the machine's IP for A2A discovery.""" + try: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect(("8.8.8.8", 80)) + ip = s.getsockname()[0] + s.close() + return ip + except Exception: + return "127.0.0.1" + + +# Re-exported from transcript_auth for the inline /transcript handler. +# Separate module keeps the security-critical gate import-light + unit-testable. +from transcript_auth import transcript_authorized as _transcript_authorized + + +async def main(): # pragma: no cover + workspace_id = os.environ.get("WORKSPACE_ID", "workspace-default") + config_path = os.environ.get("WORKSPACE_CONFIG_PATH", "/configs") + platform_url = os.environ.get("PLATFORM_URL", "http://platform:8080") + awareness_config = get_awareness_config() + + # 0. Initialise OpenTelemetry (no-op if packages not installed) + setup_telemetry(service_name=workspace_id) + + # 1. Load config + config = load_config(config_path) + port = config.a2a.port + preflight = run_preflight(config, config_path) + render_preflight_report(preflight) + if not preflight.ok: + raise SystemExit(1) + if awareness_config: + awareness_namespace = resolve_awareness_namespace( + workspace_id, + awareness_config.get("namespace", ""), + ) + print(f"Awareness enabled for namespace: {awareness_namespace}") + + # 1.5 Initialise governance adapter (no-op if disabled or package absent) + from builtin_tools.governance import initialize_governance + if config.governance.enabled: + await initialize_governance(config.governance) + print(f"Governance: Microsoft Agent Governance Toolkit enabled (mode={config.governance.policy_mode})") + else: + print("Governance: disabled (set governance.enabled: true in config.yaml to activate)") + + # 2. Create heartbeat (passed to adapter for task tracking) + heartbeat = HeartbeatLoop(platform_url, workspace_id) + + # 3. Get adapter for this runtime + runtime = config.runtime or "langgraph" + adapter_cls = get_adapter(runtime) # Raises KeyError if unknown — no silent fallback + + adapter = adapter_cls() + print(f"Runtime: {runtime} ({adapter.display_name()})") + + # 4. Build adapter config + adapter_config = AdapterConfig( + model=config.model, + system_prompt=None, # Adapter builds its own prompt + tools=config.skills, # Skill names from config.yaml + runtime_config=vars(config.runtime_config) if config.runtime_config else {}, + config_path=config_path, + workspace_id=workspace_id, + prompt_files=config.prompt_files, + a2a_port=port, + heartbeat=heartbeat, + ) + + # 5. Setup adapter and create executor + # If setup fails, ensure heartbeat is stopped to prevent resource leak + try: + await adapter.setup(adapter_config) + executor = await adapter.create_executor(adapter_config) + except Exception: + # heartbeat hasn't started yet but may have async tasks pending + if hasattr(heartbeat, "stop"): + try: + await heartbeat.stop() + except Exception: + pass + raise + + # 5.5. Initialise Temporal durable execution wrapper (optional) + # Connects to TEMPORAL_HOST (default: localhost:7233) and starts a + # co-located Temporal worker as a background asyncio task. + # No-op with a warning log if Temporal is unreachable or temporalio + # is not installed — all tasks fall back to direct execution transparently. + from builtin_tools.temporal_workflow import create_wrapper as _create_temporal_wrapper + temporal_wrapper = _create_temporal_wrapper() + await temporal_wrapper.start() + + # Get loaded skills for agent card (adapter may have populated them) + loaded_skills = getattr(adapter, "loaded_skills", []) + + # 6. Build Agent Card + machine_ip = os.environ.get("HOSTNAME", get_machine_ip()) + workspace_url = f"http://{machine_ip}:{port}" + + agent_card = AgentCard( + name=config.name, + description=config.description or config.name, + version=config.version, + url=workspace_url, + capabilities=AgentCapabilities( + streaming=config.a2a.streaming, + pushNotifications=config.a2a.push_notifications, + stateTransitionHistory=True, + ), + skills=[ + AgentSkill( + id=skill.metadata.id, + name=skill.metadata.name, + description=skill.metadata.description, + tags=skill.metadata.tags, + examples=skill.metadata.examples, + ) + for skill in loaded_skills + ], + defaultInputModes=["text/plain", "application/json"], + defaultOutputModes=["text/plain", "application/json"], + ) + + # 7. Wrap in A2A. + # + # Regression fix (#204): PR #198 tried to wire push_config_store + + # push_sender to satisfy #175 (push notification capability), but + # PushNotificationSender is an abstract base class in the a2a-sdk and + # can't be instantiated directly. Passing it crashed main.py on startup + # with `TypeError: Can't instantiate abstract class`. Dropped back to + # DefaultRequestHandler's own defaults — pushNotifications capability + # in the AgentCard below is still advertised via AgentCapabilities so + # clients know we COULD do pushes; actually implementing them requires + # a concrete sender subclass, tracked as a Phase-H follow-up to #175. + handler = DefaultRequestHandler( + agent_executor=executor, + task_store=InMemoryTaskStore(), + ) + + app = A2AStarletteApplication( + agent_card=agent_card, + http_handler=handler, + ) + + # 8. Register with platform + agent_card_dict = { + "name": config.name, + "description": config.description, + "version": config.version, + "url": workspace_url, + "skills": [ + { + "id": s.metadata.id, + "name": s.metadata.name, + "description": s.metadata.description, + "tags": s.metadata.tags, + } + for s in loaded_skills + ], + "capabilities": { + "streaming": config.a2a.streaming, + "pushNotifications": config.a2a.push_notifications, + }, + } + + async with httpx.AsyncClient(timeout=10.0) as client: + try: + resp = await client.post( + f"{platform_url}/registry/register", + json={ + "id": workspace_id, + "url": workspace_url, + "agent_card": agent_card_dict, + }, + headers=auth_headers(), + ) + print(f"Registered with platform: {resp.status_code}") + # Phase 30.1 — capture the auth token issued at first register. + # The platform only mints one on first register per workspace, + # so a subsequent restart gets an empty auth_token and we + # keep using the on-disk copy from the original issuance. + if resp.status_code == 200: + try: + body = resp.json() + tok = body.get("auth_token") + if tok: + from platform_auth import save_token + save_token(tok) + print(f"Saved workspace auth token (prefix={tok[:8]}…)") + except Exception as parse_exc: + print(f"Warning: couldn't parse register response for token: {parse_exc}") + except Exception as e: + print(f"Warning: failed to register with platform: {e}") + + # 9. Start heartbeat + heartbeat.start() + + # 9b. Start skills hot-reload watcher (background task) + # When a skill file changes the watcher reloads the skill module and calls + # back into the adapter so the next A2A request uses the updated tools. + if config.skills: + try: + from skill_loader.watcher import SkillsWatcher + + def _on_skill_reload(updated_skill): + """Rebuild the LangGraph agent when a skill changes in-place.""" + if not hasattr(adapter, "loaded_skills"): + return + # Replace the matching skill in the adapter's skill list + adapter.loaded_skills = [ + updated_skill if s.metadata.id == updated_skill.metadata.id else s + for s in adapter.loaded_skills + ] + # Rebuild the agent's tool list from updated skills + if hasattr(adapter, "all_tools") and hasattr(adapter, "system_prompt"): + from builtin_tools.approval import request_approval + from builtin_tools.delegation import delegate_to_workspace + from builtin_tools.memory import commit_memory, search_memory + from builtin_tools.sandbox import run_code + base_tools = [delegate_to_workspace, request_approval, + commit_memory, search_memory, run_code] + skill_tools = [] + for sk in adapter.loaded_skills: + skill_tools.extend(sk.tools) + adapter.all_tools = base_tools + skill_tools + # Rebuild compiled agent so next ainvoke picks up new tools + try: + from agent import create_agent + new_agent = create_agent( + config.model, adapter.all_tools, adapter.system_prompt + ) + executor.agent = new_agent + print(f"Skills hot-reload: '{updated_skill.metadata.id}' reloaded — " + f"{len(updated_skill.tools)} tool(s)") + except Exception as rebuild_err: + print(f"Skills hot-reload: agent rebuild failed: {rebuild_err}") + + skills_watcher = SkillsWatcher( + config_path=config_path, + skill_names=config.skills, + on_reload=_on_skill_reload, + ) + asyncio.create_task(skills_watcher.start()) + print(f"Skills hot-reload enabled for: {config.skills}") + except Exception as e: + print(f"Warning: skills watcher could not start: {e}") + + # 10. Run A2A server + print(f"Workspace {workspace_id} starting on port {port}") + # Wrap the ASGI app with W3C TraceContext extraction middleware so incoming + # A2A HTTP requests propagate their trace context into _incoming_trace_context. + starlette_app = app.build() + + # Add /transcript route — exposes the most-recent agent session log + # (claude-code reads ~/.claude/projects//.jsonl). Other + # runtimes return supported:false. + from starlette.responses import JSONResponse + from starlette.routing import Route + + async def _transcript_handler(request): + # Require workspace bearer token — the same token issued at registration + # and stored in /configs/.auth_token. Any container on molecule-monorepo-net + # could otherwise read the full session log. Closes #287. + # + # #328: fail CLOSED when the token file is unavailable. get_token() + # returns None during the bootstrap window (first register hasn't + # completed), if /configs/.auth_token was deleted, or on OSError. + # The old `if expected:` guard treated all three cases as "skip + # auth" — an unauthenticated container on the same Docker network + # could read the entire session log during that window. Deny + # instead. The platform's TranscriptHandler acquires the token + # during registration, so once the bootstrap completes it always + # has a valid credential to present. + from platform_auth import get_token + if not _transcript_authorized(get_token(), request.headers.get("Authorization", "")): + return JSONResponse({"error": "unauthorized"}, status_code=401) + try: + since = int(request.query_params.get("since", "0")) + limit = int(request.query_params.get("limit", "100")) + except (TypeError, ValueError): + return JSONResponse({"error": "since and limit must be integers"}, status_code=400) + result = await adapter.transcript_lines(since=since, limit=limit) + return JSONResponse(result) + + starlette_app.add_route("/transcript", _transcript_handler, methods=["GET"]) + + built_app = make_trace_middleware(starlette_app) + + server_config = uvicorn.Config( + built_app, + host="0.0.0.0", + port=port, + log_level="info", + ) + server = uvicorn.Server(server_config) + + # 10b. Schedule initial_prompt self-message after server is ready. + # Only runs on first boot — creates a marker file to prevent re-execution on restart. + initial_prompt_task = None + initial_prompt_marker = resolve_initial_prompt_marker(config_path) + if config.initial_prompt and not os.path.exists(initial_prompt_marker): + # Write the marker UP FRONT (#71): if the prompt later crashes or + # times out, we do NOT replay on next boot — that created a + # ProcessError cascade where every message kept crashing. Operators + # can always re-send via chat. Log loudly if the marker write + # fails so the situation is visible. + if not mark_initial_prompt_attempted(initial_prompt_marker): + print( + f"Initial prompt: WARNING — could not write marker at " + f"{initial_prompt_marker}; this boot may replay if it crashes.", + flush=True, + ) + async def _send_initial_prompt(): + """Wait for server to be ready, then send initial_prompt as self-message.""" + # Wait for the A2A server to accept connections + ready = False + for attempt in range(30): + await asyncio.sleep(1) + try: + async with httpx.AsyncClient(timeout=5.0) as client: + resp = await client.get(f"http://127.0.0.1:{port}/.well-known/agent.json") + if resp.status_code == 200: + ready = True + break + except Exception: + continue + + if not ready: + print("Initial prompt: server not ready after 30s, skipping", flush=True) + return + + # Send initial prompt through the platform A2A proxy (not directly to self). + # The proxy logs an a2a_receive with source_id=NULL (canvas-style), + # broadcasts A2A_RESPONSE via WebSocket so the chat shows both the + # prompt (as user message) and the response (as agent message). + # Uses urllib in a thread to avoid asyncio/httpx streaming hangs. + import json as _json + import urllib.request + + def _do_send_sync(): + import time as _time + payload = _json.dumps({ + "method": "message/send", + "params": { + "message": { + "role": "user", + "messageId": f"initial-{_uuid.uuid4().hex[:8]}", + "parts": [{"kind": "text", "text": config.initial_prompt}], + }, + }, + }).encode() + + # #220: include platform bearer token so the request isn't + # silently rejected once any workspace has a live token on + # file. Without this, initial_prompt 401s in multi-tenant + # mode exactly like /registry/register did in #215. + headers = {"Content-Type": "application/json", **auth_headers()} + + # Retry with backoff — the platform proxy may not be able to + # reach us yet (container networking takes a moment to settle). + max_retries = 5 + for attempt in range(max_retries): + try: + req = urllib.request.Request( + f"{platform_url}/workspaces/{workspace_id}/a2a", + data=payload, + headers=headers, + ) + with urllib.request.urlopen(req, timeout=600) as resp: + resp.read() + print(f"Initial prompt: completed (status={resp.status})", flush=True) + break + except Exception as e: + if attempt < max_retries - 1: + delay = 2 ** attempt # 1, 2, 4, 8, 16 seconds + print(f"Initial prompt: attempt {attempt + 1} failed ({e}), retrying in {delay}s...", flush=True) + _time.sleep(delay) + else: + print(f"Initial prompt: failed after {max_retries} attempts — {e}", flush=True) + return + + # Marker was already written up front (#71). Nothing to do here. + + print("Initial prompt: sending via platform proxy...", flush=True) + loop = asyncio.get_event_loop() + loop.run_in_executor(None, _do_send_sync) + + initial_prompt_task = asyncio.create_task(_send_initial_prompt()) + + # 10c. Idle loop — reflection-on-completion / backlog-pull pattern. + # Fires config.idle_prompt every config.idle_interval_seconds while the + # workspace has no active task. This turns every role from "waits for cron" + # into "self-wakes when idle" — the Hermes/Letta shape from today's + # multi-framework survey (see docs/ecosystem-watch.md). Cost collapses to + # event-driven in practice: the idle check is local (no LLM call, just + # heartbeat.active_tasks==0), and the prompt only fires when there's + # actually nothing to do. Gated on idle_prompt being non-empty so existing + # workspaces upgrade opt-in — set idle_prompt in org.yaml defaults or + # per-workspace to enable. + idle_loop_task = None + if config.idle_prompt: + # Idle-fire HTTP timeout. Kept tight relative to the fire cadence so a + # hung platform doesn't accumulate dangling requests — a fire that + # takes longer than the idle interval itself is almost certainly stuck. + IDLE_FIRE_TIMEOUT_SECONDS = max(60, min(300, config.idle_interval_seconds)) + # Initial settle delay — never longer than 60s so cold-start races + # don't stall the first fire, and never shorter than the configured + # interval (short intervals shouldn't fire instantly on boot either). + IDLE_INITIAL_SETTLE_SECONDS = min(config.idle_interval_seconds, 60) + + async def _run_idle_loop(): + """Self-sends config.idle_prompt periodically when the workspace is idle.""" + await asyncio.sleep(IDLE_INITIAL_SETTLE_SECONDS) + + import json as _json + from urllib import request as _urlreq, error as _urlerr + + while True: + try: + await asyncio.sleep(config.idle_interval_seconds) + except asyncio.CancelledError: + return + + # Local idle check — no platform API call, no LLM call. + # heartbeat.active_tasks == 0 means no in-flight work. + if heartbeat.active_tasks > 0: + continue + + # Self-post the idle prompt via the platform A2A proxy (same + # path as initial_prompt). The agent's own concurrency control + # rejects if the workspace becomes busy between this check and + # the post — that's the expected safety valve. + payload = _json.dumps({ + "method": "message/send", + "params": { + "message": { + "role": "user", + "messageId": f"idle-{_uuid.uuid4().hex[:8]}", + "parts": [{"kind": "text", "text": config.idle_prompt}], + }, + }, + }).encode() + + def _post_sync(): + # Returns (status_code, error_type) so the caller logs the + # actual outcome instead of a bare "post failed" line. + # #220: include auth_headers() on every idle fire. Without + # this, the idle loop 401s in multi-tenant mode. + headers = {"Content-Type": "application/json", **auth_headers()} + try: + req = _urlreq.Request( + f"{platform_url}/workspaces/{workspace_id}/a2a", + data=payload, + headers=headers, + ) + with _urlreq.urlopen(req, timeout=IDLE_FIRE_TIMEOUT_SECONDS) as resp: + resp.read() + return resp.status, None + except _urlerr.HTTPError as e: + return e.code, type(e).__name__ + except _urlerr.URLError as e: + return None, f"URLError: {e.reason}" + except Exception as e: # pragma: no cover — catch-all safety net + return None, type(e).__name__ + + print( + f"Idle loop: firing (active_tasks=0, interval={config.idle_interval_seconds}s, " + f"timeout={IDLE_FIRE_TIMEOUT_SECONDS}s)", + flush=True, + ) + loop_ref = asyncio.get_running_loop() + + def _log_result(future): + try: + status, err = future.result() + if err: + print( + f"Idle loop: post failed — status={status} err={err}", + flush=True, + ) + else: + print(f"Idle loop: post ok status={status}", flush=True) + except Exception as e: # pragma: no cover + print(f"Idle loop: executor callback crashed — {e}", flush=True) + + fut = loop_ref.run_in_executor(None, _post_sync) + fut.add_done_callback(_log_result) + + idle_loop_task = asyncio.create_task(_run_idle_loop()) + + try: + await server.serve() + finally: + # Cancel initial prompt if still running + if initial_prompt_task and not initial_prompt_task.done(): + initial_prompt_task.cancel() + # Cancel idle loop if running + if idle_loop_task and not idle_loop_task.done(): + idle_loop_task.cancel() + # Gracefully stop the Temporal worker background task on shutdown + await temporal_wrapper.stop() + + +def main_sync(): # pragma: no cover + """Synchronous entry point for the molecule-runtime console script.""" + asyncio.run(main()) + + +if __name__ == "__main__": # pragma: no cover + main_sync() diff --git a/molecule_runtime/molecule_ai_status.py b/molecule_runtime/molecule_ai_status.py new file mode 100644 index 0000000..27a03b9 --- /dev/null +++ b/molecule_runtime/molecule_ai_status.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +"""Update workspace task status on the canvas. + +Usage (from any script, cron job, or shell inside the container): + + # Set current task (shows on canvas card) + python3 /app/molecule_ai_status.py "Running weekly SEO audit..." + + # Clear task (removes banner from canvas) + python3 /app/molecule_ai_status.py "" + + # Or use the shell alias: + molecule-monorepo-status "Analyzing competitor data..." + molecule-monorepo-status "" + +The status appears as an amber banner on the workspace card in the canvas, +visible to the project owner in real-time. +""" + +import os +import sys + +import httpx + +WORKSPACE_ID = os.environ.get("WORKSPACE_ID", "") +PLATFORM_URL = os.environ.get("PLATFORM_URL", "http://platform:8080") + + +def set_status(task: str): + """Push current_task to platform via heartbeat.""" + try: + try: + from platform_auth import auth_headers as _auth + _headers = _auth() + except Exception: + _headers = {} + httpx.post( + f"{PLATFORM_URL}/registry/heartbeat", + json={ + "workspace_id": WORKSPACE_ID, + "current_task": task, + "active_tasks": 1 if task else 0, + "error_rate": 0, + "sample_error": "", + "uptime_seconds": 0, + }, + headers=_headers, + timeout=5.0, + ) + if task: + # Also log as activity for traceability + httpx.post( + f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/activity", + json={ + "activity_type": "task_update", + "source_id": WORKSPACE_ID, + "summary": task, + "status": "ok", + }, + timeout=5.0, + ) + except Exception as e: + print(f"molecule-monorepo-status: failed to update: {e}", file=sys.stderr) + + +if __name__ == "__main__": # pragma: no cover + if len(sys.argv) < 2: + print("Usage: molecule-monorepo-status 'task description'") + print(" molecule-monorepo-status '' # clear") + sys.exit(1) + + set_status(sys.argv[1]) diff --git a/molecule_runtime/platform_auth.py b/molecule_runtime/platform_auth.py new file mode 100644 index 0000000..d4a1e18 --- /dev/null +++ b/molecule_runtime/platform_auth.py @@ -0,0 +1,105 @@ +"""Workspace auth-token store (Phase 30.1). + +Single source of truth for this workspace's authentication token. The +token is issued by the platform on the first successful +``POST /registry/register`` call and travels with every subsequent +heartbeat / update-card / (later) secrets-pull / A2A request. + +The token is persisted to ``/.auth_token`` so it survives +restarts — we only expect to receive it once from the platform, since +``/registry/register`` no-ops token issuance for workspaces that already +have one on file. + +Storage: + ${CONFIGS_DIR}/.auth_token # 0600, one line, no trailing newline + +Callers interact with three functions: + :func:`get_token` — returns the cached token or None + :func:`save_token` — persists a freshly-issued token + :func:`auth_headers`— builds the Authorization header dict for httpx +""" +from __future__ import annotations + +import logging +import os +from pathlib import Path + +logger = logging.getLogger(__name__) + +# In-process cache so we don't hit disk on every heartbeat. The heartbeat +# loop fires on a short interval and reading a tiny file 10x per minute +# is wasteful. The file is the durable copy; this var is the hot path. +_cached_token: str | None = None + + +def _token_file() -> Path: + """Path to the on-disk token file. Respects CONFIGS_DIR, falls back + to /configs for the default container layout.""" + return Path(os.environ.get("CONFIGS_DIR", "/configs")) / ".auth_token" + + +def get_token() -> str | None: + """Return the cached token, reading it from disk on first call.""" + global _cached_token + if _cached_token is not None: + return _cached_token + path = _token_file() + if not path.exists(): + return None + try: + tok = path.read_text().strip() + except OSError as exc: + logger.warning("platform_auth: failed to read %s: %s", path, exc) + return None + if not tok: + return None + _cached_token = tok + return tok + + +def save_token(token: str) -> None: + """Persist a newly-issued token. Creates the file with 0600 mode atomically. + + Uses ``os.open(O_CREAT, 0o600)`` so the file is never world-readable, + even transiently. The previous ``write_text()`` + ``chmod()`` approach + had a TOCTOU window where a concurrent reader could access the token + between the two syscalls (M4 — flagged in security audit cycle 10). + + Idempotent — if an identical token is already on disk we skip the + write so we don't churn the file's mtime or trigger spurious + filesystem watchers.""" + global _cached_token + token = token.strip() + if not token: + raise ValueError("platform_auth: refusing to save empty token") + if get_token() == token: + return + path = _token_file() + path.parent.mkdir(parents=True, exist_ok=True) + # O_CREAT | O_WRONLY | O_TRUNC with mode=0o600 atomically creates (or + # truncates) the file with restricted permissions in a single syscall, + # eliminating the TOCTOU window. + fd = os.open(str(path), os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600) + try: + os.write(fd, token.encode()) + finally: + os.close(fd) + _cached_token = token + + +def auth_headers() -> dict[str, str]: + """Return a header dict to merge into httpx calls. Empty if no token + is available yet — callers send the request as-is and the platform's + heartbeat handler grandfathers pre-token workspaces through until + their next /registry/register issues one.""" + tok = get_token() + if not tok: + return {} + return {"Authorization": f"Bearer {tok}"} + + +def clear_cache() -> None: + """Reset the in-memory cache. Used by tests that write fresh token + files between cases.""" + global _cached_token + _cached_token = None diff --git a/molecule_runtime/plugins.py b/molecule_runtime/plugins.py new file mode 100644 index 0000000..8fd7f33 --- /dev/null +++ b/molecule_runtime/plugins.py @@ -0,0 +1,154 @@ +"""Plugin system for loading per-workspace and shared plugins. + +Plugins provide skills, rules, and prompt fragments to agent workspaces. +Each plugin is a directory containing: + - plugin.yaml — manifest (name, version, description, skills, rules) + - rules/*.md — always-on guidelines injected into every prompt + - skills/ — skill directories with SKILL.md + tools/*.py + - *.md — prompt fragments (excluding README, CHANGELOG, etc.) + +Loading priority: + 1. Per-workspace: /configs/plugins// (installed via API) + 2. Shared fallback: /plugins// (legacy bind mount) + Deduplication by name — per-workspace wins. +""" + +import logging +import os +from pathlib import Path +from dataclasses import dataclass, field + +import yaml + +logger = logging.getLogger(__name__) + +WORKSPACE_PLUGINS_DIR = "/configs/plugins" +SHARED_PLUGINS_DIR = os.environ.get("PLUGINS_DIR", "/plugins") + + +@dataclass +class PluginManifest: + name: str = "" + version: str = "0.0.0" + description: str = "" + author: str = "" + tags: list[str] = field(default_factory=list) + skills: list[str] = field(default_factory=list) + rules: list[str] = field(default_factory=list) + prompt_fragments: list[str] = field(default_factory=list) + adapters: dict = field(default_factory=dict) + runtimes: list[str] = field(default_factory=list) # declared supported runtimes + + +@dataclass +class Plugin: + name: str + path: str + manifest: PluginManifest = field(default_factory=PluginManifest) + rules: list[str] = field(default_factory=list) # rule content strings + prompt_fragments: list[str] = field(default_factory=list) # extra prompt content + skills_dir: str = "" # path to skills/ inside plugin + + +@dataclass +class LoadedPlugins: + rules: list[str] = field(default_factory=list) + prompt_fragments: list[str] = field(default_factory=list) + skill_dirs: list[str] = field(default_factory=list) # dirs to scan for extra skills + plugin_names: list[str] = field(default_factory=list) + plugins: list[Plugin] = field(default_factory=list) + + +def load_plugin_manifest(plugin_path: str) -> PluginManifest: + """Parse plugin.yaml from a plugin directory. Returns empty manifest if not found.""" + manifest_file = os.path.join(plugin_path, "plugin.yaml") + if not os.path.isfile(manifest_file): + return PluginManifest(name=os.path.basename(plugin_path)) + try: + with open(manifest_file) as f: + raw = yaml.safe_load(f) or {} + return PluginManifest( + name=raw.get("name", os.path.basename(plugin_path)), + version=raw.get("version", "0.0.0"), + description=raw.get("description", ""), + author=raw.get("author", ""), + tags=raw.get("tags", []), + skills=raw.get("skills", []), + rules=raw.get("rules", []), + prompt_fragments=raw.get("prompt_fragments", []), + adapters=raw.get("adapters", {}), + runtimes=raw.get("runtimes", []), + ) + except Exception as e: + logger.warning("Failed to parse plugin manifest %s: %s", manifest_file, e) + return PluginManifest(name=os.path.basename(plugin_path)) + + +def _load_single_plugin(plugin_path: str) -> Plugin: + """Load a single plugin from a directory.""" + name = os.path.basename(plugin_path) + manifest = load_plugin_manifest(plugin_path) + plugin = Plugin(name=name, path=plugin_path, manifest=manifest) + + # Load rules + rules_dir = os.path.join(plugin_path, "rules") + if os.path.isdir(rules_dir): + for rule_file in sorted(os.listdir(rules_dir)): + if rule_file.endswith(".md"): + content = Path(os.path.join(rules_dir, rule_file)).read_text().strip() + if content: + plugin.rules.append(content) + logger.info("Plugin %s: loaded rule %s", name, rule_file) + + # Load prompt fragments (any .md in root of plugin) + skip = {"readme.md", "changelog.md", "license.md", "contributing.md", "plugin.yaml"} + for f in sorted(os.listdir(plugin_path)): + if f.endswith(".md") and f.lower() not in skip and os.path.isfile(os.path.join(plugin_path, f)): + content = Path(os.path.join(plugin_path, f)).read_text().strip() + if content: + plugin.prompt_fragments.append(content) + logger.info("Plugin %s: loaded prompt fragment %s", name, f) + + # Register skills directory + skills_dir = os.path.join(plugin_path, "skills") + if os.path.isdir(skills_dir): + plugin.skills_dir = skills_dir + skill_count = len([d for d in os.listdir(skills_dir) if os.path.isdir(os.path.join(skills_dir, d))]) + logger.info("Plugin %s: found %d skills", name, skill_count) + + return plugin + + +def load_plugins( + workspace_plugins_dir: str | None = None, + shared_plugins_dir: str | None = None, +) -> LoadedPlugins: + """Scan per-workspace plugins first, then shared plugins. Deduplicate by name.""" + ws_dir = workspace_plugins_dir or WORKSPACE_PLUGINS_DIR + shared_dir = shared_plugins_dir or SHARED_PLUGINS_DIR + result = LoadedPlugins() + seen_names: set[str] = set() + + # Scan both dirs: per-workspace first (higher priority) + for base_dir in [ws_dir, shared_dir]: + if not os.path.isdir(base_dir): + continue + for entry in sorted(os.listdir(base_dir)): + plugin_path = os.path.join(base_dir, entry) + if not os.path.isdir(plugin_path) or entry in seen_names: + continue + + plugin = _load_single_plugin(plugin_path) + seen_names.add(entry) + + result.rules.extend(plugin.rules) + result.prompt_fragments.extend(plugin.prompt_fragments) + if plugin.skills_dir: + result.skill_dirs.append(plugin.skills_dir) + result.plugin_names.append(entry) + result.plugins.append(plugin) + + if result.plugin_names: + logger.info("Loaded %d plugins: %s", len(result.plugin_names), ", ".join(result.plugin_names)) + + return result diff --git a/molecule_runtime/plugins_registry/__init__.py b/molecule_runtime/plugins_registry/__init__.py new file mode 100644 index 0000000..ef1c5b3 --- /dev/null +++ b/molecule_runtime/plugins_registry/__init__.py @@ -0,0 +1,135 @@ +"""Per-runtime plugin adaptor registry with hybrid resolution. + +Resolution order for ``(plugin_name, runtime)``: + + 1. Platform registry → ``workspace-template/plugins_registry//.py`` + 2. Plugin-shipped → ``/adapters/.py`` + 3. Raw filesystem → :class:`RawDropAdaptor` (warns, drops files only) + +Path #1 wins so the platform can override or hot-fix a third-party adaptor +without forking the upstream plugin repo. Path #2 is the SDK contract: a +single GitHub repo ships its own adaptors and is installable on day one. +Path #3 is the escape hatch — power users can still bring unsupported +plugins onto a workspace, they just don't get tools wired up. + +A registered adaptor module must expose either: + - ``Adaptor`` class implementing :class:`PluginAdaptor`, OR + - ``def get_adaptor(plugin_name, runtime) -> PluginAdaptor`` +""" + +from __future__ import annotations + +import importlib.util +import logging +from pathlib import Path +from typing import Optional + +from .protocol import InstallContext, InstallResult, PluginAdaptor +from .raw_drop import RawDropAdaptor + +logger = logging.getLogger(__name__) + +# Where the platform-curated registry lives. Resolved relative to this file +# so it works regardless of CWD or how workspace-template is installed. +_REGISTRY_ROOT = Path(__file__).parent + +__all__ = [ + "InstallContext", + "InstallResult", + "PluginAdaptor", + "RawDropAdaptor", + "resolve", + "AdaptorSource", +] + + +class AdaptorSource: + REGISTRY = "registry" + PLUGIN = "plugin" + RAW_DROP = "raw_drop" + + +def _load_module_from_path(module_name: str, path: Path): + """Import a Python file by absolute path. Returns the module or None on failure.""" + spec = importlib.util.spec_from_file_location(module_name, path) + if spec is None or spec.loader is None: + return None + module = importlib.util.module_from_spec(spec) + try: + spec.loader.exec_module(module) + except Exception as exc: + logger.warning("Failed to load adaptor module %s: %s", path, exc) + return None + return module + + +def _instantiate(module, plugin_name: str, runtime: str) -> Optional[PluginAdaptor]: + """Build a PluginAdaptor from an adaptor module. + + Two conventions are supported so plugin authors can pick whichever fits: + a class named ``Adaptor`` (zero-arg constructor or ``(plugin_name, runtime)``), + or a factory function ``get_adaptor(plugin_name, runtime)``. + """ + factory = getattr(module, "get_adaptor", None) + if callable(factory): + try: + return factory(plugin_name, runtime) + except Exception as exc: + logger.warning("get_adaptor() failed for %s/%s: %s", plugin_name, runtime, exc) + return None + + cls = getattr(module, "Adaptor", None) + if cls is None: + return None + try: + try: + return cls(plugin_name, runtime) + except TypeError: + return cls() + except Exception as exc: + logger.warning("Adaptor() construction failed for %s/%s: %s", plugin_name, runtime, exc) + return None + + +def _resolve_registry(plugin_name: str, runtime: str) -> Optional[PluginAdaptor]: + path = _REGISTRY_ROOT / plugin_name / f"{runtime}.py" + if not path.is_file(): + return None + module = _load_module_from_path(f"plugins_registry.{plugin_name}.{runtime}", path) + if module is None: + return None + return _instantiate(module, plugin_name, runtime) + + +def _resolve_plugin_shipped(plugin_root: Path, plugin_name: str, runtime: str) -> Optional[PluginAdaptor]: + path = plugin_root / "adapters" / f"{runtime}.py" + if not path.is_file(): + return None + module = _load_module_from_path(f"_plugin_adaptor.{plugin_name}.{runtime}", path) + if module is None: + return None + return _instantiate(module, plugin_name, runtime) + + +def resolve( + plugin_name: str, + runtime: str, + plugin_root: Path, +) -> tuple[PluginAdaptor, str]: + """Resolve the adaptor for ``(plugin_name, runtime)``. + + Returns ``(adaptor, source)`` where ``source`` is one of + :class:`AdaptorSource` (``"registry"``, ``"plugin"``, ``"raw_drop"``). + Always returns an adaptor — the raw-drop fallback ensures plugin installs + never hard-fail on missing adaptors; instead the warning is surfaced via + :class:`InstallResult.warnings`. + """ + adaptor = _resolve_registry(plugin_name, runtime) + if adaptor is not None: + return adaptor, AdaptorSource.REGISTRY + + adaptor = _resolve_plugin_shipped(plugin_root, plugin_name, runtime) + if adaptor is not None: + return adaptor, AdaptorSource.PLUGIN + + return RawDropAdaptor(plugin_name, runtime), AdaptorSource.RAW_DROP diff --git a/molecule_runtime/plugins_registry/builtins.py b/molecule_runtime/plugins_registry/builtins.py new file mode 100644 index 0000000..634d5fb --- /dev/null +++ b/molecule_runtime/plugins_registry/builtins.py @@ -0,0 +1,327 @@ +"""Built-in plugin adaptors — one per agent shape. + +The adapter layer is our extensibility surface. Each agent "shape" (form +of installable capability) gets its own named sub-type adapter. A plugin +picks which sub-type to use by importing it as ``Adaptor`` in its +per-runtime file: + +.. code-block:: python + + # plugins//adapters/claude_code.py + from plugins_registry.builtins import AgentskillsAdaptor as Adaptor + +Shape taxonomy (one class per shape; add more as the ecosystem evolves): + +* :class:`AgentskillsAdaptor` — skills in the `agentskills.io + `_ format (``SKILL.md`` + ``scripts/`` + + ``references/`` + ``assets/``), plus Molecule AI's optional ``rules/`` and + root-level prompt fragments at the plugin level. Works on every runtime + we support (the spec's filesystem layout makes activation trivial on + Claude Code, our adapter code does the equivalent on DeepAgents / + LangGraph / etc.). **This is the default and covers the common case.** + +Planned as the ecosystem matures (none are implemented yet — rule of +three: promote a class here only after 3+ plugins ship the same custom +shape via their own ``adapters/.py``): + +* ``MCPServerAdaptor`` — install a plugin as an MCP server *(TODO)* +* ``DeepAgentsSubagentAdaptor`` — register a DeepAgents sub-agent + (runtime-locked to deepagents) *(TODO)* +* ``LangGraphSubgraphAdaptor`` — install a LangGraph sub-graph *(TODO)* +* ``RAGPipelineAdaptor`` — wire a retriever + index *(TODO)* +* ``SwarmAdaptor`` — bind an OpenAI-swarm / AutoGen-swarm *(TODO)* +* ``WebhookAdaptor`` — register an event handler *(TODO)* + +Plugins whose shape doesn't match any built-in ship their own adapter +class in ``plugins//adapters/.py`` — full Python, no +constraint. When 3+ plugins ship the same custom pattern, we promote +the class into this module. +""" + +from __future__ import annotations + +import json +import os +import shutil +import subprocess +from pathlib import Path + +from .protocol import SKILLS_SUBDIR, InstallContext, InstallResult + +# Files at the plugin root that are never treated as prompt fragments, +# even if they're markdown. Module-level so tests and other adapters can +# import the set rather than re-declaring it. +SKIP_ROOT_MD = frozenset({"readme.md", "changelog.md", "license.md", "contributing.md"}) + + +def _read_md_files(directory: Path) -> list[tuple[str, str]]: + """Return [(filename, content)] for all *.md files in directory, sorted.""" + if not directory.is_dir(): + return [] + out: list[tuple[str, str]] = [] + for p in sorted(directory.iterdir()): + if p.is_file() and p.suffix == ".md": + out.append((p.name, p.read_text().strip())) + return out + + +class AgentskillsAdaptor: + """Sub-type adaptor for `agentskills.io `_-format skills. + + This is the default adapter for the "skills + rules" shape — the most + common pattern. A plugin using this adapter ships: + + * ``skills//SKILL.md`` (+ optional ``scripts/``, ``references/``, + ``assets/``) — each skill is a spec-compliant agentskills unit, + portable to Claude Code, Cursor, Codex, and ~35 other skill-compatible + tools without modification. + * ``rules/*.md`` (optional, Molecule AI extension) — always-on prose that + gets appended to the runtime's memory file (CLAUDE.md). + * Root-level ``*.md`` (optional) — prompt fragments, also appended to + memory. + + On ``install()``: + 1. Rules → append to ``/configs/``, wrapped in a + ``# Plugin: `` marker for idempotent re-install. + 2. Prompt fragments (``*.md`` at plugin root, excl. README/CHANGELOG/etc.) + → same treatment. + 3. Skills (``skills//``) → copied to + ``/configs/skills//``. Runtimes with native agentskills + activation (Claude Code) pick them up automatically; other runtimes' + loaders scan the same path. + + Uninstall reverses the file copies and strips the rule/fragment block by + marker (best-effort — if the user edited CLAUDE.md manually, only the + marker line itself is removed). + + For shapes other than agentskills (MCP server, DeepAgents sub-agent, + LangGraph sub-graph, RAG pipeline, swarm, webhook handler, etc.), see + the module docstring for the planned sibling adapters, or ship a custom + adapter class in the plugin's ``adapters/.py``. + """ + + def __init__(self, plugin_name: str, runtime: str) -> None: + self.plugin_name = plugin_name + self.runtime = runtime + + # ------------------------------------------------------------------ + # install + # ------------------------------------------------------------------ + + async def install(self, ctx: InstallContext) -> InstallResult: + result = InstallResult( + plugin_name=self.plugin_name, + runtime=self.runtime, + source="plugin", # overridden by registry caller if source==registry + ) + + # 1. Rules — append to memory file. + rules = _read_md_files(ctx.plugin_root / "rules") + # 2. Prompt fragments — any *.md at plugin root except skip list. + root_fragments: list[tuple[str, str]] = [] + if ctx.plugin_root.is_dir(): + for p in sorted(ctx.plugin_root.iterdir()): + if p.is_file() and p.suffix == ".md" and p.name.lower() not in SKIP_ROOT_MD: + content = p.read_text().strip() + if content: + root_fragments.append((p.name, content)) + + memory_blocks: list[str] = [] + for filename, content in rules: + memory_blocks.append(f"# Plugin: {self.plugin_name} / rule: {filename}\n\n{content}") + for filename, content in root_fragments: + memory_blocks.append(f"# Plugin: {self.plugin_name} / fragment: {filename}\n\n{content}") + + if memory_blocks: + joined = "\n\n".join(memory_blocks) + ctx.append_to_memory(ctx.memory_filename, joined) + ctx.logger.info( + "%s: injected %d rule+fragment block(s) into %s", + self.plugin_name, len(memory_blocks), ctx.memory_filename, + ) + + # 3. Skills — copy each skill dir to /configs/skills/. + src_skills_dir = ctx.plugin_root / "skills" + if src_skills_dir.is_dir(): + dst_skills_root = ctx.configs_dir / SKILLS_SUBDIR + dst_skills_root.mkdir(parents=True, exist_ok=True) + copied = 0 + for entry in sorted(src_skills_dir.iterdir()): + if not entry.is_dir(): + continue + dst = dst_skills_root / entry.name + if dst.exists(): + ctx.logger.debug("%s: skill %s already present, skipping", self.plugin_name, entry.name) + continue + shutil.copytree(entry, dst) + copied += 1 + for p in dst.rglob("*"): + if p.is_file(): + result.files_written.append(str(p.relative_to(ctx.configs_dir))) + if copied: + ctx.logger.info("%s: copied %d skill dir(s) to %s", self.plugin_name, copied, dst_skills_root) + + # 4. Setup script — run setup.sh if present (for npm/pip dependencies). + # Mirrors sdk/python/molecule_plugin/builtins.py — must stay in sync + # (drift guard: tests/test_plugins_builtins_drift.py). + setup_script = ctx.plugin_root / "setup.sh" + if setup_script.is_file(): + ctx.logger.info("%s: running setup.sh", self.plugin_name) + try: + proc = subprocess.run( + ["bash", str(setup_script)], + capture_output=True, text=True, timeout=120, + cwd=str(ctx.plugin_root), + env={**os.environ, "CONFIGS_DIR": str(ctx.configs_dir)}, + ) + if proc.returncode == 0: + ctx.logger.info("%s: setup.sh completed successfully", self.plugin_name) + else: + result.warnings.append(f"setup.sh exited {proc.returncode}: {proc.stderr[:200]}") + ctx.logger.warning("%s: setup.sh failed: %s", self.plugin_name, proc.stderr[:200]) + except subprocess.TimeoutExpired: + result.warnings.append("setup.sh timed out (120s)") + ctx.logger.warning("%s: setup.sh timed out", self.plugin_name) + + # 5. Hooks — copy hooks/* into /.claude/hooks/ (Claude Code- + # style harness hooks). No-op when the plugin doesn't ship any. + # 6. Commands — copy commands/*.md into /.claude/commands/. + # 7. settings-fragment.json — merge into /.claude/settings.json, + # rewriting ${CLAUDE_DIR} to the absolute install path. Existing + # user hooks are preserved (deep-merge by event). + _install_claude_layer(ctx, result, self.plugin_name) + + return result + + # ------------------------------------------------------------------ + # uninstall + # ------------------------------------------------------------------ + + async def uninstall(self, ctx: InstallContext) -> None: + # Remove copied skill dirs. + src_skills_dir = ctx.plugin_root / "skills" + if src_skills_dir.is_dir(): + for entry in src_skills_dir.iterdir(): + dst = ctx.configs_dir / SKILLS_SUBDIR / entry.name + if dst.exists() and dst.is_dir(): + shutil.rmtree(dst) + ctx.logger.info("%s: removed %s", self.plugin_name, dst) + + # Best-effort strip of our markers from CLAUDE.md. Users can always + # edit manually; we only guarantee the injected block's first line + # is removed so re-install re-adds cleanly. + memory_path = ctx.configs_dir / ctx.memory_filename + if not memory_path.exists(): + return + text = memory_path.read_text() + prefix = f"# Plugin: {self.plugin_name} / " + lines = text.splitlines(keepends=True) + kept = [line for line in lines if not line.startswith(prefix)] + if len(kept) != len(lines): + memory_path.write_text("".join(kept)) + ctx.logger.info("%s: stripped markers from %s", self.plugin_name, ctx.memory_filename) + + + + +# ---------------------------------------------------------------------- +# Claude Code layer — hooks, slash commands, settings.json fragments. +# Promoted from the molecule-guardrails plugin so any plugin can ship +# these by dropping the right files; no custom adapter needed. +# ---------------------------------------------------------------------- + +def _install_claude_layer(ctx: InstallContext, result: InstallResult, plugin_name: str) -> None: + claude_dir = ctx.configs_dir / ".claude" + claude_dir.mkdir(parents=True, exist_ok=True) + + _copy_dir_files( + ctx.plugin_root / "hooks", + claude_dir / "hooks", + result, + executable_suffix=".sh", + ) + _copy_dir_files( + ctx.plugin_root / "commands", + claude_dir / "commands", + result, + only_suffix=".md", + ) + _merge_settings_fragment(ctx, claude_dir, result, plugin_name) + + +def _copy_dir_files( + src: Path, + dst: Path, + result: InstallResult, + executable_suffix: str | None = None, + only_suffix: str | None = None, +) -> None: + if not src.is_dir(): + return + dst.mkdir(parents=True, exist_ok=True) + for f in src.iterdir(): + if not f.is_file(): + continue + if only_suffix and f.suffix != only_suffix: + # When copying hooks, allow .py companion files alongside .sh + if not (executable_suffix and f.suffix == ".py"): + continue + target = dst / f.name + shutil.copy2(f, target) + if executable_suffix and f.suffix == executable_suffix: + target.chmod(0o755) + result.files_written.append(str(target.relative_to(target.parents[2]))) + + +def _merge_settings_fragment( + ctx: InstallContext, + claude_dir: Path, + result: InstallResult, + plugin_name: str, +) -> None: + fragment_path = ctx.plugin_root / "settings-fragment.json" + if not fragment_path.is_file(): + return + try: + fragment = json.loads(fragment_path.read_text()) + except Exception as e: + result.warnings.append(f"settings-fragment.json invalid: {e}") + return + + settings_path = claude_dir / "settings.json" + if settings_path.is_file(): + try: + existing = json.loads(settings_path.read_text()) + except Exception: + existing = {} + else: + existing = {} + + rewritten = _rewrite_hook_paths(fragment, claude_dir) + merged = _deep_merge_hooks(existing, rewritten) + settings_path.write_text(json.dumps(merged, indent=2) + "\n") + result.files_written.append(str(settings_path.relative_to(ctx.configs_dir))) + ctx.logger.info("%s: merged hook config into %s", plugin_name, settings_path) + + +def _rewrite_hook_paths(fragment: dict, claude_dir: Path) -> dict: + out = json.loads(json.dumps(fragment)) # deep copy via roundtrip + for handlers in out.get("hooks", {}).values(): + for handler in handlers: + for h in handler.get("hooks", []): + cmd = h.get("command", "") + h["command"] = cmd.replace("${CLAUDE_DIR}", str(claude_dir)) + return out + + +def _deep_merge_hooks(existing: dict, fragment: dict) -> dict: + out = dict(existing) + out.setdefault("hooks", {}) + for event, handlers in fragment.get("hooks", {}).items(): + out["hooks"].setdefault(event, []) + out["hooks"][event].extend(handlers) + for key, val in fragment.items(): + if key == "hooks": + continue + out.setdefault(key, val) + return out diff --git a/molecule_runtime/plugins_registry/protocol.py b/molecule_runtime/plugins_registry/protocol.py new file mode 100644 index 0000000..3b60a39 --- /dev/null +++ b/molecule_runtime/plugins_registry/protocol.py @@ -0,0 +1,104 @@ +"""Protocol + context types for per-runtime plugin adaptors. + +Each plugin ships (or has registered for it) a per-runtime adaptor implementing +``PluginAdaptor``. The platform resolves the adaptor for ``(plugin_name, runtime)`` +via :func:`plugins_registry.resolve` and calls ``install(ctx)`` to wire the +plugin into a workspace. + +The :class:`InstallContext` deliberately gives adaptors ONLY the hooks they +need (``register_tool``, ``register_subagent``, ``append_to_memory``) — it +does not leak runtime internals. This keeps adaptors thin and lets the +workspace runtime adapter (claude_code, deepagents, …) own its own state. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Callable, Protocol, runtime_checkable + + +# Default filename for the runtime's long-lived memory file. Claude Code +# and DeepAgents both read CLAUDE.md natively; other runtimes override via +# BaseAdapter.memory_filename() and that value flows through +# InstallContext.memory_filename so adaptors don't hardcode the name. +DEFAULT_MEMORY_FILENAME = "CLAUDE.md" + +# Subdirectory under /configs where skills get installed. +SKILLS_SUBDIR = "skills" + + +@dataclass +class InstallContext: + """Hooks + state passed to every PluginAdaptor.install() call. + + Adaptors should treat unknown verbs as no-ops on runtimes that don't + support them (e.g. ``register_subagent`` is a no-op on Claude Code). + """ + + configs_dir: Path + """Workspace's /configs directory (where CLAUDE.md, plugins/, skills/ live).""" + + workspace_id: str + """Workspace UUID — useful for per-workspace state or logging.""" + + runtime: str + """Runtime identifier (``claude_code``, ``deepagents``, …).""" + + plugin_root: Path + """Path to the plugin's directory (where plugin.yaml + content lives).""" + + memory_filename: str = DEFAULT_MEMORY_FILENAME + """Runtime's long-lived memory file (populated from + :meth:`BaseAdapter.memory_filename`). Adaptors pass this to + :attr:`append_to_memory` instead of hardcoding a filename so runtimes + with non-standard memory files (e.g. ``AGENTS.md``) work unchanged.""" + + register_tool: Callable[[str, Callable[..., Any]], None] = field( + default=lambda name, fn: None + ) + """Register a callable as a runtime tool. No-op on runtimes without a + dynamic tool registry — those runtimes pick tools up at startup via + filesystem scan instead.""" + + register_subagent: Callable[[str, dict[str, Any]], None] = field( + default=lambda name, spec: None + ) + """Register a sub-agent specification (DeepAgents-only). No-op elsewhere.""" + + append_to_memory: Callable[[str, str], None] = field( + default=lambda filename, content: None + ) + """Append text to a runtime memory file (e.g. CLAUDE.md). The default + no-op lets adaptors run in test harnesses that don't have a real + workspace filesystem.""" + + logger: logging.Logger = field(default_factory=lambda: logging.getLogger(__name__)) + + +@dataclass +class InstallResult: + """Outcome of a PluginAdaptor.install() call.""" + + plugin_name: str + runtime: str + source: str # "registry" | "plugin" | "raw_drop" + files_written: list[str] = field(default_factory=list) + tools_registered: list[str] = field(default_factory=list) + subagents_registered: list[str] = field(default_factory=list) + warnings: list[str] = field(default_factory=list) + + +@runtime_checkable +class PluginAdaptor(Protocol): + """Contract every per-runtime adaptor must implement.""" + + plugin_name: str + runtime: str + + async def install(self, ctx: InstallContext) -> InstallResult: + ... + + async def uninstall(self, ctx: InstallContext) -> None: + ... diff --git a/molecule_runtime/plugins_registry/raw_drop.py b/molecule_runtime/plugins_registry/raw_drop.py new file mode 100644 index 0000000..6c979c7 --- /dev/null +++ b/molecule_runtime/plugins_registry/raw_drop.py @@ -0,0 +1,71 @@ +"""Fallback adaptor used when no per-runtime adaptor is found. + +Behaviour: copy the plugin's content into ``/configs/plugins//`` so a +user can still inspect or hand-wire it, then surface a warning that no tools +or sub-agents were registered. + +This preserves the "power users can drop raw files" escape hatch without +silently breaking — the warning is propagated up via :class:`InstallResult` +so the API can surface it to the user. +""" + +from __future__ import annotations + +import shutil + +from .protocol import InstallContext, InstallResult, PluginAdaptor + + +class RawDropAdaptor: + """Filesystem-only fallback. Implements :class:`PluginAdaptor`.""" + + def __init__(self, plugin_name: str, runtime: str) -> None: + self.plugin_name = plugin_name + self.runtime = runtime + + async def install(self, ctx: InstallContext) -> InstallResult: + dst = ctx.configs_dir / "plugins" / self.plugin_name + files_written: list[str] = [] + + if ctx.plugin_root.exists() and ctx.plugin_root.is_dir(): + dst.parent.mkdir(parents=True, exist_ok=True) + if dst.exists(): + # Idempotent — leave existing copy alone. + ctx.logger.info( + "raw_drop: %s already present at %s, skipping copy", + self.plugin_name, dst, + ) + else: + shutil.copytree(ctx.plugin_root, dst) + for p in dst.rglob("*"): + if p.is_file(): + files_written.append(str(p.relative_to(ctx.configs_dir))) + ctx.logger.info( + "raw_drop: copied %s → %s (%d files)", + self.plugin_name, dst, len(files_written), + ) + + warning = ( + f"plugin '{self.plugin_name}' has no adaptor for runtime " + f"'{self.runtime}' — files dropped at /configs/plugins/{self.plugin_name} " + f"but no tools/sub-agents were wired in" + ) + ctx.logger.warning(warning) + + return InstallResult( + plugin_name=self.plugin_name, + runtime=self.runtime, + source="raw_drop", + files_written=files_written, + warnings=[warning], + ) + + async def uninstall(self, ctx: InstallContext) -> None: + dst = ctx.configs_dir / "plugins" / self.plugin_name + if dst.exists(): + shutil.rmtree(dst) + ctx.logger.info("raw_drop: removed %s", dst) + + +# Static check: RawDropAdaptor satisfies PluginAdaptor. +_: PluginAdaptor = RawDropAdaptor("_", "_") diff --git a/molecule_runtime/policies/__init__.py b/molecule_runtime/policies/__init__.py new file mode 100644 index 0000000..cb1d605 --- /dev/null +++ b/molecule_runtime/policies/__init__.py @@ -0,0 +1,11 @@ +"""Policy helpers for routing and execution decisions.""" + +from .namespaces import resolve_awareness_namespace, workspace_awareness_namespace +from .routing import build_team_routing_payload, summarize_children + +__all__ = [ + "build_team_routing_payload", + "resolve_awareness_namespace", + "summarize_children", + "workspace_awareness_namespace", +] diff --git a/molecule_runtime/policies/namespaces.py b/molecule_runtime/policies/namespaces.py new file mode 100644 index 0000000..7d26d6c --- /dev/null +++ b/molecule_runtime/policies/namespaces.py @@ -0,0 +1,18 @@ +"""Canonical namespace helpers for workspace-scoped resources.""" + +from __future__ import annotations + + +def workspace_awareness_namespace(workspace_id: str) -> str: + """Return the default awareness namespace for a workspace.""" + workspace_id = workspace_id.strip() + return f"workspace:{workspace_id}" if workspace_id else "workspace:unknown" + + +def resolve_awareness_namespace( + workspace_id: str, + configured_namespace: str | None = None, +) -> str: + """Return the configured namespace, or the workspace default when unset.""" + namespace = (configured_namespace or "").strip() + return namespace or workspace_awareness_namespace(workspace_id) diff --git a/molecule_runtime/policies/routing.py b/molecule_runtime/policies/routing.py new file mode 100644 index 0000000..908cd2b --- /dev/null +++ b/molecule_runtime/policies/routing.py @@ -0,0 +1,98 @@ +"""Explicit routing policy for coordinator workspaces.""" + +from __future__ import annotations + +import json +from typing import Any + + +def _load_agent_card(agent_card: Any) -> dict[str, Any]: + if isinstance(agent_card, str): + try: + loaded = json.loads(agent_card) + except json.JSONDecodeError: + return {} + return loaded if isinstance(loaded, dict) else {} + return agent_card if isinstance(agent_card, dict) else {} + + +def summarize_children(children: list[dict]) -> list[dict[str, Any]]: + """Return the minimal child summary needed for routing and prompts.""" + members: list[dict[str, Any]] = [] + for child in children: + card = _load_agent_card(child.get("agent_card", {})) + members.append( + { + "id": child.get("id"), + "name": child.get("name"), + "status": child.get("status"), + "skills": [ + s.get("name", s.get("id", "")) + for s in card.get("skills", []) + if isinstance(s, dict) + ], + } + ) + return members + + +def build_team_routing_payload( + children: list[dict], + task: str, + preferred_member_id: str = "", +) -> dict[str, Any]: + """Return the deterministic routing payload for coordinator tasks.""" + if preferred_member_id: + return { + "success": True, + "action": "delegate_to_preferred_member", + "preferred_member_id": preferred_member_id, + "task": task, + } + + members = summarize_children(children) + if not members: + return { + "success": False, + "error": "No team members available. Handle this task yourself.", + "task": task, + "members": [], + } + + return { + "success": True, + "action": "choose_member", + "message": ( + f"You have {len(members)} team members. " + "Choose the best one for this task and call delegate_to_workspace with their ID." + ), + "task": task, + "members": members, + } + + +def decide_team_route( + children: list[dict], + *, + task: str, + preferred_member_id: str = "", +) -> dict[str, Any]: + """Compatibility wrapper for older callers.""" + return build_team_routing_payload( + children, + task=task, + preferred_member_id=preferred_member_id, + ) + + +def build_team_route_decision( + children: list[dict], + task: str, + preferred_member_id: str = "", +) -> dict[str, Any]: + """Compatibility wrapper for tests and older imports.""" + return build_team_routing_payload( + children, + task=task, + preferred_member_id=preferred_member_id, + ) diff --git a/molecule_runtime/preflight.py b/molecule_runtime/preflight.py new file mode 100644 index 0000000..672d3f2 --- /dev/null +++ b/molecule_runtime/preflight.py @@ -0,0 +1,143 @@ +"""Startup preflight checks for workspace runtime configs.""" + +import os +from dataclasses import dataclass, field +from pathlib import Path + +from config import WorkspaceConfig + +SUPPORTED_RUNTIMES = { + "langgraph", + "claude-code", + "codex", + "ollama", + "custom", + "crewai", + "autogen", + "deepagents", + "openclaw", +} + + +@dataclass +class PreflightIssue: + severity: str + title: str + detail: str + fix: str = "" + + +@dataclass +class PreflightReport: + warnings: list[PreflightIssue] = field(default_factory=list) + failures: list[PreflightIssue] = field(default_factory=list) + + @property + def ok(self) -> bool: + return not self.failures + + +def run_preflight(config: WorkspaceConfig, config_path: str) -> PreflightReport: + """Check the workspace config for obvious startup blockers.""" + report = PreflightReport() + config_dir = Path(config_path) + + if config.runtime not in SUPPORTED_RUNTIMES: + report.failures.append( + PreflightIssue( + severity="fail", + title="Runtime", + detail=f"Unsupported runtime '{config.runtime}'", + fix="Choose one of the supported runtimes or install the matching adapter.", + ) + ) + + if not 1 <= int(config.a2a.port) <= 65535: + report.failures.append( + PreflightIssue( + severity="fail", + title="A2A port", + detail=f"Invalid A2A port: {config.a2a.port}", + fix="Set a2a.port to a value between 1 and 65535.", + ) + ) + + # Check required environment variables (e.g. CLAUDE_CODE_OAUTH_TOKEN, OPENAI_API_KEY). + # These are declared per-runtime in config.yaml and injected via the secrets API. + required_env = getattr(config.runtime_config, "required_env", []) or [] + for env_var in required_env: + if not os.environ.get(env_var): + report.failures.append( + PreflightIssue( + severity="fail", + title="Required env", + detail=f"Missing required environment variable: {env_var}", + fix=f"Set {env_var} via the secrets API (global or workspace-level).", + ) + ) + + # Backward compat: if legacy auth_token_file is set, warn but don't block + # if the token is available via required_env or auth_token_env. + token_file = getattr(config.runtime_config, "auth_token_file", "") + if token_file: + token_path = config_dir / token_file + if not token_path.exists(): + token_env = getattr(config.runtime_config, "auth_token_env", "") + env_has_token = bool(token_env and os.environ.get(token_env)) + # Also check if any required_env is set (covers the new path) + if not env_has_token and required_env: + env_has_token = all(os.environ.get(e) for e in required_env) + + if not env_has_token: + report.failures.append( + PreflightIssue( + severity="fail", + title="Auth token", + detail=f"Missing auth token file: {token_file}", + fix="Remove auth_token_file and use required_env + secrets API instead.", + ) + ) + + prompt_files = config.prompt_files or ["system-prompt.md"] + for prompt_file in prompt_files: + prompt_path = config_dir / prompt_file + if not prompt_path.exists(): + report.warnings.append( + PreflightIssue( + severity="warn", + title="Prompt file", + detail=f"Missing prompt file: {prompt_file}", + fix="Add the file or remove it from prompt_files.", + ) + ) + + skills_dir = config_dir / "skills" + for skill_name in config.skills: + skill_path = skills_dir / skill_name / "SKILL.md" + if not skill_path.exists(): + report.warnings.append( + PreflightIssue( + severity="warn", + title="Skill", + detail=f"Missing skill package: {skill_name}", + fix="Restore the skill folder or remove it from config.yaml.", + ) + ) + + return report + + +def render_preflight_report(report: PreflightReport) -> None: + """Print a concise startup report.""" + if not report.warnings and not report.failures: + return + + print("Preflight checks:") + for issue in report.failures: + print(f"[FAIL] {issue.title}: {issue.detail}") + if issue.fix: + print(f" Fix: {issue.fix}") + for issue in report.warnings: + print(f"[WARN] {issue.title}: {issue.detail}") + if issue.fix: + print(f" Fix: {issue.fix}") diff --git a/molecule_runtime/prompt.py b/molecule_runtime/prompt.py new file mode 100644 index 0000000..a9876d4 --- /dev/null +++ b/molecule_runtime/prompt.py @@ -0,0 +1,132 @@ +"""Build the system prompt for the workspace agent.""" + +from pathlib import Path + +from skill_loader.loader import LoadedSkill +from adapters.shared_runtime import build_peer_section + +DEFAULT_MEMORY_SNAPSHOT_FILES = ("MEMORY.md", "USER.md") + + +async def get_peer_capabilities(platform_url: str, workspace_id: str) -> list[dict]: + """Fetch peer workspace capabilities from the platform.""" + try: + import httpx + + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.get( + f"{platform_url}/registry/{workspace_id}/peers", + headers={"X-Workspace-ID": workspace_id}, + ) + if resp.status_code == 200: + return resp.json() + except Exception as e: + print(f"Warning: could not fetch peers: {e}") + return [] + + +def build_system_prompt( + config_path: str, + workspace_id: str, + loaded_skills: list[LoadedSkill], + peers: list[dict], + prompt_files: list[str] | None = None, + plugin_rules: list[str] | None = None, + plugin_prompts: list[str] | None = None, + parent_context: list[dict] | None = None, +) -> str: + """Build the complete system prompt. + + Loads prompt files in order from config_path. If prompt_files is specified + in config.yaml, those files are loaded in order. Otherwise falls back to + system-prompt.md for backwards compatibility. + If MEMORY.md or USER.md exist alongside the config, they are appended as a + frozen memory snapshot without needing to list them explicitly. + + This allows different agent frameworks to use their own file structures: + - OpenClaw: SOUL.md, BOOTSTRAP.md, AGENTS.md, HEARTBEAT.md, TOOLS.md, USER.md + - Claude Code: CLAUDE.md + - Default: system-prompt.md + """ + parts = [] + + # Load prompt files in order + files_to_load = list(prompt_files or []) + if not files_to_load: + # Backwards compatible: fall back to system-prompt.md + files_to_load = ["system-prompt.md"] + + seen_files = set(files_to_load) + + for filename in files_to_load: + file_path = Path(config_path) / filename + if file_path.exists(): + content = file_path.read_text().strip() + if content: + parts.append(content) + else: + print(f"Warning: prompt file not found: {file_path}") + + # Hermes-style memory snapshot files: load automatically when present. + # These stay as thin markdown files so the runtime does not need a new storage layer. + for filename in DEFAULT_MEMORY_SNAPSHOT_FILES: + if filename in seen_files: + continue + file_path = Path(config_path) / filename + if file_path.exists(): + content = file_path.read_text().strip() + if content: + parts.append(content) + + # Inject parent's shared context (if this workspace is a child) + if parent_context: + parts.append("\n## Parent Context\n") + parts.append("The following context was shared by your parent workspace:\n") + for ctx_file in parent_context: + path = ctx_file.get("path", "unknown") + content = ctx_file.get("content", "") + if content.strip(): + parts.append(f"### {path}") + parts.append(content.strip()) + parts.append("") + + # Inject plugin rules (always-on guidelines from ECC, Superpowers, etc.) + if plugin_rules: + parts.append("\n## Platform Rules\n") + for rule in plugin_rules: + parts.append(rule) + parts.append("") + + # Inject plugin prompt fragments + if plugin_prompts: + parts.append("\n## Platform Guidelines\n") + for fragment in plugin_prompts: + parts.append(fragment) + parts.append("") + + # Add skill instructions + if loaded_skills: + parts.append("\n## Your Skills\n") + for skill in loaded_skills: + parts.append(f"### {skill.metadata.name}") + if skill.metadata.description: + parts.append(skill.metadata.description) + parts.append(skill.instructions) + parts.append("") + + # Add peer capabilities with a single shared renderer. + peer_section = build_peer_section(peers) + if peer_section: + parts.append(peer_section) + + # Add delegation failure handling + parts.append(""" +## Handling delegation failures +If a delegation fails: +1. Check if the task is blocking — if not, continue other work +2. Retry transient failures (connection errors) after 30 seconds +3. For persistent failures, report to the caller with context +4. Never silently drop a failed task +""") + + return "\n".join(parts) diff --git a/molecule_runtime/skill_loader/__init__.py b/molecule_runtime/skill_loader/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/molecule_runtime/skill_loader/loader.py b/molecule_runtime/skill_loader/loader.py new file mode 100644 index 0000000..0533356 --- /dev/null +++ b/molecule_runtime/skill_loader/loader.py @@ -0,0 +1,191 @@ +"""Load skill packages from the workspace config directory.""" + +import importlib.util +import logging +import os +import sys +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import yaml + +logger = logging.getLogger(__name__) + +try: + from builtin_tools.security_scan import SkillSecurityError, scan_skill_dependencies + _SECURITY_SCAN_AVAILABLE = True +except ImportError: # lightweight test environments without tools/ on sys.path + _SECURITY_SCAN_AVAILABLE = False + + +@dataclass +class SkillMetadata: + id: str + name: str + description: str + tags: list[str] = field(default_factory=list) + examples: list[str] = field(default_factory=list) + + +@dataclass +class LoadedSkill: + metadata: SkillMetadata + instructions: str + tools: list[Any] = field(default_factory=list) + + +def parse_skill_frontmatter(skill_md_path: Path) -> tuple[dict, str]: + """Parse YAML frontmatter from a SKILL.md file. + + Runtime-side: tolerant of malformed frontmatter (returns ``({}, body)`` + so the skill loads with empty metadata rather than crashing the + workspace at startup). The SDK's :func:`molecule_plugin.parse_skill_md` + is the authoring-time strict validator that surfaces the same errors. + Keep behaviour aligned: if you change acceptance rules here, mirror + them in the SDK's parser. + """ + content = skill_md_path.read_text() + + if not content.startswith("---"): + return {}, content + + parts = content.split("---", 2) + if len(parts) < 3: + return {}, content + + try: + frontmatter = yaml.safe_load(parts[1]) or {} + except yaml.YAMLError: + logger.warning("SKILL.md at %s has malformed frontmatter; loading with empty metadata", skill_md_path) + frontmatter = {} + if not isinstance(frontmatter, dict): + logger.warning("SKILL.md at %s frontmatter is not a mapping; ignoring", skill_md_path) + frontmatter = {} + + body = parts[2].strip() + return frontmatter, body + + +def load_skill_tools(scripts_dir: Path) -> list[Any]: + """Dynamically load tool functions from a skill's scripts/ directory. + + Follows the agentskills.io spec layout: each skill's executable code + lives under ``scripts/``. Returns an empty list if the directory + doesn't exist. + """ + tools = [] + if not scripts_dir.exists(): + return tools + + # Import langchain only when we actually have scripts to process. + # Keeps test environments (and empty skills) from needing langchain. + from langchain_core.tools import BaseTool + + # Sensitive env vars that must not be readable by skill scripts. + # Fix C (Cycle 5): scrub before exec_module() so a malicious skill cannot + # exfiltrate credentials even if it somehow bypasses the POST /plugins + # auth gate (defence in depth). + _SCRUB_KEYS = ( + "CLAUDE_CODE_OAUTH_TOKEN", + "ANTHROPIC_API_KEY", + "OPENAI_API_KEY", + "WORKSPACE_AUTH_TOKEN", + "GITHUB_TOKEN", + "GH_TOKEN", + ) + + for py_file in sorted(scripts_dir.glob("*.py")): + if py_file.name.startswith("_"): + continue + + # Verify the script is actually inside the expected scripts directory + # (path traversal guard — glob shouldn't produce outside paths, but + # belt-and-suspenders for symlink attacks). + try: + py_file.resolve().relative_to(scripts_dir.resolve()) + except ValueError: + logger.warning("skill_loader: rejecting script outside scripts_dir: %s", py_file) + continue + + module_name = f"skill_tool_{py_file.stem}" + spec = importlib.util.spec_from_file_location(module_name, py_file) + if spec is None or spec.loader is None: + continue + + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + + # Temporarily remove sensitive env vars before running skill code. + _saved_env = {k: os.environ.pop(k) for k in _SCRUB_KEYS if k in os.environ} + try: + spec.loader.exec_module(module) + finally: + # Always restore so the rest of the agent process retains them. + os.environ.update(_saved_env) + + # Look for functions decorated with @tool (BaseTool instances) + for attr_name in dir(module): + attr = getattr(module, attr_name) + if isinstance(attr, BaseTool): + tools.append(attr) + + return tools + + +def load_skills(config_path: str, skill_names: list[str]) -> list[LoadedSkill]: + """Load all skills specified in the config.""" + skills_dir = Path(config_path) / "skills" + loaded = [] + + # Resolve security scan mode once before the loop + scan_mode = "warn" + fail_open_if_no_scanner = True # safe default matches security_scan.py default + if _SECURITY_SCAN_AVAILABLE: + try: + from config import load_config + _cfg = load_config(config_path) + scan_mode = _cfg.security_scan.mode + fail_open_if_no_scanner = _cfg.security_scan.fail_open_if_no_scanner + except Exception: + pass # use defaults — never block on config error + + for skill_name in skill_names: + skill_path = skills_dir / skill_name + skill_md = skill_path / "SKILL.md" + + if not skill_md.exists(): + logger.warning("SKILL.md not found for %s, skipping", skill_name) + continue + + # --- Security scan before loading any code from the skill ------------ + if _SECURITY_SCAN_AVAILABLE and scan_mode != "off": + try: + scan_skill_dependencies( + skill_name, skill_path, scan_mode, + fail_open_if_no_scanner=fail_open_if_no_scanner, + ) + except SkillSecurityError as exc: + logger.warning("Skipping skill '%s': blocked by security scan — %s", skill_name, exc) + continue + + frontmatter, instructions = parse_skill_frontmatter(skill_md) + + metadata = SkillMetadata( + id=skill_name, + name=frontmatter.get("name", skill_name), + description=frontmatter.get("description", ""), + tags=frontmatter.get("tags", []), + examples=frontmatter.get("examples", []), + ) + + # Executables live under scripts/ per the agentskills.io spec. + tools = load_skill_tools(skill_path / "scripts") + + loaded.append(LoadedSkill( + metadata=metadata, + instructions=instructions, + tools=tools, + )) + + return loaded diff --git a/molecule_runtime/skill_loader/watcher.py b/molecule_runtime/skill_loader/watcher.py new file mode 100644 index 0000000..03b2372 --- /dev/null +++ b/molecule_runtime/skill_loader/watcher.py @@ -0,0 +1,227 @@ +"""Skills hot-reload watcher. + +Monitors the workspace's ``skills/`` directory for file changes and reloads +affected skill modules in-place — no coordinator restart required. + +Architecture +------------ +``SkillsWatcher`` runs as a background asyncio task alongside the agent. It +polls the skill directories every ``POLL_INTERVAL`` seconds (default 3 s), +computes SHA-256 hashes of every file, and fires ``_reload_skill()`` when any +file inside a skill's folder changes. + +``_reload_skill()`` calls ``load_skills()`` from ``skills.loader`` for the +changed skill and passes the fresh ``LoadedSkill`` to every registered +``on_reload`` callback. Adapters register a callback that rebuilds the +LangGraph agent with the updated tool set, so the change takes effect on +the very next incoming A2A task — zero downtime. + +Audit event +----------- +Every successful reload emits:: + + event_type : "skill_reload" + action : "reload" + resource : "" + outcome : "success" | "failure" + changed_files : [list of relative paths that triggered the reload] + +Usage:: + + watcher = SkillsWatcher( + config_path="/configs", + skill_names=["web_search", "code_review"], + on_reload=lambda skill: rebuild_agent_with_skill(skill), + ) + asyncio.create_task(watcher.start()) +""" + +from __future__ import annotations + +import asyncio +import hashlib +import logging +import sys +from pathlib import Path +from typing import Callable + +logger = logging.getLogger(__name__) + +POLL_INTERVAL = 3.0 # seconds between filesystem polls +DEBOUNCE_SECS = 1.5 # wait for writes to settle before reloading + + +class SkillsWatcher: + """Watches skill directories and reloads changed skills without restarting. + + Args: + config_path: Path to the workspace config directory (contains ``skills/``). + skill_names: List of skill IDs to watch (subfolder names under ``skills/``). + on_reload: Async or sync callable invoked with a fresh ``LoadedSkill`` + every time a skill is reloaded. May be called concurrently + for multiple skills if several change at once. + """ + + def __init__( + self, + config_path: str, + skill_names: list[str], + on_reload: Callable | None = None, + ) -> None: + self.config_path = config_path + self.skill_names = list(skill_names) + self.on_reload = on_reload + self._hashes: dict[str, str] = {} # rel_path → sha256 hex + self._running = False + + # ------------------------------------------------------------------ + # Public interface + # ------------------------------------------------------------------ + + async def start(self) -> None: + """Start the poll loop in the current event loop. Runs until ``stop()``.""" + self._running = True + self._hashes = self._scan() + logger.info( + "SkillsWatcher: monitoring %d skill(s) in %s", + len(self.skill_names), self.config_path, + ) + + while self._running: + await asyncio.sleep(POLL_INTERVAL) + await self._tick() + + def stop(self) -> None: + self._running = False + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _skills_root(self) -> Path: + return Path(self.config_path) / "skills" + + def _hash_file(self, path: Path) -> str: + try: + # H1: SHA-256 replaces MD5 for file-integrity change detection. + return hashlib.sha256(path.read_bytes()).hexdigest() + except OSError: + return "" + + def _scan(self) -> dict[str, str]: + """Return {relative_path: sha256} for every file in watched skill dirs.""" + hashes: dict[str, str] = {} + root = self._skills_root() + for skill_name in self.skill_names: + skill_dir = root / skill_name + if not skill_dir.is_dir(): + continue + for fpath in skill_dir.rglob("*"): + if fpath.is_file() and not fpath.name.startswith("."): + rel = str(fpath.relative_to(root)) + hashes[rel] = self._hash_file(fpath) + return hashes + + def _changed_skills(self, new_hashes: dict[str, str]) -> dict[str, list[str]]: + """Return {skill_name: [changed_file, …]} for skills with file changes.""" + changed: dict[str, list[str]] = {} + + all_paths = set(new_hashes) | set(self._hashes) + for rel_path in all_paths: + old = self._hashes.get(rel_path, "") + new = new_hashes.get(rel_path, "") + if old != new: + # rel_path is like "web_search/SKILL.md" or "web_search/tools/foo.py" + skill_name = rel_path.split("/")[0] + if skill_name in self.skill_names: + changed.setdefault(skill_name, []).append(rel_path) + + return changed + + async def _tick(self) -> None: + """One poll cycle: detect changes, debounce, reload.""" + new_hashes = self._scan() + changed = self._changed_skills(new_hashes) + + if not changed: + return + + logger.info("SkillsWatcher: changes detected in %s", list(changed.keys())) + await asyncio.sleep(DEBOUNCE_SECS) + + # Re-scan after debounce to absorb any writes still in-flight + new_hashes = self._scan() + changed = self._changed_skills(new_hashes) + + self._hashes = new_hashes # commit new baseline + + for skill_name, files in changed.items(): + await self._reload_skill(skill_name, files) + + async def _reload_skill(self, skill_name: str, changed_files: list[str]) -> None: + """Reload *skill_name*'s modules and notify the callback.""" + logger.info("SkillsWatcher: reloading skill '%s' (changed: %s)", skill_name, changed_files) + + # Evict stale module entries so importlib loads fresh copies + stale = [k for k in sys.modules if k.startswith(f"skill_tool_")] + for key in stale: + del sys.modules[key] + + try: + from skill_loader.loader import load_skills + loaded = load_skills(self.config_path, [skill_name]) + + if loaded: + skill = loaded[0] + logger.info( + "SkillsWatcher: skill '%s' reloaded — %d tool(s)", + skill_name, len(skill.tools), + ) + + # Audit event + try: + from builtin_tools.audit import log_event + log_event( + event_type="skill_reload", + action="reload", + resource=skill_name, + outcome="success", + changed_files=changed_files, + tool_count=len(skill.tools), + ) + except Exception: + pass + + # Notify adapter callback + if self.on_reload is not None: + try: + result = self.on_reload(skill) + if asyncio.iscoroutine(result): + await result + except Exception as exc: + logger.error( + "SkillsWatcher: on_reload callback failed for '%s': %s", + skill_name, exc, + ) + else: + logger.warning("SkillsWatcher: no LoadedSkill returned for '%s'", skill_name) + self._audit_failure(skill_name, changed_files, "no_skill_returned") + + except Exception as exc: + logger.error("SkillsWatcher: reload failed for '%s': %s", skill_name, exc) + self._audit_failure(skill_name, changed_files, str(exc)) + + @staticmethod + def _audit_failure(skill_name: str, changed_files: list[str], error: str) -> None: + try: + from builtin_tools.audit import log_event + log_event( + event_type="skill_reload", + action="reload", + resource=skill_name, + outcome="failure", + changed_files=changed_files, + error=error, + ) + except Exception: + pass diff --git a/molecule_runtime/transcript_auth.py b/molecule_runtime/transcript_auth.py new file mode 100644 index 0000000..49b0f62 --- /dev/null +++ b/molecule_runtime/transcript_auth.py @@ -0,0 +1,30 @@ +"""Auth gate for the /transcript Starlette route. + +Extracted from main.py so the security-critical logic is unit-testable +without standing up the full uvicorn/a2a/httpx import stack. + +#328: the route must fail CLOSED when the expected token is unavailable +(bootstrap window, missing file, OSError). The previous implementation +treated a missing token as "skip auth entirely" — any container on the +same Docker network could read the session log during provisioning. +""" + + +def transcript_authorized(expected_token: str | None, auth_header: str) -> bool: + """Return True iff /transcript should serve the request. + + Args: + expected_token: the workspace's registered bearer token, or None + if `/configs/.auth_token` is absent / unreadable. + auth_header: raw value of the Authorization request header. + + Behavior: + - None/empty expected → fail closed (401). This is the #328 fix; + a missing token file is an auth failure, not a bypass. + - Non-empty expected: strict equality check against "Bearer ". + Bearer prefix is case-sensitive (matches the platform's + wsauth.BearerTokenFromHeader contract). + """ + if not expected_token: + return False + return auth_header == f"Bearer {expected_token}" diff --git a/molecule_runtime/watcher.py b/molecule_runtime/watcher.py new file mode 100644 index 0000000..ca22042 --- /dev/null +++ b/molecule_runtime/watcher.py @@ -0,0 +1,120 @@ +"""File watcher for hot-reloading skills and config changes. + +Monitors the config directory for file changes and triggers +agent rebuild + Agent Card update broadcast. +""" + +import asyncio +import hashlib +import logging +import os +from pathlib import Path + +import httpx + +logger = logging.getLogger(__name__) + +DEBOUNCE_SECONDS = 2.0 +POLL_INTERVAL = 3.0 # seconds between filesystem checks + + +class ConfigWatcher: + """Watches the config directory for changes and triggers reload callbacks.""" + + def __init__( + self, + config_path: str, + platform_url: str, + workspace_id: str, + on_reload=None, + ): + self.config_path = config_path + self.platform_url = platform_url + self.workspace_id = workspace_id + self.on_reload = on_reload + self._file_hashes: dict[str, str] = {} + self._running = False + + def _hash_file(self, path: str) -> str: + try: + # H1: SHA-256 replaces MD5 for file-integrity change detection. + # MD5 is collision-prone; using SHA-256 prevents a crafted config + # file from producing the same hash as a benign one, which would + # silently suppress the hot-reload callback. + return hashlib.sha256(Path(path).read_bytes()).hexdigest() + except (OSError, IOError): + return "" + + def _scan_hashes(self) -> dict[str, str]: + """Scan all files in config directory and return hash map.""" + hashes = {} + for root, _, files in os.walk(self.config_path): + for fname in files: + if fname.startswith("."): + continue + fpath = os.path.join(root, fname) + rel = os.path.relpath(fpath, self.config_path) + hashes[rel] = self._hash_file(fpath) + return hashes + + def _detect_changes(self) -> list[str]: + """Compare current state with cached hashes, return changed files.""" + current = self._scan_hashes() + changed = [] + + for path, h in current.items(): + if path not in self._file_hashes or self._file_hashes[path] != h: + changed.append(path) + + for path in self._file_hashes: + if path not in current: + changed.append(path) + + self._file_hashes = current + return changed + + async def _notify_platform(self, agent_card: dict): + """Push updated Agent Card to the platform.""" + try: + async with httpx.AsyncClient(timeout=10.0) as client: + await client.post( + f"{self.platform_url}/registry/update-card", + json={ + "workspace_id": self.workspace_id, + "agent_card": agent_card, + }, + ) + logger.info("Agent Card updated via platform") + except Exception as e: + logger.warning("Failed to update Agent Card: %s", e) + + async def start(self): + """Start watching for changes in a background loop.""" + self._running = True + self._file_hashes = self._scan_hashes() + logger.info("Config watcher started for %s", self.config_path) + + while self._running: + await asyncio.sleep(POLL_INTERVAL) + + changed = self._detect_changes() + if not changed: + continue + + logger.info("Config changes detected: %s", changed) + + # Debounce — wait for writes to settle + await asyncio.sleep(DEBOUNCE_SECONDS) + + # Re-scan after debounce (more changes may have occurred) + self._detect_changes() + + # Trigger reload callback + if self.on_reload: + try: + await self.on_reload() + except Exception as e: + logger.error("Reload callback failed: %s", e) + + def stop(self): + self._running = False diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..ed64cbc --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,35 @@ +[build-system] +requires = ["setuptools>=68.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "molecule-ai-workspace-runtime" +version = "0.1.0" +description = "Molecule AI workspace runtime — shared infrastructure for all agent adapters" +requires-python = ">=3.11" +license = {text = "BSL-1.1"} +readme = "README.md" +# Don't pin heavy deps — each adapter adds its own +dependencies = [ + "a2a-sdk[http-server]>=0.3.25", + "httpx>=0.27.0", + "uvicorn>=0.30.0", + "starlette>=0.38.0", + "websockets>=12.0", + "pyyaml>=6.0", + "langchain-core>=0.3.0", + "opentelemetry-api>=1.24.0", + "opentelemetry-sdk>=1.24.0", + "opentelemetry-exporter-otlp-proto-http>=1.24.0", + "temporalio>=1.7.0", +] + +[project.scripts] +molecule-runtime = "molecule_runtime.main:main_sync" + +[tool.setuptools.packages.find] +where = ["."] +include = ["molecule_runtime*"] + +[tool.setuptools.package-data] +"molecule_runtime" = ["py.typed"]