feat: initial release of molecule-ai-workspace-runtime 0.1.0
Extracts shared workspace runtime from molecule-monorepo/workspace-template into a publishable PyPI package. - molecule_runtime/ package with all shared infrastructure modules - Adapter discovery via ADAPTER_MODULE env var (standalone repos) + built-in scan - molecule-runtime console script entry point (main_sync) - CI workflow to publish on version tags - Published to PyPI as molecule-ai-workspace-runtime==0.1.0 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
commit
851a6d7bfd
29
.github/workflows/publish.yml
vendored
Normal file
29
.github/workflows/publish.yml
vendored
Normal file
@ -0,0 +1,29 @@
|
||||
name: Publish to PyPI
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "v*"
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
build-and-publish:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Install build tools
|
||||
run: pip install build twine
|
||||
|
||||
- name: Build package
|
||||
run: python -m build
|
||||
|
||||
- name: Publish to PyPI
|
||||
env:
|
||||
TWINE_USERNAME: __token__
|
||||
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
|
||||
run: python -m twine upload dist/*
|
||||
11
.gitignore
vendored
Normal file
11
.gitignore
vendored
Normal file
@ -0,0 +1,11 @@
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*.pyo
|
||||
*.egg-info/
|
||||
dist/
|
||||
build/
|
||||
.eggs/
|
||||
*.egg
|
||||
.venv/
|
||||
venv/
|
||||
*.pyc
|
||||
65
README.md
Normal file
65
README.md
Normal file
@ -0,0 +1,65 @@
|
||||
# molecule-ai-workspace-runtime
|
||||
|
||||
Shared Python runtime infrastructure for all Molecule AI agent adapters.
|
||||
|
||||
This package provides the core machinery that every Molecule AI workspace container needs:
|
||||
|
||||
- **A2A server** — Registers with the platform, heartbeats, serves A2A JSON-RPC
|
||||
- **Adapter interface** — `BaseAdapter` / `AdapterConfig` / `SetupResult`
|
||||
- **Built-in tools** — delegation, memory, approvals, sandbox, telemetry
|
||||
- **Skill loader** — loads and hot-reloads skill modules from `/configs/skills/`
|
||||
- **Plugin system** — per-workspace + shared plugin discovery and install
|
||||
- **Config / preflight** — YAML config loading with validation
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install molecule-ai-workspace-runtime
|
||||
```
|
||||
|
||||
## Adapter Discovery
|
||||
|
||||
The runtime discovers adapters in two ways:
|
||||
|
||||
1. **`ADAPTER_MODULE` env var** (standalone adapter repos):
|
||||
```bash
|
||||
ADAPTER_MODULE=my_adapter molecule-runtime
|
||||
```
|
||||
The module must export an `Adapter` class extending `BaseAdapter`.
|
||||
|
||||
2. **Built-in subdirectory scan** (monorepo local dev):
|
||||
Scans `molecule_runtime/adapters/` subdirectories for `Adapter` classes.
|
||||
|
||||
## Writing an Adapter
|
||||
|
||||
```python
|
||||
from molecule_runtime.adapters.base import BaseAdapter, AdapterConfig
|
||||
from a2a.server.agent_execution import AgentExecutor
|
||||
|
||||
class Adapter(BaseAdapter):
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
return "my-runtime"
|
||||
|
||||
@staticmethod
|
||||
def display_name() -> str:
|
||||
return "My Runtime"
|
||||
|
||||
@staticmethod
|
||||
def description() -> str:
|
||||
return "My custom agent runtime"
|
||||
|
||||
async def setup(self, config: AdapterConfig) -> None:
|
||||
result = await self._common_setup(config)
|
||||
# Store result attributes for create_executor
|
||||
|
||||
async def create_executor(self, config: AdapterConfig) -> AgentExecutor:
|
||||
# Return an AgentExecutor instance
|
||||
...
|
||||
```
|
||||
|
||||
Set `ADAPTER_MODULE=my_package.adapter` and run `molecule-runtime`.
|
||||
|
||||
## License
|
||||
|
||||
BSL-1.1 — see LICENSE for details.
|
||||
6
molecule_runtime/__init__.py
Normal file
6
molecule_runtime/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
"""Molecule AI workspace runtime — shared infrastructure for all agent adapters."""
|
||||
|
||||
from molecule_runtime.adapters.base import BaseAdapter, AdapterConfig, SetupResult
|
||||
|
||||
__version__ = "0.1.0"
|
||||
__all__ = ["BaseAdapter", "AdapterConfig", "SetupResult"]
|
||||
245
molecule_runtime/a2a_cli.py
Normal file
245
molecule_runtime/a2a_cli.py
Normal file
@ -0,0 +1,245 @@
|
||||
#!/usr/bin/env python3
|
||||
"""A2A CLI — command-line tools for inter-workspace communication.
|
||||
|
||||
Supports both synchronous and asynchronous delegation:
|
||||
a2a delegate <id> <task> — Send task, wait for response (sync)
|
||||
a2a delegate --async <id> <task> — Send task, return task ID immediately
|
||||
a2a status <task_id> — Check task status / get result
|
||||
a2a peers — List available peers
|
||||
a2a info — Show this workspace's info
|
||||
|
||||
Environment variables:
|
||||
WORKSPACE_ID — this workspace's ID
|
||||
PLATFORM_URL — platform API base URL
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
import httpx
|
||||
|
||||
WORKSPACE_ID = os.environ.get("WORKSPACE_ID", "")
|
||||
PLATFORM_URL = os.environ.get("PLATFORM_URL", "http://platform:8080")
|
||||
|
||||
|
||||
async def discover(target_id: str) -> dict | None:
|
||||
"""Discover a peer workspace's URL."""
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.get(
|
||||
f"{PLATFORM_URL}/registry/discover/{target_id}",
|
||||
headers={"X-Workspace-ID": WORKSPACE_ID},
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
return resp.json()
|
||||
return None
|
||||
|
||||
|
||||
async def delegate(target_id: str, task: str, async_mode: bool = False):
|
||||
"""Delegate a task to another workspace."""
|
||||
peer = await discover(target_id)
|
||||
if not peer:
|
||||
print(f"Error: cannot reach workspace {target_id} (access denied or offline)", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
target_url = peer.get("url", "")
|
||||
if not target_url:
|
||||
print(f"Error: workspace {target_id} has no URL", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
if async_mode:
|
||||
# Async: send and return immediately, don't wait for response
|
||||
# Use a background task that fires and forgets
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
try:
|
||||
# Send with a short timeout — just confirm receipt
|
||||
resp = await client.post(
|
||||
target_url,
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"id": task_id,
|
||||
"method": "message/send",
|
||||
"params": {
|
||||
"message": {
|
||||
"role": "user",
|
||||
"messageId": str(uuid.uuid4()),
|
||||
"parts": [{"kind": "text", "text": task}],
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
# Even if we timeout, the task is queued on the target
|
||||
print(json.dumps({
|
||||
"task_id": task_id,
|
||||
"target": target_id,
|
||||
"status": "submitted",
|
||||
"target_url": target_url,
|
||||
}))
|
||||
except httpx.TimeoutException:
|
||||
# Request was sent but we didn't get confirmation — task may or may not have been received
|
||||
print(json.dumps({
|
||||
"task_id": task_id,
|
||||
"target": target_id,
|
||||
"status": "uncertain",
|
||||
"note": "Request sent but response timed out — delivery unconfirmed. Use 'a2a status' to check.",
|
||||
}), file=sys.stderr)
|
||||
return
|
||||
|
||||
# Sync: wait for full response with retry on rate limit
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
async with httpx.AsyncClient(timeout=300.0) as client:
|
||||
try:
|
||||
resp = await client.post(
|
||||
target_url,
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"id": task_id,
|
||||
"method": "message/send",
|
||||
"params": {
|
||||
"message": {
|
||||
"role": "user",
|
||||
"messageId": str(uuid.uuid4()),
|
||||
"parts": [{"kind": "text", "text": task}],
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
try:
|
||||
data = resp.json()
|
||||
except Exception:
|
||||
print(f"Error: invalid JSON response (status {resp.status_code})", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
if "result" in data:
|
||||
parts = data["result"].get("parts", [])
|
||||
text = parts[0].get("text", "") if parts else ""
|
||||
if text and text != "(no response generated)":
|
||||
print(text)
|
||||
return
|
||||
# Empty or no-response — might be rate limited, retry
|
||||
if attempt < max_retries - 1:
|
||||
delay = 5 * (2 ** attempt)
|
||||
print(f"(empty response, retrying in {delay}s...)", file=sys.stderr)
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
print(text or "(no response after retries)")
|
||||
elif "error" in data:
|
||||
error_msg = data['error'].get('message', 'unknown')
|
||||
if ("rate" in error_msg.lower() or "overloaded" in error_msg.lower()) and attempt < max_retries - 1:
|
||||
delay = 5 * (2 ** attempt)
|
||||
print(f"(rate limited, retrying in {delay}s...)", file=sys.stderr)
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
print(f"Error: {error_msg}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
return
|
||||
except httpx.TimeoutException:
|
||||
if attempt < max_retries - 1:
|
||||
delay = 5 * (2 ** attempt)
|
||||
print(f"(timeout, retrying in {delay}s...)", file=sys.stderr)
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
print("Error: request timed out after retries", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
async def check_status(target_id: str, task_id: str):
|
||||
"""Check the status of an async task."""
|
||||
peer = await discover(target_id)
|
||||
if not peer:
|
||||
print(f"Error: cannot reach workspace {target_id}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
target_url = peer.get("url", "")
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.post(
|
||||
target_url,
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"id": str(uuid.uuid4()),
|
||||
"method": "tasks/get",
|
||||
"params": {"id": task_id},
|
||||
},
|
||||
)
|
||||
data = resp.json()
|
||||
if "result" in data:
|
||||
task = data["result"]
|
||||
status = task.get("status", {}).get("state", "unknown")
|
||||
print(f"Status: {status}")
|
||||
if status == "completed":
|
||||
artifacts = task.get("artifacts", [])
|
||||
for a in artifacts:
|
||||
for p in a.get("parts", []):
|
||||
if p.get("text"):
|
||||
print(p["text"])
|
||||
elif "error" in data:
|
||||
print(f"Error: {data['error'].get('message', 'unknown')}")
|
||||
|
||||
|
||||
async def peers():
|
||||
"""List available peers."""
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.get(f"{PLATFORM_URL}/registry/{WORKSPACE_ID}/peers")
|
||||
if resp.status_code != 200:
|
||||
print("Error: could not fetch peers", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
for p in resp.json():
|
||||
status = p.get("status", "?")
|
||||
role = p.get("role", "")
|
||||
print(f"{p['id']} {p['name']:30s} {status:10s} {role}")
|
||||
|
||||
|
||||
async def info():
|
||||
"""Get this workspace's info."""
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.get(f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}")
|
||||
if resp.status_code == 200:
|
||||
d = resp.json()
|
||||
print(f"ID: {d['id']}")
|
||||
print(f"Name: {d['name']}")
|
||||
print(f"Role: {d.get('role', '')}")
|
||||
print(f"Tier: {d['tier']}")
|
||||
print(f"Status: {d['status']}")
|
||||
print(f"Parent: {d.get('parent_id', '(root)')}")
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: a2a <command> [args]")
|
||||
print("Commands:")
|
||||
print(" delegate <workspace_id> <task> — Send task, wait for response")
|
||||
print(" delegate --async <workspace_id> <task> — Send task, return immediately")
|
||||
print(" status <workspace_id> <task_id> — Check async task status")
|
||||
print(" peers — List available peers")
|
||||
print(" info — Show workspace info")
|
||||
sys.exit(1)
|
||||
|
||||
cmd = sys.argv[1]
|
||||
|
||||
if cmd == "delegate":
|
||||
async_mode = "--async" in sys.argv
|
||||
args = [a for a in sys.argv[2:] if a != "--async"]
|
||||
if len(args) < 2:
|
||||
print("Usage: a2a delegate [--async] <workspace_id> <task>", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
asyncio.run(delegate(args[0], " ".join(args[1:]), async_mode))
|
||||
elif cmd == "status":
|
||||
if len(sys.argv) < 4:
|
||||
print("Usage: a2a status <workspace_id> <task_id>", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
asyncio.run(check_status(sys.argv[2], sys.argv[3]))
|
||||
elif cmd == "peers":
|
||||
asyncio.run(peers())
|
||||
elif cmd == "info":
|
||||
asyncio.run(info())
|
||||
else:
|
||||
print(f"Unknown command: {cmd}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__": # pragma: no cover
|
||||
main()
|
||||
111
molecule_runtime/a2a_client.py
Normal file
111
molecule_runtime/a2a_client.py
Normal file
@ -0,0 +1,111 @@
|
||||
"""A2A protocol client — peer discovery, messaging, and workspace info.
|
||||
|
||||
Shared constants (WORKSPACE_ID, PLATFORM_URL) live here so that
|
||||
a2a_tools and a2a_mcp_server can import them from a single place.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import httpx
|
||||
|
||||
from platform_auth import auth_headers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
WORKSPACE_ID = os.environ.get("WORKSPACE_ID", "")
|
||||
PLATFORM_URL = os.environ.get("PLATFORM_URL", "http://platform:8080")
|
||||
|
||||
# Cache workspace ID → name mappings (populated by list_peers calls)
|
||||
_peer_names: dict[str, str] = {}
|
||||
|
||||
# Sentinel prefix for errors originating from send_a2a_message / child agents.
|
||||
# Used by delegate_task to distinguish real errors from normal response text.
|
||||
_A2A_ERROR_PREFIX = "[A2A_ERROR] "
|
||||
|
||||
|
||||
async def discover_peer(target_id: str) -> dict | None:
|
||||
"""Discover a peer workspace's URL via the platform registry."""
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
try:
|
||||
resp = await client.get(
|
||||
f"{PLATFORM_URL}/registry/discover/{target_id}",
|
||||
headers={"X-Workspace-ID": WORKSPACE_ID, **auth_headers()},
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
return resp.json()
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Discovery failed for {target_id}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def send_a2a_message(target_url: str, message: str) -> str:
|
||||
"""Send an A2A message/send to a target workspace."""
|
||||
# Fix F (Cycle 5 / H2 — flagged 5 consecutive audits): timeout=None allowed
|
||||
# a hung upstream to block the agent indefinitely. Use a generous but bounded
|
||||
# timeout: 30s connect + 300s read (long enough for slow LLM responses).
|
||||
async with httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(connect=30.0, read=300.0, write=30.0, pool=30.0)
|
||||
) as client:
|
||||
try:
|
||||
resp = await client.post(
|
||||
target_url,
|
||||
headers=auth_headers(),
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"id": str(uuid.uuid4()),
|
||||
"method": "message/send",
|
||||
"params": {
|
||||
"message": {
|
||||
"role": "user",
|
||||
"messageId": str(uuid.uuid4()),
|
||||
"parts": [{"kind": "text", "text": message}],
|
||||
}
|
||||
},
|
||||
},
|
||||
)
|
||||
data = resp.json()
|
||||
if "result" in data:
|
||||
parts = data["result"].get("parts", [])
|
||||
text = parts[0].get("text", "") if parts else "(no response)"
|
||||
# Tag child-reported errors so the caller can detect them reliably
|
||||
if text.startswith("Agent error:"):
|
||||
return f"{_A2A_ERROR_PREFIX}{text}"
|
||||
return text
|
||||
elif "error" in data:
|
||||
return f"{_A2A_ERROR_PREFIX}{data['error'].get('message', 'unknown')}"
|
||||
return str(data)
|
||||
except Exception as e:
|
||||
return f"{_A2A_ERROR_PREFIX}{e}"
|
||||
|
||||
|
||||
async def get_peers() -> list[dict]:
|
||||
"""Get this workspace's peers from the platform registry."""
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
try:
|
||||
resp = await client.get(
|
||||
f"{PLATFORM_URL}/registry/{WORKSPACE_ID}/peers",
|
||||
headers={"X-Workspace-ID": WORKSPACE_ID, **auth_headers()},
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
return resp.json()
|
||||
return []
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
async def get_workspace_info() -> dict:
|
||||
"""Get this workspace's info from the platform."""
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
try:
|
||||
resp = await client.get(
|
||||
f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}",
|
||||
headers=auth_headers(),
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
return resp.json()
|
||||
return {"error": "not found"}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
419
molecule_runtime/a2a_executor.py
Normal file
419
molecule_runtime/a2a_executor.py
Normal file
@ -0,0 +1,419 @@
|
||||
"""Bridge between LangGraph agent and A2A protocol, with SSE streaming support.
|
||||
|
||||
SSE streaming architecture
|
||||
--------------------------
|
||||
The A2A SDK (``DefaultRequestHandler`` + ``EventQueue``) owns the SSE transport
|
||||
layer. This executor's job is to push the right event types into the queue as
|
||||
work progresses:
|
||||
|
||||
1. ``TaskStatusUpdateEvent(state=working)`` — immediately signals start
|
||||
2. ``TaskArtifactUpdateEvent(chunk, append=…)`` — one per LLM text token
|
||||
3. ``Message(final_text)`` — terminal event
|
||||
|
||||
Client compatibility
|
||||
--------------------
|
||||
*Non-streaming* (``message/send``):
|
||||
``ResultAggregator.consume_all()`` processes status/artifact events
|
||||
(updating the task in the store) and returns the final ``Message``
|
||||
immediately — backward-compatible with ``a2a_client.py`` which reads
|
||||
``data["result"]["parts"][0]["text"]``.
|
||||
|
||||
*Streaming* (``message/stream``):
|
||||
``consume_and_emit()`` yields every event above as SSE, letting the client
|
||||
render tokens in real time.
|
||||
|
||||
LangGraph integration
|
||||
---------------------
|
||||
Uses ``agent.astream_events(version="v2")`` to receive ``on_chat_model_stream``
|
||||
events with ``AIMessageChunk`` payloads. Text is extracted from both plain
|
||||
strings (OpenAI / Groq) and Anthropic-style content-block lists. Non-text
|
||||
content (tool_use, etc.) is silently skipped. A fresh ``artifact_id`` is
|
||||
generated for each new LLM ``run_id`` so tool-call cycles are grouped cleanly.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from a2a.server.agent_execution import AgentExecutor, RequestContext
|
||||
from a2a.server.events import EventQueue
|
||||
from a2a.server.tasks import TaskUpdater
|
||||
from a2a.types import Part, TextPart
|
||||
from a2a.utils import new_agent_text_message
|
||||
from adapters.shared_runtime import (
|
||||
extract_history as _extract_history,
|
||||
extract_message_text,
|
||||
brief_task,
|
||||
set_current_task,
|
||||
)
|
||||
from builtin_tools.telemetry import (
|
||||
A2A_TASK_ID,
|
||||
GEN_AI_OPERATION_NAME,
|
||||
GEN_AI_REQUEST_MODEL,
|
||||
GEN_AI_SYSTEM,
|
||||
WORKSPACE_ID_ATTR,
|
||||
_incoming_trace_context,
|
||||
gen_ai_system_from_model,
|
||||
get_tracer,
|
||||
record_llm_token_usage,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_WORKSPACE_ID = os.environ.get("WORKSPACE_ID", "unknown")
|
||||
|
||||
# LangGraph ReAct cycle budget per turn. Library default is 25; 500 covers
|
||||
# PM fan-outs (plan → 6 delegations → 6 awaits → 6 results → synthesize ≈
|
||||
# 30+ steps even before retries). Overridable via LANGGRAPH_RECURSION_LIMIT.
|
||||
DEFAULT_RECURSION_LIMIT = 500
|
||||
|
||||
|
||||
def _parse_recursion_limit() -> int:
|
||||
"""Read LANGGRAPH_RECURSION_LIMIT; fall back to DEFAULT_RECURSION_LIMIT
|
||||
with a WARNING log on any unparseable or non-positive value."""
|
||||
raw = os.environ.get("LANGGRAPH_RECURSION_LIMIT", "")
|
||||
if not raw:
|
||||
return DEFAULT_RECURSION_LIMIT
|
||||
try:
|
||||
n = int(raw)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"LANGGRAPH_RECURSION_LIMIT=%r is not an integer; using default %d",
|
||||
raw, DEFAULT_RECURSION_LIMIT,
|
||||
)
|
||||
return DEFAULT_RECURSION_LIMIT
|
||||
if n <= 0:
|
||||
logger.warning(
|
||||
"LANGGRAPH_RECURSION_LIMIT=%d is not positive; using default %d",
|
||||
n, DEFAULT_RECURSION_LIMIT,
|
||||
)
|
||||
return DEFAULT_RECURSION_LIMIT
|
||||
return n
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Compliance (OWASP Top 10 for Agentic Apps) — optional, lazy-loaded
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
try:
|
||||
from builtin_tools.compliance import (
|
||||
AgencyTracker,
|
||||
ExcessiveAgencyError,
|
||||
PromptInjectionError,
|
||||
redact_pii as _redact_pii,
|
||||
sanitize_input as _sanitize_input,
|
||||
)
|
||||
_COMPLIANCE_AVAILABLE = True
|
||||
except ImportError: # pragma: no cover
|
||||
_COMPLIANCE_AVAILABLE = False
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _get_compliance_cfg():
|
||||
"""Return ComplianceConfig or None (cached for process lifetime)."""
|
||||
try:
|
||||
from config import load_config
|
||||
return load_config().compliance
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _extract_chunk_text(content) -> list[str]:
|
||||
"""Extract text strings from an LLM streaming chunk's content field.
|
||||
|
||||
Handles both provider content styles:
|
||||
- OpenAI / Groq: ``content`` is a plain ``str`` (empty for tool-call chunks).
|
||||
- Anthropic: ``content`` is a list of typed blocks, e.g.
|
||||
``[{"type": "text", "text": "Hello"}, {"type": "tool_use", ...}]``
|
||||
|
||||
Only ``"text"`` blocks are returned; ``tool_use``, ``tool_result``, and
|
||||
other non-text blocks are filtered out so raw tool JSON never appears in
|
||||
the SSE stream.
|
||||
|
||||
Args:
|
||||
content: ``chunk.content`` value from an ``on_chat_model_stream`` event.
|
||||
|
||||
Returns:
|
||||
List of non-empty text strings.
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
return [content] if content else []
|
||||
if isinstance(content, list):
|
||||
texts: list[str] = []
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
text = block.get("text", "")
|
||||
if text:
|
||||
texts.append(text)
|
||||
elif isinstance(block, str) and block:
|
||||
texts.append(block)
|
||||
return texts
|
||||
return []
|
||||
|
||||
|
||||
class LangGraphA2AExecutor(AgentExecutor):
|
||||
"""Bridges LangGraph agent to A2A event model with SSE streaming support.
|
||||
|
||||
Always uses ``agent.astream_events()`` so that:
|
||||
- Streaming clients (``message/stream``) receive token-level SSE events.
|
||||
- Non-streaming clients (``message/send``) receive the final ``Message``
|
||||
collected from the same stream — no duplicate LLM call, full compat.
|
||||
"""
|
||||
|
||||
def __init__(self, agent, heartbeat=None, model: str = "unknown"):
|
||||
self.agent = agent # Compiled LangGraph graph (create_react_agent output)
|
||||
self._heartbeat = heartbeat
|
||||
self._model = model # e.g. "anthropic:claude-sonnet-4-6"
|
||||
|
||||
async def execute(self, context: RequestContext, event_queue: EventQueue) -> None:
|
||||
"""Execute a task from an A2A request with SSE streaming.
|
||||
|
||||
Routes through the Temporal durable workflow when a global
|
||||
``TemporalWorkflowWrapper`` is initialised and connected to Temporal;
|
||||
otherwise falls back to ``_core_execute()`` (direct path).
|
||||
|
||||
Event emission sequence:
|
||||
1. TaskStatusUpdateEvent(working) — immediate start signal
|
||||
2. TaskArtifactUpdateEvent chunks — token-by-token via astream_events
|
||||
3. Message(final_text) — terminal; non-streaming clients
|
||||
return on this; streaming clients
|
||||
also receive it as the last SSE event.
|
||||
"""
|
||||
# ── Optional Temporal durable execution wrapper ──────────────────────
|
||||
# When a TemporalWorkflowWrapper is active this routes execution through
|
||||
# a MoleculeAIAgentWorkflow (task_receive → llm_call → task_complete).
|
||||
# Falls back silently to _core_execute() on any error or if Temporal
|
||||
# is unavailable, so the client always receives a response.
|
||||
try:
|
||||
from builtin_tools.temporal_workflow import get_wrapper as _get_temporal_wrapper
|
||||
|
||||
_tw = _get_temporal_wrapper()
|
||||
if _tw is not None and _tw.is_available():
|
||||
return await _tw.run(self, context, event_queue)
|
||||
except Exception:
|
||||
pass # Never let the wrapper path crash the executor
|
||||
|
||||
await self._core_execute(context, event_queue)
|
||||
|
||||
async def _core_execute(self, context: RequestContext, event_queue: EventQueue) -> str:
|
||||
"""Core execution pipeline — called directly or from a Temporal activity.
|
||||
|
||||
This is the original ``execute()`` body, extracted so that the Temporal
|
||||
``llm_call`` activity can invoke it without re-entering the wrapper
|
||||
check and causing infinite recursion.
|
||||
|
||||
Returns the final response text (empty string on empty input or error).
|
||||
|
||||
Event emission sequence:
|
||||
1. TaskStatusUpdateEvent(working) — immediate start signal
|
||||
2. TaskArtifactUpdateEvent chunks — token-by-token via astream_events
|
||||
3. Message(final_text) — terminal event
|
||||
"""
|
||||
user_input = extract_message_text(context)
|
||||
if not user_input:
|
||||
parts = getattr(getattr(context, "message", None), "parts", None)
|
||||
logger.warning("A2A execute: no text content in message parts: %s", parts)
|
||||
await event_queue.enqueue_event(
|
||||
new_agent_text_message("Error: message contained no text content.")
|
||||
)
|
||||
return ""
|
||||
|
||||
# ── OA-01: Prompt injection check (OWASP Agentic Top 10) ────────────
|
||||
_compliance_cfg = _get_compliance_cfg() if _COMPLIANCE_AVAILABLE else None
|
||||
if _COMPLIANCE_AVAILABLE and _compliance_cfg and _compliance_cfg.mode == "owasp_agentic":
|
||||
try:
|
||||
user_input = _sanitize_input(
|
||||
user_input,
|
||||
prompt_injection_mode=_compliance_cfg.prompt_injection,
|
||||
context_id=context.context_id or "",
|
||||
)
|
||||
except PromptInjectionError as exc:
|
||||
await event_queue.enqueue_event(
|
||||
new_agent_text_message(f"Request blocked: {exc}")
|
||||
)
|
||||
return ""
|
||||
|
||||
logger.info("A2A execute: user_input=%s", user_input[:200])
|
||||
|
||||
# ── OTEL: task_receive span ──────────────────────────────────────────
|
||||
parent_ctx = _incoming_trace_context.get()
|
||||
tracer = get_tracer()
|
||||
|
||||
_result: str = "" # captured inside the span for return after it closes
|
||||
|
||||
with tracer.start_as_current_span("task_receive", context=parent_ctx) as task_span:
|
||||
task_span.set_attribute(WORKSPACE_ID_ATTR, _WORKSPACE_ID)
|
||||
task_span.set_attribute(A2A_TASK_ID, context.context_id or "")
|
||||
task_span.set_attribute("a2a.input_preview", user_input[:256])
|
||||
|
||||
await set_current_task(self._heartbeat, brief_task(user_input))
|
||||
|
||||
# Resolve IDs — the RequestContextBuilder always sets them, but
|
||||
# we generate fallbacks for safety (e.g. in unit tests).
|
||||
task_id = context.task_id or str(uuid.uuid4())
|
||||
context_id = context.context_id or str(uuid.uuid4())
|
||||
|
||||
updater = TaskUpdater(event_queue, task_id, context_id)
|
||||
|
||||
try:
|
||||
messages = _extract_history(context)
|
||||
if messages:
|
||||
logger.info("A2A execute: injecting %d history messages", len(messages))
|
||||
messages.append(("human", user_input))
|
||||
|
||||
# Recursion limit: see DEFAULT_RECURSION_LIMIT and
|
||||
# _parse_recursion_limit() at module top. Re-read on every
|
||||
# call so the env var can be hot-changed between requests.
|
||||
recursion_limit = _parse_recursion_limit()
|
||||
run_config = {
|
||||
"configurable": {"thread_id": context_id},
|
||||
"run_name": f"a2a-{context_id[:8]}",
|
||||
"recursion_limit": recursion_limit,
|
||||
}
|
||||
|
||||
# ── OTEL: llm_call span ──────────────────────────────────────
|
||||
with tracer.start_as_current_span("llm_call") as llm_span:
|
||||
llm_span.set_attribute(GEN_AI_OPERATION_NAME, "chat")
|
||||
llm_span.set_attribute(GEN_AI_SYSTEM, gen_ai_system_from_model(self._model))
|
||||
llm_span.set_attribute(GEN_AI_REQUEST_MODEL, self._model)
|
||||
llm_span.set_attribute(WORKSPACE_ID_ATTR, _WORKSPACE_ID)
|
||||
|
||||
# ── Step 1: signal "working" to streaming clients ─────────
|
||||
await updater.start_work()
|
||||
|
||||
# ── Step 2: stream tokens via LangGraph astream_events ────
|
||||
# Each "on_chat_model_stream" event carries an AIMessageChunk.
|
||||
# We emit one TaskArtifactUpdateEvent per text chunk so SSE
|
||||
# clients can render tokens in real time.
|
||||
# artifact_id resets on each new LLM run_id so agent→tool→agent
|
||||
# cycles each get their own artifact slot.
|
||||
|
||||
artifact_id = str(uuid.uuid4())
|
||||
has_streamed = False # True after first chunk for current artifact
|
||||
current_run_id = None # Detects new LLM call in a ReAct cycle
|
||||
accumulated: list[str] = [] # All text for the final Message
|
||||
last_ai_message = None # Saved for token-usage telemetry
|
||||
|
||||
# ── OA-03: Excessive agency tracker ──────────────────────
|
||||
_agency = (
|
||||
AgencyTracker(
|
||||
max_tool_calls=_compliance_cfg.max_tool_calls_per_task,
|
||||
max_duration_seconds=float(_compliance_cfg.max_task_duration_seconds),
|
||||
)
|
||||
if _COMPLIANCE_AVAILABLE and _compliance_cfg and _compliance_cfg.mode == "owasp_agentic"
|
||||
else None
|
||||
)
|
||||
|
||||
async for event in self.agent.astream_events(
|
||||
{"messages": messages},
|
||||
config=run_config,
|
||||
version="v2",
|
||||
):
|
||||
kind = event.get("event", "")
|
||||
|
||||
if kind == "on_chat_model_stream":
|
||||
run_id = event.get("run_id", "")
|
||||
if run_id and run_id != current_run_id:
|
||||
# New LLM run started — fresh artifact slot
|
||||
current_run_id = run_id
|
||||
artifact_id = str(uuid.uuid4())
|
||||
has_streamed = False
|
||||
|
||||
chunk = event.get("data", {}).get("chunk")
|
||||
if chunk is not None:
|
||||
texts = _extract_chunk_text(chunk.content)
|
||||
for text in texts:
|
||||
await updater.add_artifact(
|
||||
parts=[Part(root=TextPart(text=text))],
|
||||
artifact_id=artifact_id,
|
||||
append=has_streamed, # False=first, True=append
|
||||
last_chunk=False,
|
||||
)
|
||||
has_streamed = True
|
||||
accumulated.append(text)
|
||||
|
||||
elif kind == "on_tool_start":
|
||||
tool_name = event.get("name", "?")
|
||||
logger.debug("SSE: tool start — %s", tool_name)
|
||||
if _agency is not None:
|
||||
_agency.on_tool_call(
|
||||
tool_name=tool_name,
|
||||
context_id=context_id,
|
||||
)
|
||||
|
||||
elif kind == "on_tool_end":
|
||||
logger.debug("SSE: tool end — %s", event.get("name", "?"))
|
||||
|
||||
elif kind == "on_chat_model_end":
|
||||
# Capture the last completed AIMessage for token telemetry
|
||||
output = event.get("data", {}).get("output")
|
||||
if output is not None:
|
||||
last_ai_message = output
|
||||
|
||||
# Record token usage from the last completed LLM call
|
||||
if last_ai_message is not None:
|
||||
record_llm_token_usage(llm_span, {"messages": [last_ai_message]})
|
||||
|
||||
# Build final text from all accumulated streaming tokens
|
||||
final_text = "".join(accumulated).strip() or "(no response generated)"
|
||||
logger.info("A2A execute: response length=%d chars", len(final_text))
|
||||
|
||||
# ── OA-02 / OA-06: Output PII redaction ──────────────────────
|
||||
if _COMPLIANCE_AVAILABLE and _compliance_cfg and _compliance_cfg.mode == "owasp_agentic":
|
||||
final_text, _pii_types = _redact_pii(final_text)
|
||||
if _pii_types:
|
||||
from builtin_tools.audit import log_event as _audit_log
|
||||
_audit_log(
|
||||
event_type="compliance",
|
||||
action="pii.redact",
|
||||
resource="task_output",
|
||||
outcome="redacted",
|
||||
pii_types=_pii_types,
|
||||
context_id=context_id,
|
||||
)
|
||||
|
||||
# ── OTEL: task_complete span ─────────────────────────────────
|
||||
with tracer.start_as_current_span("task_complete") as done_span:
|
||||
done_span.set_attribute(WORKSPACE_ID_ATTR, _WORKSPACE_ID)
|
||||
done_span.set_attribute(A2A_TASK_ID, context_id)
|
||||
done_span.set_attribute("task.has_response", bool(accumulated))
|
||||
done_span.set_attribute("task.response_length", len(final_text))
|
||||
|
||||
# ── Step 3: emit final Message ────────────────────────────────
|
||||
# Non-streaming: ResultAggregator.consume_all() returns this
|
||||
# immediately as the response (a2a_client.py reads .parts[0].text).
|
||||
# Streaming: yielded as the last SSE event in the stream.
|
||||
await event_queue.enqueue_event(
|
||||
new_agent_text_message(final_text, task_id=task_id, context_id=context_id)
|
||||
)
|
||||
_result = final_text
|
||||
|
||||
except Exception as e:
|
||||
logger.error("A2A execute error: %s", e, exc_info=True)
|
||||
try:
|
||||
task_span.record_exception(e)
|
||||
from opentelemetry.trace import StatusCode
|
||||
task_span.set_status(StatusCode.ERROR, str(e))
|
||||
except Exception:
|
||||
pass
|
||||
# Emit a Message so both streaming and non-streaming clients
|
||||
# receive an error response rather than hanging.
|
||||
await event_queue.enqueue_event(
|
||||
new_agent_text_message(
|
||||
f"Agent error: {e}", task_id=task_id, context_id=context_id
|
||||
)
|
||||
)
|
||||
finally:
|
||||
await set_current_task(self._heartbeat, "")
|
||||
|
||||
return _result
|
||||
|
||||
async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None:
|
||||
"""Cancel a running task — emits canceled state to comply with A2A protocol."""
|
||||
from a2a.types import TaskStatus, TaskState, TaskStatusUpdateEvent
|
||||
await event_queue.enqueue_event(
|
||||
TaskStatusUpdateEvent(
|
||||
status=TaskStatus(state=TaskState.canceled),
|
||||
final=True,
|
||||
)
|
||||
)
|
||||
293
molecule_runtime/a2a_mcp_server.py
Normal file
293
molecule_runtime/a2a_mcp_server.py
Normal file
@ -0,0 +1,293 @@
|
||||
#!/usr/bin/env python3
|
||||
"""A2A MCP Server — runs inside each workspace container.
|
||||
|
||||
Exposes A2A delegation, peer discovery, and workspace info as MCP tools
|
||||
so CLI-based runtimes (Claude Code, Codex) can communicate with other workspaces.
|
||||
|
||||
Launched automatically by main.py for CLI runtimes. Runs on stdio transport
|
||||
and is configured as a local MCP server for the claude --print invocation.
|
||||
|
||||
Environment variables (set by the workspace container):
|
||||
WORKSPACE_ID — this workspace's ID
|
||||
PLATFORM_URL — platform API base URL (e.g. http://platform:8080)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from a2a_tools import (
|
||||
tool_check_task_status,
|
||||
tool_commit_memory,
|
||||
tool_delegate_task,
|
||||
tool_delegate_task_async,
|
||||
tool_get_workspace_info,
|
||||
tool_list_peers,
|
||||
tool_recall_memory,
|
||||
tool_send_message_to_user,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Re-export constants and client functions so existing imports
|
||||
# (e.g. tests that do `import a2a_mcp_server`) still work.
|
||||
from a2a_client import ( # noqa: F401, E402
|
||||
PLATFORM_URL,
|
||||
WORKSPACE_ID,
|
||||
_A2A_ERROR_PREFIX,
|
||||
_peer_names,
|
||||
discover_peer,
|
||||
get_peers,
|
||||
get_workspace_info,
|
||||
send_a2a_message,
|
||||
)
|
||||
from a2a_tools import report_activity # noqa: F401, E402
|
||||
|
||||
# --- Tool definitions (schemas) ---
|
||||
|
||||
TOOLS = [
|
||||
{
|
||||
"name": "delegate_task",
|
||||
"description": "Delegate a task to another workspace via A2A protocol and WAIT for the response. Use for quick tasks. The target must be a peer (sibling or parent/child). Use list_peers to find available targets.",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"workspace_id": {
|
||||
"type": "string",
|
||||
"description": "Target workspace ID (from list_peers)",
|
||||
},
|
||||
"task": {
|
||||
"type": "string",
|
||||
"description": "The task description to send to the target workspace",
|
||||
},
|
||||
},
|
||||
"required": ["workspace_id", "task"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "delegate_task_async",
|
||||
"description": "Send a task to another workspace with a short timeout (fire-and-forget). Returns immediately — the target continues processing. Best when you don't need the result right away. Note: check_task_status may not work with all workspace implementations.",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"workspace_id": {
|
||||
"type": "string",
|
||||
"description": "Target workspace ID (from list_peers)",
|
||||
},
|
||||
"task": {
|
||||
"type": "string",
|
||||
"description": "The task description to send to the target workspace",
|
||||
},
|
||||
},
|
||||
"required": ["workspace_id", "task"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check_task_status",
|
||||
"description": "Check the status of a previously submitted async task via tasks/get. Note: only works if the target workspace's A2A implementation supports task persistence. May return 'not found' for completed tasks.",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"workspace_id": {
|
||||
"type": "string",
|
||||
"description": "The workspace ID the task was sent to",
|
||||
},
|
||||
"task_id": {
|
||||
"type": "string",
|
||||
"description": "The task_id returned by delegate_task_async",
|
||||
},
|
||||
},
|
||||
"required": ["workspace_id", "task_id"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "list_peers",
|
||||
"description": "List all workspaces this agent can communicate with (siblings and parent/children). Returns name, ID, status, and role for each peer.",
|
||||
"inputSchema": {"type": "object", "properties": {}},
|
||||
},
|
||||
{
|
||||
"name": "get_workspace_info",
|
||||
"description": "Get this workspace's own info — ID, name, role, tier, parent, status.",
|
||||
"inputSchema": {"type": "object", "properties": {}},
|
||||
},
|
||||
{
|
||||
"name": "send_message_to_user",
|
||||
"description": "Send a message directly to the user's canvas chat — pushed instantly via WebSocket. Use this to: (1) acknowledge a task immediately ('Got it, I'll start working on this'), (2) send interim progress updates while doing long work, (3) deliver follow-up results after delegation completes. The message appears in the user's chat as if you're proactively reaching out.",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "The message to send to the user",
|
||||
},
|
||||
},
|
||||
"required": ["message"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "commit_memory",
|
||||
"description": "Save important information to persistent memory. Use this to remember decisions, conversation context, task results, and anything that should survive a restart. Scope: LOCAL (this workspace only), TEAM (parent + siblings), GLOBAL (entire org).",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The information to remember — be detailed and specific",
|
||||
},
|
||||
"scope": {
|
||||
"type": "string",
|
||||
"enum": ["LOCAL", "TEAM", "GLOBAL"],
|
||||
"description": "Memory scope (default: LOCAL)",
|
||||
},
|
||||
},
|
||||
"required": ["content"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "recall_memory",
|
||||
"description": "Search persistent memory for previously saved information. Returns all matching memories. Use this at the start of conversations to recall prior context.",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query (empty returns all memories)",
|
||||
},
|
||||
"scope": {
|
||||
"type": "string",
|
||||
"enum": ["LOCAL", "TEAM", "GLOBAL", ""],
|
||||
"description": "Filter by scope (empty returns all accessible)",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# --- Tool dispatch ---
|
||||
|
||||
async def handle_tool_call(name: str, arguments: dict) -> str:
|
||||
"""Handle a tool call and return the result as text."""
|
||||
if name == "delegate_task":
|
||||
return await tool_delegate_task(
|
||||
arguments.get("workspace_id", ""),
|
||||
arguments.get("task", ""),
|
||||
)
|
||||
elif name == "delegate_task_async":
|
||||
return await tool_delegate_task_async(
|
||||
arguments.get("workspace_id", ""),
|
||||
arguments.get("task", ""),
|
||||
)
|
||||
elif name == "check_task_status":
|
||||
return await tool_check_task_status(
|
||||
arguments.get("workspace_id", ""),
|
||||
arguments.get("task_id", ""),
|
||||
)
|
||||
elif name == "send_message_to_user":
|
||||
return await tool_send_message_to_user(arguments.get("message", ""))
|
||||
elif name == "list_peers":
|
||||
return await tool_list_peers()
|
||||
elif name == "get_workspace_info":
|
||||
return await tool_get_workspace_info()
|
||||
elif name == "commit_memory":
|
||||
return await tool_commit_memory(
|
||||
arguments.get("content", ""),
|
||||
arguments.get("scope", "LOCAL"),
|
||||
)
|
||||
elif name == "recall_memory":
|
||||
return await tool_recall_memory(
|
||||
arguments.get("query", ""),
|
||||
arguments.get("scope", ""),
|
||||
)
|
||||
return f"Unknown tool: {name}"
|
||||
|
||||
|
||||
# --- MCP Server (JSON-RPC over stdio) ---
|
||||
|
||||
async def main(): # pragma: no cover
|
||||
"""Run MCP server on stdio — reads JSON-RPC requests, writes responses."""
|
||||
reader = asyncio.StreamReader()
|
||||
protocol = asyncio.StreamReaderProtocol(reader)
|
||||
await asyncio.get_event_loop().connect_read_pipe(lambda: protocol, sys.stdin)
|
||||
|
||||
writer_transport, writer_protocol = await asyncio.get_event_loop().connect_write_pipe(
|
||||
asyncio.streams.FlowControlMixin, sys.stdout
|
||||
)
|
||||
writer = asyncio.StreamWriter(writer_transport, writer_protocol, None, asyncio.get_event_loop())
|
||||
|
||||
async def write_response(response: dict):
|
||||
data = json.dumps(response) + "\n"
|
||||
writer.write(data.encode())
|
||||
await writer.drain()
|
||||
|
||||
buffer = ""
|
||||
while True:
|
||||
try:
|
||||
chunk = await reader.read(65536)
|
||||
if not chunk:
|
||||
break
|
||||
buffer += chunk.decode(errors="replace")
|
||||
|
||||
while "\n" in buffer:
|
||||
line, buffer = buffer.split("\n", 1)
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
request = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
req_id = request.get("id")
|
||||
method = request.get("method", "")
|
||||
|
||||
if method == "initialize":
|
||||
await write_response({
|
||||
"jsonrpc": "2.0",
|
||||
"id": req_id,
|
||||
"result": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {"tools": {"listChanged": False}},
|
||||
"serverInfo": {"name": "a2a-delegation", "version": "1.0.0"},
|
||||
},
|
||||
})
|
||||
|
||||
elif method == "notifications/initialized":
|
||||
pass # No response needed
|
||||
|
||||
elif method == "tools/list":
|
||||
await write_response({
|
||||
"jsonrpc": "2.0",
|
||||
"id": req_id,
|
||||
"result": {"tools": TOOLS},
|
||||
})
|
||||
|
||||
elif method == "tools/call":
|
||||
params = request.get("params", {})
|
||||
tool_name = params.get("name", "")
|
||||
tool_args = params.get("arguments", {})
|
||||
result_text = await handle_tool_call(tool_name, tool_args)
|
||||
await write_response({
|
||||
"jsonrpc": "2.0",
|
||||
"id": req_id,
|
||||
"result": {
|
||||
"content": [{"type": "text", "text": result_text}],
|
||||
},
|
||||
})
|
||||
|
||||
else:
|
||||
await write_response({
|
||||
"jsonrpc": "2.0",
|
||||
"id": req_id,
|
||||
"error": {"code": -32601, "message": f"Method not found: {method}"},
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MCP server error: {e}")
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__": # pragma: no cover
|
||||
asyncio.run(main())
|
||||
269
molecule_runtime/a2a_tools.py
Normal file
269
molecule_runtime/a2a_tools.py
Normal file
@ -0,0 +1,269 @@
|
||||
"""A2A MCP tool implementations — the body of each tool handler.
|
||||
|
||||
Imports shared client functions and constants from a2a_client.
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
|
||||
import httpx
|
||||
|
||||
from a2a_client import (
|
||||
PLATFORM_URL,
|
||||
WORKSPACE_ID,
|
||||
_A2A_ERROR_PREFIX,
|
||||
_peer_names,
|
||||
discover_peer,
|
||||
get_peers,
|
||||
get_workspace_info,
|
||||
send_a2a_message,
|
||||
)
|
||||
|
||||
|
||||
def _auth_headers_for_heartbeat() -> dict[str, str]:
|
||||
"""Return Phase 30.1 auth headers; tolerate platform_auth being absent
|
||||
in older installs (e.g. during rolling upgrade)."""
|
||||
try:
|
||||
from platform_auth import auth_headers
|
||||
return auth_headers()
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
async def report_activity(
|
||||
activity_type: str, target_id: str = "", summary: str = "", status: str = "ok",
|
||||
task_text: str = "", response_text: str = "",
|
||||
):
|
||||
"""Report activity to the platform for live progress tracking."""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
payload: dict = {
|
||||
"activity_type": activity_type,
|
||||
"source_id": WORKSPACE_ID,
|
||||
"target_id": target_id,
|
||||
"method": "message/send",
|
||||
"summary": summary,
|
||||
"status": status,
|
||||
}
|
||||
if task_text:
|
||||
payload["request_body"] = {"task": task_text}
|
||||
if response_text:
|
||||
payload["response_body"] = {"result": response_text}
|
||||
await client.post(
|
||||
f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/activity",
|
||||
json=payload,
|
||||
headers=_auth_headers_for_heartbeat(),
|
||||
)
|
||||
# Also push current_task via heartbeat for canvas card display
|
||||
if summary:
|
||||
await client.post(
|
||||
f"{PLATFORM_URL}/registry/heartbeat",
|
||||
json={
|
||||
"workspace_id": WORKSPACE_ID,
|
||||
"current_task": summary,
|
||||
"active_tasks": 1,
|
||||
"error_rate": 0,
|
||||
"sample_error": "",
|
||||
"uptime_seconds": 0,
|
||||
},
|
||||
headers=_auth_headers_for_heartbeat(),
|
||||
)
|
||||
except Exception:
|
||||
pass # Best-effort — don't block delegation on activity reporting
|
||||
|
||||
|
||||
async def tool_delegate_task(workspace_id: str, task: str) -> str:
|
||||
"""Delegate a task to another workspace via A2A (synchronous — waits for response)."""
|
||||
if not workspace_id or not task:
|
||||
return "Error: workspace_id and task are required"
|
||||
|
||||
# Discover the target
|
||||
peer = await discover_peer(workspace_id)
|
||||
if not peer:
|
||||
return f"Error: workspace {workspace_id} not found or not accessible (check access control)"
|
||||
|
||||
target_url = peer.get("url", "")
|
||||
if not target_url:
|
||||
return f"Error: workspace {workspace_id} has no URL (may be offline)"
|
||||
|
||||
# Report delegation start — include the task text for traceability
|
||||
peer_name = peer.get("name") or _peer_names.get(workspace_id) or workspace_id[:8]
|
||||
_peer_names[workspace_id] = peer_name # cache for future use
|
||||
# Brief summary for canvas display — just the delegation target
|
||||
await report_activity("a2a_send", workspace_id, f"Delegating to {peer_name}", task_text=task)
|
||||
|
||||
# Send A2A message and log the full round-trip
|
||||
result = await send_a2a_message(target_url, task)
|
||||
|
||||
# Detect delegation failures — wrap them clearly so the calling agent
|
||||
# can decide to retry, use another peer, or handle the task itself.
|
||||
is_error = result.startswith(_A2A_ERROR_PREFIX)
|
||||
await report_activity(
|
||||
"a2a_receive", workspace_id,
|
||||
f"{peer_name} responded ({len(result)} chars)" if not is_error else f"{peer_name} failed",
|
||||
task_text=task, response_text=result,
|
||||
status="error" if is_error else "ok",
|
||||
)
|
||||
if is_error:
|
||||
return (
|
||||
f"DELEGATION FAILED to {peer_name}: {result}\n"
|
||||
f"You should either: (1) try a different peer, (2) handle this task yourself, "
|
||||
f"or (3) inform the user that {peer_name} is unavailable and provide your best answer."
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
async def tool_delegate_task_async(workspace_id: str, task: str) -> str:
|
||||
"""Delegate a task via the platform's async delegation API (fire-and-forget).
|
||||
|
||||
Uses POST /workspaces/:id/delegate which runs the A2A request in the background.
|
||||
Results are tracked in the platform DB and broadcast via WebSocket.
|
||||
Use check_task_status to poll for results.
|
||||
"""
|
||||
if not workspace_id or not task:
|
||||
return "Error: workspace_id and task are required"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.post(
|
||||
f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/delegate",
|
||||
json={"target_id": workspace_id, "task": task},
|
||||
headers=_auth_headers_for_heartbeat(),
|
||||
)
|
||||
if resp.status_code == 202:
|
||||
data = resp.json()
|
||||
return json.dumps({
|
||||
"delegation_id": data.get("delegation_id", ""),
|
||||
"workspace_id": workspace_id,
|
||||
"status": "delegated",
|
||||
"note": "Task delegated. The platform runs it in the background. Use check_task_status to poll for results.",
|
||||
})
|
||||
else:
|
||||
return f"Error: delegation failed with status {resp.status_code}: {resp.text[:200]}"
|
||||
except Exception as e:
|
||||
return f"Error: delegation failed — {e}"
|
||||
|
||||
|
||||
async def tool_check_task_status(workspace_id: str, task_id: str) -> str:
|
||||
"""Check delegations for this workspace via the platform API.
|
||||
|
||||
Args:
|
||||
workspace_id: Ignored (kept for backward compat). Checks this workspace's delegations.
|
||||
task_id: Optional delegation_id to filter. If empty, returns all recent delegations.
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.get(
|
||||
f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/delegations",
|
||||
headers=_auth_headers_for_heartbeat(),
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return f"Error: failed to check delegations ({resp.status_code})"
|
||||
delegations = resp.json()
|
||||
if task_id:
|
||||
# Filter by delegation_id
|
||||
matching = [d for d in delegations if d.get("delegation_id") == task_id]
|
||||
if matching:
|
||||
return json.dumps(matching[0])
|
||||
return json.dumps({"status": "not_found", "delegation_id": task_id})
|
||||
# Return all recent delegations
|
||||
summary = []
|
||||
for d in delegations[:10]:
|
||||
summary.append({
|
||||
"delegation_id": d.get("delegation_id", ""),
|
||||
"target_id": d.get("target_id", ""),
|
||||
"status": d.get("status", ""),
|
||||
"summary": d.get("summary", ""),
|
||||
"response_preview": d.get("response_preview", ""),
|
||||
})
|
||||
return json.dumps({"delegations": summary, "count": len(delegations)})
|
||||
except Exception as e:
|
||||
return f"Error checking delegations: {e}"
|
||||
|
||||
|
||||
async def tool_send_message_to_user(message: str) -> str:
|
||||
"""Send a message directly to the user's canvas chat via WebSocket."""
|
||||
if not message:
|
||||
return "Error: message is required"
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
resp = await client.post(
|
||||
f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/notify",
|
||||
json={"message": message},
|
||||
headers=_auth_headers_for_heartbeat(),
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
return "Message sent to user"
|
||||
return f"Error: platform returned {resp.status_code}"
|
||||
except Exception as e:
|
||||
return f"Error sending message: {e}"
|
||||
|
||||
|
||||
async def tool_list_peers() -> str:
|
||||
"""List all workspaces this agent can communicate with."""
|
||||
peers = await get_peers()
|
||||
if not peers:
|
||||
return "No peers available (this workspace may be isolated)"
|
||||
lines = []
|
||||
for p in peers:
|
||||
status = p.get("status", "unknown")
|
||||
role = p.get("role", "")
|
||||
# Cache name for use in delegate_task
|
||||
_peer_names[p["id"]] = p["name"]
|
||||
lines.append(f"- {p['name']} (ID: {p['id']}, status: {status}, role: {role})")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
async def tool_get_workspace_info() -> str:
|
||||
"""Get this workspace's own info."""
|
||||
info = await get_workspace_info()
|
||||
return json.dumps(info, indent=2)
|
||||
|
||||
|
||||
async def tool_commit_memory(content: str, scope: str = "LOCAL") -> str:
|
||||
"""Save important information to persistent memory."""
|
||||
if not content:
|
||||
return "Error: content is required"
|
||||
scope = scope.upper()
|
||||
if scope not in ("LOCAL", "TEAM", "GLOBAL"):
|
||||
scope = "LOCAL"
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.post(
|
||||
f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/memories",
|
||||
json={"content": content, "scope": scope},
|
||||
headers=_auth_headers_for_heartbeat(),
|
||||
)
|
||||
data = resp.json()
|
||||
if resp.status_code in (200, 201):
|
||||
return json.dumps({"success": True, "id": data.get("id"), "scope": scope})
|
||||
return f"Error: {data.get('error', resp.text)}"
|
||||
except Exception as e:
|
||||
return f"Error saving memory: {e}"
|
||||
|
||||
|
||||
async def tool_recall_memory(query: str = "", scope: str = "") -> str:
|
||||
"""Search persistent memory for previously saved information."""
|
||||
params = {}
|
||||
if query:
|
||||
params["q"] = query
|
||||
if scope:
|
||||
params["scope"] = scope.upper()
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.get(
|
||||
f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/memories",
|
||||
params=params,
|
||||
headers=_auth_headers_for_heartbeat(),
|
||||
)
|
||||
data = resp.json()
|
||||
if isinstance(data, list):
|
||||
if not data:
|
||||
return "No memories found."
|
||||
lines = []
|
||||
for m in data:
|
||||
lines.append(f"[{m.get('scope', '?')}] {m.get('content', '')}")
|
||||
return "\n".join(lines)
|
||||
return json.dumps(data)
|
||||
except Exception as e:
|
||||
return f"Error recalling memory: {e}"
|
||||
86
molecule_runtime/adapters/__init__.py
Normal file
86
molecule_runtime/adapters/__init__.py
Normal file
@ -0,0 +1,86 @@
|
||||
"""Adapter registry — discovers and loads agent infrastructure adapters."""
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
from .base import BaseAdapter, AdapterConfig, SetupResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ADAPTER_CACHE: dict[str, type[BaseAdapter]] = {}
|
||||
|
||||
|
||||
def discover_adapters() -> dict[str, type[BaseAdapter]]:
|
||||
"""Scan subdirectories for adapter modules. Each must export an Adapter class.
|
||||
|
||||
This is used for local development inside the monorepo where adapters
|
||||
live as subdirectories. In standalone adapter repos, use ADAPTER_MODULE
|
||||
env var instead.
|
||||
"""
|
||||
if _ADAPTER_CACHE:
|
||||
return _ADAPTER_CACHE
|
||||
|
||||
from pathlib import Path
|
||||
adapters_dir = Path(__file__).parent
|
||||
for entry in sorted(adapters_dir.iterdir()):
|
||||
if not entry.is_dir() or entry.name.startswith("_"):
|
||||
continue
|
||||
try:
|
||||
mod = importlib.import_module(f"molecule_runtime.adapters.{entry.name}")
|
||||
adapter_cls = getattr(mod, "Adapter", None)
|
||||
if adapter_cls and issubclass(adapter_cls, BaseAdapter):
|
||||
_ADAPTER_CACHE[adapter_cls.name()] = adapter_cls
|
||||
logger.debug(f"Loaded adapter: {adapter_cls.name()} ({adapter_cls.display_name()})")
|
||||
except Exception as e:
|
||||
# Log but don't crash — adapter may have uninstalled deps
|
||||
logger.debug(f"Skipped adapter {entry.name}: {e}")
|
||||
|
||||
return _ADAPTER_CACHE
|
||||
|
||||
|
||||
def get_adapter(runtime: str) -> type[BaseAdapter]:
|
||||
"""Get adapter class by runtime name.
|
||||
|
||||
Resolution order:
|
||||
1. ADAPTER_MODULE env var — used by standalone adapter repos to register
|
||||
their adapter without modifying the runtime package.
|
||||
2. Built-in discovery — scans subdirectories (for local monorepo dev).
|
||||
|
||||
Raises KeyError if the adapter cannot be found.
|
||||
"""
|
||||
# First check env override (standalone adapter repos set this)
|
||||
adapter_module = os.environ.get("ADAPTER_MODULE")
|
||||
if adapter_module:
|
||||
try:
|
||||
mod = importlib.import_module(adapter_module)
|
||||
cls = getattr(mod, "Adapter")
|
||||
if cls and issubclass(cls, BaseAdapter):
|
||||
return cls
|
||||
except Exception as e:
|
||||
raise KeyError(
|
||||
f"ADAPTER_MODULE={adapter_module!r} could not be loaded: {e}"
|
||||
) from e
|
||||
|
||||
# Fall back to built-in discovery (for local dev / monorepo)
|
||||
adapters = discover_adapters()
|
||||
if runtime not in adapters:
|
||||
available = ", ".join(sorted(adapters.keys()))
|
||||
raise KeyError(f"Unknown runtime '{runtime}'. Available: {available}")
|
||||
return adapters[runtime]
|
||||
|
||||
|
||||
def list_adapters() -> list[dict]:
|
||||
"""Return metadata for all discovered adapters (for API/UI)."""
|
||||
adapters = discover_adapters()
|
||||
return [
|
||||
{
|
||||
"name": cls.name(),
|
||||
"display_name": cls.display_name(),
|
||||
"description": cls.description(),
|
||||
"config_schema": cls.get_config_schema(),
|
||||
}
|
||||
for cls in adapters.values()
|
||||
]
|
||||
|
||||
|
||||
__all__ = ["BaseAdapter", "AdapterConfig", "SetupResult", "get_adapter", "list_adapters", "discover_adapters"]
|
||||
309
molecule_runtime/adapters/base.py
Normal file
309
molecule_runtime/adapters/base.py
Normal file
@ -0,0 +1,309 @@
|
||||
"""Base adapter interface for agent infrastructure providers."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from a2a.server.agent_execution import AgentExecutor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SetupResult:
|
||||
"""Result from the shared _common_setup() pipeline."""
|
||||
system_prompt: str
|
||||
loaded_skills: list # LoadedSkill instances
|
||||
langchain_tools: list # LangChain BaseTool instances
|
||||
is_coordinator: bool
|
||||
children: list # child workspace dicts
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdapterConfig:
|
||||
"""Standardized config passed to every adapter."""
|
||||
model: str # e.g. "anthropic:claude-sonnet-4-6" or "openrouter:google/gemini-2.5-flash"
|
||||
system_prompt: str | None = None # Assembled system prompt text
|
||||
tools: list[str] = field(default_factory=list) # Tool names from config.yaml
|
||||
runtime_config: dict[str, Any] = field(default_factory=dict) # Raw runtime_config block
|
||||
config_path: str = "/configs" # Path to configs directory
|
||||
workspace_id: str = "" # Workspace identifier
|
||||
prompt_files: list[str] = field(default_factory=list) # Ordered prompt file names
|
||||
a2a_port: int = 8000 # Port for A2A server
|
||||
heartbeat: Any = None # HeartbeatLoop instance
|
||||
|
||||
|
||||
class BaseAdapter(ABC):
|
||||
"""Interface every agent infrastructure adapter must implement.
|
||||
|
||||
To add a new agent infra:
|
||||
1. Create workspace-template/adapters/<your_infra>/
|
||||
2. Implement adapter.py with a class extending BaseAdapter
|
||||
3. Add requirements.txt with your infra's dependencies
|
||||
4. Export as Adapter in __init__.py
|
||||
5. Submit a PR
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def name() -> str: # pragma: no cover
|
||||
"""Return the runtime identifier (e.g. 'langgraph', 'crewai').
|
||||
This must match the 'runtime' field in config.yaml."""
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def display_name() -> str: # pragma: no cover
|
||||
"""Human-readable name for UI display."""
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def description() -> str: # pragma: no cover
|
||||
"""Short description of what this adapter provides."""
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
def get_config_schema() -> dict:
|
||||
"""Return JSON Schema for runtime_config fields this adapter supports.
|
||||
Used by the Config tab UI to render the right form fields.
|
||||
Override in subclasses for adapter-specific settings."""
|
||||
return {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Plugin install hooks
|
||||
# ------------------------------------------------------------------
|
||||
# New pipeline: each plugin ships per-runtime adaptors resolved via
|
||||
# `plugins_registry.resolve()`. Adapters expose hooks below that
|
||||
# adaptors call to wire plugin content into the runtime.
|
||||
#
|
||||
# Default implementations are filesystem-only (write to /configs,
|
||||
# append to CLAUDE.md). Runtimes with a dynamic tool registry
|
||||
# (e.g. DeepAgents sub-agents) override the hooks to also register
|
||||
# in-process state.
|
||||
|
||||
def memory_filename(self) -> str:
|
||||
"""File under /configs that the runtime treats as long-lived memory.
|
||||
|
||||
Both Claude Code and DeepAgents read CLAUDE.md natively, so this is
|
||||
the sensible default. Override only if a runtime expects a different
|
||||
filename.
|
||||
"""
|
||||
return "CLAUDE.md"
|
||||
|
||||
def register_tool_hook(self, name: str, fn) -> None:
|
||||
"""Default no-op. Override on runtimes with a dynamic tool registry.
|
||||
|
||||
Runtimes that pick tools up at startup via filesystem scan (Claude
|
||||
Code reads /configs/skills, LangGraph globs **/*.py) don't need to
|
||||
do anything here — the adaptor's file-write step is enough.
|
||||
"""
|
||||
return None
|
||||
|
||||
async def transcript_lines(self, since: int = 0, limit: int = 100) -> dict:
|
||||
"""Return live transcript entries for the most-recent agent session.
|
||||
|
||||
Default implementation returns ``supported: False`` for runtimes
|
||||
that don't expose a per-session log on disk. Override in subclasses
|
||||
that DO (Claude Code reads ``~/.claude/projects/<cwd>/<session>.jsonl``).
|
||||
|
||||
This is the "look over the agent's shoulder" feature — lets canvas /
|
||||
operators see live tool calls + AI thinking instead of waiting for
|
||||
the high-level activity log to flush.
|
||||
|
||||
Args:
|
||||
since: line offset to skip — caller's last cursor (0 = from start)
|
||||
limit: max lines to return (caller-side cap, default 100, max 1000)
|
||||
|
||||
Returns:
|
||||
``{runtime, supported, lines, cursor, more, source}`` where
|
||||
``cursor`` is the new offset to pass on the next poll, ``more``
|
||||
is True if additional lines remain past ``limit``, and ``source``
|
||||
is the file path lines were read from (useful for debugging).
|
||||
"""
|
||||
return {
|
||||
"runtime": self.name(),
|
||||
"supported": False,
|
||||
"lines": [],
|
||||
"cursor": since,
|
||||
"more": False,
|
||||
"source": None,
|
||||
}
|
||||
|
||||
def register_subagent_hook(self, name: str, spec: dict) -> None:
|
||||
"""Default no-op. DeepAgents overrides to register a sub-agent."""
|
||||
return None
|
||||
|
||||
def append_to_memory_hook(self, config: AdapterConfig, filename: str, content: str) -> None:
|
||||
"""Append text to /configs/<filename> if the marker isn't already present.
|
||||
|
||||
Idempotent: looks for the first line of `content` as a marker so a
|
||||
re-install doesn't duplicate the block. Adaptors should pass content
|
||||
beginning with a unique header (e.g. ``# Plugin: molecule-dev-conventions``).
|
||||
"""
|
||||
import os
|
||||
target = os.path.join(config.config_path, filename)
|
||||
marker = content.splitlines()[0].strip() if content else ""
|
||||
existing = ""
|
||||
if os.path.exists(target):
|
||||
with open(target) as f:
|
||||
existing = f.read()
|
||||
if marker and marker in existing:
|
||||
logger.info("append_to_memory: %s already contains %r — skipping", filename, marker)
|
||||
return
|
||||
os.makedirs(os.path.dirname(target) or ".", exist_ok=True)
|
||||
with open(target, "a") as f:
|
||||
if existing and not existing.endswith("\n"):
|
||||
f.write("\n")
|
||||
f.write(content if content.endswith("\n") else content + "\n")
|
||||
logger.info("append_to_memory: appended %d chars to %s", len(content), filename)
|
||||
|
||||
async def install_plugins_via_registry(
|
||||
self,
|
||||
config: AdapterConfig,
|
||||
plugins,
|
||||
) -> list:
|
||||
"""Drive the new per-runtime adaptor pipeline for every loaded plugin.
|
||||
|
||||
For each plugin in `plugins.plugins`, resolve the adaptor for this
|
||||
runtime (via :func:`plugins_registry.resolve`) and invoke
|
||||
``install(ctx)``. Returns the list of :class:`InstallResult` so
|
||||
callers can surface warnings (e.g. raw-drop fallback hits).
|
||||
|
||||
Adapters whose runtime supports the new pipeline call this from
|
||||
``setup()`` instead of the legacy ``inject_plugins()``.
|
||||
"""
|
||||
from pathlib import Path
|
||||
from plugins_registry import InstallContext, resolve
|
||||
|
||||
results = []
|
||||
runtime = self.name().replace("-", "_") # e.g. "claude-code" -> "claude_code"
|
||||
|
||||
for plugin in plugins.plugins:
|
||||
adaptor, source = resolve(plugin.name, runtime, Path(plugin.path))
|
||||
ctx = InstallContext(
|
||||
configs_dir=Path(config.config_path),
|
||||
workspace_id=config.workspace_id,
|
||||
runtime=runtime,
|
||||
plugin_root=Path(plugin.path),
|
||||
memory_filename=self.memory_filename(),
|
||||
register_tool=self.register_tool_hook,
|
||||
register_subagent=self.register_subagent_hook,
|
||||
append_to_memory=lambda fn, c, _cfg=config: self.append_to_memory_hook(_cfg, fn, c),
|
||||
)
|
||||
try:
|
||||
result = await adaptor.install(ctx)
|
||||
results.append(result)
|
||||
logger.info(
|
||||
"Plugin %s installed via %s adaptor (warnings: %d)",
|
||||
plugin.name, source, len(result.warnings),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.exception("Plugin %s install via %s failed: %s", plugin.name, source, exc)
|
||||
|
||||
return results
|
||||
|
||||
async def inject_plugins(self, config: AdapterConfig, plugins) -> None:
|
||||
"""Legacy hook — kept for backwards compatibility during migration.
|
||||
|
||||
Default: drive the new per-runtime adaptor pipeline. Adapters not yet
|
||||
migrated may still override this with their own logic.
|
||||
"""
|
||||
await self.install_plugins_via_registry(config, plugins)
|
||||
|
||||
async def _common_setup(self, config: AdapterConfig) -> SetupResult:
|
||||
"""Shared setup pipeline — loads plugins, skills, tools, coordinator, and builds system prompt.
|
||||
|
||||
All adapters can call this to get the full platform feature set.
|
||||
Returns a SetupResult with LangChain BaseTool instances that adapters
|
||||
convert to their native format if needed.
|
||||
"""
|
||||
from plugins import load_plugins
|
||||
from skill_loader.loader import load_skills
|
||||
from coordinator import get_children, get_parent_context, build_children_description
|
||||
from prompt import build_system_prompt, get_peer_capabilities
|
||||
from builtin_tools.approval import request_approval
|
||||
from builtin_tools.delegation import delegate_to_workspace, check_delegation_status
|
||||
from builtin_tools.memory import commit_memory, search_memory
|
||||
from builtin_tools.sandbox import run_code
|
||||
|
||||
platform_url = os.environ.get("PLATFORM_URL", "http://platform:8080")
|
||||
|
||||
# Load plugins from per-workspace dir first, then shared fallback
|
||||
workspace_plugins_dir = os.path.join(config.config_path, "plugins")
|
||||
plugins = load_plugins(
|
||||
workspace_plugins_dir=workspace_plugins_dir,
|
||||
shared_plugins_dir=os.environ.get("PLUGINS_DIR", "/plugins"),
|
||||
)
|
||||
await self.inject_plugins(config, plugins)
|
||||
if plugins.plugin_names:
|
||||
logger.info(f"Plugins: {', '.join(plugins.plugin_names)}")
|
||||
|
||||
# Load skills (workspace + plugin skills, deduped)
|
||||
loaded_skills = load_skills(config.config_path, config.tools)
|
||||
seen_skill_ids = {s.metadata.id for s in loaded_skills}
|
||||
for plugin_skills_dir in plugins.skill_dirs:
|
||||
plugin_skill_names = [
|
||||
d for d in os.listdir(plugin_skills_dir)
|
||||
if os.path.isdir(os.path.join(plugin_skills_dir, d))
|
||||
]
|
||||
for skill in load_skills(plugin_skills_dir, plugin_skill_names):
|
||||
if skill.metadata.id not in seen_skill_ids:
|
||||
loaded_skills.append(skill)
|
||||
seen_skill_ids.add(skill.metadata.id)
|
||||
logger.info(f"Loaded {len(loaded_skills)} skills: {[s.metadata.id for s in loaded_skills]}")
|
||||
|
||||
# Assemble tools: 6 core + skill tools
|
||||
all_tools = [delegate_to_workspace, check_delegation_status, request_approval, commit_memory, search_memory, run_code]
|
||||
for skill in loaded_skills:
|
||||
all_tools.extend(skill.tools)
|
||||
|
||||
# Coordinator mode: detect children and add routing tool
|
||||
children = await get_children()
|
||||
is_coordinator = len(children) > 0
|
||||
if is_coordinator:
|
||||
from coordinator import route_task_to_team
|
||||
logger.info(f"Coordinator mode: {len(children)} children")
|
||||
all_tools.append(route_task_to_team)
|
||||
|
||||
# Parent context (if this is a child workspace)
|
||||
parent_context = await get_parent_context()
|
||||
|
||||
# Build system prompt with all context
|
||||
peers = await get_peer_capabilities(platform_url, config.workspace_id)
|
||||
coordinator_prompt = build_children_description(children) if is_coordinator else ""
|
||||
extra_prompts = list(plugins.prompt_fragments)
|
||||
if coordinator_prompt:
|
||||
extra_prompts.append(coordinator_prompt)
|
||||
|
||||
system_prompt = build_system_prompt(
|
||||
config.config_path, config.workspace_id, loaded_skills, peers,
|
||||
prompt_files=config.prompt_files,
|
||||
plugin_rules=plugins.rules,
|
||||
plugin_prompts=extra_prompts,
|
||||
parent_context=parent_context,
|
||||
)
|
||||
|
||||
return SetupResult(
|
||||
system_prompt=system_prompt,
|
||||
loaded_skills=loaded_skills,
|
||||
langchain_tools=all_tools,
|
||||
is_coordinator=is_coordinator,
|
||||
children=children,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
async def setup(self, config: AdapterConfig) -> None:
|
||||
"""One-time setup: validate config, prepare internal state.
|
||||
Called after deps are installed but before create_executor().
|
||||
Raise RuntimeError if setup fails (missing deps, bad config, etc.)."""
|
||||
... # pragma: no cover
|
||||
|
||||
@abstractmethod
|
||||
async def create_executor(self, config: AdapterConfig) -> AgentExecutor:
|
||||
"""Create and return an AgentExecutor ready for A2A integration.
|
||||
The returned executor's execute() method will be called by the
|
||||
A2A server's DefaultRequestHandler."""
|
||||
... # pragma: no cover
|
||||
190
molecule_runtime/adapters/shared_runtime.py
Normal file
190
molecule_runtime/adapters/shared_runtime.py
Normal file
@ -0,0 +1,190 @@
|
||||
"""Shared runtime helpers for A2A-backed workspace executors."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from a2a.server.agent_execution import RequestContext
|
||||
|
||||
|
||||
def _extract_part_text(part) -> str:
|
||||
"""Extract text from a message part, handling dicts and A2A objects."""
|
||||
if isinstance(part, dict):
|
||||
text = part.get("text", "")
|
||||
if text:
|
||||
return text
|
||||
root = part.get("root")
|
||||
if isinstance(root, dict):
|
||||
return root.get("text", "")
|
||||
return ""
|
||||
if hasattr(part, "text") and part.text:
|
||||
return part.text
|
||||
if hasattr(part, "root") and hasattr(part.root, "text") and part.root.text:
|
||||
return part.root.text
|
||||
return ""
|
||||
|
||||
|
||||
def extract_message_text(context_or_parts) -> str:
|
||||
"""Extract concatenated plain text from A2A message parts."""
|
||||
parts = getattr(getattr(context_or_parts, "message", None), "parts", None)
|
||||
if parts is None:
|
||||
parts = context_or_parts
|
||||
return " ".join(
|
||||
text for part in (parts or []) if (text := _extract_part_text(part))
|
||||
).strip()
|
||||
|
||||
|
||||
def extract_history(context: RequestContext) -> list[tuple[str, str]]:
|
||||
"""Extract conversation history from A2A request metadata."""
|
||||
messages: list[tuple[str, str]] = []
|
||||
request = getattr(context, "request", None)
|
||||
metadata = getattr(request, "metadata", None) if request else None
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = getattr(context, "metadata", None) or {}
|
||||
history = metadata.get("history", []) if isinstance(metadata, dict) else []
|
||||
if not isinstance(history, list):
|
||||
return messages
|
||||
|
||||
for entry in history:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
role = entry.get("role", "user")
|
||||
parts = entry.get("parts", [])
|
||||
text = " ".join(
|
||||
text for part in (parts or []) if (text := _extract_part_text(part))
|
||||
).strip()
|
||||
if text:
|
||||
mapped_role = "human" if role == "user" else "ai"
|
||||
messages.append((mapped_role, text))
|
||||
return messages
|
||||
|
||||
|
||||
def format_conversation_history(history: list[tuple[str, str]]) -> str:
|
||||
"""Render `(role, text)` history into a stable human-readable transcript."""
|
||||
return "\n".join(
|
||||
f"{'User' if role == 'human' else 'Agent'}: {text}" for role, text in history
|
||||
)
|
||||
|
||||
|
||||
def build_task_text(user_message: str, history: list[tuple[str, str]]) -> str:
|
||||
"""Build a single task/request string with optional prepended conversation history."""
|
||||
if not history:
|
||||
return user_message
|
||||
transcript = format_conversation_history(history)
|
||||
return f"Conversation so far:\n{transcript}\n\nCurrent request: {user_message}"
|
||||
|
||||
|
||||
def append_peer_guidance(
|
||||
base_text: str | None,
|
||||
peers_info: str,
|
||||
*,
|
||||
default_text: str,
|
||||
tool_name: str,
|
||||
) -> str:
|
||||
"""Append peer guidance text when peers are available."""
|
||||
text = (base_text or default_text).strip()
|
||||
if peers_info:
|
||||
text += f"\n\n## Peers\n{peers_info}\nUse {tool_name} to communicate with them."
|
||||
return text
|
||||
|
||||
|
||||
def summarize_peer_cards(peers: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Return compact peer metadata for prompt rendering."""
|
||||
summaries: list[dict[str, Any]] = []
|
||||
for peer in peers:
|
||||
agent_card = peer.get("agent_card")
|
||||
if not agent_card:
|
||||
continue
|
||||
if isinstance(agent_card, str):
|
||||
try:
|
||||
import json
|
||||
|
||||
agent_card = json.loads(agent_card)
|
||||
except Exception:
|
||||
continue
|
||||
if not isinstance(agent_card, dict):
|
||||
continue
|
||||
|
||||
skills = agent_card.get("skills", [])
|
||||
summaries.append(
|
||||
{
|
||||
"id": peer.get("id", "unknown"),
|
||||
"name": agent_card.get("name", peer.get("name", "Unknown")),
|
||||
"status": peer.get("status", "unknown"),
|
||||
"skills": [
|
||||
s.get("name", s.get("id", ""))
|
||||
for s in skills
|
||||
if isinstance(s, dict)
|
||||
],
|
||||
}
|
||||
)
|
||||
return summaries
|
||||
|
||||
|
||||
def build_peer_section(
|
||||
peers: list[dict[str, Any]],
|
||||
*,
|
||||
heading: str = "## Your Peers (workspaces you can delegate to)",
|
||||
instruction: str = (
|
||||
"Use the `delegate_to_workspace` tool to send tasks to peers. "
|
||||
"Only delegate to peers listed above."
|
||||
),
|
||||
) -> str:
|
||||
"""Render a stable peer section for system prompts."""
|
||||
summaries = summarize_peer_cards(peers)
|
||||
if not summaries:
|
||||
return ""
|
||||
|
||||
parts = [heading, ""]
|
||||
for peer in summaries:
|
||||
parts.append(f"- **{peer['name']}** (id: `{peer['id']}`, status: {peer['status']})")
|
||||
if peer["skills"]:
|
||||
parts.append(f" Skills: {', '.join(peer['skills'])}")
|
||||
parts.append("")
|
||||
parts.append(instruction)
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def brief_task(text: str, limit: int = 60) -> str:
|
||||
"""Create a short human-readable task label for the heartbeat banner."""
|
||||
return text[:limit] + ("..." if len(text) > limit else "")
|
||||
|
||||
|
||||
async def set_current_task(heartbeat: Any, task: str) -> None:
|
||||
"""Update current task on heartbeat and push immediately to platform.
|
||||
|
||||
The heartbeat loop only fires every 30s, so quick tasks would finish
|
||||
before the canvas ever sees them. Setting a task pushes immediately.
|
||||
Clearing a task only updates the heartbeat object — the next heartbeat
|
||||
cycle will broadcast the clear, keeping the task visible longer.
|
||||
"""
|
||||
if heartbeat:
|
||||
heartbeat.current_task = task
|
||||
heartbeat.active_tasks = 1 if task else 0
|
||||
|
||||
# Only push immediately when SETTING a task (not clearing)
|
||||
# Clearing is handled by the next heartbeat cycle, which keeps
|
||||
# the task visible on the canvas for quick A2A responses
|
||||
if not task:
|
||||
return
|
||||
|
||||
import os
|
||||
workspace_id = os.environ.get("WORKSPACE_ID", "")
|
||||
platform_url = os.environ.get("PLATFORM_URL", "")
|
||||
if workspace_id and platform_url:
|
||||
try:
|
||||
import httpx
|
||||
async with httpx.AsyncClient(timeout=3.0) as client:
|
||||
await client.post(
|
||||
f"{platform_url}/registry/heartbeat",
|
||||
json={
|
||||
"workspace_id": workspace_id,
|
||||
"current_task": task,
|
||||
"active_tasks": 1,
|
||||
"error_rate": 0,
|
||||
"sample_error": "",
|
||||
"uptime_seconds": 0,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
pass # Best-effort
|
||||
133
molecule_runtime/agent.py
Normal file
133
molecule_runtime/agent.py
Normal file
@ -0,0 +1,133 @@
|
||||
"""Create the Deep Agent with model + skills + tools."""
|
||||
|
||||
import os
|
||||
import logging
|
||||
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_agent(model_str: str, tools: list, system_prompt: str):
|
||||
"""Create a LangGraph ReAct agent.
|
||||
|
||||
Args:
|
||||
model_str: LangChain-compatible model string (e.g., 'anthropic:claude-sonnet-4-6')
|
||||
tools: List of tool functions
|
||||
system_prompt: The system prompt for the agent
|
||||
"""
|
||||
# Parse provider:model format
|
||||
if ":" in model_str:
|
||||
provider, model_name = model_str.split(":", 1)
|
||||
else:
|
||||
provider = "anthropic"
|
||||
model_name = model_str
|
||||
|
||||
# Import the provider package
|
||||
try:
|
||||
if provider in ("anthropic",):
|
||||
from langchain_anthropic import ChatAnthropic as LLMClass
|
||||
elif provider in ("openai", "openrouter", "groq", "cerebras", "qianfan"):
|
||||
from langchain_openai import ChatOpenAI as LLMClass
|
||||
elif provider == "google_genai":
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI as LLMClass
|
||||
elif provider == "ollama":
|
||||
from langchain_ollama import ChatOllama as LLMClass
|
||||
else:
|
||||
raise ValueError(f"Unsupported model provider: {provider}")
|
||||
except ImportError as e:
|
||||
pkg = "langchain-openai" if provider == "openrouter" else f"langchain-{provider}"
|
||||
raise ImportError(f"Provider '{provider}' requires package '{pkg}'. Install: pip install {pkg}") from e
|
||||
|
||||
# Instantiate the LLM
|
||||
if provider == "anthropic":
|
||||
llm_kwargs = {"model": model_name}
|
||||
anthropic_base_url = os.environ.get("ANTHROPIC_BASE_URL", "")
|
||||
if anthropic_base_url:
|
||||
llm_kwargs["anthropic_api_url"] = anthropic_base_url
|
||||
llm = LLMClass(**llm_kwargs)
|
||||
elif provider == "openrouter":
|
||||
api_key = os.environ.get("OPENROUTER_API_KEY", os.environ.get("OPENAI_API_KEY", ""))
|
||||
max_tokens = int(os.environ.get("MAX_TOKENS", "2048"))
|
||||
llm = LLMClass(
|
||||
model=model_name,
|
||||
openai_api_key=api_key,
|
||||
openai_api_base="https://openrouter.ai/api/v1",
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
elif provider == "groq":
|
||||
api_key = os.environ.get("GROQ_API_KEY", "")
|
||||
llm = LLMClass(
|
||||
model=model_name,
|
||||
openai_api_key=api_key,
|
||||
openai_api_base="https://api.groq.com/openai/v1",
|
||||
)
|
||||
elif provider == "cerebras":
|
||||
api_key = os.environ.get("CEREBRAS_API_KEY", "")
|
||||
llm = LLMClass(
|
||||
model=model_name,
|
||||
openai_api_key=api_key,
|
||||
openai_api_base="https://api.cerebras.ai/v1",
|
||||
)
|
||||
elif provider == "qianfan":
|
||||
api_key = os.environ.get("QIANFAN_API_KEY", os.environ.get("AISTUDIO_API_KEY", ""))
|
||||
llm = LLMClass(
|
||||
model=model_name,
|
||||
openai_api_key=api_key,
|
||||
openai_api_base="https://qianfan.baidubce.com/v2",
|
||||
)
|
||||
elif provider == "openai":
|
||||
llm_kwargs = {"model": model_name}
|
||||
openai_base_url = os.environ.get("OPENAI_BASE_URL", "")
|
||||
if openai_base_url:
|
||||
llm_kwargs["openai_api_base"] = openai_base_url
|
||||
llm = LLMClass(**llm_kwargs)
|
||||
else:
|
||||
llm = LLMClass(model=model_name)
|
||||
|
||||
# Auto-inject Langfuse tracing if env vars are present
|
||||
callbacks = _setup_langfuse()
|
||||
if callbacks:
|
||||
llm.callbacks = callbacks
|
||||
|
||||
agent = create_react_agent(
|
||||
model=llm,
|
||||
tools=tools,
|
||||
prompt=system_prompt,
|
||||
)
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def _setup_langfuse():
|
||||
"""Set up Langfuse tracing if LANGFUSE_* env vars are present.
|
||||
|
||||
Returns list of callbacks to pass to agent invocations, or empty list.
|
||||
"""
|
||||
langfuse_host = os.environ.get("LANGFUSE_HOST")
|
||||
langfuse_public = os.environ.get("LANGFUSE_PUBLIC_KEY")
|
||||
langfuse_secret = os.environ.get("LANGFUSE_SECRET_KEY")
|
||||
|
||||
if not (langfuse_host and langfuse_public and langfuse_secret):
|
||||
return []
|
||||
|
||||
try:
|
||||
from langfuse.callback import CallbackHandler
|
||||
|
||||
handler = CallbackHandler(
|
||||
host=langfuse_host,
|
||||
public_key=langfuse_public,
|
||||
secret_key=langfuse_secret,
|
||||
)
|
||||
logger.info("Langfuse tracing enabled: %s", langfuse_host)
|
||||
|
||||
# Also set LANGSMITH_TRACING for LangGraph native integration
|
||||
os.environ.setdefault("LANGSMITH_TRACING", "true")
|
||||
|
||||
return [handler]
|
||||
except ImportError:
|
||||
logger.warning("Langfuse env vars set but langfuse package not installed")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.warning("Langfuse setup failed: %s", e)
|
||||
return []
|
||||
0
molecule_runtime/builtin_tools/__init__.py
Normal file
0
molecule_runtime/builtin_tools/__init__.py
Normal file
85
molecule_runtime/builtin_tools/a2a_tools.py
Normal file
85
molecule_runtime/builtin_tools/a2a_tools.py
Normal file
@ -0,0 +1,85 @@
|
||||
"""A2A communication tools — framework-agnostic delegation and peer discovery.
|
||||
|
||||
These are plain async functions that any adapter can wrap in its native tool format.
|
||||
The LangChain @tool versions are in tools/delegation.py.
|
||||
"""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import httpx
|
||||
|
||||
PLATFORM_URL = os.environ.get("PLATFORM_URL", "http://platform:8080")
|
||||
WORKSPACE_ID = os.environ.get("WORKSPACE_ID", "")
|
||||
|
||||
|
||||
async def list_peers() -> list[dict]:
|
||||
"""Get this workspace's peers from the platform registry."""
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
try:
|
||||
resp = await client.get(f"{PLATFORM_URL}/registry/{WORKSPACE_ID}/peers")
|
||||
if resp.status_code == 200:
|
||||
return resp.json()
|
||||
return []
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
async def delegate_task(workspace_id: str, task: str) -> str:
|
||||
"""Send a task to a peer workspace via A2A and return the response text."""
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
# Discover target URL
|
||||
try:
|
||||
resp = await client.get(
|
||||
f"{PLATFORM_URL}/registry/discover/{workspace_id}",
|
||||
headers={"X-Workspace-ID": WORKSPACE_ID},
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return f"Error: cannot reach workspace {workspace_id} (status {resp.status_code})"
|
||||
target_url = resp.json().get("url", "")
|
||||
if not target_url:
|
||||
return f"Error: workspace {workspace_id} has no URL"
|
||||
except Exception as e:
|
||||
return f"Error discovering workspace: {e}"
|
||||
|
||||
# Send A2A message
|
||||
try:
|
||||
a2a_resp = await client.post(
|
||||
target_url,
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"id": str(uuid.uuid4()),
|
||||
"method": "message/send",
|
||||
"params": {
|
||||
"message": {
|
||||
"role": "user",
|
||||
"messageId": str(uuid.uuid4()),
|
||||
"parts": [{"kind": "text", "text": task}],
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
data = a2a_resp.json()
|
||||
if "result" in data:
|
||||
parts = data["result"].get("parts", [])
|
||||
return parts[0].get("text", "(no text)") if parts else str(data["result"])
|
||||
elif "error" in data:
|
||||
return f"Error: {data['error'].get('message', str(data['error']))}"
|
||||
return str(data)
|
||||
except Exception as e:
|
||||
return f"Error sending A2A message: {e}"
|
||||
|
||||
|
||||
async def get_peers_summary() -> str:
|
||||
"""Return a formatted string of available peers for system prompts."""
|
||||
peers = await list_peers()
|
||||
if not peers:
|
||||
return "No peers available."
|
||||
lines = []
|
||||
for p in peers:
|
||||
name = p.get("name", "Unknown")
|
||||
pid = p.get("id", "")
|
||||
role = p.get("role", "")
|
||||
status = p.get("status", "")
|
||||
lines.append(f"- {name} (ID: {pid}) — {role} [{status}]")
|
||||
return "Available peers:\n" + "\n".join(lines)
|
||||
320
molecule_runtime/builtin_tools/approval.py
Normal file
320
molecule_runtime/builtin_tools/approval.py
Normal file
@ -0,0 +1,320 @@
|
||||
"""Approval tool for human-in-the-loop workflows.
|
||||
|
||||
When an agent encounters a destructive, expensive, or unauthorized action,
|
||||
it calls request_approval() which creates a request and waits for a decision.
|
||||
|
||||
## Notification strategy
|
||||
|
||||
By default this module uses a **WebSocket subscription** (APPROVAL_USE_WEBSOCKET=true
|
||||
or when the ``websockets`` package is installed). The platform pushes an
|
||||
``APPROVAL_DECIDED`` event to the workspace WebSocket as soon as a human
|
||||
clicks Approve / Deny on the canvas — no polling required, instant delivery.
|
||||
|
||||
If WebSocket is unavailable (env var opt-out or import error) the module
|
||||
falls back to a **polling loop** so existing deployments without WebSocket
|
||||
support continue to work without any config change.
|
||||
|
||||
RBAC enforcement
|
||||
----------------
|
||||
The calling workspace must hold a role that grants the ``"approve"`` action.
|
||||
Roles are read from ``config.yaml`` under ``rbac.roles`` (default: operator).
|
||||
|
||||
Audit trail
|
||||
-----------
|
||||
Every approval lifecycle emits structured JSON Lines records:
|
||||
|
||||
1. ``approval / approve / requested`` — request submitted to platform
|
||||
2. ``approval / approve / granted`` — human approved (actor = decided_by)
|
||||
3. ``approval / approve / denied`` — human denied (actor = decided_by)
|
||||
4. ``approval / approve / timeout`` — no decision within APPROVAL_TIMEOUT
|
||||
|
||||
RBAC denials emit an ``rbac / rbac.deny / denied`` event instead.
|
||||
|
||||
Environment variables
|
||||
---------------------
|
||||
PLATFORM_URL Platform base URL (default: http://platform:8080)
|
||||
WORKSPACE_ID This workspace's ID (default: "")
|
||||
APPROVAL_TIMEOUT Max wait in seconds (default: 300)
|
||||
APPROVAL_POLL_INTERVAL Polling interval in seconds (default: 5, polling path only)
|
||||
APPROVAL_USE_WEBSOCKET "true" to force WS, "false"
|
||||
to force polling (default: auto-detect)
|
||||
AUDIT_LOG_PATH Path for JSON Lines audit log (default: /var/log/molecule/audit.jsonl)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from builtin_tools.audit import check_permission, get_workspace_roles, log_event
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PLATFORM_URL = os.environ.get("PLATFORM_URL", "http://platform:8080")
|
||||
WORKSPACE_ID = os.environ.get("WORKSPACE_ID", "")
|
||||
APPROVAL_POLL_INTERVAL = float(os.environ.get("APPROVAL_POLL_INTERVAL", "5"))
|
||||
APPROVAL_TIMEOUT = float(os.environ.get("APPROVAL_TIMEOUT", "300"))
|
||||
|
||||
# Auto-detect WebSocket support; can be overridden with env var
|
||||
_ws_env = os.environ.get("APPROVAL_USE_WEBSOCKET", "").lower()
|
||||
if _ws_env == "false":
|
||||
_USE_WEBSOCKET_DEFAULT = False
|
||||
elif _ws_env == "true":
|
||||
_USE_WEBSOCKET_DEFAULT = True
|
||||
else:
|
||||
try:
|
||||
import websockets as _ws_probe # noqa: F401
|
||||
_USE_WEBSOCKET_DEFAULT = True
|
||||
except ImportError:
|
||||
_USE_WEBSOCKET_DEFAULT = False
|
||||
|
||||
# Module-level reference so tests can monkeypatch it
|
||||
try:
|
||||
import websockets
|
||||
except ImportError:
|
||||
websockets = None # type: ignore[assignment]
|
||||
|
||||
# Expose for test introspection
|
||||
APPROVAL_USE_WEBSOCKET = _USE_WEBSOCKET_DEFAULT
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _create_approval_request(action: str, reason: str) -> dict:
|
||||
"""POST to the platform to create an approval request.
|
||||
|
||||
Returns {"approval_id": str} on success or {"error": str} on failure.
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
try:
|
||||
resp = await client.post(
|
||||
f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/approvals",
|
||||
json={"action": action, "reason": reason},
|
||||
)
|
||||
if resp.status_code != 201:
|
||||
return {"error": f"Failed to create request: {resp.status_code}"}
|
||||
try:
|
||||
approval_id = resp.json().get("approval_id")
|
||||
except (ValueError, Exception):
|
||||
return {"error": f"Platform returned invalid JSON (status {resp.status_code})"}
|
||||
logger.info("Approval requested: %s (id=%s)", action, approval_id)
|
||||
return {"approval_id": approval_id}
|
||||
except Exception as e:
|
||||
return {"error": f"Failed to request approval: {e}"}
|
||||
|
||||
|
||||
async def _wait_websocket(approval_id: str, timeout: float) -> dict:
|
||||
"""Subscribe to the platform WebSocket and wait for APPROVAL_DECIDED event.
|
||||
|
||||
Returns the decision dict or raises asyncio.TimeoutError on expiry.
|
||||
"""
|
||||
ws_url = (
|
||||
PLATFORM_URL.replace("http://", "ws://").replace("https://", "wss://")
|
||||
+ "/ws"
|
||||
)
|
||||
headers = {"X-Workspace-ID": WORKSPACE_ID}
|
||||
|
||||
logger.debug("Approval %s: waiting via WebSocket %s", approval_id, ws_url)
|
||||
|
||||
async with websockets.connect(ws_url, additional_headers=headers) as ws:
|
||||
async for raw_message in ws:
|
||||
try:
|
||||
event = json.loads(raw_message)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if event.get("event") != "APPROVAL_DECIDED":
|
||||
continue
|
||||
if event.get("approval_id") != approval_id:
|
||||
continue
|
||||
|
||||
status = event.get("status")
|
||||
decided_by = event.get("decided_by", "")
|
||||
logger.info("Approval %s decided via WebSocket: %s by %s",
|
||||
approval_id, status, decided_by)
|
||||
|
||||
if status == "approved":
|
||||
return {
|
||||
"approved": True,
|
||||
"approval_id": approval_id,
|
||||
"decided_by": decided_by,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"approved": False,
|
||||
"approval_id": approval_id,
|
||||
"decided_by": decided_by,
|
||||
"message": "Denied by human",
|
||||
}
|
||||
|
||||
|
||||
async def _wait_polling(approval_id: str, timeout: float) -> dict:
|
||||
"""Legacy polling loop — checks platform REST endpoint every APPROVAL_POLL_INTERVAL seconds."""
|
||||
elapsed = 0.0
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
while elapsed < timeout:
|
||||
await asyncio.sleep(APPROVAL_POLL_INTERVAL)
|
||||
elapsed += APPROVAL_POLL_INTERVAL
|
||||
try:
|
||||
resp = await client.get(
|
||||
f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/approvals",
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
for a in resp.json():
|
||||
if a.get("id") == approval_id:
|
||||
status = a.get("status")
|
||||
if status == "approved":
|
||||
logger.info("Approval granted (poll): %s", approval_id)
|
||||
return {
|
||||
"approved": True,
|
||||
"approval_id": approval_id,
|
||||
"decided_by": a.get("decided_by"),
|
||||
}
|
||||
elif status == "denied":
|
||||
logger.info("Approval denied (poll): %s", approval_id)
|
||||
return {
|
||||
"approved": False,
|
||||
"approval_id": approval_id,
|
||||
"decided_by": a.get("decided_by"),
|
||||
"message": "Denied by human",
|
||||
}
|
||||
except Exception:
|
||||
pass # transient error — keep retrying
|
||||
|
||||
raise asyncio.TimeoutError()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public tool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@tool
|
||||
async def request_approval(
|
||||
action: str,
|
||||
reason: str,
|
||||
) -> dict:
|
||||
"""Request human approval before proceeding with a sensitive action.
|
||||
|
||||
Use this when you're about to do something destructive, expensive,
|
||||
or outside your normal authority. The request is sent to the canvas
|
||||
where a human can approve or deny it.
|
||||
|
||||
Args:
|
||||
action: Short description of what you want to do
|
||||
reason: Why this action is necessary
|
||||
"""
|
||||
# One trace_id links every audit event for this approval lifecycle.
|
||||
trace_id = str(uuid.uuid4())
|
||||
|
||||
# --- RBAC check -----------------------------------------------------------
|
||||
roles, custom_perms = get_workspace_roles()
|
||||
if not check_permission("approve", roles, custom_perms):
|
||||
log_event(
|
||||
event_type="rbac",
|
||||
action="rbac.deny",
|
||||
resource=action,
|
||||
outcome="denied",
|
||||
trace_id=trace_id,
|
||||
attempted_action="approve",
|
||||
roles=roles,
|
||||
)
|
||||
return {
|
||||
"approved": False,
|
||||
"error": (
|
||||
"RBAC: this workspace does not have the 'approve' permission. "
|
||||
f"Current roles: {roles}"
|
||||
),
|
||||
}
|
||||
|
||||
# Step 1: Create the approval request
|
||||
creation = await _create_approval_request(action, reason)
|
||||
if "error" in creation:
|
||||
log_event(
|
||||
event_type="approval",
|
||||
action="approve",
|
||||
resource=action,
|
||||
outcome="failure",
|
||||
trace_id=trace_id,
|
||||
reason="submit_failed",
|
||||
error=creation["error"],
|
||||
)
|
||||
return {"approved": False, "error": creation["error"]}
|
||||
|
||||
approval_id = creation["approval_id"]
|
||||
log_event(
|
||||
event_type="approval",
|
||||
action="approve",
|
||||
resource=action,
|
||||
outcome="requested",
|
||||
trace_id=trace_id,
|
||||
approval_id=approval_id,
|
||||
reason_text=reason,
|
||||
)
|
||||
|
||||
timeout = float(os.environ.get("APPROVAL_TIMEOUT", str(APPROVAL_TIMEOUT)))
|
||||
|
||||
# Step 2: Wait for decision — WebSocket preferred, polling as fallback
|
||||
use_ws = APPROVAL_USE_WEBSOCKET and websockets is not None
|
||||
|
||||
try:
|
||||
if use_ws:
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
_wait_websocket(approval_id, timeout),
|
||||
timeout=timeout,
|
||||
)
|
||||
except Exception as ws_err:
|
||||
# WebSocket failed (connection error, etc.) — fall through to polling
|
||||
logger.warning(
|
||||
"WebSocket approval wait failed (%s), falling back to polling",
|
||||
ws_err,
|
||||
)
|
||||
result = await asyncio.wait_for(
|
||||
_wait_polling(approval_id, timeout),
|
||||
timeout=timeout + APPROVAL_POLL_INTERVAL,
|
||||
)
|
||||
else:
|
||||
# Polling path (primary when WS disabled)
|
||||
result = await asyncio.wait_for(
|
||||
_wait_polling(approval_id, timeout),
|
||||
timeout=timeout + APPROVAL_POLL_INTERVAL, # slight grace period
|
||||
)
|
||||
|
||||
# Log the human decision
|
||||
decided_by = result.get("decided_by")
|
||||
outcome = "granted" if result.get("approved") else "denied"
|
||||
log_event(
|
||||
event_type="approval",
|
||||
action="approve",
|
||||
resource=action,
|
||||
outcome=outcome,
|
||||
# Record the human identity as actor when available
|
||||
actor=decided_by or WORKSPACE_ID,
|
||||
trace_id=trace_id,
|
||||
approval_id=approval_id,
|
||||
decided_by=decided_by,
|
||||
)
|
||||
return result
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Approval timed out after %.0fs: %s", timeout, approval_id)
|
||||
log_event(
|
||||
event_type="approval",
|
||||
action="approve",
|
||||
resource=action,
|
||||
outcome="timeout",
|
||||
trace_id=trace_id,
|
||||
approval_id=approval_id,
|
||||
timeout_seconds=timeout,
|
||||
)
|
||||
return {
|
||||
"approved": False,
|
||||
"approval_id": approval_id,
|
||||
"error": f"Timed out after {timeout}s waiting for human decision",
|
||||
}
|
||||
274
molecule_runtime/builtin_tools/audit.py
Normal file
274
molecule_runtime/builtin_tools/audit.py
Normal file
@ -0,0 +1,274 @@
|
||||
"""Immutable append-only audit log for EU AI Act compliance.
|
||||
|
||||
Fulfils Article 12 (record-keeping), Article 13 (transparency), and
|
||||
Article 17 (quality-management system) requirements for high-risk AI systems.
|
||||
|
||||
Log format: JSON Lines (one UTF-8 JSON object per line), suitable for direct
|
||||
ingestion by any SIEM (Splunk, Elastic, Datadog, etc.).
|
||||
|
||||
Required event fields
|
||||
---------------------
|
||||
timestamp ISO 8601 UTC datetime with timezone offset
|
||||
event_type Coarse category: "delegation", "approval", "memory", "rbac"
|
||||
workspace_id Workspace that generated this event
|
||||
actor Entity that triggered the action; defaults to workspace_id for
|
||||
automated events, or the human identity for approval decisions
|
||||
action Verb describing what was attempted:
|
||||
delegate | approve | memory.read | memory.write | rbac.deny
|
||||
resource Object of the action: target workspace ID, memory scope,
|
||||
approval action string, etc.
|
||||
outcome One of: allowed | denied | success | failure | timeout |
|
||||
requested | granted
|
||||
trace_id UUID v4 correlating related events across workspaces
|
||||
|
||||
The log file is opened in append mode ("a") on every write — it is NEVER
|
||||
truncated, rewritten, or deleted by this module. Rotate externally using
|
||||
logrotate (with ``copytruncate`` disabled) or ship to a SIEM before rotating.
|
||||
|
||||
Configuration
|
||||
-------------
|
||||
AUDIT_LOG_PATH env var — full path to the JSONL file
|
||||
default: /var/log/molecule/audit.jsonl
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass # avoid circular import at runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
AUDIT_LOG_PATH: str = os.environ.get(
|
||||
"AUDIT_LOG_PATH", "/var/log/molecule/audit.jsonl"
|
||||
)
|
||||
WORKSPACE_ID: str = os.environ.get("WORKSPACE_ID", "")
|
||||
|
||||
# Protects the open() + write() sequence; prevents interleaved JSON lines
|
||||
# when multiple async tasks run in the same event-loop thread.
|
||||
_write_lock = threading.Lock()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Built-in role → permitted-action mappings
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
#: Maps each built-in role name to the set of actions it grants.
|
||||
#: Custom roles can be added in config.yaml under ``rbac.allowed_actions``.
|
||||
ROLE_PERMISSIONS: dict[str, set[str]] = {
|
||||
# Full access — shortcircuits all other checks
|
||||
"admin": {"delegate", "approve", "memory.read", "memory.write"},
|
||||
# Standard agent role
|
||||
"operator": {"delegate", "approve", "memory.read", "memory.write"},
|
||||
# Read-only observer — no writes, no delegation, no approvals
|
||||
"read-only": {"memory.read"},
|
||||
# Can approve and write memory, but cannot delegate
|
||||
"no-delegation": {"approve", "memory.read", "memory.write"},
|
||||
# Can delegate and write memory, but cannot invoke approval gate
|
||||
"no-approval": {"delegate", "memory.read", "memory.write"},
|
||||
# Memory reads only (useful for analytic sidecars)
|
||||
"memory-readonly": {"memory.read"},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config loader (lazy, cached per process)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _load_workspace_config():
|
||||
"""Return the WorkspaceConfig or None if it cannot be loaded."""
|
||||
try:
|
||||
from config import load_config # local import avoids circular deps
|
||||
return load_config()
|
||||
except Exception as exc:
|
||||
logger.warning("audit: could not load workspace config for RBAC: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
def get_workspace_roles() -> tuple[list[str], dict[str, list[str]]]:
|
||||
"""Return ``(roles, custom_permissions)`` from the workspace config.
|
||||
|
||||
Falls back to ``["operator"]`` / ``{}`` when the config is unavailable so
|
||||
that agents remain functional in degraded environments.
|
||||
"""
|
||||
cfg = _load_workspace_config()
|
||||
if cfg is None:
|
||||
return ["operator"], {}
|
||||
return list(cfg.rbac.roles), dict(cfg.rbac.allowed_actions)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RBAC helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def check_permission(
|
||||
action: str,
|
||||
roles: list[str],
|
||||
custom_permissions: dict[str, list[str]] | None = None,
|
||||
) -> bool:
|
||||
"""Return True if *any* of ``roles`` grants ``action``.
|
||||
|
||||
Evaluation order
|
||||
~~~~~~~~~~~~~~~~
|
||||
1. ``"admin"`` shortcircuits — always grants everything.
|
||||
2. Custom role definitions (from ``rbac.allowed_actions`` in config.yaml).
|
||||
3. Built-in :data:`ROLE_PERMISSIONS` table.
|
||||
|
||||
When a role appears in *custom_permissions* its built-in definition is
|
||||
**ignored** — the custom list is the complete permission set for that role.
|
||||
|
||||
Args:
|
||||
action: Action to authorise, e.g. ``"delegate"``.
|
||||
roles: Roles assigned to the calling workspace.
|
||||
custom_permissions: Optional ``{role: [action, ...]}`` mapping loaded
|
||||
from ``WorkspaceConfig.rbac.allowed_actions``.
|
||||
|
||||
Returns:
|
||||
``True`` if the action is permitted, ``False`` otherwise.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> check_permission("delegate", ["operator"])
|
||||
True
|
||||
>>> check_permission("delegate", ["read-only"])
|
||||
False
|
||||
>>> check_permission("deploy", ["developer"], {"developer": ["deploy"]})
|
||||
True
|
||||
"""
|
||||
for role in roles:
|
||||
if role == "admin":
|
||||
return True
|
||||
if custom_permissions and role in custom_permissions:
|
||||
# Custom entry is definitive for this role
|
||||
if action in custom_permissions[role]:
|
||||
return True
|
||||
continue # Don't fall through to built-ins for custom roles
|
||||
if role in ROLE_PERMISSIONS and action in ROLE_PERMISSIONS[role]:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public audit API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def log_event(
|
||||
event_type: str,
|
||||
action: str,
|
||||
resource: str,
|
||||
outcome: str,
|
||||
actor: str | None = None,
|
||||
trace_id: str | None = None,
|
||||
**extra: Any,
|
||||
) -> str:
|
||||
"""Append one audit event to the immutable JSON Lines log.
|
||||
|
||||
Args:
|
||||
event_type: Coarse category — ``"delegation"``, ``"approval"``,
|
||||
``"memory"``, or ``"rbac"``.
|
||||
action: Verb — ``"delegate"``, ``"approve"``, ``"memory.write"``,
|
||||
``"memory.read"``, ``"rbac.deny"``.
|
||||
resource: Object of the action — target workspace ID, memory scope,
|
||||
approval action string, etc.
|
||||
outcome: Terminal state — one of ``"allowed"``, ``"denied"``,
|
||||
``"success"``, ``"failure"``, ``"timeout"``,
|
||||
``"requested"``, ``"granted"``.
|
||||
actor: Identity that triggered the event. Defaults to
|
||||
``WORKSPACE_ID`` (the running workspace) for automated
|
||||
events. Pass ``decided_by`` for human approval decisions.
|
||||
trace_id: Caller-supplied UUID v4 for cross-event correlation.
|
||||
A fresh UUID is generated when omitted.
|
||||
**extra: Additional key-value pairs appended verbatim to the JSON
|
||||
object (e.g. ``target_workspace_id``, ``memory_scope``,
|
||||
``attempt``). Built-in keys cannot be overridden.
|
||||
|
||||
Returns:
|
||||
The ``trace_id`` used for this event, enabling callers to chain
|
||||
related events under a single correlation identifier.
|
||||
|
||||
Example::
|
||||
|
||||
trace = log_event(
|
||||
event_type="delegation",
|
||||
action="delegate",
|
||||
resource="billing-agent",
|
||||
outcome="success",
|
||||
target_workspace_id="billing-agent",
|
||||
attempt=1,
|
||||
)
|
||||
"""
|
||||
if trace_id is None:
|
||||
trace_id = str(uuid.uuid4())
|
||||
|
||||
event: dict[str, Any] = {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"event_type": event_type,
|
||||
"workspace_id": WORKSPACE_ID,
|
||||
"actor": actor if actor is not None else WORKSPACE_ID,
|
||||
"action": action,
|
||||
"resource": resource,
|
||||
"outcome": outcome,
|
||||
"trace_id": trace_id,
|
||||
}
|
||||
|
||||
# Merge extra fields — built-in keys are not overridable
|
||||
for key, value in extra.items():
|
||||
if key not in event:
|
||||
event[key] = value
|
||||
|
||||
_write_event(event)
|
||||
return trace_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal writer
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _ensure_log_dir(path: str) -> None:
|
||||
"""Create the parent directory for *path* if it does not already exist."""
|
||||
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def _write_event(event: dict[str, Any]) -> None:
|
||||
"""Serialise *event* as a JSON line and fsync-append it to the log file.
|
||||
|
||||
The write is atomic with respect to other threads in this process: the
|
||||
lock ensures that no two JSON objects are interleaved on the same line.
|
||||
|
||||
Failures are emitted to the standard Python logger at WARNING level but
|
||||
are **never** re-raised — the application must not crash because audit
|
||||
logging is temporarily unavailable (e.g. disk full, permission error).
|
||||
In production, consider wiring an alert on WARNING messages from this
|
||||
module so that missing audit records are detected quickly.
|
||||
"""
|
||||
try:
|
||||
log_path = AUDIT_LOG_PATH
|
||||
_ensure_log_dir(log_path)
|
||||
line = json.dumps(event, default=str, ensure_ascii=False) + "\n"
|
||||
with _write_lock:
|
||||
with open(log_path, "a", encoding="utf-8") as fh:
|
||||
fh.write(line)
|
||||
fh.flush()
|
||||
os.fsync(fh.fileno())
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
logger.warning(
|
||||
"Audit log write failed — event NOT persisted "
|
||||
"(trace_id=%s, action=%s): %s",
|
||||
event.get("trace_id", "?"),
|
||||
event.get("action", "?"),
|
||||
exc,
|
||||
)
|
||||
122
molecule_runtime/builtin_tools/awareness_client.py
Normal file
122
molecule_runtime/builtin_tools/awareness_client.py
Normal file
@ -0,0 +1,122 @@
|
||||
"""Workspace-scoped awareness backend wrapper.
|
||||
|
||||
The agent-facing memory tools keep their existing signatures and delegate
|
||||
to this helper when workspace awareness is configured.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
from policies.namespaces import resolve_awareness_namespace
|
||||
|
||||
try: # pragma: no cover - optional runtime dependency in lightweight test envs
|
||||
import httpx # type: ignore
|
||||
except ImportError: # pragma: no cover
|
||||
httpx = SimpleNamespace(AsyncClient=None)
|
||||
|
||||
|
||||
DEFAULT_AWARENESS_TIMEOUT = 10.0
|
||||
|
||||
|
||||
def get_awareness_config() -> dict[str, str] | None:
|
||||
"""Return awareness connection settings if the workspace is configured."""
|
||||
base_url = os.environ.get("AWARENESS_URL", "").rstrip("/")
|
||||
workspace_id = os.environ.get("WORKSPACE_ID", "")
|
||||
configured_namespace = os.environ.get("AWARENESS_NAMESPACE", "")
|
||||
if not base_url:
|
||||
return None
|
||||
if not workspace_id and not configured_namespace:
|
||||
return None
|
||||
namespace = resolve_awareness_namespace(workspace_id, configured_namespace)
|
||||
return {
|
||||
"base_url": base_url,
|
||||
"namespace": namespace,
|
||||
}
|
||||
|
||||
|
||||
class AwarenessClient:
|
||||
"""Small HTTP client for workspace-scoped awareness memory operations."""
|
||||
|
||||
def __init__(self, base_url: str, namespace: str, timeout: float = DEFAULT_AWARENESS_TIMEOUT):
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.namespace = namespace
|
||||
self.timeout = timeout
|
||||
|
||||
def _memories_url(self) -> str:
|
||||
# Keep the awareness path isolated in one helper so the contract can
|
||||
# be adjusted later without touching the agent-facing tools.
|
||||
return f"{self.base_url}/api/v1/namespaces/{self.namespace}/memories"
|
||||
|
||||
async def commit(self, content: str, scope: str) -> dict[str, Any]:
|
||||
client_cls = _resolve_async_client()
|
||||
async with client_cls(timeout=self.timeout) as client:
|
||||
resp = await client.post(
|
||||
self._memories_url(),
|
||||
json={"content": content, "scope": scope},
|
||||
)
|
||||
return _parse_commit_response(resp, scope)
|
||||
|
||||
async def search(self, query: str = "", scope: str = "") -> dict[str, Any]:
|
||||
params: dict[str, str] = {}
|
||||
if query:
|
||||
params["q"] = query
|
||||
if scope:
|
||||
params["scope"] = scope
|
||||
|
||||
client_cls = _resolve_async_client()
|
||||
async with client_cls(timeout=self.timeout) as client:
|
||||
resp = await client.get(self._memories_url(), params=params)
|
||||
return _parse_search_response(resp)
|
||||
|
||||
|
||||
def build_awareness_client() -> AwarenessClient | None:
|
||||
"""Create an awareness client from the current workspace environment."""
|
||||
config = get_awareness_config()
|
||||
if not config:
|
||||
return None
|
||||
return AwarenessClient(config["base_url"], config["namespace"])
|
||||
|
||||
|
||||
def _parse_commit_response(resp: httpx.Response, scope: str) -> dict[str, Any]:
|
||||
data = _safe_json(resp)
|
||||
if resp.status_code in (200, 201):
|
||||
return {"success": True, "id": data.get("id"), "scope": scope}
|
||||
return {"success": False, "error": data.get("error", resp.text)}
|
||||
|
||||
|
||||
def _parse_search_response(resp: httpx.Response) -> dict[str, Any]:
|
||||
data = _safe_json(resp)
|
||||
if resp.status_code == 200:
|
||||
memories = data if isinstance(data, list) else data.get("memories", [])
|
||||
return {
|
||||
"success": True,
|
||||
"count": len(memories),
|
||||
"memories": memories,
|
||||
}
|
||||
return {"success": False, "error": data.get("error", resp.text)}
|
||||
|
||||
|
||||
def _safe_json(resp: httpx.Response) -> dict[str, Any] | list[Any]:
|
||||
try:
|
||||
return resp.json()
|
||||
except ValueError:
|
||||
return {"error": resp.text}
|
||||
|
||||
|
||||
def _resolve_async_client():
|
||||
client_cls = getattr(httpx, "AsyncClient", None)
|
||||
if client_cls is not None:
|
||||
return client_cls
|
||||
|
||||
memory_module = sys.modules.get("builtin_tools.memory")
|
||||
if memory_module is not None:
|
||||
memory_httpx = getattr(memory_module, "httpx", None)
|
||||
client_cls = getattr(memory_httpx, "AsyncClient", None)
|
||||
if client_cls is not None:
|
||||
return client_cls
|
||||
|
||||
raise RuntimeError("httpx.AsyncClient is unavailable")
|
||||
359
molecule_runtime/builtin_tools/compliance.py
Normal file
359
molecule_runtime/builtin_tools/compliance.py
Normal file
@ -0,0 +1,359 @@
|
||||
"""OWASP Top 10 for Agentic Applications compliance enforcement (Dec 2025).
|
||||
|
||||
Enable via config.yaml::
|
||||
|
||||
compliance:
|
||||
mode: owasp_agentic
|
||||
prompt_injection: detect # detect | block
|
||||
max_tool_calls_per_task: 50
|
||||
max_task_duration_seconds: 300
|
||||
|
||||
When ``mode`` is absent or empty, this module is a no-op — no overhead, no
|
||||
behaviour change. This makes it safe to import unconditionally.
|
||||
|
||||
Coverage
|
||||
--------
|
||||
|
||||
OA-01 Prompt Injection (``sanitize_input``)
|
||||
Scans user-supplied text for instruction-override patterns, role-hijacking
|
||||
attempts, system-prompt delimiter injection, and known jailbreak keywords.
|
||||
|
||||
- ``detect`` (default): log an audit event, return the original text so
|
||||
the agent still processes the input. Operators are alerted without
|
||||
breaking legitimate use-cases that happen to contain trigger words.
|
||||
|
||||
- ``block``: raise ``PromptInjectionError`` before the agent sees the text.
|
||||
|
||||
OA-03 Excessive Agency (``check_agency_limits``)
|
||||
Tracks the number of tool calls and wall-clock time elapsed per task.
|
||||
When a limit is exceeded, ``ExcessiveAgencyError`` is raised. The caller
|
||||
(``a2a_executor.py``) catches it and terminates the task gracefully.
|
||||
|
||||
OA-02 / OA-06 Insecure Output / Sensitive Data Exposure (``redact_pii``)
|
||||
Scans agent output for credit-card numbers, SSNs, API keys, AWS access
|
||||
keys, and e-mail addresses. Detected values are replaced with
|
||||
``[REDACTED:<type>]`` tokens before the response reaches the caller.
|
||||
An audit event records the PII types found (not the values themselves).
|
||||
|
||||
Note on streaming: ``redact_pii`` is applied to the *final accumulated
|
||||
text* before the terminal ``Message`` event is emitted. Token-by-token
|
||||
SSE artifacts that have already been sent to streaming clients are not
|
||||
retroactively redacted. For full streaming redaction, integrate
|
||||
``redact_pii`` at the ``TaskArtifactUpdateEvent`` level.
|
||||
|
||||
Compliance posture report (``get_compliance_posture``)
|
||||
Returns the current effective compliance configuration as a plain ``dict``
|
||||
suitable for a health or audit endpoint, letting operators verify that the
|
||||
correct settings are active without reading config files.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from builtin_tools.audit import log_event
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public exceptions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class PromptInjectionError(ValueError):
|
||||
"""Raised when prompt injection is detected and ``prompt_injection=block``."""
|
||||
|
||||
|
||||
class ExcessiveAgencyError(RuntimeError):
|
||||
"""Raised when the tool-call count or task-duration limit is exceeded."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OA-01 — Prompt Injection detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
#: Compiled patterns matched against normalised (lowercased + collapsed) input.
|
||||
#: Add workspace-specific patterns in config if needed.
|
||||
_INJECTION_PATTERNS: list[tuple[re.Pattern[str], str]] = [
|
||||
# Instruction override
|
||||
(re.compile(r"ignore\s+(all\s+)?previous\s+instructions?", re.I), "instruction_override"),
|
||||
(re.compile(r"disregard\s+(all\s+)?previous", re.I), "instruction_override"),
|
||||
(re.compile(r"forget\s+(all\s+)?previous", re.I), "instruction_override"),
|
||||
(re.compile(r"override\s+(your\s+)?(instructions?|guidelines?|rules?)", re.I), "instruction_override"),
|
||||
# Role hijacking
|
||||
(re.compile(r"you\s+are\s+now\s+\w", re.I), "role_hijack"),
|
||||
(re.compile(r"act\s+as\s+(a\s+)?(new\s+|different\s+|unrestricted\s+)", re.I), "role_hijack"),
|
||||
(re.compile(r"roleplay\s+as", re.I), "role_hijack"),
|
||||
(re.compile(r"pretend\s+(you\s+are|to\s+be)\b", re.I), "role_hijack"),
|
||||
(re.compile(r"from\s+now\s+on\s+(you\s+are|act\s+as)", re.I), "role_hijack"),
|
||||
# System-prompt delimiter injection (LLM-specific tokens)
|
||||
(re.compile(r"<\|?\s*(system|im_start|im_end|endoftext)\s*\|?>", re.I), "delimiter_injection"),
|
||||
(re.compile(r"\[INST\]|\[/INST\]|\[\[SYS\]\]|\[\[/SYS\]\]", re.I), "delimiter_injection"),
|
||||
(re.compile(r"<</SYS>>|<<SYS>>", re.I), "delimiter_injection"),
|
||||
# DAN / jailbreak keywords
|
||||
(re.compile(r"\bDAN\b.{0,30}(mode|now|enabled|activated)", re.I), "jailbreak"),
|
||||
(re.compile(r"do\s+anything\s+now", re.I), "jailbreak"),
|
||||
(re.compile(r"\bjailbreak\b", re.I), "jailbreak"),
|
||||
(re.compile(r"developer\s+mode\s+(enabled|on)", re.I), "jailbreak"),
|
||||
# Prompt exfiltration
|
||||
(re.compile(r"(repeat|print|output|show|reveal|display)\s+(your\s+)?(system\s+prompt|initial\s+instructions?)", re.I), "prompt_exfiltration"),
|
||||
(re.compile(r"what\s+(are\s+)?your\s+(instructions?|system\s+prompt)", re.I), "prompt_exfiltration"),
|
||||
]
|
||||
|
||||
|
||||
def detect_prompt_injection(text: str) -> list[tuple[str, str]]:
|
||||
"""Return a list of ``(pattern_description, category)`` for each match.
|
||||
|
||||
Args:
|
||||
text: Raw user input to scan.
|
||||
|
||||
Returns:
|
||||
List of ``(matched_pattern, category)`` tuples; empty means clean.
|
||||
"""
|
||||
matches: list[tuple[str, str]] = []
|
||||
for pattern, category in _INJECTION_PATTERNS:
|
||||
m = pattern.search(text)
|
||||
if m:
|
||||
matches.append((m.group(0)[:80], category))
|
||||
return matches
|
||||
|
||||
|
||||
def sanitize_input(
|
||||
text: str,
|
||||
*,
|
||||
prompt_injection_mode: str = "detect",
|
||||
context_id: str = "",
|
||||
) -> str:
|
||||
"""Check *text* for prompt injection and enforce the configured response.
|
||||
|
||||
Args:
|
||||
text: User-supplied input to the agent.
|
||||
prompt_injection_mode: ``"detect"`` or ``"block"``.
|
||||
context_id: Task/context identifier for audit correlation.
|
||||
|
||||
Returns:
|
||||
The original *text* unchanged (``detect`` mode always returns input).
|
||||
|
||||
Raises:
|
||||
:class:`PromptInjectionError`: only when ``prompt_injection_mode="block"``
|
||||
and at least one injection pattern is matched.
|
||||
"""
|
||||
matches = detect_prompt_injection(text)
|
||||
if not matches:
|
||||
return text
|
||||
|
||||
categories = list({cat for _, cat in matches})
|
||||
trace_id = str(uuid.uuid4())
|
||||
|
||||
log_event(
|
||||
event_type="compliance",
|
||||
action="prompt_injection.detect",
|
||||
resource="user_input",
|
||||
outcome="detected" if prompt_injection_mode == "detect" else "blocked",
|
||||
trace_id=trace_id,
|
||||
context_id=context_id,
|
||||
categories=categories,
|
||||
match_count=len(matches),
|
||||
# Log category + truncated match, never the full raw text (OA-06)
|
||||
matches=[{"category": cat, "snippet": snippet} for snippet, cat in matches[:5]],
|
||||
)
|
||||
|
||||
if prompt_injection_mode == "block":
|
||||
raise PromptInjectionError(
|
||||
f"Prompt injection detected ({', '.join(categories)}). "
|
||||
"Request blocked by compliance policy."
|
||||
)
|
||||
|
||||
# detect mode — log and continue
|
||||
logger.warning(
|
||||
"Prompt injection patterns detected (context_id=%s, categories=%s) — "
|
||||
"passing to agent in detect mode",
|
||||
context_id,
|
||||
categories,
|
||||
)
|
||||
return text
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OA-03 — Excessive Agency
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgencyTracker:
|
||||
"""Per-task mutable state for excessive-agency enforcement.
|
||||
|
||||
Instantiate once per ``execute()`` call and pass to
|
||||
:func:`check_agency_limits` at each tool-start event.
|
||||
"""
|
||||
|
||||
max_tool_calls: int = 50
|
||||
max_duration_seconds: float = 300.0
|
||||
tool_call_count: int = field(default=0, init=False)
|
||||
start_time: float = field(default_factory=time.monotonic, init=False)
|
||||
|
||||
def on_tool_call(self, tool_name: str = "", context_id: str = "") -> None:
|
||||
"""Increment counter and enforce limits.
|
||||
|
||||
Raises:
|
||||
:class:`ExcessiveAgencyError`: if either limit is exceeded.
|
||||
"""
|
||||
self.tool_call_count += 1
|
||||
elapsed = time.monotonic() - self.start_time
|
||||
|
||||
if self.tool_call_count > self.max_tool_calls:
|
||||
log_event(
|
||||
event_type="compliance",
|
||||
action="excessive_agency.tool_limit",
|
||||
resource=tool_name or "unknown_tool",
|
||||
outcome="blocked",
|
||||
context_id=context_id,
|
||||
tool_call_count=self.tool_call_count,
|
||||
limit=self.max_tool_calls,
|
||||
elapsed_seconds=round(elapsed, 2),
|
||||
)
|
||||
raise ExcessiveAgencyError(
|
||||
f"Tool call limit exceeded: {self.tool_call_count} calls > "
|
||||
f"max {self.max_tool_calls} per task"
|
||||
)
|
||||
|
||||
if elapsed > self.max_duration_seconds:
|
||||
log_event(
|
||||
event_type="compliance",
|
||||
action="excessive_agency.duration_limit",
|
||||
resource=tool_name or "unknown_tool",
|
||||
outcome="blocked",
|
||||
context_id=context_id,
|
||||
tool_call_count=self.tool_call_count,
|
||||
elapsed_seconds=round(elapsed, 2),
|
||||
limit_seconds=self.max_duration_seconds,
|
||||
)
|
||||
raise ExcessiveAgencyError(
|
||||
f"Task duration limit exceeded: {elapsed:.0f}s > "
|
||||
f"max {self.max_duration_seconds:.0f}s per task"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OA-02 / OA-06 — PII redaction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
#: ``(compiled_pattern, replacement_token)`` pairs applied in order.
|
||||
#: The replacement tokens are SIEM-friendly: ``[REDACTED:type]``.
|
||||
_PII_PATTERNS: list[tuple[re.Pattern[str], str]] = [
|
||||
# Formatted credit cards: XXXX-XXXX-XXXX-XXXX or XXXX XXXX XXXX XXXX
|
||||
(re.compile(r"\b\d{4}[\s\-]\d{4}[\s\-]\d{4}[\s\-]\d{4}\b"), "[REDACTED:credit_card]"),
|
||||
# US Social Security Numbers: XXX-XX-XXXX
|
||||
(re.compile(r"\b\d{3}-\d{2}-\d{4}\b"), "[REDACTED:ssn]"),
|
||||
# OpenAI-style keys: sk-... (≥ 32 chars after prefix)
|
||||
(re.compile(r"\bsk-[A-Za-z0-9_\-]{32,}\b"), "[REDACTED:api_key]"),
|
||||
# Generic API/secret keys with common prefixes
|
||||
(re.compile(r"\b(?:sk|pk|api|secret|token|auth)[-_][A-Za-z0-9_\-]{20,}\b", re.I), "[REDACTED:api_key]"),
|
||||
# AWS Access Key IDs
|
||||
(re.compile(r"\bAKIA[0-9A-Z]{16}\b"), "[REDACTED:aws_key]"),
|
||||
# GitHub personal access tokens — classic format (36-char alphanumeric suffix)
|
||||
(re.compile(r"\bghp_[A-Za-z0-9]{36}\b"), "[REDACTED:github_token]"),
|
||||
# GitHub personal access tokens — fine-grained format (82-char alphanumeric+underscore suffix)
|
||||
(re.compile(r"\bgithub_pat_[A-Za-z0-9_]{82}\b"), "[REDACTED:github_token]"),
|
||||
# Email addresses
|
||||
(re.compile(r"\b[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}\b"), "[REDACTED:email]"),
|
||||
]
|
||||
|
||||
|
||||
def redact_pii(text: str) -> tuple[str, list[str]]:
|
||||
"""Redact PII from *text* and return ``(redacted_text, pii_types_found)``.
|
||||
|
||||
Each unique PII type is reported at most once in ``pii_types_found``.
|
||||
The replacement tokens (``[REDACTED:type]``) are SIEM-indexable and
|
||||
preserve the structural context of the output while hiding sensitive data.
|
||||
|
||||
Args:
|
||||
text: Agent output text to scan.
|
||||
|
||||
Returns:
|
||||
Tuple of ``(redacted_text, list_of_pii_type_strings)``. The list is
|
||||
empty when no PII is detected (the common case).
|
||||
|
||||
Examples::
|
||||
|
||||
>>> redacted, types = redact_pii("Call me at test@example.com sk-abc123...")
|
||||
>>> "email" in types
|
||||
True
|
||||
>>> "[REDACTED:email]" in redacted
|
||||
True
|
||||
"""
|
||||
found: list[str] = []
|
||||
result = text
|
||||
for pattern, replacement in _PII_PATTERNS:
|
||||
new_result = pattern.sub(replacement, result)
|
||||
if new_result != result:
|
||||
# Extract type from "[REDACTED:type]"
|
||||
pii_type = replacement[len("[REDACTED:"):-1]
|
||||
if pii_type not in found:
|
||||
found.append(pii_type)
|
||||
result = new_result
|
||||
return result, found
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Compliance posture report
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def get_compliance_posture() -> dict[str, Any]:
|
||||
"""Return the current compliance configuration as a serialisable dict.
|
||||
|
||||
Loads ``WorkspaceConfig`` lazily (cached) and returns a snapshot of the
|
||||
active compliance settings. Safe to call from a health endpoint.
|
||||
|
||||
Returns a dict with these keys::
|
||||
|
||||
{
|
||||
"compliance_mode": "owasp_agentic" | "",
|
||||
"enabled": true | false,
|
||||
"prompt_injection": "detect" | "block",
|
||||
"max_tool_calls_per_task": 50,
|
||||
"max_task_duration_seconds": 300,
|
||||
"pii_redaction_enabled": true,
|
||||
"security_scan_mode": "warn" | "block" | "off",
|
||||
"rbac_roles": ["operator"],
|
||||
}
|
||||
"""
|
||||
try:
|
||||
from builtin_tools.audit import _load_workspace_config
|
||||
cfg = _load_workspace_config()
|
||||
except Exception:
|
||||
cfg = None
|
||||
|
||||
if cfg is None:
|
||||
return {
|
||||
"compliance_mode": "",
|
||||
"enabled": False,
|
||||
"prompt_injection": "detect",
|
||||
"max_tool_calls_per_task": 50,
|
||||
"max_task_duration_seconds": 300,
|
||||
"pii_redaction_enabled": False,
|
||||
"security_scan_mode": "warn",
|
||||
"rbac_roles": [],
|
||||
"note": "config unavailable",
|
||||
}
|
||||
|
||||
c = cfg.compliance
|
||||
enabled = c.mode == "owasp_agentic"
|
||||
return {
|
||||
"compliance_mode": c.mode,
|
||||
"enabled": enabled,
|
||||
"prompt_injection": c.prompt_injection,
|
||||
"max_tool_calls_per_task": c.max_tool_calls_per_task,
|
||||
"max_task_duration_seconds": c.max_task_duration_seconds,
|
||||
# PII redaction is active whenever compliance mode is on
|
||||
"pii_redaction_enabled": enabled,
|
||||
"security_scan_mode": cfg.security_scan.mode,
|
||||
"rbac_roles": list(cfg.rbac.roles),
|
||||
}
|
||||
366
molecule_runtime/builtin_tools/delegation.py
Normal file
366
molecule_runtime/builtin_tools/delegation.py
Normal file
@ -0,0 +1,366 @@
|
||||
"""Async delegation tool for sending tasks to peer workspaces via A2A.
|
||||
|
||||
Delegations are non-blocking: the tool fires the A2A request in the background
|
||||
and returns immediately with a task_id. The agent can check status anytime via
|
||||
check_delegation_status, or just continue working and check later.
|
||||
|
||||
When the delegate responds, the result is stored and the agent is notified
|
||||
via a status update.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from builtin_tools.audit import check_permission, get_workspace_roles, log_event
|
||||
from builtin_tools.telemetry import (
|
||||
A2A_SOURCE_WORKSPACE,
|
||||
A2A_TARGET_WORKSPACE,
|
||||
A2A_TASK_ID,
|
||||
WORKSPACE_ID_ATTR,
|
||||
get_current_traceparent,
|
||||
get_tracer,
|
||||
inject_trace_headers,
|
||||
)
|
||||
|
||||
PLATFORM_URL = os.environ.get("PLATFORM_URL", "http://platform:8080")
|
||||
WORKSPACE_ID = os.environ.get("WORKSPACE_ID", "")
|
||||
DELEGATION_RETRY_ATTEMPTS = int(os.environ.get("DELEGATION_RETRY_ATTEMPTS", "3"))
|
||||
DELEGATION_RETRY_DELAY = float(os.environ.get("DELEGATION_RETRY_DELAY", "5.0"))
|
||||
DELEGATION_TIMEOUT = float(os.environ.get("DELEGATION_TIMEOUT", "300.0"))
|
||||
|
||||
|
||||
class DelegationStatus(str, Enum):
|
||||
PENDING = "pending"
|
||||
IN_PROGRESS = "in_progress"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DelegationTask:
|
||||
task_id: str
|
||||
workspace_id: str
|
||||
task_description: str
|
||||
status: DelegationStatus = DelegationStatus.PENDING
|
||||
result: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
# In-memory store of delegation tasks for this workspace
|
||||
_delegations: dict[str, DelegationTask] = {}
|
||||
_background_tasks: set[asyncio.Task] = set()
|
||||
MAX_DELEGATION_HISTORY = 100
|
||||
logger = __import__("logging").getLogger(__name__)
|
||||
|
||||
|
||||
def _evict_old_delegations():
|
||||
"""Remove completed/failed delegations when store exceeds MAX_DELEGATION_HISTORY."""
|
||||
if len(_delegations) <= MAX_DELEGATION_HISTORY:
|
||||
return
|
||||
# Evict oldest completed/failed first
|
||||
removable = [
|
||||
tid for tid, d in _delegations.items()
|
||||
if d.status in (DelegationStatus.COMPLETED, DelegationStatus.FAILED)
|
||||
]
|
||||
for tid in removable[:len(_delegations) - MAX_DELEGATION_HISTORY]:
|
||||
del _delegations[tid]
|
||||
|
||||
|
||||
def _on_task_done(task: asyncio.Task):
|
||||
"""Callback for background tasks — log unhandled exceptions."""
|
||||
_background_tasks.discard(task)
|
||||
if not task.cancelled() and task.exception():
|
||||
logger.error("Delegation background task failed: %s", task.exception())
|
||||
|
||||
|
||||
async def _notify_completion(task_id: str, target_workspace_id: str, status: str):
|
||||
"""Push notification to platform when delegation completes/fails."""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
await client.post(
|
||||
f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/notify",
|
||||
json={
|
||||
"type": "delegation_complete",
|
||||
"task_id": task_id,
|
||||
"target_workspace_id": target_workspace_id,
|
||||
"status": status,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Delegation notify failed (best-effort): %s", e)
|
||||
|
||||
|
||||
async def _record_delegation_on_platform(task_id: str, target_workspace_id: str, task: str):
|
||||
"""Register the delegation in the platform's activity_logs (#64 fix).
|
||||
|
||||
Best-effort POST to /workspaces/<self>/delegations/record. The agent still
|
||||
fires A2A directly for speed + OTEL propagation, but the platform's
|
||||
GET /delegations endpoint now mirrors the same set an agent's local
|
||||
check_delegation_status sees.
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
await client.post(
|
||||
f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/delegations/record",
|
||||
json={
|
||||
"target_id": target_workspace_id,
|
||||
"task": task,
|
||||
"delegation_id": task_id,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Delegation record failed (best-effort): %s", e)
|
||||
|
||||
|
||||
async def _update_delegation_on_platform(task_id: str, status: str, error: str = "", response_preview: str = ""):
|
||||
"""Mirror status changes to the platform's activity_logs (#64 fix).
|
||||
|
||||
Paired with _record_delegation_on_platform — fires on completion/failure
|
||||
so the platform view stays in sync with the agent's local dict.
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
await client.post(
|
||||
f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/delegations/{task_id}/update",
|
||||
json={
|
||||
"status": status,
|
||||
"error": error,
|
||||
"response_preview": response_preview[:500],
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Delegation update failed (best-effort): %s", e)
|
||||
|
||||
|
||||
async def _execute_delegation(task_id: str, workspace_id: str, task: str):
|
||||
"""Background coroutine that sends the A2A request and stores the result."""
|
||||
delegation = _delegations[task_id]
|
||||
delegation.status = DelegationStatus.IN_PROGRESS
|
||||
|
||||
# #64: register on the platform so GET /workspaces/<self>/delegations
|
||||
# sees the same set as check_delegation_status. Best-effort — platform
|
||||
# unreachability must not block the actual A2A delegation.
|
||||
await _record_delegation_on_platform(task_id, workspace_id, task)
|
||||
|
||||
tracer = get_tracer()
|
||||
with tracer.start_as_current_span("task_delegate") as delegate_span:
|
||||
delegate_span.set_attribute(WORKSPACE_ID_ATTR, WORKSPACE_ID)
|
||||
delegate_span.set_attribute(A2A_SOURCE_WORKSPACE, WORKSPACE_ID)
|
||||
delegate_span.set_attribute(A2A_TARGET_WORKSPACE, workspace_id)
|
||||
delegate_span.set_attribute(A2A_TASK_ID, task_id)
|
||||
|
||||
async with httpx.AsyncClient(timeout=DELEGATION_TIMEOUT) as client:
|
||||
# Discover target URL
|
||||
try:
|
||||
discover_resp = await client.get(
|
||||
f"{PLATFORM_URL}/registry/discover/{workspace_id}",
|
||||
headers={"X-Workspace-ID": WORKSPACE_ID},
|
||||
)
|
||||
if discover_resp.status_code != 200:
|
||||
delegation.status = DelegationStatus.FAILED
|
||||
delegation.error = f"Discovery failed: HTTP {discover_resp.status_code}"
|
||||
log_event(event_type="delegation", action="delegate", resource=workspace_id,
|
||||
outcome="failure", trace_id=task_id, reason="discovery_error")
|
||||
return
|
||||
|
||||
target_url = discover_resp.json().get("url")
|
||||
if not target_url:
|
||||
delegation.status = DelegationStatus.FAILED
|
||||
delegation.error = "No URL for workspace"
|
||||
return
|
||||
except Exception as e:
|
||||
delegation.status = DelegationStatus.FAILED
|
||||
delegation.error = f"Discovery error: {e}"
|
||||
return
|
||||
|
||||
# Send A2A with retry
|
||||
outgoing_headers = inject_trace_headers({
|
||||
"Content-Type": "application/json",
|
||||
"X-Workspace-ID": WORKSPACE_ID,
|
||||
})
|
||||
traceparent = get_current_traceparent()
|
||||
|
||||
last_error = None
|
||||
for attempt in range(DELEGATION_RETRY_ATTEMPTS):
|
||||
try:
|
||||
a2a_resp = await client.post(
|
||||
target_url,
|
||||
headers=outgoing_headers,
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"method": "message/send",
|
||||
"id": f"delegation-{task_id}-{attempt}",
|
||||
"params": {
|
||||
"message": {
|
||||
"role": "user",
|
||||
"parts": [{"kind": "text", "text": task}],
|
||||
"messageId": f"msg-{task_id}-{attempt}",
|
||||
},
|
||||
"metadata": {
|
||||
"parent_task_id": task_id,
|
||||
"source_workspace_id": WORKSPACE_ID,
|
||||
"traceparent": traceparent,
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
if a2a_resp.status_code == 200:
|
||||
try:
|
||||
result = a2a_resp.json()
|
||||
except Exception:
|
||||
delegation.status = DelegationStatus.FAILED
|
||||
delegation.error = "Invalid JSON response"
|
||||
return
|
||||
|
||||
if "result" in result:
|
||||
task_result = result["result"]
|
||||
artifacts = task_result.get("artifacts", [])
|
||||
texts = []
|
||||
for artifact in artifacts:
|
||||
for part in artifact.get("parts", []):
|
||||
if part.get("kind") == "text":
|
||||
texts.append(part["text"])
|
||||
# Also check top-level parts
|
||||
for part in task_result.get("parts", []):
|
||||
if part.get("kind") == "text":
|
||||
texts.append(part["text"])
|
||||
|
||||
delegation.status = DelegationStatus.COMPLETED
|
||||
delegation.result = "\n".join(texts) if texts else str(task_result)
|
||||
log_event(event_type="delegation", action="delegate", resource=workspace_id,
|
||||
outcome="success", trace_id=task_id, attempt=attempt + 1)
|
||||
await _notify_completion(task_id, workspace_id, "completed")
|
||||
# #64: mirror to platform activity_logs so
|
||||
# GET /delegations shows the completion state.
|
||||
await _update_delegation_on_platform(
|
||||
task_id, "completed", "",
|
||||
delegation.result or "",
|
||||
)
|
||||
return
|
||||
|
||||
if "error" in result:
|
||||
last_error = result["error"].get("message", str(result["error"]))
|
||||
break
|
||||
|
||||
except (httpx.ConnectError, httpx.TimeoutException) as e:
|
||||
last_error = str(e)
|
||||
if attempt < DELEGATION_RETRY_ATTEMPTS - 1:
|
||||
await asyncio.sleep(DELEGATION_RETRY_DELAY * (attempt + 1))
|
||||
continue
|
||||
|
||||
delegation.status = DelegationStatus.FAILED
|
||||
delegation.error = str(last_error)
|
||||
log_event(event_type="delegation", action="delegate", resource=workspace_id,
|
||||
outcome="failure", trace_id=task_id, last_error=str(last_error))
|
||||
await _notify_completion(task_id, workspace_id, "failed")
|
||||
# #64: mirror failure to platform activity_logs.
|
||||
await _update_delegation_on_platform(
|
||||
task_id, "failed", str(last_error), "",
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
async def delegate_to_workspace(
|
||||
workspace_id: str,
|
||||
task: str,
|
||||
) -> dict:
|
||||
"""Delegate a task to a peer workspace via A2A protocol (non-blocking).
|
||||
|
||||
Sends the task in the background and returns immediately with a task_id.
|
||||
Use check_delegation_status to poll for the result, or continue working
|
||||
and check later. The delegate works independently.
|
||||
|
||||
Args:
|
||||
workspace_id: The ID of the target workspace to delegate to.
|
||||
task: The task description to send to the peer.
|
||||
|
||||
Returns:
|
||||
A dict with task_id and status="delegated". Use check_delegation_status(task_id) to get results.
|
||||
"""
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# RBAC check
|
||||
roles, custom_perms = get_workspace_roles()
|
||||
if not check_permission("delegate", roles, custom_perms):
|
||||
log_event(event_type="rbac", action="rbac.deny", resource=workspace_id,
|
||||
outcome="denied", trace_id=task_id, attempted_action="delegate", roles=roles)
|
||||
return {"success": False, "error": f"RBAC: no 'delegate' permission. Roles: {roles}"}
|
||||
|
||||
log_event(event_type="delegation", action="delegate", resource=workspace_id,
|
||||
outcome="dispatched", trace_id=task_id, task_preview=task[:200])
|
||||
|
||||
# Store the delegation and launch background task
|
||||
delegation = DelegationTask(
|
||||
task_id=task_id,
|
||||
workspace_id=workspace_id,
|
||||
task_description=task[:200],
|
||||
)
|
||||
_delegations[task_id] = delegation
|
||||
_evict_old_delegations()
|
||||
|
||||
bg_task = asyncio.create_task(_execute_delegation(task_id, workspace_id, task))
|
||||
_background_tasks.add(bg_task)
|
||||
bg_task.add_done_callback(_on_task_done)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"task_id": task_id,
|
||||
"status": "delegated",
|
||||
"message": f"Task delegated to {workspace_id}. Use check_delegation_status('{task_id}') to get the result when ready.",
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
async def check_delegation_status(
|
||||
task_id: str = "",
|
||||
) -> dict:
|
||||
"""Check the status of a delegated task, or list all active delegations.
|
||||
|
||||
Args:
|
||||
task_id: The task_id returned by delegate_to_workspace. If empty, lists all delegations.
|
||||
|
||||
Returns:
|
||||
Status and result (if completed) of the delegation.
|
||||
"""
|
||||
if not task_id:
|
||||
# List all delegations
|
||||
summary = []
|
||||
for tid, d in _delegations.items():
|
||||
entry = {
|
||||
"task_id": tid,
|
||||
"workspace_id": d.workspace_id,
|
||||
"status": d.status.value,
|
||||
"task": d.task_description,
|
||||
}
|
||||
if d.status == DelegationStatus.COMPLETED:
|
||||
entry["result_preview"] = (d.result or "")[:200]
|
||||
if d.status == DelegationStatus.FAILED:
|
||||
entry["error"] = d.error
|
||||
summary.append(entry)
|
||||
return {"delegations": summary, "count": len(summary)}
|
||||
|
||||
delegation = _delegations.get(task_id)
|
||||
if not delegation:
|
||||
return {"error": f"No delegation found with task_id {task_id}"}
|
||||
|
||||
result = {
|
||||
"task_id": task_id,
|
||||
"workspace_id": delegation.workspace_id,
|
||||
"status": delegation.status.value,
|
||||
"task": delegation.task_description,
|
||||
}
|
||||
|
||||
if delegation.status == DelegationStatus.COMPLETED:
|
||||
result["result"] = delegation.result
|
||||
elif delegation.status == DelegationStatus.FAILED:
|
||||
result["error"] = delegation.error
|
||||
|
||||
return result
|
||||
403
molecule_runtime/builtin_tools/governance.py
Normal file
403
molecule_runtime/builtin_tools/governance.py
Normal file
@ -0,0 +1,403 @@
|
||||
"""Bridge between Molecule AI's RBAC + audit subsystem and the Microsoft Agent
|
||||
Governance Toolkit (agent-os-kernel, released April 2, 2026).
|
||||
|
||||
Integration points
|
||||
------------------
|
||||
* ``check_permission`` → ``PolicyEvaluator.evaluate()``
|
||||
Molecule AI's RBAC gate runs first; if RBAC allows the action the toolkit
|
||||
evaluator is consulted according to ``policy_mode``.
|
||||
|
||||
* ``log_event`` → governance audit sink
|
||||
Every permission decision (allow or deny) is written via
|
||||
``tools.audit.log_event`` with extra governance metadata so the full
|
||||
decision trail lands in Molecule AI's existing audit stream.
|
||||
|
||||
* OTEL traceparent flows through
|
||||
``tools.telemetry.get_current_traceparent()`` is called inside ``emit()``
|
||||
and the W3C traceparent string is attached to every audit record, giving
|
||||
end-to-end distributed tracing across agent boundaries.
|
||||
|
||||
Graceful degradation
|
||||
--------------------
|
||||
If ``agent-os-kernel`` is not installed the module falls back to Molecule AI
|
||||
RBAC alone. No exception propagates to the agent — governance is a
|
||||
best-effort overlay, never a hard dependency.
|
||||
|
||||
Install::
|
||||
|
||||
pip install agent-os-kernel
|
||||
|
||||
Minimal config.yaml snippet::
|
||||
|
||||
governance:
|
||||
enabled: true
|
||||
toolkit: microsoft
|
||||
policy_mode: strict # strict | permissive | audit
|
||||
policy_endpoint: https://your-tenant.governance.azure.com
|
||||
policy_file: policies/workspace.rego
|
||||
blocked_patterns:
|
||||
- ".*\\.exec$"
|
||||
- "shell\\."
|
||||
max_tool_calls_per_task: 50
|
||||
|
||||
NOTE: The agent-os-kernel package was released April 2, 2026 and is in
|
||||
community preview. The API bindings in this module target v3.0.x of the
|
||||
package (agent_os.policies.PolicyEvaluator). If the package API changes,
|
||||
update _init_evaluator() accordingly.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
WORKSPACE_ID: str = os.environ.get("WORKSPACE_ID", "")
|
||||
|
||||
# Module-level singleton — set by initialize_governance() at startup
|
||||
_adapter: Optional["GovernanceAdapter"] = None
|
||||
|
||||
|
||||
class GovernanceAdapter:
|
||||
"""Bridges Molecule AI RBAC + audit trail to the Microsoft Agent Governance Toolkit."""
|
||||
|
||||
def __init__(self, config: Any) -> None:
|
||||
self._config = config
|
||||
self._evaluator = None
|
||||
self._toolkit_available: bool = False
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Async entry point: initialise evaluator and log outcome."""
|
||||
self._init_evaluator()
|
||||
if self._toolkit_available:
|
||||
logger.info(
|
||||
"GovernanceAdapter initialised — toolkit=%s mode=%s",
|
||||
self._config.toolkit,
|
||||
self._config.policy_mode,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"GovernanceAdapter initialised in RBAC-only mode "
|
||||
"(agent-os-kernel not available or failed to load)."
|
||||
)
|
||||
|
||||
def _init_evaluator(self) -> None:
|
||||
"""Lazy-import and configure the PolicyEvaluator from agent-os-kernel.
|
||||
|
||||
All failures are caught and logged; the adapter simply runs without
|
||||
the toolkit rather than crashing the workspace.
|
||||
"""
|
||||
try:
|
||||
try:
|
||||
from agent_os.policies import PolicyEvaluator # type: ignore[import]
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"agent-os-kernel is not installed — graceful degradation active. "
|
||||
"Governance will use Molecule AI RBAC only. "
|
||||
"To enable the Microsoft Agent Governance Toolkit run: "
|
||||
"pip install agent-os-kernel"
|
||||
)
|
||||
return
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"policy_mode": self._config.policy_mode,
|
||||
"max_tool_calls_per_task": self._config.max_tool_calls_per_task,
|
||||
"blocked_patterns": self._config.blocked_patterns,
|
||||
}
|
||||
if self._config.policy_endpoint:
|
||||
kwargs["endpoint"] = self._config.policy_endpoint
|
||||
|
||||
self._evaluator = PolicyEvaluator(**kwargs)
|
||||
|
||||
# Load a policy file if one is configured and exists on disk.
|
||||
if self._config.policy_file:
|
||||
policy_file = self._config.policy_file
|
||||
if os.path.exists(policy_file):
|
||||
ext = os.path.splitext(policy_file)[1].lower()
|
||||
if ext == ".rego":
|
||||
self._evaluator.load_rego(path=policy_file)
|
||||
logger.info("Loaded Rego policy file: %s", policy_file)
|
||||
elif ext in (".yaml", ".yml"):
|
||||
self._evaluator.load_yaml(path=policy_file)
|
||||
logger.info("Loaded YAML policy file: %s", policy_file)
|
||||
elif ext == ".cedar":
|
||||
self._evaluator.load_cedar(path=policy_file)
|
||||
logger.info("Loaded Cedar policy file: %s", policy_file)
|
||||
else:
|
||||
logger.warning(
|
||||
"Unrecognised policy file extension '%s' — skipping load.",
|
||||
ext,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"policy_file '%s' does not exist — skipping load.",
|
||||
policy_file,
|
||||
)
|
||||
|
||||
self._toolkit_available = True
|
||||
logger.info(
|
||||
"agent-os-kernel PolicyEvaluator ready — policy_mode=%s",
|
||||
self._config.policy_mode,
|
||||
)
|
||||
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning(
|
||||
"Failed to initialise agent-os-kernel PolicyEvaluator: %s — "
|
||||
"graceful degradation active (RBAC only).",
|
||||
exc,
|
||||
)
|
||||
|
||||
def check_permission(
|
||||
self,
|
||||
action: str,
|
||||
roles: list[str],
|
||||
custom_permissions: dict | None = None,
|
||||
context: dict | None = None,
|
||||
) -> tuple[bool, str]:
|
||||
"""Evaluate an action against Molecule AI RBAC and (optionally) the toolkit.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[bool, str]
|
||||
``(allowed, reason)`` — reason is a short human-readable string
|
||||
explaining the decision.
|
||||
"""
|
||||
from builtin_tools import audit # inline import to avoid circular dependencies
|
||||
|
||||
context = context or {}
|
||||
|
||||
# --- Step 1: Molecule AI RBAC gate (always runs) ---
|
||||
rbac_allowed: bool = audit.check_permission(action, roles, custom_permissions)
|
||||
|
||||
if not rbac_allowed:
|
||||
self.emit(
|
||||
event_type="permission_check",
|
||||
action=action,
|
||||
resource=context.get("resource", ""),
|
||||
outcome="denied",
|
||||
actor=context.get("actor"),
|
||||
policy_decision="rbac_deny",
|
||||
roles=roles,
|
||||
)
|
||||
return False, f"RBAC denied action '{action}' for roles {roles}"
|
||||
|
||||
# --- Step 2: If toolkit unavailable or audit-only mode, return RBAC result ---
|
||||
if not self._toolkit_available or self._config.policy_mode == "audit":
|
||||
self.emit(
|
||||
event_type="permission_check",
|
||||
action=action,
|
||||
resource=context.get("resource", ""),
|
||||
outcome="allowed",
|
||||
actor=context.get("actor"),
|
||||
policy_decision="rbac_allowed",
|
||||
roles=roles,
|
||||
toolkit_mode=self._config.policy_mode,
|
||||
)
|
||||
return rbac_allowed, "rbac_allowed"
|
||||
|
||||
# --- Step 3: Toolkit evaluation ---
|
||||
eval_context: dict[str, Any] = {
|
||||
"action": action,
|
||||
"resource": context.get("resource", ""),
|
||||
"roles": roles,
|
||||
"workspace_id": WORKSPACE_ID,
|
||||
}
|
||||
# Merge any extra context keys the caller supplied.
|
||||
for key, value in context.items():
|
||||
if key not in eval_context:
|
||||
eval_context[key] = value
|
||||
|
||||
toolkit_allowed: bool = True
|
||||
reason: str = ""
|
||||
evaluator_name: str = "agent-os-kernel"
|
||||
|
||||
try:
|
||||
decision = self._evaluator.evaluate(eval_context)
|
||||
toolkit_allowed = getattr(decision, "allowed", True)
|
||||
reason = getattr(decision, "reason", "")
|
||||
evaluator_name = getattr(decision, "evaluator_name", "agent-os-kernel")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning(
|
||||
"agent-os-kernel evaluation raised an exception: %s — "
|
||||
"falling back to RBAC result to avoid blocking the agent.",
|
||||
exc,
|
||||
)
|
||||
self.emit(
|
||||
event_type="permission_check",
|
||||
action=action,
|
||||
resource=context.get("resource", ""),
|
||||
outcome="allowed",
|
||||
actor=context.get("actor"),
|
||||
policy_decision="toolkit_evaluation_error",
|
||||
toolkit_mode=self._config.policy_mode,
|
||||
roles=roles,
|
||||
)
|
||||
return rbac_allowed, "toolkit_evaluation_error"
|
||||
|
||||
# --- Step 4: Combine results according to policy_mode ---
|
||||
if self._config.policy_mode == "permissive":
|
||||
# Toolkit denial is advisory only in permissive mode.
|
||||
if not toolkit_allowed:
|
||||
logger.warning(
|
||||
"Governance toolkit denied action '%s' (reason=%s) but policy_mode "
|
||||
"is 'permissive' — allowing and logging advisory denial.",
|
||||
action,
|
||||
reason,
|
||||
)
|
||||
final_allowed = rbac_allowed
|
||||
else:
|
||||
# strict: both gates must allow.
|
||||
final_allowed = rbac_allowed and toolkit_allowed
|
||||
|
||||
outcome = "allowed" if final_allowed else "denied"
|
||||
self.emit(
|
||||
event_type="permission_check",
|
||||
action=action,
|
||||
resource=context.get("resource", ""),
|
||||
outcome=outcome,
|
||||
actor=context.get("actor"),
|
||||
policy_decision=reason or outcome,
|
||||
evaluator=evaluator_name,
|
||||
toolkit_mode=self._config.policy_mode,
|
||||
roles=roles,
|
||||
)
|
||||
return final_allowed, reason or "allowed"
|
||||
|
||||
def emit(
|
||||
self,
|
||||
event_type: str,
|
||||
action: str,
|
||||
resource: str,
|
||||
outcome: str,
|
||||
actor: str | None = None,
|
||||
trace_id: str | None = None,
|
||||
**extra: Any,
|
||||
) -> str:
|
||||
"""Write a governance-annotated audit event.
|
||||
|
||||
Pulls the current W3C traceparent from the active OTEL span so that
|
||||
governance decisions are traceable across service boundaries.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
The ``trace_id`` produced by ``audit.log_event``.
|
||||
"""
|
||||
from builtin_tools import audit # inline import to avoid circular dependencies
|
||||
from builtin_tools.telemetry import get_current_traceparent # inline import
|
||||
|
||||
traceparent: str | None = get_current_traceparent()
|
||||
|
||||
recorded_trace_id: str = audit.log_event(
|
||||
event_type,
|
||||
action,
|
||||
resource,
|
||||
outcome,
|
||||
actor=actor,
|
||||
trace_id=trace_id,
|
||||
governance_toolkit=(
|
||||
self._config.toolkit if self._toolkit_available else "disabled"
|
||||
),
|
||||
traceparent=traceparent or "",
|
||||
**extra,
|
||||
)
|
||||
return recorded_trace_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def initialize_governance(config: Any) -> Optional[GovernanceAdapter]:
|
||||
"""Initialize the module-level GovernanceAdapter singleton.
|
||||
|
||||
Called once at startup by main.py when governance.enabled is True.
|
||||
Returns the adapter, or None if initialization fails.
|
||||
"""
|
||||
global _adapter
|
||||
|
||||
try:
|
||||
adapter = GovernanceAdapter(config)
|
||||
await adapter.initialize()
|
||||
_adapter = adapter
|
||||
logger.info(
|
||||
"Governance singleton initialised — toolkit=%s mode=%s",
|
||||
config.toolkit,
|
||||
config.policy_mode,
|
||||
)
|
||||
return adapter
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning(
|
||||
"initialize_governance() failed: %s — governance disabled for this session.",
|
||||
exc,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def get_governance_adapter() -> Optional[GovernanceAdapter]:
|
||||
"""Return the module-level GovernanceAdapter singleton (may be None)."""
|
||||
return _adapter
|
||||
|
||||
|
||||
def check_permission_with_governance(
|
||||
action: str,
|
||||
roles: list[str],
|
||||
custom_permissions: dict | None = None,
|
||||
context: dict | None = None,
|
||||
) -> tuple[bool, str]:
|
||||
"""Convenience wrapper: use GovernanceAdapter when available, else RBAC only.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
action:
|
||||
The action name to evaluate (e.g. ``"memory.write"``).
|
||||
roles:
|
||||
The list of role names held by the requesting actor.
|
||||
custom_permissions:
|
||||
Optional custom role→action mapping to overlay on built-in roles.
|
||||
context:
|
||||
Optional extra context forwarded to the PolicyEvaluator.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[bool, str]
|
||||
``(allowed, reason)``
|
||||
"""
|
||||
if _adapter is None:
|
||||
from builtin_tools import audit # inline import to avoid circular dependencies
|
||||
|
||||
result: bool = audit.check_permission(action, roles, custom_permissions)
|
||||
return result, "rbac_only"
|
||||
|
||||
return _adapter.check_permission(action, roles, custom_permissions, context)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Private helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _emit_governance_event(
|
||||
event_type: str,
|
||||
action: str,
|
||||
resource: str,
|
||||
outcome: str,
|
||||
actor: str | None = None,
|
||||
trace_id: str | None = None,
|
||||
**extra: Any,
|
||||
) -> Optional[str]:
|
||||
"""Emit a governance audit event via the singleton adapter if one is set.
|
||||
|
||||
Returns the trace_id produced by log_event, or None if no adapter is set.
|
||||
"""
|
||||
if _adapter is None:
|
||||
return None
|
||||
return _adapter.emit(
|
||||
event_type,
|
||||
action,
|
||||
resource,
|
||||
outcome,
|
||||
actor=actor,
|
||||
trace_id=trace_id,
|
||||
**extra,
|
||||
)
|
||||
531
molecule_runtime/builtin_tools/hitl.py
Normal file
531
molecule_runtime/builtin_tools/hitl.py
Normal file
@ -0,0 +1,531 @@
|
||||
"""Human-In-The-Loop (HITL) workflow primitives.
|
||||
|
||||
Generalizes the approval tool into reusable HITL building blocks that work
|
||||
across all Molecule AI adapters.
|
||||
|
||||
Features
|
||||
--------
|
||||
@requires_approval
|
||||
Decorator that gates *any* async callable (tool, method, standalone fn)
|
||||
behind a human approval request. The decorated function only runs if
|
||||
the request is granted. Roles in ``hitl.bypass_roles`` skip the gate.
|
||||
|
||||
pause_task / resume_task
|
||||
LangChain tools for explicit pause/resume of in-flight tasks. An agent
|
||||
calls ``pause_task(task_id, reason)`` to suspend itself; an external
|
||||
signal (webhook, dashboard click, another agent) calls ``resume_task``
|
||||
with the same task_id to wake it up.
|
||||
|
||||
Notification channels
|
||||
---------------------
|
||||
Configured under ``hitl:`` in ``config.yaml``:
|
||||
|
||||
hitl:
|
||||
channels:
|
||||
- type: dashboard # always active; uses platform approval API
|
||||
- type: slack
|
||||
webhook_url: https://hooks.slack.com/services/…
|
||||
- type: email
|
||||
smtp_host: smtp.example.com
|
||||
smtp_port: 587
|
||||
from: alerts@example.com
|
||||
to: ops@example.com
|
||||
username: alerts@example.com # optional; password from SMTP_PASSWORD env
|
||||
default_timeout: 300 # seconds before an unanswered request times out
|
||||
bypass_roles: [admin] # roles that skip the approval gate entirely
|
||||
|
||||
Environment variables
|
||||
---------------------
|
||||
SMTP_PASSWORD Password for SMTP authentication (preferred over config file)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import smtplib
|
||||
from dataclasses import dataclass, field
|
||||
from email.mime.text import MIMEText
|
||||
from typing import Any, Callable
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class HITLConfig:
|
||||
"""HITL settings loaded from the ``hitl:`` block in config.yaml."""
|
||||
channels: list[dict] = field(default_factory=lambda: [{"type": "dashboard"}])
|
||||
default_timeout: float = 300.0
|
||||
bypass_roles: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
def _load_hitl_config() -> HITLConfig:
|
||||
"""Load HITL config from workspace config; fall back to safe defaults."""
|
||||
try:
|
||||
from config import load_config
|
||||
cfg = load_config()
|
||||
raw = getattr(cfg, "hitl", None)
|
||||
if raw is None:
|
||||
return HITLConfig()
|
||||
return HITLConfig(
|
||||
channels=raw.channels if hasattr(raw, "channels") else [{"type": "dashboard"}],
|
||||
default_timeout=float(raw.default_timeout if hasattr(raw, "default_timeout") else 300),
|
||||
bypass_roles=list(raw.bypass_roles if hasattr(raw, "bypass_roles") else []),
|
||||
)
|
||||
except Exception:
|
||||
return HITLConfig()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pause / Resume registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _TaskPauseRegistry:
|
||||
"""In-process registry mapping task_id → asyncio.Event + optional result.
|
||||
|
||||
Multiple coroutines awaiting the same task_id are all unblocked when
|
||||
``resume()`` is called. Results survive until the awaiting coroutine
|
||||
calls ``pop_result()``.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._events: dict[str, asyncio.Event] = {}
|
||||
self._results: dict[str, dict] = {}
|
||||
# #265: owner map — workspace_id that created each task.
|
||||
# Empty string means "no owner / legacy" (bypasses ownership check).
|
||||
self._owners: dict[str, str] = {}
|
||||
|
||||
def register(self, task_id: str, owner: str = "") -> asyncio.Event:
|
||||
"""Create and store an Event for *task_id*. Returns the event.
|
||||
|
||||
Args:
|
||||
task_id: Unique task identifier.
|
||||
owner: Workspace ID that owns this task. When set, ``resume``
|
||||
will reject callers from a different workspace.
|
||||
"""
|
||||
ev = asyncio.Event()
|
||||
self._events[task_id] = ev
|
||||
self._owners[task_id] = owner
|
||||
return ev
|
||||
|
||||
def resume(self, task_id: str, result: dict | None = None, owner: str = "") -> bool:
|
||||
"""Signal the Event for *task_id*. Returns False if not registered.
|
||||
|
||||
Args:
|
||||
task_id: The identifier used in ``register``.
|
||||
result: Optional result payload forwarded to the waiting coroutine.
|
||||
owner: Caller's workspace ID. When both the stored owner and
|
||||
*owner* are non-empty and they differ, the call is rejected
|
||||
(returns False) — prevents cross-workspace prompt injection
|
||||
(#265). Passing ``owner=""`` bypasses the check (used in
|
||||
direct registry calls from tests and platform code).
|
||||
"""
|
||||
# #265 ownership check
|
||||
stored_owner = self._owners.get(task_id, "")
|
||||
if owner and stored_owner and owner != stored_owner:
|
||||
logger.warning(
|
||||
"HITL: resume rejected for task %s — caller workspace %r != owner %r",
|
||||
task_id, owner, stored_owner,
|
||||
)
|
||||
return False
|
||||
ev = self._events.get(task_id)
|
||||
if ev is None:
|
||||
return False
|
||||
self._results[task_id] = result or {}
|
||||
ev.set()
|
||||
return True
|
||||
|
||||
def pop_result(self, task_id: str) -> dict:
|
||||
"""Return and remove the stored result for *task_id*."""
|
||||
return self._results.pop(task_id, {})
|
||||
|
||||
def cleanup(self, task_id: str) -> None:
|
||||
"""Remove *task_id* from all dicts."""
|
||||
self._events.pop(task_id, None)
|
||||
self._results.pop(task_id, None)
|
||||
self._owners.pop(task_id, None)
|
||||
|
||||
def list_paused(self) -> list[str]:
|
||||
"""Return IDs of tasks whose events have not yet been set."""
|
||||
return [tid for tid, ev in self._events.items() if not ev.is_set()]
|
||||
|
||||
|
||||
# Global singleton — safe within one asyncio event loop / process
|
||||
pause_registry = _TaskPauseRegistry()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Notification channels
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _notify_channels(
|
||||
action: str,
|
||||
reason: str,
|
||||
approval_id: str,
|
||||
cfg: HITLConfig,
|
||||
) -> None:
|
||||
"""Fire-and-forget notifications to all configured channels.
|
||||
|
||||
Errors in individual channels are logged but never re-raised so that a
|
||||
misconfigured Slack webhook cannot block the approval flow.
|
||||
"""
|
||||
platform_url = os.environ.get("PLATFORM_URL", "http://platform:8080")
|
||||
workspace_id = os.environ.get("WORKSPACE_ID", "")
|
||||
|
||||
for channel in cfg.channels:
|
||||
ch_type = channel.get("type", "dashboard")
|
||||
try:
|
||||
if ch_type == "slack":
|
||||
await _notify_slack(channel, action, reason, approval_id,
|
||||
platform_url, workspace_id)
|
||||
elif ch_type == "email":
|
||||
await _notify_email(channel, action, reason, approval_id,
|
||||
platform_url, workspace_id)
|
||||
# "dashboard" is handled by the platform via the approval POST
|
||||
except Exception as exc:
|
||||
logger.warning("HITL: channel '%s' notification failed: %s", ch_type, exc)
|
||||
|
||||
|
||||
async def _notify_slack(
|
||||
cfg: dict,
|
||||
action: str,
|
||||
reason: str,
|
||||
approval_id: str,
|
||||
platform_url: str,
|
||||
workspace_id: str,
|
||||
) -> None:
|
||||
webhook_url = cfg.get("webhook_url", "")
|
||||
if not webhook_url:
|
||||
return
|
||||
|
||||
approve_url = f"{platform_url}/workspaces/{workspace_id}/approvals/{approval_id}/approve"
|
||||
deny_url = f"{platform_url}/workspaces/{workspace_id}/approvals/{approval_id}/deny"
|
||||
|
||||
payload = {
|
||||
"text": f":warning: Approval required from workspace `{workspace_id}`",
|
||||
"blocks": [
|
||||
{
|
||||
"type": "section",
|
||||
"text": {
|
||||
"type": "mrkdwn",
|
||||
"text": (
|
||||
f"*Action:* {action}\n"
|
||||
f"*Reason:* {reason}\n"
|
||||
f"*Approval ID:* `{approval_id}`"
|
||||
),
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "actions",
|
||||
"elements": [
|
||||
{
|
||||
"type": "button",
|
||||
"text": {"type": "plain_text", "text": "Approve"},
|
||||
"style": "primary",
|
||||
"url": approve_url,
|
||||
},
|
||||
{
|
||||
"type": "button",
|
||||
"text": {"type": "plain_text", "text": "Deny"},
|
||||
"style": "danger",
|
||||
"url": deny_url,
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
await client.post(webhook_url, json=payload)
|
||||
logger.info("HITL: Slack notification sent for approval %s", approval_id)
|
||||
|
||||
|
||||
async def _notify_email(
|
||||
cfg: dict,
|
||||
action: str,
|
||||
reason: str,
|
||||
approval_id: str,
|
||||
platform_url: str,
|
||||
workspace_id: str,
|
||||
) -> None:
|
||||
smtp_host = cfg.get("smtp_host", "")
|
||||
smtp_port = int(cfg.get("smtp_port", 587))
|
||||
from_addr = cfg.get("from", "")
|
||||
to_addr = cfg.get("to", "")
|
||||
|
||||
if not all([smtp_host, from_addr, to_addr]):
|
||||
logger.warning("HITL: email channel missing smtp_host/from/to — skipping")
|
||||
return
|
||||
|
||||
approve_url = f"{platform_url}/workspaces/{workspace_id}/approvals/{approval_id}/approve"
|
||||
deny_url = f"{platform_url}/workspaces/{workspace_id}/approvals/{approval_id}/deny"
|
||||
|
||||
body = (
|
||||
f"Approval required from workspace {workspace_id}\n\n"
|
||||
f"Action : {action}\n"
|
||||
f"Reason : {reason}\n"
|
||||
f"ID : {approval_id}\n\n"
|
||||
f"Approve: {approve_url}\n"
|
||||
f"Deny : {deny_url}\n"
|
||||
)
|
||||
|
||||
msg = MIMEText(body, "plain", "utf-8")
|
||||
msg["Subject"] = f"[Molecule AI] Approval required: {action}"
|
||||
msg["From"] = from_addr
|
||||
msg["To"] = to_addr
|
||||
|
||||
username = cfg.get("username", "")
|
||||
password = cfg.get("password", os.environ.get("SMTP_PASSWORD", ""))
|
||||
|
||||
def _send() -> None:
|
||||
with smtplib.SMTP(smtp_host, smtp_port) as srv:
|
||||
srv.ehlo()
|
||||
srv.starttls()
|
||||
if username and password:
|
||||
srv.login(username, password)
|
||||
srv.send_message(msg)
|
||||
|
||||
await asyncio.to_thread(_send)
|
||||
logger.info("HITL: email notification sent for approval %s", approval_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# @requires_approval decorator
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def requires_approval(
|
||||
action_description: str = "",
|
||||
reason_template: str = "",
|
||||
bypass_roles: list[str] | None = None,
|
||||
) -> Callable[[Callable], Callable]:
|
||||
"""Decorator that gates an async callable behind a human approval request.
|
||||
|
||||
The wrapped function executes only when a human approves. Use this on
|
||||
any tool or async helper that performs destructive or high-impact work.
|
||||
|
||||
Args:
|
||||
action_description: Short label for the action shown to the approver.
|
||||
Defaults to the function's ``name`` attribute or
|
||||
``__name__``.
|
||||
reason_template: f-string template for the reason line. Keyword
|
||||
arguments of the decorated function are available,
|
||||
e.g. ``"Delete table {table_name}"``).
|
||||
bypass_roles: Roles that skip the gate entirely. Overrides
|
||||
``hitl.bypass_roles`` in config.yaml when given.
|
||||
|
||||
Returns:
|
||||
A decorator; applying it to a function returns an async wrapper.
|
||||
|
||||
Usage::
|
||||
|
||||
@tool
|
||||
@requires_approval("Wipe production DB", bypass_roles=["admin"])
|
||||
async def drop_table(table_name: str) -> dict:
|
||||
...
|
||||
|
||||
# Works with plain async functions too:
|
||||
@requires_approval("Send customer email")
|
||||
async def send_email(to: str, body: str) -> dict:
|
||||
...
|
||||
"""
|
||||
def decorator(fn: Callable) -> Callable:
|
||||
action = action_description or getattr(fn, "name", None) or fn.__name__
|
||||
|
||||
@functools.wraps(fn)
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
hitl_cfg = _load_hitl_config()
|
||||
|
||||
# --- Check bypass roles -----------------------------------------
|
||||
active_bypass = bypass_roles if bypass_roles is not None else hitl_cfg.bypass_roles
|
||||
if active_bypass:
|
||||
try:
|
||||
from builtin_tools.audit import get_workspace_roles
|
||||
roles, _ = get_workspace_roles()
|
||||
if any(r in active_bypass for r in roles):
|
||||
logger.info(
|
||||
"@requires_approval bypassed (role %s) for '%s'", roles, action
|
||||
)
|
||||
return await fn(*args, **kwargs)
|
||||
except Exception:
|
||||
pass # If RBAC check fails, proceed to approval gate
|
||||
|
||||
# --- Build reason string -----------------------------------------
|
||||
if reason_template:
|
||||
try:
|
||||
reason = reason_template.format(**kwargs)
|
||||
except (KeyError, IndexError):
|
||||
reason = reason_template
|
||||
else:
|
||||
arg_parts = [f"{k}={str(v)[:60]}" for k, v in list(kwargs.items())[:3]]
|
||||
reason = f"Args: {', '.join(arg_parts)}" if arg_parts else "Automated action"
|
||||
|
||||
# --- Fire non-dashboard notifications (async, non-blocking) ------
|
||||
asyncio.create_task(
|
||||
_notify_channels(action, reason, "pending", hitl_cfg)
|
||||
)
|
||||
|
||||
# --- Request approval via approval tool --------------------------
|
||||
try:
|
||||
from builtin_tools.approval import request_approval
|
||||
approval_result = await request_approval.ainvoke(
|
||||
{"action": action, "reason": reason}
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("@requires_approval: approval call failed: %s", exc)
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Approval gate error: {exc}",
|
||||
}
|
||||
|
||||
if not approval_result.get("approved"):
|
||||
return {
|
||||
"success": False,
|
||||
"error": (
|
||||
f"Action '{action}' not approved: "
|
||||
f"{approval_result.get('message', approval_result.get('error', 'denied'))}"
|
||||
),
|
||||
"approval_id": approval_result.get("approval_id"),
|
||||
}
|
||||
|
||||
# --- Approved — run the original function ------------------------
|
||||
return await fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pause / Resume LangChain tools
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@tool
|
||||
async def pause_task(task_id: str, reason: str = "") -> dict:
|
||||
"""Suspend the current task and wait for a resume signal.
|
||||
|
||||
The agent calls this to pause itself at a decision point. Execution
|
||||
resumes when ``resume_task`` is called with the same task_id, or after
|
||||
the configured ``hitl.default_timeout`` seconds.
|
||||
|
||||
Args:
|
||||
task_id: Unique identifier for this pause point (use the A2A task ID
|
||||
or any stable string that the caller can reference later).
|
||||
reason: Human-readable description of why the task is pausing.
|
||||
"""
|
||||
# #265: record workspace ownership on registration so resume_task can
|
||||
# reject callers from a different workspace (cross-workspace prompt-injection
|
||||
# prevention). External task_id is unchanged — only internal ownership
|
||||
# metadata is added, so no tests or callers need to update their task IDs.
|
||||
_ws = os.environ.get("WORKSPACE_ID", "")
|
||||
|
||||
try:
|
||||
from builtin_tools.audit import log_event
|
||||
log_event(
|
||||
event_type="hitl",
|
||||
action="pause",
|
||||
resource=task_id,
|
||||
outcome="paused",
|
||||
trace_id=task_id,
|
||||
reason=reason,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
event = pause_registry.register(task_id, owner=_ws)
|
||||
timeout = _load_hitl_config().default_timeout
|
||||
logger.info("HITL: task %s paused — %s", task_id, reason or "(no reason given)")
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(event.wait(), timeout=timeout)
|
||||
result = pause_registry.pop_result(task_id)
|
||||
logger.info("HITL: task %s resumed", task_id)
|
||||
try:
|
||||
from builtin_tools.audit import log_event
|
||||
log_event(
|
||||
event_type="hitl",
|
||||
action="resume",
|
||||
resource=task_id,
|
||||
outcome="resumed",
|
||||
trace_id=task_id,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return {"resumed": True, "task_id": task_id, **result}
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("HITL: task %s timed out after %.0fs", task_id, timeout)
|
||||
try:
|
||||
from builtin_tools.audit import log_event
|
||||
log_event(
|
||||
event_type="hitl",
|
||||
action="pause",
|
||||
resource=task_id,
|
||||
outcome="timeout",
|
||||
trace_id=task_id,
|
||||
timeout_seconds=timeout,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"resumed": False,
|
||||
"task_id": task_id,
|
||||
"error": f"Timed out after {timeout:.0f}s waiting for resume signal",
|
||||
}
|
||||
finally:
|
||||
pause_registry.cleanup(task_id)
|
||||
|
||||
|
||||
@tool
|
||||
async def resume_task(task_id: str, message: str = "") -> dict:
|
||||
"""Resume a previously paused task.
|
||||
|
||||
Signals the ``pause_task`` coroutine waiting on *task_id* to continue.
|
||||
Safe to call even if the task has already resumed or timed out (returns
|
||||
success=False in that case).
|
||||
|
||||
Args:
|
||||
task_id: The identifier passed to ``pause_task``.
|
||||
message: Optional message forwarded to the resumed task.
|
||||
"""
|
||||
# #265: pass caller's workspace ID so the registry can reject a resume
|
||||
# from a different workspace (ownership check in _TaskPauseRegistry.resume).
|
||||
_ws = os.environ.get("WORKSPACE_ID", "")
|
||||
|
||||
result_payload = {"message": message} if message else {}
|
||||
success = pause_registry.resume(task_id, result_payload, owner=_ws)
|
||||
|
||||
if success:
|
||||
logger.info("HITL: resume signal sent for task %s", task_id)
|
||||
try:
|
||||
from builtin_tools.audit import log_event
|
||||
log_event(
|
||||
event_type="hitl",
|
||||
action="resume",
|
||||
resource=task_id,
|
||||
outcome="success",
|
||||
trace_id=task_id,
|
||||
message=message,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return {"success": True, "task_id": task_id}
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"task_id": task_id,
|
||||
"error": "Task not found or already resumed",
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
async def list_paused_tasks() -> dict:
|
||||
"""List all tasks currently suspended and waiting for a resume signal."""
|
||||
paused = pause_registry.list_paused()
|
||||
return {"paused_tasks": paused, "count": len(paused)}
|
||||
106
molecule_runtime/builtin_tools/medo.py
Normal file
106
molecule_runtime/builtin_tools/medo.py
Normal file
@ -0,0 +1,106 @@
|
||||
"""MeDo builtin tools — Baidu MeDo no-code AI platform integration.
|
||||
|
||||
MeDo (摩搭, moda.baidu.com) is Baidu's no-code AI application builder used in
|
||||
the Molecule AI hackathon integration (May 2026). Three core operations:
|
||||
create_medo_app — scaffold a new application from a template
|
||||
update_medo_app — push content / config changes to an existing app
|
||||
publish_medo_app — publish a draft app to a target environment
|
||||
|
||||
Authentication: set MEDO_API_KEY as a workspace secret.
|
||||
Override base URL via MEDO_BASE_URL (default: https://api.moda.baidu.com/v1).
|
||||
|
||||
Mock backend: when MEDO_API_KEY is absent the tools return a predictable stub
|
||||
response — safe for unit tests and local development.
|
||||
TODO: swap _mock_http_post for a real httpx.AsyncClient call once keys are live.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MEDO_BASE_URL = os.environ.get("MEDO_BASE_URL", "https://api.moda.baidu.com/v1")
|
||||
MEDO_API_KEY = os.environ.get("MEDO_API_KEY", "")
|
||||
|
||||
_VALID_TEMPLATES = ("blank", "chatbot", "form", "dashboard")
|
||||
_VALID_ENVS = ("production", "staging")
|
||||
|
||||
|
||||
async def _mock_http_post(path: str, payload: dict) -> dict:
|
||||
"""Stub HTTP call. TODO: replace with real httpx.AsyncClient once MEDO_API_KEY is live."""
|
||||
return {"status": "ok", "mock": True, "path": path, "payload_keys": list(payload.keys())}
|
||||
|
||||
|
||||
@tool
|
||||
async def create_medo_app(name: str, template: str = "blank", description: str = "") -> dict:
|
||||
"""Create a new MeDo application.
|
||||
|
||||
Args:
|
||||
name: Application name (required).
|
||||
template: Starting template — blank | chatbot | form | dashboard (default: blank).
|
||||
description: Short description of the application.
|
||||
|
||||
Returns:
|
||||
dict with 'app_id' and 'status' on success, 'error' key on failure.
|
||||
"""
|
||||
if not name:
|
||||
return {"error": "name is required"}
|
||||
if template not in _VALID_TEMPLATES:
|
||||
return {"error": f"template must be one of: {', '.join(_VALID_TEMPLATES)}"}
|
||||
try:
|
||||
result = await _mock_http_post("/apps", {"name": name, "template": template, "description": description})
|
||||
logger.info("MeDo create_app: name=%s template=%s → %s", name, template, result)
|
||||
return result
|
||||
except Exception as exc:
|
||||
logger.exception("MeDo create_app failed")
|
||||
return {"error": str(exc)}
|
||||
|
||||
|
||||
@tool
|
||||
async def update_medo_app(app_id: str, content: dict) -> dict:
|
||||
"""Push content or configuration changes to an existing MeDo application.
|
||||
|
||||
Args:
|
||||
app_id: The MeDo application ID returned by create_medo_app.
|
||||
content: Dict of fields to update (e.g. {"title": "...", "nodes": [...]}).
|
||||
|
||||
Returns:
|
||||
dict with 'status' on success, 'error' key on failure.
|
||||
"""
|
||||
if not app_id:
|
||||
return {"error": "app_id is required"}
|
||||
if not content:
|
||||
return {"error": "content must be a non-empty dict"}
|
||||
try:
|
||||
result = await _mock_http_post(f"/apps/{app_id}", content)
|
||||
logger.info("MeDo update_app: app_id=%s keys=%s → %s", app_id, list(content.keys()), result)
|
||||
return result
|
||||
except Exception as exc:
|
||||
logger.exception("MeDo update_app failed")
|
||||
return {"error": str(exc)}
|
||||
|
||||
|
||||
@tool
|
||||
async def publish_medo_app(app_id: str, environment: str = "production") -> dict:
|
||||
"""Publish a MeDo application to a target environment.
|
||||
|
||||
Args:
|
||||
app_id: The MeDo application ID to publish.
|
||||
environment: Target — production | staging (default: production).
|
||||
|
||||
Returns:
|
||||
dict with 'status' on success, 'error' key on failure.
|
||||
"""
|
||||
if not app_id:
|
||||
return {"error": "app_id is required"}
|
||||
if environment not in _VALID_ENVS:
|
||||
return {"error": f"environment must be one of: {', '.join(_VALID_ENVS)}"}
|
||||
try:
|
||||
result = await _mock_http_post(f"/apps/{app_id}/publish", {"environment": environment})
|
||||
logger.info("MeDo publish_app: app_id=%s env=%s → %s", app_id, environment, result)
|
||||
return result
|
||||
except Exception as exc:
|
||||
logger.exception("MeDo publish_app failed")
|
||||
return {"error": str(exc)}
|
||||
468
molecule_runtime/builtin_tools/memory.py
Normal file
468
molecule_runtime/builtin_tools/memory.py
Normal file
@ -0,0 +1,468 @@
|
||||
"""HMA memory tools for agents.
|
||||
|
||||
Hierarchical Memory Architecture:
|
||||
- LOCAL: private to this workspace, invisible to others
|
||||
- TEAM: shared with parent + siblings (same team)
|
||||
- GLOBAL: readable by all, writable by root workspaces only
|
||||
|
||||
RBAC enforcement
|
||||
----------------
|
||||
``commit_memory`` requires the ``"memory.write"`` action.
|
||||
``search_memory`` requires the ``"memory.read"`` action.
|
||||
Roles are read from ``config.yaml`` under ``rbac.roles`` (default: operator).
|
||||
|
||||
Audit trail
|
||||
-----------
|
||||
Every memory operation appends a JSON Lines record to the audit log:
|
||||
|
||||
memory / memory.write / allowed — write permitted by RBAC
|
||||
memory / memory.write / success — write committed successfully
|
||||
memory / memory.write / failure — write failed (platform error)
|
||||
memory / memory.read / allowed — read permitted by RBAC
|
||||
memory / memory.read / success — search returned results
|
||||
memory / memory.read / failure — search failed (platform error)
|
||||
|
||||
RBAC denials emit ``rbac / rbac.deny / denied`` events instead.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
from builtin_tools.awareness_client import build_awareness_client
|
||||
from builtin_tools.audit import check_permission, get_workspace_roles, log_event
|
||||
from builtin_tools.telemetry import MEMORY_QUERY, MEMORY_SCOPE, WORKSPACE_ID_ATTR, get_tracer
|
||||
|
||||
try: # pragma: no cover - optional runtime dependency in lightweight test envs
|
||||
import httpx # type: ignore
|
||||
except ImportError: # pragma: no cover
|
||||
httpx = SimpleNamespace(AsyncClient=None)
|
||||
|
||||
PLATFORM_URL = os.environ.get("PLATFORM_URL", "http://platform:8080")
|
||||
WORKSPACE_ID = os.environ.get("WORKSPACE_ID", "")
|
||||
|
||||
|
||||
@tool
|
||||
async def commit_memory(content: str, scope: str = "LOCAL") -> dict:
|
||||
"""Store a fact in memory with a specific scope.
|
||||
|
||||
Args:
|
||||
content: The fact or knowledge to remember.
|
||||
scope: Memory scope — LOCAL (private), TEAM (shared with team), or GLOBAL (company-wide, root only).
|
||||
"""
|
||||
trace_id = str(uuid.uuid4())
|
||||
scope = scope.upper()
|
||||
if scope not in ("LOCAL", "TEAM", "GLOBAL"):
|
||||
return {"error": "scope must be LOCAL, TEAM, or GLOBAL"}
|
||||
|
||||
# --- RBAC check -----------------------------------------------------------
|
||||
roles, custom_perms = get_workspace_roles()
|
||||
if not check_permission("memory.write", roles, custom_perms):
|
||||
log_event(
|
||||
event_type="rbac",
|
||||
action="rbac.deny",
|
||||
resource=scope,
|
||||
outcome="denied",
|
||||
trace_id=trace_id,
|
||||
attempted_action="memory.write",
|
||||
roles=roles,
|
||||
)
|
||||
return {
|
||||
"success": False,
|
||||
"error": (
|
||||
"RBAC: this workspace does not have the 'memory.write' permission. "
|
||||
f"Current roles: {roles}"
|
||||
),
|
||||
}
|
||||
|
||||
log_event(
|
||||
event_type="memory",
|
||||
action="memory.write",
|
||||
resource=scope,
|
||||
outcome="allowed",
|
||||
trace_id=trace_id,
|
||||
memory_scope=scope,
|
||||
content_length=len(content),
|
||||
)
|
||||
|
||||
# ── OTEL: memory_write span ──────────────────────────────────────────────
|
||||
tracer = get_tracer()
|
||||
|
||||
with tracer.start_as_current_span("memory_write") as mem_span:
|
||||
mem_span.set_attribute(WORKSPACE_ID_ATTR, WORKSPACE_ID)
|
||||
mem_span.set_attribute(MEMORY_SCOPE, scope)
|
||||
mem_span.set_attribute("memory.content_length", len(content))
|
||||
|
||||
awareness_client = build_awareness_client()
|
||||
if awareness_client is not None:
|
||||
try:
|
||||
result = await awareness_client.commit(content, scope)
|
||||
except Exception as e:
|
||||
log_event(
|
||||
event_type="memory",
|
||||
action="memory.write",
|
||||
resource=scope,
|
||||
outcome="failure",
|
||||
trace_id=trace_id,
|
||||
memory_scope=scope,
|
||||
error=str(e),
|
||||
)
|
||||
try:
|
||||
mem_span.record_exception(e)
|
||||
except Exception:
|
||||
pass
|
||||
return {"success": False, "error": str(e)}
|
||||
else:
|
||||
# #215-class bug: platform now gates /workspaces/:id/memories behind
|
||||
# workspace auth. Import auth_headers lazily (same pattern as the
|
||||
# activity-log path below) so test environments that don't ship
|
||||
# platform_auth still work.
|
||||
try:
|
||||
from platform_auth import auth_headers as _auth
|
||||
_headers = _auth()
|
||||
except Exception:
|
||||
_headers = {}
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
try:
|
||||
resp = await client.post(
|
||||
f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/memories",
|
||||
json={"content": content, "scope": scope},
|
||||
headers=_headers,
|
||||
)
|
||||
if resp.status_code == 201:
|
||||
result = {"success": True, "id": resp.json().get("id"), "scope": scope}
|
||||
else:
|
||||
result = {"success": False, "error": resp.json().get("error", resp.text)}
|
||||
except Exception as e:
|
||||
log_event(
|
||||
event_type="memory",
|
||||
action="memory.write",
|
||||
resource=scope,
|
||||
outcome="failure",
|
||||
trace_id=trace_id,
|
||||
memory_scope=scope,
|
||||
error=str(e),
|
||||
)
|
||||
try:
|
||||
mem_span.record_exception(e)
|
||||
except Exception:
|
||||
pass
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
if result.get("success"):
|
||||
mem_span.set_attribute("memory.id", result.get("id") or "")
|
||||
mem_span.set_attribute("memory.success", True)
|
||||
log_event(
|
||||
event_type="memory",
|
||||
action="memory.write",
|
||||
resource=scope,
|
||||
outcome="success",
|
||||
trace_id=trace_id,
|
||||
memory_scope=scope,
|
||||
memory_id=result.get("id"),
|
||||
)
|
||||
# #125: surface memory writes in /activity so the Canvas
|
||||
# "Agent Comms" tab shows what an agent chose to remember.
|
||||
# Fire-and-forget — failure here must not poison the tool
|
||||
# response since the memory write itself already succeeded.
|
||||
await _record_memory_activity(scope, content, result.get("id"))
|
||||
await _maybe_log_skill_promotion(content, scope, result)
|
||||
else:
|
||||
mem_span.set_attribute("memory.success", False)
|
||||
log_event(
|
||||
event_type="memory",
|
||||
action="memory.write",
|
||||
resource=scope,
|
||||
outcome="failure",
|
||||
trace_id=trace_id,
|
||||
memory_scope=scope,
|
||||
error=result.get("error"),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@tool
|
||||
async def search_memory(query: str = "", scope: str = "") -> dict:
|
||||
"""Search stored memories.
|
||||
|
||||
Args:
|
||||
query: Text to search for (empty returns all).
|
||||
scope: Filter by scope — LOCAL, TEAM, GLOBAL, or empty for all accessible.
|
||||
"""
|
||||
trace_id = str(uuid.uuid4())
|
||||
scope = scope.upper()
|
||||
if scope and scope not in ("LOCAL", "TEAM", "GLOBAL"):
|
||||
return {"error": "scope must be LOCAL, TEAM, GLOBAL, or empty"}
|
||||
|
||||
# --- RBAC check -----------------------------------------------------------
|
||||
roles, custom_perms = get_workspace_roles()
|
||||
if not check_permission("memory.read", roles, custom_perms):
|
||||
log_event(
|
||||
event_type="rbac",
|
||||
action="rbac.deny",
|
||||
resource=scope or "all",
|
||||
outcome="denied",
|
||||
trace_id=trace_id,
|
||||
attempted_action="memory.read",
|
||||
roles=roles,
|
||||
)
|
||||
return {
|
||||
"success": False,
|
||||
"error": (
|
||||
"RBAC: this workspace does not have the 'memory.read' permission. "
|
||||
f"Current roles: {roles}"
|
||||
),
|
||||
}
|
||||
|
||||
log_event(
|
||||
event_type="memory",
|
||||
action="memory.read",
|
||||
resource=scope or "all",
|
||||
outcome="allowed",
|
||||
trace_id=trace_id,
|
||||
memory_scope=scope or "all",
|
||||
query_length=len(query),
|
||||
)
|
||||
|
||||
# ── OTEL: memory_read span ───────────────────────────────────────────────
|
||||
tracer = get_tracer()
|
||||
|
||||
with tracer.start_as_current_span("memory_read") as mem_span:
|
||||
mem_span.set_attribute(WORKSPACE_ID_ATTR, WORKSPACE_ID)
|
||||
mem_span.set_attribute(MEMORY_SCOPE, scope or "all")
|
||||
mem_span.set_attribute(MEMORY_QUERY, query[:256] if query else "")
|
||||
|
||||
awareness_client = build_awareness_client()
|
||||
if awareness_client is not None:
|
||||
try:
|
||||
result = await awareness_client.search(query, scope)
|
||||
mem_span.set_attribute("memory.result_count", result.get("count", 0))
|
||||
mem_span.set_attribute("memory.success", result.get("success", False))
|
||||
log_event(
|
||||
event_type="memory",
|
||||
action="memory.read",
|
||||
resource=scope or "all",
|
||||
outcome="success" if result.get("success") else "failure",
|
||||
trace_id=trace_id,
|
||||
memory_scope=scope or "all",
|
||||
result_count=result.get("count", 0),
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
log_event(
|
||||
event_type="memory",
|
||||
action="memory.read",
|
||||
resource=scope or "all",
|
||||
outcome="failure",
|
||||
trace_id=trace_id,
|
||||
memory_scope=scope or "all",
|
||||
error=str(e),
|
||||
)
|
||||
try:
|
||||
mem_span.record_exception(e)
|
||||
except Exception:
|
||||
pass
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
params = {}
|
||||
if query:
|
||||
params["q"] = query
|
||||
if scope:
|
||||
params["scope"] = scope.upper()
|
||||
|
||||
# #215-class bug (search path): same fix as commit_memory above —
|
||||
# the platform gates GET /workspaces/:id/memories behind workspace
|
||||
# auth, so without auth_headers() every search silently 401s and the
|
||||
# agent thinks its backlog is empty (observed on Technical Researcher
|
||||
# idle-loop pilot 2026-04-15).
|
||||
try:
|
||||
from platform_auth import auth_headers as _auth
|
||||
_headers = _auth()
|
||||
except Exception:
|
||||
_headers = {}
|
||||
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
try:
|
||||
resp = await client.get(
|
||||
f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/memories",
|
||||
params=params,
|
||||
headers=_headers,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
memories = resp.json()
|
||||
mem_span.set_attribute("memory.result_count", len(memories))
|
||||
mem_span.set_attribute("memory.success", True)
|
||||
log_event(
|
||||
event_type="memory",
|
||||
action="memory.read",
|
||||
resource=scope or "all",
|
||||
outcome="success",
|
||||
trace_id=trace_id,
|
||||
memory_scope=scope or "all",
|
||||
result_count=len(memories),
|
||||
)
|
||||
return {
|
||||
"success": True,
|
||||
"count": len(memories),
|
||||
"memories": memories,
|
||||
}
|
||||
mem_span.set_attribute("memory.success", False)
|
||||
log_event(
|
||||
event_type="memory",
|
||||
action="memory.read",
|
||||
resource=scope or "all",
|
||||
outcome="failure",
|
||||
trace_id=trace_id,
|
||||
memory_scope=scope or "all",
|
||||
http_status=resp.status_code,
|
||||
)
|
||||
return {"success": False, "error": resp.json().get("error", resp.text)}
|
||||
except Exception as e:
|
||||
log_event(
|
||||
event_type="memory",
|
||||
action="memory.read",
|
||||
resource=scope or "all",
|
||||
outcome="failure",
|
||||
trace_id=trace_id,
|
||||
memory_scope=scope or "all",
|
||||
error=str(e),
|
||||
)
|
||||
try:
|
||||
mem_span.record_exception(e)
|
||||
except Exception:
|
||||
pass
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
def _parse_promotion_packet(content: str) -> dict[str, Any] | None:
|
||||
"""Return a structured memory packet when content looks like promotion metadata."""
|
||||
text = content.strip()
|
||||
if not text.startswith("{"):
|
||||
return None
|
||||
|
||||
try:
|
||||
payload = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
if not isinstance(payload, dict): # pragma: no cover
|
||||
return None
|
||||
if not payload.get("promote_to_skill"):
|
||||
return None
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
async def _record_memory_activity(scope: str, content: str, memory_id: str | None) -> None:
|
||||
"""Surface a successful memory write as an activity row so the Canvas
|
||||
"Agent Comms" tab can display what an agent chose to remember.
|
||||
Fire-and-forget — never raises. #125.
|
||||
|
||||
The summary is intentionally short (scope tag + first 80 chars of
|
||||
content with a ``…`` ellipsis when truncated) so the activity table
|
||||
stays readable; full content lives in ``agent_memories``.
|
||||
"""
|
||||
workspace_id = WORKSPACE_ID.strip()
|
||||
platform_url = PLATFORM_URL.strip().rstrip("/")
|
||||
if not workspace_id or not platform_url:
|
||||
return
|
||||
|
||||
preview = content.strip().replace("\n", " ")
|
||||
if len(preview) > 80:
|
||||
preview = preview[:80] + "…"
|
||||
summary = f"[{scope}] {preview}"
|
||||
|
||||
# NOTE: target_id is a UUID column scoped to workspace_id references —
|
||||
# cannot hold awareness/memory IDs (which are arbitrary strings).
|
||||
# We embed the memory_id in the summary instead so it's still searchable.
|
||||
if memory_id:
|
||||
summary = f"{summary} (id={memory_id[:24]})"
|
||||
payload: dict[str, Any] = {
|
||||
"workspace_id": workspace_id,
|
||||
"activity_type": "memory_write",
|
||||
"summary": summary,
|
||||
"status": "ok",
|
||||
}
|
||||
|
||||
try:
|
||||
try:
|
||||
from platform_auth import auth_headers as _auth
|
||||
_headers = _auth()
|
||||
except Exception:
|
||||
_headers = {}
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
await client.post(
|
||||
f"{platform_url}/workspaces/{workspace_id}/activity",
|
||||
json=payload,
|
||||
headers=_headers,
|
||||
)
|
||||
except Exception:
|
||||
# Activity logging is purely observability — never poison the
|
||||
# tool response on a failure here. We don't even log_event the
|
||||
# failure since the memory write itself succeeded and that's
|
||||
# what matters to the caller.
|
||||
pass
|
||||
|
||||
|
||||
async def _maybe_log_skill_promotion(content: str, scope: str, memory_result: dict) -> None:
|
||||
"""Best-effort activity log for durable memory entries that should become skills."""
|
||||
packet = _parse_promotion_packet(content)
|
||||
if packet is None:
|
||||
return
|
||||
|
||||
workspace_id = WORKSPACE_ID.strip()
|
||||
platform_url = PLATFORM_URL.strip().rstrip("/")
|
||||
if not workspace_id or not platform_url:
|
||||
return
|
||||
|
||||
repetition_signal = packet.get("repetition_signal")
|
||||
summary = (
|
||||
packet.get("summary")
|
||||
or packet.get("title")
|
||||
or packet.get("what changed")
|
||||
or "Repeatable workflow promoted to skill candidate"
|
||||
)
|
||||
metadata: dict[str, Any] = {
|
||||
"source": "memory-curation",
|
||||
"scope": scope,
|
||||
"memory_id": memory_result.get("id"),
|
||||
"promote_to_skill": True,
|
||||
"repetition_signal": repetition_signal,
|
||||
"memory_packet": packet,
|
||||
}
|
||||
|
||||
payload = {
|
||||
"activity_type": "skill_promotion",
|
||||
"method": "memory/skill-promotion",
|
||||
"summary": summary,
|
||||
"status": "ok",
|
||||
"source_id": workspace_id,
|
||||
"request_body": packet,
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
await client.post(
|
||||
f"{platform_url}/workspaces/{workspace_id}/activity",
|
||||
json=payload,
|
||||
)
|
||||
await client.post(
|
||||
f"{platform_url}/registry/heartbeat",
|
||||
json={
|
||||
"workspace_id": workspace_id,
|
||||
"error_rate": 0,
|
||||
"sample_error": "",
|
||||
"active_tasks": 1,
|
||||
"uptime_seconds": 0,
|
||||
"current_task": f"Skill promotion: {summary}",
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
# Best-effort observability only. Memory commits must never fail because
|
||||
# the promotion log could not be written.
|
||||
return
|
||||
281
molecule_runtime/builtin_tools/sandbox.py
Normal file
281
molecule_runtime/builtin_tools/sandbox.py
Normal file
@ -0,0 +1,281 @@
|
||||
"""Code sandbox tool for safe code execution.
|
||||
|
||||
Executes code in an isolated environment. Three backends are supported:
|
||||
|
||||
subprocess (default)
|
||||
Runs code locally via asyncio subprocess with a hard timeout.
|
||||
Best for Tier 1/2 agents where run_code is lightly used and the
|
||||
workspace container itself is the isolation boundary.
|
||||
|
||||
docker
|
||||
Throwaway Docker-in-Docker container: network disabled, memory capped,
|
||||
read-only filesystem. Requires Docker socket access inside the container.
|
||||
Best for Tier 3 on-prem deployments.
|
||||
|
||||
e2b
|
||||
Cloud-hosted microVM sandbox via E2B (https://e2b.dev).
|
||||
No local Docker required — code runs in E2B's isolated cloud VMs.
|
||||
Supports Python and JavaScript.
|
||||
Requires:
|
||||
- e2b-code-interpreter Python package (pinned in requirements.txt)
|
||||
- E2B_API_KEY workspace secret (set via canvas Secrets panel or API)
|
||||
Best for hosted/cloud Molecule AI deployments.
|
||||
|
||||
Backend is selected via the SANDBOX_BACKEND env var, which the provisioner
|
||||
sets from config.yaml → sandbox.backend. Default: "subprocess".
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SANDBOX_BACKEND = os.environ.get("SANDBOX_BACKEND", "subprocess")
|
||||
SANDBOX_TIMEOUT = int(os.environ.get("SANDBOX_TIMEOUT", "30"))
|
||||
SANDBOX_MEMORY_LIMIT = os.environ.get("SANDBOX_MEMORY_LIMIT", "256m")
|
||||
MAX_OUTPUT = 10_000
|
||||
|
||||
# E2B kernel names differ from internal language names.
|
||||
_E2B_KERNEL_MAP = {
|
||||
"python": "python3",
|
||||
"javascript": "js",
|
||||
"js": "js",
|
||||
}
|
||||
|
||||
|
||||
@tool
|
||||
async def run_code(code: str, language: str = "python") -> dict:
|
||||
"""Execute code in an isolated sandbox and return the output.
|
||||
|
||||
Args:
|
||||
code: The code to execute.
|
||||
language: Programming language — python, javascript, or shell.
|
||||
The e2b backend supports python and javascript only.
|
||||
"""
|
||||
if SANDBOX_BACKEND == "docker":
|
||||
return await _run_docker(code, language)
|
||||
elif SANDBOX_BACKEND == "e2b":
|
||||
return await _run_e2b(code, language)
|
||||
else:
|
||||
return await _run_subprocess(code, language)
|
||||
|
||||
|
||||
async def _run_subprocess(code: str, language: str) -> dict:
|
||||
"""Fallback: run code in a subprocess with timeout."""
|
||||
cmd_map = {
|
||||
"python": ["python3", "-c"],
|
||||
"javascript": ["node", "-e"],
|
||||
"shell": ["sh", "-c"],
|
||||
"bash": ["bash", "-c"],
|
||||
}
|
||||
|
||||
cmd_prefix = cmd_map.get(language)
|
||||
if not cmd_prefix:
|
||||
return {"error": f"Unsupported language: {language}", "exit_code": -1}
|
||||
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*cmd_prefix, code,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=SANDBOX_TIMEOUT)
|
||||
|
||||
return {
|
||||
"exit_code": proc.returncode,
|
||||
"stdout": stdout.decode("utf-8", errors="replace")[:MAX_OUTPUT],
|
||||
"stderr": stderr.decode("utf-8", errors="replace")[:MAX_OUTPUT],
|
||||
"language": language,
|
||||
"backend": "subprocess",
|
||||
}
|
||||
except asyncio.TimeoutError:
|
||||
try:
|
||||
proc.kill()
|
||||
await proc.wait()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
return {"error": f"Timeout after {SANDBOX_TIMEOUT}s", "exit_code": -1}
|
||||
except Exception as e:
|
||||
return {"error": str(e), "exit_code": -1}
|
||||
|
||||
|
||||
async def _run_docker(code: str, language: str) -> dict:
|
||||
"""Run code in a throwaway Docker container via mounted temp file."""
|
||||
image_map = {
|
||||
"python": ("python:3.11-slim", ["python3", "/sandbox/code.py"]),
|
||||
"javascript": ("node:20-slim", ["node", "/sandbox/code.js"]),
|
||||
"shell": ("alpine:3.18", ["sh", "/sandbox/code.sh"]),
|
||||
"bash": ("alpine:3.18", ["sh", "/sandbox/code.sh"]),
|
||||
}
|
||||
|
||||
entry = image_map.get(language)
|
||||
if not entry:
|
||||
return {"error": f"Unsupported language: {language}", "exit_code": -1}
|
||||
|
||||
image, run_cmd = entry
|
||||
code_file = None
|
||||
|
||||
try:
|
||||
# Write code to temp file — avoids shell metacharacter injection
|
||||
ext = {"python": ".py", "javascript": ".js", "shell": ".sh", "bash": ".sh"}.get(language, ".txt")
|
||||
fd, code_file = tempfile.mkstemp(suffix=ext, prefix="sandbox_")
|
||||
with os.fdopen(fd, "w") as f:
|
||||
f.write(code)
|
||||
|
||||
cmd = [
|
||||
"docker", "run", "--rm",
|
||||
"--network", "none",
|
||||
"--memory", SANDBOX_MEMORY_LIMIT,
|
||||
"--cpus", "0.5",
|
||||
"--read-only",
|
||||
"--tmpfs", "/tmp:size=32m",
|
||||
"-v", f"{code_file}:/sandbox/code{ext}:ro",
|
||||
image,
|
||||
] + run_cmd
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=SANDBOX_TIMEOUT)
|
||||
|
||||
return {
|
||||
"exit_code": proc.returncode,
|
||||
"stdout": stdout.decode("utf-8", errors="replace")[:MAX_OUTPUT],
|
||||
"stderr": stderr.decode("utf-8", errors="replace")[:MAX_OUTPUT],
|
||||
"language": language,
|
||||
"backend": "docker",
|
||||
"image": image,
|
||||
}
|
||||
except asyncio.TimeoutError:
|
||||
return {"error": f"Timeout after {SANDBOX_TIMEOUT}s", "exit_code": -1}
|
||||
except Exception as e:
|
||||
return {"error": str(e), "exit_code": -1}
|
||||
finally:
|
||||
if code_file:
|
||||
try:
|
||||
os.unlink(code_file)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
async def _run_e2b(code: str, language: str) -> dict:
|
||||
"""Run code in an E2B cloud microVM sandbox.
|
||||
|
||||
Requires the e2b-code-interpreter package and an E2B_API_KEY secret.
|
||||
Each call creates a fresh sandbox, runs the code, and destroys the sandbox.
|
||||
Sandbox lifetime is bounded by SANDBOX_TIMEOUT seconds.
|
||||
|
||||
Supported languages: python, javascript.
|
||||
"""
|
||||
# Import lazily so the package is only required when the e2b backend is
|
||||
# actually configured — other backends work without it installed.
|
||||
try:
|
||||
from e2b_code_interpreter import Sandbox
|
||||
except ImportError:
|
||||
return {
|
||||
"error": (
|
||||
"e2b-code-interpreter is not installed. "
|
||||
"Add it to requirements.txt or switch to the docker/subprocess backend."
|
||||
),
|
||||
"exit_code": -1,
|
||||
}
|
||||
|
||||
api_key = os.environ.get("E2B_API_KEY")
|
||||
if not api_key:
|
||||
return {
|
||||
"error": (
|
||||
"E2B_API_KEY is not set. "
|
||||
"Add it as a workspace secret via the canvas Secrets panel or platform API."
|
||||
),
|
||||
"exit_code": -1,
|
||||
}
|
||||
|
||||
kernel = _E2B_KERNEL_MAP.get(language)
|
||||
if kernel is None:
|
||||
return {
|
||||
"error": (
|
||||
f"Language '{language}' is not supported by the e2b backend. "
|
||||
"Supported: python, javascript."
|
||||
),
|
||||
"exit_code": -1,
|
||||
}
|
||||
|
||||
sandbox = None
|
||||
try:
|
||||
# Create a fresh sandbox for this execution.
|
||||
# timeout controls the sandbox lifetime in seconds.
|
||||
sandbox = await asyncio.wait_for(
|
||||
asyncio.get_running_loop().run_in_executor(
|
||||
None,
|
||||
lambda: Sandbox(api_key=api_key, timeout=SANDBOX_TIMEOUT),
|
||||
),
|
||||
timeout=SANDBOX_TIMEOUT,
|
||||
)
|
||||
|
||||
# Execute code and collect results.
|
||||
execution = await asyncio.wait_for(
|
||||
asyncio.get_running_loop().run_in_executor(
|
||||
None,
|
||||
lambda: sandbox.run_code(code, language=kernel),
|
||||
),
|
||||
timeout=SANDBOX_TIMEOUT,
|
||||
)
|
||||
|
||||
# E2B returns a list of Result objects; collect text/error output.
|
||||
stdout_parts = []
|
||||
stderr_parts = []
|
||||
|
||||
for result in execution.results:
|
||||
# result.text is the primary output (stdout equivalent)
|
||||
if hasattr(result, "text") and result.text:
|
||||
stdout_parts.append(str(result.text))
|
||||
# Some result types expose an error attribute
|
||||
if hasattr(result, "error") and result.error:
|
||||
stderr_parts.append(str(result.error))
|
||||
|
||||
# Logs are stored separately in execution.logs
|
||||
if hasattr(execution, "logs"):
|
||||
logs = execution.logs
|
||||
if hasattr(logs, "stdout") and logs.stdout:
|
||||
stdout_parts.extend(logs.stdout)
|
||||
if hasattr(logs, "stderr") and logs.stderr:
|
||||
stderr_parts.extend(logs.stderr)
|
||||
|
||||
combined_stdout = "".join(stdout_parts)[:MAX_OUTPUT]
|
||||
combined_stderr = "".join(stderr_parts)[:MAX_OUTPUT]
|
||||
|
||||
# Treat any stderr output as a non-zero exit code (e2b doesn't expose
|
||||
# a numeric exit code at the sandbox level).
|
||||
exit_code = 1 if combined_stderr else 0
|
||||
|
||||
return {
|
||||
"exit_code": exit_code,
|
||||
"stdout": combined_stdout,
|
||||
"stderr": combined_stderr,
|
||||
"language": language,
|
||||
"backend": "e2b",
|
||||
}
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("E2B sandbox timed out after %ds", SANDBOX_TIMEOUT)
|
||||
return {"error": f"Timeout after {SANDBOX_TIMEOUT}s", "exit_code": -1}
|
||||
except Exception as e:
|
||||
logger.exception("E2B sandbox error: %s", e)
|
||||
return {"error": str(e), "exit_code": -1}
|
||||
finally:
|
||||
# Always destroy the sandbox to avoid leaking E2B credits.
|
||||
if sandbox is not None:
|
||||
try:
|
||||
await asyncio.get_running_loop().run_in_executor(
|
||||
None, sandbox.kill
|
||||
)
|
||||
except Exception:
|
||||
pass # Best-effort cleanup
|
||||
344
molecule_runtime/builtin_tools/security_scan.py
Normal file
344
molecule_runtime/builtin_tools/security_scan.py
Normal file
@ -0,0 +1,344 @@
|
||||
"""Skill dependency security scanner — supply-chain risk management.
|
||||
|
||||
Scans a skill's ``requirements.txt`` for known CVEs before the skill is
|
||||
loaded into the workspace. Two scanners are supported:
|
||||
|
||||
Snyk CLI — ``snyk test --file=requirements.txt --json``
|
||||
Preferred; requires the ``snyk`` binary in PATH and
|
||||
a SNYK_TOKEN env var for authenticated scans.
|
||||
|
||||
pip-audit — ``pip-audit -r requirements.txt --json``
|
||||
Fallback; no authentication required.
|
||||
|
||||
The scanner is auto-selected: Snyk if available, pip-audit otherwise.
|
||||
If neither is present in PATH the scan is silently skipped with a log line.
|
||||
|
||||
Scan mode (``security_scan.mode`` in config.yaml):
|
||||
|
||||
block — raise ``SkillSecurityError`` when critical/high CVEs are found;
|
||||
the skill is *not* loaded.
|
||||
warn — log a WARNING + audit event; the skill is loaded anyway.
|
||||
off — skip scanning entirely; useful in air-gapped CI.
|
||||
|
||||
Audit trail
|
||||
-----------
|
||||
Every scan (pass or fail) is recorded via ``tools.audit.log_event`` with
|
||||
``event_type="security_scan"``, enabling compliance reports to prove that
|
||||
all loaded skills were checked before activation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import shutil
|
||||
import subprocess
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from builtin_tools.audit import log_event
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public exception
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SkillSecurityError(RuntimeError):
|
||||
"""Raised when a skill fails security scanning in ``block`` mode.
|
||||
|
||||
The message contains the skill name, scanner used, and a summary of the
|
||||
critical/high findings so operators can act on it immediately.
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class CVEFinding:
|
||||
"""A single vulnerability finding from a security scanner."""
|
||||
|
||||
vuln_id: str
|
||||
"""CVE or advisory identifier, e.g. ``SNYK-PYTHON-REQUESTS-1234``."""
|
||||
package: str
|
||||
"""Affected package name."""
|
||||
version: str
|
||||
"""Installed version of the package."""
|
||||
severity: str
|
||||
"""One of: critical | high | medium | low | unknown."""
|
||||
description: str
|
||||
"""Short human-readable summary (≤ 200 chars)."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScanResult:
|
||||
"""Aggregated result of a single skill dependency scan."""
|
||||
|
||||
skill_name: str
|
||||
scanner: str
|
||||
"""Scanner used: ``"snyk"`` | ``"pip-audit"`` | ``"none"``."""
|
||||
requirements_file: Optional[str]
|
||||
"""Absolute path to the scanned requirements.txt, or ``None``."""
|
||||
findings: list[CVEFinding] = field(default_factory=list)
|
||||
scan_error: Optional[str] = None
|
||||
"""Non-fatal scanner error (e.g. timeout); findings may be incomplete."""
|
||||
|
||||
@property
|
||||
def critical_or_high(self) -> list[CVEFinding]:
|
||||
return [f for f in self.findings if f.severity in ("critical", "high")]
|
||||
|
||||
@property
|
||||
def has_critical_or_high(self) -> bool:
|
||||
return bool(self.critical_or_high)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _find_requirements(skill_path: Path) -> Optional[Path]:
|
||||
"""Return the first ``requirements.txt`` found in the skill tree."""
|
||||
for candidate in (
|
||||
skill_path / "requirements.txt",
|
||||
skill_path / "tools" / "requirements.txt",
|
||||
):
|
||||
if candidate.exists():
|
||||
return candidate
|
||||
return None
|
||||
|
||||
|
||||
def _run_scanner(cmd: list[str], timeout: int = 120) -> tuple[str, Optional[str]]:
|
||||
"""Run a scanner subprocess and return ``(stdout, error_or_None)``."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
)
|
||||
# Both Snyk and pip-audit exit 1 when vulns are found — not an error.
|
||||
# Exit 2 from Snyk means a genuine scan failure.
|
||||
if result.returncode == 2 and not result.stdout.strip():
|
||||
return "", f"scanner exited 2: {result.stderr.strip()[:200]}"
|
||||
return result.stdout, None
|
||||
except subprocess.TimeoutExpired:
|
||||
return "", f"scanner timed out after {timeout}s"
|
||||
except FileNotFoundError as exc:
|
||||
return "", str(exc)
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
return "", str(exc)
|
||||
|
||||
|
||||
def _parse_snyk(stdout: str) -> tuple[list[CVEFinding], Optional[str]]:
|
||||
"""Parse ``snyk test --json`` output."""
|
||||
if not stdout.strip():
|
||||
return [], "empty snyk output"
|
||||
try:
|
||||
data = json.loads(stdout)
|
||||
except json.JSONDecodeError as exc:
|
||||
return [], f"snyk JSON parse error: {exc}"
|
||||
|
||||
vulns = data.get("vulnerabilities", [])
|
||||
findings = [
|
||||
CVEFinding(
|
||||
vuln_id=v.get("id", "UNKNOWN"),
|
||||
package=v.get("packageName", "?"),
|
||||
version=v.get("version", "?"),
|
||||
severity=v.get("severity", "unknown").lower(),
|
||||
description=(v.get("title", "") or "")[:200],
|
||||
)
|
||||
for v in vulns
|
||||
if isinstance(v, dict)
|
||||
]
|
||||
return findings, None
|
||||
|
||||
|
||||
def _parse_pip_audit(stdout: str) -> tuple[list[CVEFinding], Optional[str]]:
|
||||
"""Parse ``pip-audit --json`` output.
|
||||
|
||||
pip-audit does not always provide a CVSS severity level. When absent we
|
||||
conservatively classify the finding as ``"high"`` so it is not silently
|
||||
ignored in ``warn`` mode.
|
||||
"""
|
||||
if not stdout.strip():
|
||||
return [], "empty pip-audit output"
|
||||
try:
|
||||
data = json.loads(stdout)
|
||||
except json.JSONDecodeError as exc:
|
||||
return [], f"pip-audit JSON parse error: {exc}"
|
||||
|
||||
# pip-audit ≥ 2.x wraps results in {"dependencies": [...]}
|
||||
if isinstance(data, dict):
|
||||
deps = data.get("dependencies", [])
|
||||
else:
|
||||
deps = data # older versions return a bare list
|
||||
|
||||
findings: list[CVEFinding] = []
|
||||
for dep in deps:
|
||||
if not isinstance(dep, dict):
|
||||
continue
|
||||
for vuln in dep.get("vulns", []):
|
||||
sev_raw = vuln.get("fix_versions") and "high" # pip-audit lacks severity
|
||||
sev = (vuln.get("severity") or sev_raw or "high").lower()
|
||||
findings.append(
|
||||
CVEFinding(
|
||||
vuln_id=vuln.get("id", "UNKNOWN"),
|
||||
package=dep.get("name", "?"),
|
||||
version=dep.get("version", "?"),
|
||||
severity=sev,
|
||||
description=(vuln.get("description", "") or "")[:200],
|
||||
)
|
||||
)
|
||||
return findings, None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def scan_skill_dependencies(
|
||||
skill_name: str,
|
||||
skill_path: Path,
|
||||
mode: str,
|
||||
fail_open_if_no_scanner: bool = True,
|
||||
) -> ScanResult:
|
||||
"""Scan a skill's dependency file for known CVEs.
|
||||
|
||||
Args:
|
||||
skill_name: Name of the skill (used in log messages and audit events).
|
||||
skill_path: Absolute path to the skill's root directory.
|
||||
mode: ``"block"`` | ``"warn"`` | ``"off"``
|
||||
fail_open_if_no_scanner:
|
||||
When *True* (default) silently skip scanning if neither snyk nor
|
||||
pip-audit is in PATH. When *False* and ``mode="block"``, raise
|
||||
:class:`SkillSecurityError` so operators know the gate is absent.
|
||||
Corresponds to ``security_scan.fail_open_if_no_scanner`` in
|
||||
config.yaml. Closes #268.
|
||||
|
||||
Returns:
|
||||
A :class:`ScanResult` describing what was found.
|
||||
|
||||
Raises:
|
||||
:class:`SkillSecurityError`: When ``mode="block"`` and one or more
|
||||
critical/high severity CVEs are found — OR when
|
||||
``mode="block"`` and ``fail_open_if_no_scanner=False`` and no
|
||||
scanner is available.
|
||||
"""
|
||||
if mode == "off":
|
||||
return ScanResult(skill_name=skill_name, scanner="none", requirements_file=None)
|
||||
|
||||
req_file = _find_requirements(skill_path)
|
||||
if req_file is None:
|
||||
# No requirements file — nothing to scan; not a problem.
|
||||
return ScanResult(skill_name=skill_name, scanner="none", requirements_file=None)
|
||||
|
||||
# ── Select scanner ────────────────────────────────────────────────────────
|
||||
scanner_name: str
|
||||
findings: list[CVEFinding]
|
||||
scan_error: Optional[str]
|
||||
|
||||
if shutil.which("snyk"):
|
||||
scanner_name = "snyk"
|
||||
stdout, run_error = _run_scanner(
|
||||
["snyk", "test", f"--file={req_file}", "--json"]
|
||||
)
|
||||
if run_error:
|
||||
findings, scan_error = [], run_error
|
||||
else:
|
||||
findings, scan_error = _parse_snyk(stdout)
|
||||
|
||||
elif shutil.which("pip-audit"):
|
||||
scanner_name = "pip-audit"
|
||||
stdout, run_error = _run_scanner(
|
||||
["pip-audit", "-r", str(req_file), "--json", "--progress-spinner=off"]
|
||||
)
|
||||
if run_error:
|
||||
findings, scan_error = [], run_error
|
||||
else:
|
||||
findings, scan_error = _parse_pip_audit(stdout)
|
||||
|
||||
else:
|
||||
logger.info(
|
||||
"security_scan: no scanner (snyk, pip-audit) in PATH — skipping %s",
|
||||
skill_name,
|
||||
)
|
||||
log_event(
|
||||
event_type="security_scan",
|
||||
action="skill.security_scan",
|
||||
resource=skill_name,
|
||||
outcome="skipped",
|
||||
reason="no_scanner_in_path",
|
||||
requirements_file=str(req_file),
|
||||
mode=mode,
|
||||
)
|
||||
# #268: if fail_open_if_no_scanner=False and mode=block, the operator
|
||||
# explicitly opted in to "fail closed" — raise so the missing scanner
|
||||
# is visible rather than silently skipped.
|
||||
if not fail_open_if_no_scanner and mode == "block":
|
||||
raise SkillSecurityError(
|
||||
f"Skill '{skill_name}' blocked: no scanner (snyk or pip-audit) "
|
||||
f"found in PATH and fail_open_if_no_scanner=false"
|
||||
)
|
||||
return ScanResult(
|
||||
skill_name=skill_name,
|
||||
scanner="none",
|
||||
requirements_file=str(req_file),
|
||||
scan_error="No scanner (snyk or pip-audit) found in PATH",
|
||||
)
|
||||
|
||||
result = ScanResult(
|
||||
skill_name=skill_name,
|
||||
scanner=scanner_name,
|
||||
requirements_file=str(req_file),
|
||||
findings=findings,
|
||||
scan_error=scan_error,
|
||||
)
|
||||
|
||||
# ── Log scan outcome to audit trail ──────────────────────────────────────
|
||||
audit_outcome = "clean" if not result.has_critical_or_high else "vulnerable"
|
||||
log_event(
|
||||
event_type="security_scan",
|
||||
action="skill.security_scan",
|
||||
resource=skill_name,
|
||||
outcome=audit_outcome,
|
||||
scanner=scanner_name,
|
||||
requirements_file=str(req_file),
|
||||
total_findings=len(findings),
|
||||
critical_or_high_count=len(result.critical_or_high),
|
||||
scan_error=scan_error,
|
||||
)
|
||||
|
||||
if scan_error:
|
||||
logger.warning(
|
||||
"security_scan: scanner error for skill '%s': %s", skill_name, scan_error
|
||||
)
|
||||
|
||||
# ── Enforce mode ─────────────────────────────────────────────────────────
|
||||
if result.has_critical_or_high:
|
||||
summary = ", ".join(
|
||||
f"{f.vuln_id}({f.severity}) in {f.package}@{f.version}"
|
||||
for f in result.critical_or_high[:5]
|
||||
)
|
||||
if len(result.critical_or_high) > 5:
|
||||
summary += f" … and {len(result.critical_or_high) - 5} more"
|
||||
|
||||
msg = (
|
||||
f"Skill '{skill_name}' has {len(result.critical_or_high)} "
|
||||
f"critical/high CVE(s) [{scanner_name}]: {summary}"
|
||||
)
|
||||
|
||||
if mode == "block":
|
||||
logger.error("Blocking skill load — %s", msg)
|
||||
raise SkillSecurityError(msg)
|
||||
|
||||
# warn mode — continue loading, but make noise
|
||||
logger.warning("Security warning — %s", msg)
|
||||
|
||||
return result
|
||||
418
molecule_runtime/builtin_tools/telemetry.py
Normal file
418
molecule_runtime/builtin_tools/telemetry.py
Normal file
@ -0,0 +1,418 @@
|
||||
"""OpenTelemetry (OTEL) instrumentation for the Molecule AI workspace runtime.
|
||||
|
||||
Architecture
|
||||
------------
|
||||
* One global ``TracerProvider`` is initialised at startup via ``setup_telemetry()``.
|
||||
* Up to three exporters are wired in:
|
||||
1. **OTLP/HTTP** — activated when ``OTEL_EXPORTER_OTLP_ENDPOINT`` is set.
|
||||
Point this at any compatible collector (Jaeger, Tempo, Grafana OTEL, …).
|
||||
2. **Langfuse OTLP bridge** — activated when the ``LANGFUSE_HOST``,
|
||||
``LANGFUSE_PUBLIC_KEY`` and ``LANGFUSE_SECRET_KEY`` env vars are all present.
|
||||
Langfuse ≥4 accepts OTLP/HTTP at ``<host>/api/public/otel``.
|
||||
This is a *second* exporter alongside the existing Langfuse LangChain
|
||||
callback handler in agent.py — both paths emit spans simultaneously.
|
||||
3. **Console** (debug) — activated when ``OTEL_DEBUG=1``.
|
||||
|
||||
* **W3C TraceContext** propagation (``traceparent`` / ``tracestate``) is used for
|
||||
cross-workspace context injection and extraction so A2A hops form a single
|
||||
distributed trace.
|
||||
|
||||
* ``make_trace_middleware()`` returns an ASGI middleware that extracts incoming
|
||||
trace context from HTTP headers and stores it in a ``ContextVar`` so the
|
||||
A2A executor can access it to parent its spans correctly.
|
||||
|
||||
GenAI semantic conventions
|
||||
--------------------------
|
||||
Attribute constants for ``gen_ai.*`` follow OpenTelemetry GenAI SemConv 1.26.
|
||||
|
||||
Usage example
|
||||
-------------
|
||||
# main.py — call once at startup
|
||||
from builtin_tools.telemetry import setup_telemetry, make_trace_middleware
|
||||
setup_telemetry(service_name=workspace_id)
|
||||
instrumented = make_trace_middleware(app.build())
|
||||
|
||||
# Any module
|
||||
from builtin_tools.telemetry import get_tracer
|
||||
tracer = get_tracer()
|
||||
with tracer.start_as_current_span("my_span") as span:
|
||||
span.set_attribute("key", "value")
|
||||
|
||||
# Outgoing HTTP — inject W3C headers
|
||||
from builtin_tools.telemetry import inject_trace_headers
|
||||
headers = inject_trace_headers({"Content-Type": "application/json"})
|
||||
await client.post(url, headers=headers, ...)
|
||||
|
||||
# Incoming HTTP — extract context (done automatically by middleware)
|
||||
from builtin_tools.telemetry import extract_trace_context
|
||||
ctx = extract_trace_context(dict(request.headers))
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GenAI Semantic Convention attribute keys (OTel SemConv 1.26)
|
||||
# https://opentelemetry.io/docs/specs/semconv/gen-ai/
|
||||
# ---------------------------------------------------------------------------
|
||||
GEN_AI_SYSTEM = "gen_ai.system"
|
||||
GEN_AI_REQUEST_MODEL = "gen_ai.request.model"
|
||||
GEN_AI_OPERATION_NAME = "gen_ai.operation.name"
|
||||
GEN_AI_USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens"
|
||||
GEN_AI_USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens"
|
||||
GEN_AI_RESPONSE_FINISH_REASONS = "gen_ai.response.finish_reasons"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Workspace / A2A attribute keys
|
||||
# ---------------------------------------------------------------------------
|
||||
WORKSPACE_ID_ATTR = "workspace.id"
|
||||
A2A_SOURCE_WORKSPACE = "a2a.source_workspace_id"
|
||||
A2A_TARGET_WORKSPACE = "a2a.target_workspace_id"
|
||||
A2A_TASK_ID = "a2a.task_id"
|
||||
MEMORY_SCOPE = "memory.scope"
|
||||
MEMORY_QUERY = "memory.query"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level state
|
||||
# ---------------------------------------------------------------------------
|
||||
WORKSPACE_ID: str = os.environ.get("WORKSPACE_ID", "unknown")
|
||||
|
||||
_initialized: bool = False
|
||||
_tracer: Any = None # opentelemetry.trace.Tracer | _NoopTracer
|
||||
|
||||
# ContextVar that carries incoming trace context from the ASGI middleware to
|
||||
# the A2A executor. Using a ContextVar (rather than a global) is safe with
|
||||
# asyncio because each task inherits a copy of the context at creation time.
|
||||
_incoming_trace_context: ContextVar[Optional[Any]] = ContextVar(
|
||||
"otel_incoming_trace_context", default=None
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def setup_telemetry(service_name: Optional[str] = None) -> None:
|
||||
"""Initialise the global ``TracerProvider``. Safe to call multiple times.
|
||||
|
||||
Reads configuration from environment variables:
|
||||
|
||||
``OTEL_EXPORTER_OTLP_ENDPOINT``
|
||||
Base URL of an OTLP-compatible collector (e.g. ``http://jaeger:4318``).
|
||||
Spans are sent to ``<endpoint>/v1/traces``.
|
||||
|
||||
``LANGFUSE_HOST`` + ``LANGFUSE_PUBLIC_KEY`` + ``LANGFUSE_SECRET_KEY``
|
||||
When all three are set, a second OTLP exporter is wired to Langfuse's
|
||||
ingest endpoint using HTTP Basic auth.
|
||||
|
||||
``OTEL_DEBUG``
|
||||
Set to ``1`` / ``true`` to also print spans to stdout.
|
||||
"""
|
||||
global _initialized, _tracer
|
||||
|
||||
if _initialized:
|
||||
return
|
||||
|
||||
try:
|
||||
from opentelemetry import propagate, trace
|
||||
from opentelemetry.baggage.propagation import W3CBaggagePropagator
|
||||
from opentelemetry.propagators.composite import CompositePropagator
|
||||
from opentelemetry.sdk.resources import SERVICE_NAME as OTEL_SERVICE_NAME
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter
|
||||
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
|
||||
except ImportError as exc:
|
||||
logger.warning(
|
||||
"OTEL: opentelemetry packages not installed — telemetry disabled. "
|
||||
"Add opentelemetry-api, opentelemetry-sdk, "
|
||||
"opentelemetry-exporter-otlp-proto-http to requirements.txt. "
|
||||
"Error: %s",
|
||||
exc,
|
||||
)
|
||||
return
|
||||
|
||||
svc = service_name or f"molecule-{WORKSPACE_ID}"
|
||||
|
||||
resource = Resource.create(
|
||||
{
|
||||
OTEL_SERVICE_NAME: svc,
|
||||
"service.version": "1.0.0",
|
||||
WORKSPACE_ID_ATTR: WORKSPACE_ID,
|
||||
}
|
||||
)
|
||||
|
||||
provider = TracerProvider(resource=resource)
|
||||
|
||||
# -- Exporter 1: Generic OTLP/HTTP ----------------------------------------
|
||||
otlp_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT", "").rstrip("/")
|
||||
if otlp_endpoint:
|
||||
try:
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||
|
||||
exporter = OTLPSpanExporter(endpoint=f"{otlp_endpoint}/v1/traces")
|
||||
provider.add_span_processor(BatchSpanProcessor(exporter))
|
||||
logger.info("OTEL: OTLP/HTTP exporter → %s", otlp_endpoint)
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"OTEL: OTEL_EXPORTER_OTLP_ENDPOINT is set but "
|
||||
"opentelemetry-exporter-otlp-proto-http is not installed"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("OTEL: OTLP exporter init failed: %s", exc)
|
||||
|
||||
# -- Exporter 2: Langfuse OTLP bridge -------------------------------------
|
||||
# Langfuse ≥4 accepts OTLP at <host>/api/public/otel (Basic auth).
|
||||
lf_host = os.environ.get("LANGFUSE_HOST", "").rstrip("/")
|
||||
lf_public = os.environ.get("LANGFUSE_PUBLIC_KEY", "")
|
||||
lf_secret = os.environ.get("LANGFUSE_SECRET_KEY", "")
|
||||
|
||||
if lf_host and lf_public and lf_secret:
|
||||
try:
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||
|
||||
lf_endpoint = f"{lf_host}/api/public/otel/v1/traces"
|
||||
token = base64.b64encode(f"{lf_public}:{lf_secret}".encode()).decode()
|
||||
lf_exporter = OTLPSpanExporter(
|
||||
endpoint=lf_endpoint,
|
||||
headers={"Authorization": f"Basic {token}"},
|
||||
)
|
||||
provider.add_span_processor(BatchSpanProcessor(lf_exporter))
|
||||
logger.info("OTEL: Langfuse OTLP bridge → %s", lf_endpoint)
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"OTEL: Langfuse env vars set but "
|
||||
"opentelemetry-exporter-otlp-proto-http is not installed"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("OTEL: Langfuse OTLP bridge init failed: %s", exc)
|
||||
|
||||
# -- Exporter 3: Console (debug) ------------------------------------------
|
||||
if os.environ.get("OTEL_DEBUG", "").lower() in ("1", "true", "yes"):
|
||||
provider.add_span_processor(BatchSpanProcessor(ConsoleSpanExporter()))
|
||||
logger.info("OTEL: console debug exporter enabled")
|
||||
|
||||
# -- Register global provider + W3C propagators ---------------------------
|
||||
trace.set_tracer_provider(provider)
|
||||
propagate.set_global_textmap(
|
||||
CompositePropagator(
|
||||
[
|
||||
TraceContextTextMapPropagator(),
|
||||
W3CBaggagePropagator(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
_tracer = trace.get_tracer(
|
||||
"molecule.workspace",
|
||||
schema_url="https://opentelemetry.io/schemas/1.26.0",
|
||||
)
|
||||
_initialized = True
|
||||
logger.info("OTEL: telemetry initialised for service '%s'", svc)
|
||||
|
||||
|
||||
def get_tracer() -> Any:
|
||||
"""Return the global ``Tracer``. Lazily calls ``setup_telemetry()`` if needed.
|
||||
|
||||
Returns a no-op tracer when the opentelemetry packages are not installed so
|
||||
that instrumented code never raises ``ImportError``.
|
||||
"""
|
||||
global _tracer
|
||||
|
||||
if not _initialized:
|
||||
setup_telemetry()
|
||||
|
||||
if _tracer is None:
|
||||
# Packages unavailable — hand back a no-op implementation
|
||||
try:
|
||||
from opentelemetry import trace
|
||||
|
||||
return trace.get_tracer("molecule.noop")
|
||||
except ImportError:
|
||||
return _NoopTracer()
|
||||
|
||||
return _tracer
|
||||
|
||||
|
||||
def inject_trace_headers(headers: dict) -> dict:
|
||||
"""Inject W3C ``traceparent`` / ``tracestate`` into *headers* and return it.
|
||||
|
||||
Mutates the dict in-place so it can be used directly::
|
||||
|
||||
headers = inject_trace_headers({"Content-Type": "application/json"})
|
||||
await client.post(url, headers=headers, ...)
|
||||
"""
|
||||
try:
|
||||
from opentelemetry import propagate
|
||||
|
||||
propagate.inject(headers)
|
||||
except Exception:
|
||||
pass # Never let telemetry break the caller
|
||||
return headers
|
||||
|
||||
|
||||
def extract_trace_context(carrier: dict) -> Any:
|
||||
"""Extract W3C trace context from a header mapping.
|
||||
|
||||
Returns an OpenTelemetry ``Context`` object suitable for::
|
||||
|
||||
tracer.start_as_current_span("name", context=ctx)
|
||||
|
||||
Returns ``None`` when packages are unavailable or no context is present.
|
||||
"""
|
||||
try:
|
||||
from opentelemetry import propagate
|
||||
|
||||
return propagate.extract(carrier)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def get_current_traceparent() -> Optional[str]:
|
||||
"""Return the W3C ``traceparent`` string for the active span, or ``None``."""
|
||||
try:
|
||||
from opentelemetry import trace
|
||||
|
||||
span = trace.get_current_span()
|
||||
ctx = span.get_span_context()
|
||||
if not ctx.is_valid:
|
||||
return None
|
||||
trace_id = format(ctx.trace_id, "032x")
|
||||
span_id = format(ctx.span_id, "016x")
|
||||
flags = "01" if ctx.trace_flags else "00"
|
||||
return f"00-{trace_id}-{span_id}-{flags}"
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def make_trace_middleware(asgi_app: Any) -> Any:
|
||||
"""Wrap an ASGI application with W3C trace-context extraction middleware.
|
||||
|
||||
The middleware reads ``traceparent`` / ``tracestate`` from every incoming
|
||||
HTTP request and stores the extracted ``Context`` in the
|
||||
``_incoming_trace_context`` ContextVar. The A2A executor reads that
|
||||
ContextVar to parent its ``task_receive`` span correctly, forming an
|
||||
unbroken distributed trace across workspace hops.
|
||||
|
||||
Usage::
|
||||
|
||||
built = app.build()
|
||||
instrumented = make_trace_middleware(built)
|
||||
uvicorn.Config(instrumented, ...)
|
||||
"""
|
||||
|
||||
async def _middleware(scope: dict, receive: Any, send: Any) -> None: # type: ignore[override]
|
||||
if scope.get("type") != "http":
|
||||
await asgi_app(scope, receive, send)
|
||||
return
|
||||
|
||||
# Decode byte-headers from the ASGI scope (latin-1 per HTTP/1.1 spec)
|
||||
raw_headers: list[tuple[bytes, bytes]] = scope.get("headers", [])
|
||||
str_headers: dict[str, str] = {
|
||||
k.decode("latin-1"): v.decode("latin-1") for k, v in raw_headers
|
||||
}
|
||||
|
||||
ctx = extract_trace_context(str_headers)
|
||||
token = _incoming_trace_context.set(ctx)
|
||||
try:
|
||||
await asgi_app(scope, receive, send)
|
||||
finally:
|
||||
_incoming_trace_context.reset(token)
|
||||
|
||||
return _middleware
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers for GenAI attributes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def gen_ai_system_from_model(model_str: str) -> str:
|
||||
"""Map a ``provider:model`` string to a ``gen_ai.system`` value."""
|
||||
if ":" not in model_str:
|
||||
return "unknown"
|
||||
provider = model_str.split(":", 1)[0].lower()
|
||||
return {
|
||||
"anthropic": "anthropic",
|
||||
"openai": "openai",
|
||||
"openrouter": "openrouter",
|
||||
"groq": "groq",
|
||||
"google_genai": "google",
|
||||
"ollama": "ollama",
|
||||
}.get(provider, provider)
|
||||
|
||||
|
||||
def record_llm_token_usage(span: Any, result: dict) -> None:
|
||||
"""Extract token counts from a LangGraph ainvoke result and set span attrs.
|
||||
|
||||
Handles both Anthropic (``usage``) and OpenAI (``token_usage``) metadata
|
||||
shapes. Silently skips if metadata is absent.
|
||||
"""
|
||||
try:
|
||||
messages = result.get("messages", [])
|
||||
for msg in reversed(messages):
|
||||
meta = getattr(msg, "response_metadata", {}) or {}
|
||||
# Anthropic
|
||||
usage = meta.get("usage", {})
|
||||
if usage:
|
||||
inp = usage.get("input_tokens") or usage.get("prompt_tokens")
|
||||
out = usage.get("output_tokens") or usage.get("completion_tokens")
|
||||
if inp is not None:
|
||||
span.set_attribute(GEN_AI_USAGE_INPUT_TOKENS, int(inp))
|
||||
if out is not None:
|
||||
span.set_attribute(GEN_AI_USAGE_OUTPUT_TOKENS, int(out))
|
||||
return
|
||||
# OpenAI
|
||||
token_usage = meta.get("token_usage", {})
|
||||
if token_usage:
|
||||
inp = token_usage.get("prompt_tokens")
|
||||
out = token_usage.get("completion_tokens")
|
||||
if inp is not None:
|
||||
span.set_attribute(GEN_AI_USAGE_INPUT_TOKENS, int(inp))
|
||||
if out is not None:
|
||||
span.set_attribute(GEN_AI_USAGE_OUTPUT_TOKENS, int(out))
|
||||
return
|
||||
except Exception:
|
||||
pass # Best-effort — never break the caller
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# No-op fallbacks (used when opentelemetry packages are absent)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _NoopSpan:
|
||||
"""Transparent no-op span that satisfies the context-manager protocol."""
|
||||
|
||||
def set_attribute(self, key: str, value: Any) -> None: # noqa: ARG002
|
||||
pass
|
||||
|
||||
def set_status(self, *args: Any, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def record_exception(self, exc: BaseException, *args: Any, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def add_event(self, name: str, *args: Any, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def __enter__(self) -> "_NoopSpan":
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class _NoopTracer:
|
||||
"""Transparent no-op tracer returned when the SDK is unavailable."""
|
||||
|
||||
def start_as_current_span(self, name: str, *args: Any, **kwargs: Any) -> _NoopSpan: # noqa: ARG002
|
||||
return _NoopSpan()
|
||||
|
||||
def start_span(self, name: str, *args: Any, **kwargs: Any) -> _NoopSpan: # noqa: ARG002
|
||||
return _NoopSpan()
|
||||
515
molecule_runtime/builtin_tools/temporal_workflow.py
Normal file
515
molecule_runtime/builtin_tools/temporal_workflow.py
Normal file
@ -0,0 +1,515 @@
|
||||
"""Temporal durable execution wrapper for Molecule AI A2A workspaces.
|
||||
|
||||
Architecture
|
||||
-----------
|
||||
A co-located Temporal worker runs as an asyncio background task **inside the
|
||||
same process** as the A2A server. This means worker activities share the same
|
||||
memory space as the A2A handler, which lets us bridge non-serialisable objects
|
||||
(LangGraph agent, EventQueue, RequestContext) through an in-process registry
|
||||
without having to serialise them through Temporal's state store.
|
||||
|
||||
Workflow stages (names mirror the OTEL span names in a2a_executor.py):
|
||||
|
||||
task_receive → llm_call → task_complete
|
||||
|
||||
task_receive — durable checkpoint: task acknowledged, queued
|
||||
llm_call — durable checkpoint: LLM execution + SSE streaming (retryable)
|
||||
task_complete — durable checkpoint: execution finished, telemetry recorded
|
||||
|
||||
Crash-recovery behaviour
|
||||
------------------------
|
||||
If the process crashes while ``llm_call`` is running, Temporal retries the
|
||||
activity on the restarted process. The in-process registry is empty after a
|
||||
restart, so the activity detects a registry miss, logs a warning, and returns
|
||||
an error result. The SSE client connection is already gone at that point so
|
||||
no response can be delivered — but the task is permanently recorded in
|
||||
Temporal's history and will not silently disappear.
|
||||
|
||||
Env vars
|
||||
--------
|
||||
TEMPORAL_HOST Temporal gRPC endpoint (default: ``localhost:7233``)
|
||||
Set this to enable durable execution. Leave unset (or point
|
||||
at an unreachable host) to run in direct-execution mode.
|
||||
|
||||
Dependencies (optional)
|
||||
-----------
|
||||
temporalio>=1.7.0
|
||||
|
||||
Add to requirements.txt to enable. The module loads and the wrapper class
|
||||
works without the package installed — all Temporal paths return early with a
|
||||
graceful fallback to direct execution.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from datetime import timedelta
|
||||
from typing import Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Constants
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
_TASK_QUEUE = "molecule-agent-tasks"
|
||||
_WORKFLOW_EXECUTION_TIMEOUT = timedelta(minutes=30)
|
||||
_ACTIVITY_START_TO_CLOSE_TIMEOUT = timedelta(minutes=10)
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Serialisable data models
|
||||
# These are the only objects that cross the Temporal serialisation boundary.
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AgentTaskInput:
|
||||
"""Serialisable snapshot of an incoming A2A task.
|
||||
|
||||
All fields must be JSON-representable so that Temporal can persist them in
|
||||
its workflow history (used for crash recovery and replay).
|
||||
"""
|
||||
|
||||
task_id: str
|
||||
context_id: str
|
||||
user_input: str
|
||||
model: str
|
||||
workspace_id: str
|
||||
history: list # [[role, content], ...] — tuples converted to lists
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LLMResult:
|
||||
"""Serialisable execution result passed from ``llm_call`` to ``task_complete``."""
|
||||
|
||||
final_text: str
|
||||
success: bool
|
||||
error: str = ""
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# In-process registry
|
||||
#
|
||||
# Maps task_id → {executor, context, event_queue, final_text}
|
||||
# Activities look up non-serialisable objects here. The registry is
|
||||
# populated by TemporalWorkflowWrapper.run() before the workflow starts and
|
||||
# cleaned up in the finally block when the workflow completes.
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
_task_registry: dict[str, dict[str, Any]] = {}
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Temporal workflow + activities
|
||||
# Loaded only when the temporalio package is installed. The surrounding
|
||||
# try/except ensures the module imports cleanly without the package.
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
_TEMPORAL_AVAILABLE = False
|
||||
|
||||
try:
|
||||
from temporalio import activity, workflow
|
||||
from temporalio.client import Client
|
||||
from temporalio.worker import Worker
|
||||
|
||||
_TEMPORAL_AVAILABLE = True
|
||||
|
||||
# ── Activities ────────────────────────────────────────────────────────── #
|
||||
|
||||
@activity.defn(name="task_receive")
|
||||
async def task_receive_activity(inp: AgentTaskInput) -> dict:
|
||||
"""Durable checkpoint: task received and queued for LLM execution.
|
||||
|
||||
Mirrors the *task_receive* OTEL span opened in
|
||||
``LangGraphA2AExecutor._core_execute()``. This activity is lightweight —
|
||||
it validates that the in-process registry entry exists and logs receipt.
|
||||
The actual A2A "working" signal (``updater.start_work()``) is emitted
|
||||
inside ``_core_execute()`` so that SSE timing is preserved.
|
||||
"""
|
||||
logger.info(
|
||||
"Temporal[task_receive] task_id=%s context_id=%s workspace=%s model=%s",
|
||||
inp.task_id,
|
||||
inp.context_id,
|
||||
inp.workspace_id,
|
||||
inp.model,
|
||||
)
|
||||
if inp.task_id not in _task_registry:
|
||||
logger.warning(
|
||||
"Temporal[task_receive] task_id=%s not found in registry "
|
||||
"(crash recovery path — no SSE client connection available)",
|
||||
inp.task_id,
|
||||
)
|
||||
return {"task_id": inp.task_id, "status": "registry_miss"}
|
||||
|
||||
return {"task_id": inp.task_id, "status": "received"}
|
||||
|
||||
@activity.defn(name="llm_call")
|
||||
async def llm_call_activity(inp: AgentTaskInput) -> LLMResult:
|
||||
"""Durable checkpoint: LLM execution with streaming to the event_queue.
|
||||
|
||||
Mirrors the *llm_call* OTEL span in ``LangGraphA2AExecutor._core_execute()``.
|
||||
Calls ``executor._core_execute()`` which handles the full execution pipeline:
|
||||
SSE streaming, OTEL sub-spans, final message emission, and heartbeat updates.
|
||||
|
||||
On crash recovery (empty registry): logs a warning and returns an error
|
||||
result. Temporal records the failure and will retry if configured to do so.
|
||||
The original SSE client connection is gone after a crash, so no response
|
||||
can be delivered, but the task is durably recorded in Temporal's history.
|
||||
"""
|
||||
logger.info("Temporal[llm_call] task_id=%s", inp.task_id)
|
||||
|
||||
entry = _task_registry.get(inp.task_id)
|
||||
if entry is None:
|
||||
msg = (
|
||||
f"task_id={inp.task_id} not in registry — "
|
||||
"process likely restarted; original SSE client connection is gone"
|
||||
)
|
||||
logger.warning("Temporal[llm_call] registry miss: %s", msg)
|
||||
return LLMResult(final_text="", success=False, error=msg)
|
||||
|
||||
try:
|
||||
executor = entry["executor"]
|
||||
context = entry["context"]
|
||||
event_queue = entry["event_queue"]
|
||||
|
||||
# _core_execute() is the renamed body of the original execute().
|
||||
# It handles: OTEL spans, SSE streaming, final message, heartbeat.
|
||||
final_text = await executor._core_execute(context, event_queue)
|
||||
|
||||
# Cache for task_complete observability
|
||||
entry["final_text"] = final_text or ""
|
||||
return LLMResult(final_text=final_text or "", success=True)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Temporal[llm_call] task_id=%s execution error: %s",
|
||||
inp.task_id,
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
return LLMResult(final_text="", success=False, error=str(exc))
|
||||
|
||||
@activity.defn(name="task_complete")
|
||||
async def task_complete_activity(result: LLMResult) -> None:
|
||||
"""Durable checkpoint: task execution finished.
|
||||
|
||||
Mirrors the *task_complete* OTEL span in ``LangGraphA2AExecutor._core_execute()``.
|
||||
This activity records the outcome for Temporal observability. The actual
|
||||
OTEL task_complete span fires inside ``_core_execute()``; this activity
|
||||
provides a durable, queryable record in Temporal's workflow history.
|
||||
"""
|
||||
if result.success:
|
||||
logger.info(
|
||||
"Temporal[task_complete] success=True final_text_len=%d",
|
||||
len(result.final_text),
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Temporal[task_complete] success=False error=%r",
|
||||
result.error,
|
||||
)
|
||||
|
||||
# ── Workflow ──────────────────────────────────────────────────────────── #
|
||||
|
||||
@workflow.defn
|
||||
class MoleculeAIAgentWorkflow:
|
||||
"""Durable Temporal workflow for Molecule AI A2A agent task execution.
|
||||
|
||||
Sequences three activities that mirror the OTEL span hierarchy in
|
||||
``LangGraphA2AExecutor._core_execute()``:
|
||||
|
||||
task_receive → llm_call → task_complete
|
||||
|
||||
Each activity is a durable checkpoint: if the process crashes between
|
||||
activities, Temporal resumes from the last completed checkpoint on
|
||||
restart. If an activity fails (exception or timeout), Temporal can
|
||||
retry it according to the configured retry policy.
|
||||
"""
|
||||
|
||||
@workflow.run
|
||||
async def run(self, inp: AgentTaskInput) -> LLMResult:
|
||||
opts: dict[str, Any] = {
|
||||
"start_to_close_timeout": _ACTIVITY_START_TO_CLOSE_TIMEOUT,
|
||||
}
|
||||
|
||||
# Stage 1 — acknowledge receipt (lightweight checkpoint)
|
||||
await workflow.execute_activity(task_receive_activity, inp, **opts)
|
||||
|
||||
# Stage 2 — LLM execution (main work; retryable on crash/timeout)
|
||||
result: LLMResult = await workflow.execute_activity(
|
||||
llm_call_activity, inp, **opts
|
||||
)
|
||||
|
||||
# Stage 3 — record completion (lightweight checkpoint)
|
||||
await workflow.execute_activity(task_complete_activity, result, **opts)
|
||||
|
||||
return result
|
||||
|
||||
except ImportError:
|
||||
# temporalio not installed — the wrapper class below will gracefully fall
|
||||
# back to direct execution for every call.
|
||||
logger.debug(
|
||||
"Temporal: temporalio package not installed — "
|
||||
"durable execution disabled (add temporalio>=1.7.0 to requirements.txt)"
|
||||
)
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# TemporalWorkflowWrapper
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TemporalWorkflowWrapper:
|
||||
"""Wraps ``LangGraphA2AExecutor.execute()`` with Temporal durable execution.
|
||||
|
||||
The wrapper intercepts each ``execute()`` call and routes it through a
|
||||
``MoleculeAIAgentWorkflow`` Temporal workflow. If Temporal is unavailable
|
||||
for any reason, execution falls back transparently to the direct path
|
||||
(``executor._core_execute()``), so the A2A server never crashes due to
|
||||
Temporal issues.
|
||||
|
||||
Lifecycle
|
||||
---------
|
||||
1. ``create_wrapper()`` — instantiate and register the global singleton.
|
||||
2. ``await wrapper.start()`` — connect to Temporal, launch the background
|
||||
worker. No-op (with a log warning) if Temporal is unreachable.
|
||||
3. Normal operation — ``wrapper.run()`` is called from ``execute()``.
|
||||
4. ``await wrapper.stop()`` — cancel the background worker task on shutdown.
|
||||
|
||||
Co-located worker pattern
|
||||
-------------------------
|
||||
The Temporal worker runs as an asyncio background task in the **same event
|
||||
loop** as the A2A server. This means:
|
||||
- No separate worker process to manage.
|
||||
- Activities share the process's memory (registry access works).
|
||||
- Worker and server share the same asyncio event loop.
|
||||
|
||||
Env vars
|
||||
--------
|
||||
``TEMPORAL_HOST`` Temporal gRPC address, e.g. ``localhost:7233`` or
|
||||
``temporal.internal:7233``. Defaults to
|
||||
``localhost:7233``. If Temporal is not reachable at
|
||||
this address, the wrapper falls back to direct execution.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._host: str = os.environ.get("TEMPORAL_HOST", "localhost:7233")
|
||||
self._client: Optional[Any] = None
|
||||
self._worker: Optional[Any] = None
|
||||
self._worker_task: Optional[asyncio.Task] = None # type: ignore[type-arg]
|
||||
self._available: bool = False
|
||||
|
||||
# ── Lifecycle ─────────────────────────────────────────────────────────── #
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Connect to Temporal and start the co-located background worker.
|
||||
|
||||
Safe to call multiple times (idempotent after first success).
|
||||
Never raises — logs a warning and returns on any failure.
|
||||
"""
|
||||
if not _TEMPORAL_AVAILABLE:
|
||||
logger.info(
|
||||
"Temporal: temporalio package not installed — "
|
||||
"all tasks will use direct execution. "
|
||||
"To enable durable execution: pip install temporalio>=1.7.0"
|
||||
)
|
||||
return
|
||||
|
||||
if self._available:
|
||||
return # already started
|
||||
|
||||
# Connect to the Temporal server
|
||||
try:
|
||||
self._client = await Client.connect(self._host) # type: ignore[name-defined]
|
||||
logger.info("Temporal: connected to %s", self._host)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Temporal: cannot connect to %s (%s) — "
|
||||
"all tasks will use direct execution (no durable state)",
|
||||
self._host,
|
||||
exc,
|
||||
)
|
||||
return
|
||||
|
||||
# Start the worker as an asyncio background task
|
||||
try:
|
||||
self._worker = Worker( # type: ignore[name-defined]
|
||||
self._client,
|
||||
task_queue=_TASK_QUEUE,
|
||||
workflows=[MoleculeAIAgentWorkflow], # type: ignore[name-defined]
|
||||
activities=[
|
||||
task_receive_activity, # type: ignore[name-defined]
|
||||
llm_call_activity, # type: ignore[name-defined]
|
||||
task_complete_activity, # type: ignore[name-defined]
|
||||
],
|
||||
)
|
||||
self._worker_task = asyncio.create_task(
|
||||
self._worker.run(),
|
||||
name="temporal-worker",
|
||||
)
|
||||
self._available = True
|
||||
logger.info(
|
||||
"Temporal: co-located worker started on task queue '%s'",
|
||||
_TASK_QUEUE,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Temporal: worker initialisation failed (%s) — "
|
||||
"falling back to direct execution",
|
||||
exc,
|
||||
)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Gracefully stop the Temporal worker background task."""
|
||||
self._available = False
|
||||
if self._worker_task and not self._worker_task.done():
|
||||
self._worker_task.cancel()
|
||||
try:
|
||||
await self._worker_task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
logger.info("Temporal: worker stopped")
|
||||
|
||||
# ── Public API ────────────────────────────────────────────────────────── #
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Return ``True`` if Temporal is connected and the worker is running."""
|
||||
return self._available
|
||||
|
||||
async def run(
|
||||
self,
|
||||
executor: Any,
|
||||
context: Any,
|
||||
event_queue: Any,
|
||||
) -> None:
|
||||
"""Route one A2A task execution through a Temporal durable workflow.
|
||||
|
||||
Steps
|
||||
-----
|
||||
1. Build a serialisable ``AgentTaskInput`` from the A2A request context.
|
||||
2. Store non-serialisable state (executor, context, event_queue) in
|
||||
the in-process ``_task_registry`` keyed by task_id.
|
||||
3. Submit and await ``MoleculeAIAgentWorkflow`` on the Temporal server.
|
||||
4. Clean up the registry entry (always, via ``finally``).
|
||||
|
||||
Falls back to ``executor._core_execute()`` if:
|
||||
- Temporal is not available (``is_available()`` is False).
|
||||
- Input extraction fails.
|
||||
- The workflow raises any exception.
|
||||
|
||||
This guarantees that the A2A client always receives a response even
|
||||
when Temporal is misconfigured or temporarily unreachable.
|
||||
"""
|
||||
if not self._available or self._client is None:
|
||||
# Temporal unavailable — silent direct fallback
|
||||
await executor._core_execute(context, event_queue)
|
||||
return
|
||||
|
||||
task_id = getattr(context, "task_id", None) or str(uuid.uuid4())
|
||||
context_id = getattr(context, "context_id", None) or str(uuid.uuid4())
|
||||
|
||||
# Build serialisable AgentTaskInput
|
||||
try:
|
||||
from adapters.shared_runtime import (
|
||||
extract_history as _extract_history,
|
||||
extract_message_text,
|
||||
)
|
||||
|
||||
user_input = extract_message_text(context) or ""
|
||||
raw_history = _extract_history(context)
|
||||
# Convert (role, content) tuples → [role, content] lists (JSON-safe)
|
||||
history: list = [list(pair) for pair in raw_history]
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Temporal: failed to extract serialisable task input (%s) — "
|
||||
"falling back to direct execution",
|
||||
exc,
|
||||
)
|
||||
await executor._core_execute(context, event_queue)
|
||||
return
|
||||
|
||||
inp = AgentTaskInput(
|
||||
task_id=task_id,
|
||||
context_id=context_id,
|
||||
user_input=user_input,
|
||||
model=getattr(executor, "_model", "unknown"),
|
||||
workspace_id=os.environ.get("WORKSPACE_ID", "unknown"),
|
||||
history=history,
|
||||
)
|
||||
|
||||
# Register non-serialisable in-process state for activities to access
|
||||
_task_registry[task_id] = {
|
||||
"executor": executor,
|
||||
"context": context,
|
||||
"event_queue": event_queue,
|
||||
"final_text": "",
|
||||
}
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
"Temporal: starting workflow molecule-%s on queue '%s'",
|
||||
task_id,
|
||||
_TASK_QUEUE,
|
||||
)
|
||||
await self._client.execute_workflow(
|
||||
MoleculeAIAgentWorkflow.run, # type: ignore[name-defined]
|
||||
inp,
|
||||
id=f"molecule-{task_id}",
|
||||
task_queue=_TASK_QUEUE,
|
||||
execution_timeout=_WORKFLOW_EXECUTION_TIMEOUT,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"Temporal: workflow molecule-%s failed (%s) — "
|
||||
"falling back to direct execution so client receives a response",
|
||||
task_id,
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
# Direct fallback ensures the SSE client is never left hanging
|
||||
await executor._core_execute(context, event_queue)
|
||||
finally:
|
||||
_task_registry.pop(task_id, None)
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Module-level singleton helpers
|
||||
# Used by a2a_executor.py and main.py
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
_global_wrapper: Optional[TemporalWorkflowWrapper] = None
|
||||
|
||||
|
||||
def get_wrapper() -> Optional[TemporalWorkflowWrapper]:
|
||||
"""Return the global ``TemporalWorkflowWrapper``, or ``None`` if not set.
|
||||
|
||||
Called from ``LangGraphA2AExecutor.execute()`` on every request.
|
||||
Returns ``None`` before ``create_wrapper()`` is called (direct-execution mode).
|
||||
"""
|
||||
return _global_wrapper
|
||||
|
||||
|
||||
def create_wrapper() -> TemporalWorkflowWrapper:
|
||||
"""Create (or return the existing) global ``TemporalWorkflowWrapper``.
|
||||
|
||||
Idempotent — safe to call multiple times. Call ``await wrapper.start()``
|
||||
after this to connect to Temporal and launch the background worker.
|
||||
|
||||
Example (in main.py)::
|
||||
|
||||
from builtin_tools.temporal_workflow import create_wrapper as create_temporal_wrapper
|
||||
temporal_wrapper = create_temporal_wrapper()
|
||||
await temporal_wrapper.start() # connects + starts worker
|
||||
try:
|
||||
await server.serve()
|
||||
finally:
|
||||
await temporal_wrapper.stop()
|
||||
"""
|
||||
global _global_wrapper
|
||||
if _global_wrapper is None:
|
||||
_global_wrapper = TemporalWorkflowWrapper()
|
||||
return _global_wrapper
|
||||
449
molecule_runtime/claude_sdk_executor.py
Normal file
449
molecule_runtime/claude_sdk_executor.py
Normal file
@ -0,0 +1,449 @@
|
||||
"""SDK-based agent executor for Claude Code runtime.
|
||||
|
||||
Uses the official `claude-agent-sdk` Python package to invoke the Claude Code
|
||||
engine programmatically — no subprocess, no stdout parsing, no zombie reap.
|
||||
|
||||
Replaces CLIAgentExecutor for the `claude-code` runtime only. Other CLI runtimes
|
||||
(codex, ollama) keep using `cli_executor.py`.
|
||||
|
||||
Benefits over CLI subprocess:
|
||||
- No per-message ~500ms startup overhead
|
||||
- No stdout buffering issues
|
||||
- Native Python session management (no JSON parsing of stdout)
|
||||
- Real message stream — can surface tool calls in future for live UX
|
||||
- Cooperative cancel (closes the query async generator on cancel())
|
||||
- Same Claude Code engine, so plugins / skills / CLAUDE.md still apply
|
||||
|
||||
Concurrency model
|
||||
-----------------
|
||||
Turns are serialized per-executor via an asyncio.Lock. The old CLI executor
|
||||
serialized implicitly by spawning one subprocess per message and awaiting it;
|
||||
the SDK removes that, so we re-introduce serialization explicitly. This keeps
|
||||
session_id updates race-free and makes cancel() well-defined (there's at most
|
||||
one active stream at any given moment).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import claude_agent_sdk as sdk
|
||||
|
||||
from a2a.server.agent_execution import AgentExecutor, RequestContext
|
||||
from a2a.server.events import EventQueue
|
||||
from a2a.utils import new_agent_text_message
|
||||
|
||||
from executor_helpers import (
|
||||
CONFIG_MOUNT,
|
||||
MEMORY_CONTENT_MAX_CHARS,
|
||||
WORKSPACE_MOUNT,
|
||||
brief_summary,
|
||||
commit_memory,
|
||||
extract_message_text,
|
||||
get_a2a_instructions,
|
||||
get_mcp_server_path,
|
||||
get_system_prompt,
|
||||
read_delegation_results,
|
||||
recall_memories,
|
||||
sanitize_agent_error,
|
||||
set_current_task,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from heartbeat import HeartbeatLoop
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_NO_TEXT_MSG = "Error: message contained no text content."
|
||||
_NO_RESPONSE_MSG = "(no response generated)"
|
||||
_MAX_RETRIES = 3
|
||||
_BASE_RETRY_DELAY_S = 5
|
||||
# Cap for stderr captured from the CLI subprocess in the executor log. Keeps
|
||||
# log lines bounded while still surfacing enough context to diagnose crashes.
|
||||
# Fixes #66 (previously the executor logged nothing beyond the generic
|
||||
# "Check stderr output for details" message).
|
||||
_PROCESS_ERROR_STDERR_MAX_CHARS = 4096
|
||||
|
||||
# Substrings in error messages that indicate a transient failure worth retrying.
|
||||
_RETRYABLE_PATTERNS = (
|
||||
"rate",
|
||||
"limit",
|
||||
"429",
|
||||
"overloaded",
|
||||
"capacity",
|
||||
"exit code 1",
|
||||
"try again",
|
||||
)
|
||||
|
||||
|
||||
_SWALLOWED_STDERR_MARKER = "Check stderr output for details"
|
||||
|
||||
|
||||
def _probe_claude_cli_error() -> str | None:
|
||||
"""Run ``claude --print`` directly and capture its stderr + stdout.
|
||||
|
||||
Used as a fallback when the claude-agent-sdk raises a bare ``Exception``
|
||||
with the swallowed "Check stderr output for details" placeholder — that
|
||||
happens when the SDK wraps a stream error from the CLI subprocess and
|
||||
loses both the ``.stderr`` attribute and the exit code. At that point
|
||||
the only way to see the real failure reason (rate limit, auth error,
|
||||
network outage, missing token) is to run the CLI ourselves.
|
||||
|
||||
Bounded by a 30s timeout so a hung CLI can't stall the error path.
|
||||
Returns None if the probe itself failed (wrong invariant — don't
|
||||
corrupt the main error message with probe noise).
|
||||
"""
|
||||
try:
|
||||
import subprocess
|
||||
# --print reads stdin, prints response, exits. Empty stdin gives the
|
||||
# CLI something to work with without triggering an actual model call
|
||||
# when it's going to fail anyway.
|
||||
proc = subprocess.run(
|
||||
["claude", "--print"],
|
||||
input="probe",
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30,
|
||||
)
|
||||
if proc.returncode == 0:
|
||||
# CLI succeeded — the original error was a transient state that
|
||||
# resolved between the SDK failure and our probe. Signal that.
|
||||
return "<cli probe succeeded — error was transient>"
|
||||
raw = (proc.stderr or "") + (proc.stdout or "")
|
||||
raw = raw.strip()
|
||||
if not raw:
|
||||
return f"<cli exited {proc.returncode} with empty output>"
|
||||
if len(raw) > _PROCESS_ERROR_STDERR_MAX_CHARS:
|
||||
raw = raw[:_PROCESS_ERROR_STDERR_MAX_CHARS] + "... [truncated]"
|
||||
return raw
|
||||
except Exception as probe_exc: # pragma: no cover — best-effort diagnostic
|
||||
return f"<probe failed: {type(probe_exc).__name__}: {probe_exc}>"
|
||||
|
||||
|
||||
def _format_process_error(exc: BaseException) -> str:
|
||||
"""Render a Claude-SDK ProcessError (or any ClaudeSDKError) with its full
|
||||
captured context — exit code, stderr, exception type. Plain strings for
|
||||
non-SDK exceptions fall back to str(exc).
|
||||
|
||||
Bounded at _PROCESS_ERROR_STDERR_MAX_CHARS so a runaway CLI can't spam
|
||||
the log. Used by the executor's error path (fixes #66 — the SDK's
|
||||
ProcessError carries `.stderr`/`.exit_code` attributes that the previous
|
||||
code silently discarded, leaving every CLI crash with an identical
|
||||
"Check stderr output for details" message in the workspace log).
|
||||
|
||||
Fixes #160: when the SDK raises a bare ``Exception`` containing the
|
||||
"Check stderr output for details" placeholder (which happens when the
|
||||
CLI subprocess emits a stream error the SDK can't categorize — rate
|
||||
limit, auth, network), there's no ``.stderr``/``.exit_code`` to read.
|
||||
In that case we fall back to running the CLI ourselves via
|
||||
``_probe_claude_cli_error`` so the operator sees the real failure
|
||||
reason (e.g. ``You've hit your limit · resets Apr 17``) instead of
|
||||
chasing ghosts in the workspace logs.
|
||||
"""
|
||||
parts = [f"{type(exc).__name__}: {exc}"]
|
||||
exit_code = getattr(exc, "exit_code", None)
|
||||
if exit_code is not None:
|
||||
parts.append(f"exit_code={exit_code}")
|
||||
stderr = getattr(exc, "stderr", None)
|
||||
if stderr:
|
||||
trimmed = stderr[:_PROCESS_ERROR_STDERR_MAX_CHARS]
|
||||
if len(stderr) > _PROCESS_ERROR_STDERR_MAX_CHARS:
|
||||
trimmed += f"... [{len(stderr) - _PROCESS_ERROR_STDERR_MAX_CHARS} more chars truncated]"
|
||||
parts.append(f"stderr={trimmed!r}")
|
||||
elif exit_code is None and _SWALLOWED_STDERR_MARKER in str(exc):
|
||||
# #160: generic exception with the swallowed-stderr placeholder.
|
||||
# Probe the CLI directly — this is the only way to surface the real
|
||||
# error when the SDK lost it in translation.
|
||||
probed = _probe_claude_cli_error()
|
||||
if probed:
|
||||
parts.append(f"probed_cli_error={probed!r}")
|
||||
return " | ".join(parts)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryResult:
|
||||
"""Outcome of a single `query()` stream.
|
||||
|
||||
`text` is the canonical final response; `session_id` is the id the SDK
|
||||
reports in its ResultMessage (used for resume on the next turn).
|
||||
"""
|
||||
text: str
|
||||
session_id: str | None
|
||||
|
||||
|
||||
class ClaudeSDKExecutor(AgentExecutor):
|
||||
"""Executes agent tasks via the claude-agent-sdk programmatic API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
system_prompt: str | None,
|
||||
config_path: str,
|
||||
heartbeat: "HeartbeatLoop | None",
|
||||
model: str = "sonnet",
|
||||
):
|
||||
self.system_prompt = system_prompt
|
||||
self.config_path = config_path
|
||||
self.heartbeat = heartbeat
|
||||
self.model = model
|
||||
self._session_id: str | None = None
|
||||
self._active_stream: AsyncIterator[Any] | None = None
|
||||
# Serializes concurrent execute() calls on the same executor so
|
||||
# session_id / _active_stream mutations stay race-free.
|
||||
self._run_lock = asyncio.Lock()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Prompt + options builders
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _resolve_cwd(self) -> str:
|
||||
"""Run in /workspace if it has been populated, otherwise /configs."""
|
||||
if os.path.isdir(WORKSPACE_MOUNT) and os.listdir(WORKSPACE_MOUNT):
|
||||
return WORKSPACE_MOUNT
|
||||
return CONFIG_MOUNT
|
||||
|
||||
def _build_system_prompt(self) -> str | None:
|
||||
"""Compose system prompt from file + A2A delegation instructions."""
|
||||
base = get_system_prompt(self.config_path, fallback=self.system_prompt)
|
||||
a2a = get_a2a_instructions(mcp=True)
|
||||
if base and a2a:
|
||||
return f"{base}\n\n{a2a}"
|
||||
return base or a2a
|
||||
|
||||
def _prepare_prompt(self, user_input: str) -> str:
|
||||
"""Prepend delegation results that arrived while idle."""
|
||||
delegation_context = read_delegation_results()
|
||||
if delegation_context:
|
||||
return (
|
||||
"[Delegation results received while you were idle]\n"
|
||||
f"{delegation_context}\n\n[New message]\n{user_input}"
|
||||
)
|
||||
return user_input
|
||||
|
||||
async def _inject_memories_if_first_turn(self, prompt: str) -> str:
|
||||
if self._session_id:
|
||||
return prompt
|
||||
memories = await recall_memories()
|
||||
if not memories:
|
||||
return prompt
|
||||
return f"[Prior context from memory]\n{memories}\n\n{prompt}"
|
||||
|
||||
def _build_options(self) -> Any:
|
||||
"""Build ClaudeAgentOptions.
|
||||
|
||||
No allowed_tools allowlist — bypassPermissions grants full access,
|
||||
matching the old CLI `--dangerously-skip-permissions` so Claude can
|
||||
use every built-in tool (Task, TodoWrite, NotebookEdit, BashOutput/
|
||||
KillShell, ExitPlanMode, etc.) plus all MCP tools.
|
||||
|
||||
The MCP server launcher uses `sys.executable` so tests and alternate
|
||||
virtual-env layouts don't depend on a `python3` shim being on PATH.
|
||||
"""
|
||||
mcp_servers = {
|
||||
"a2a": {
|
||||
"command": sys.executable,
|
||||
"args": [get_mcp_server_path()],
|
||||
}
|
||||
}
|
||||
return sdk.ClaudeAgentOptions(
|
||||
model=self.model,
|
||||
permission_mode="bypassPermissions",
|
||||
cwd=self._resolve_cwd(),
|
||||
mcp_servers=mcp_servers,
|
||||
system_prompt=self._build_system_prompt(),
|
||||
resume=self._session_id,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Query streaming
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _run_query(self, prompt: str, options: Any) -> QueryResult:
|
||||
"""Drive the SDK query stream and return a QueryResult.
|
||||
|
||||
Prefers ResultMessage.result (the canonical final text — same field
|
||||
the CLI's --output-format json used) and only falls back to the
|
||||
concatenation of AssistantMessage TextBlocks when result is absent.
|
||||
Otherwise pre-tool reasoning and post-tool summary get double-emitted.
|
||||
|
||||
Pure: does not mutate executor state other than setting / clearing
|
||||
`self._active_stream` so cancel() can reach in. The caller decides
|
||||
whether to persist the returned session_id.
|
||||
"""
|
||||
assistant_chunks: list[str] = []
|
||||
result_text: str | None = None
|
||||
session_id: str | None = None
|
||||
self._active_stream = sdk.query(prompt=prompt, options=options)
|
||||
try:
|
||||
async for message in self._active_stream:
|
||||
if isinstance(message, sdk.AssistantMessage):
|
||||
for block in message.content:
|
||||
if isinstance(block, sdk.TextBlock):
|
||||
assistant_chunks.append(block.text)
|
||||
elif isinstance(message, sdk.ResultMessage):
|
||||
sid = getattr(message, "session_id", None)
|
||||
if sid:
|
||||
session_id = sid
|
||||
result_text = getattr(message, "result", None)
|
||||
finally:
|
||||
self._active_stream = None
|
||||
text = result_text if result_text is not None else "".join(assistant_chunks)
|
||||
return QueryResult(text=text, session_id=session_id)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# AgentExecutor interface
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def execute(self, context: RequestContext, event_queue: EventQueue):
|
||||
"""Run a turn through the Claude Agent SDK and emit the response.
|
||||
|
||||
Serialized via `self._run_lock` — concurrent A2A messages to the same
|
||||
workspace queue rather than racing on `_session_id` / `_active_stream`.
|
||||
"""
|
||||
user_input = extract_message_text(context.message)
|
||||
if not user_input:
|
||||
await event_queue.enqueue_event(new_agent_text_message(_NO_TEXT_MSG))
|
||||
return
|
||||
|
||||
async with self._run_lock:
|
||||
response_text = await self._execute_locked(user_input)
|
||||
|
||||
# Enqueue outside the lock so the next queued turn can start
|
||||
# preparing its prompt while this turn's response ships. Event
|
||||
# ordering is preserved per-queue by the A2A server, so no races.
|
||||
await event_queue.enqueue_event(new_agent_text_message(response_text))
|
||||
|
||||
@staticmethod
|
||||
def _is_retryable(exc: BaseException) -> bool:
|
||||
"""Check if an SDK exception looks like a transient rate-limit or
|
||||
capacity error that's worth retrying with backoff."""
|
||||
msg = str(exc).lower()
|
||||
return any(p in msg for p in _RETRYABLE_PATTERNS)
|
||||
|
||||
def _reset_session_after_error(self, exc: BaseException) -> None:
|
||||
"""Clear `_session_id` if the exception looks like a subprocess
|
||||
crash (#75). On the next `_build_options()` call `resume=None` is
|
||||
passed to the SDK, so the CLI boots a brand-new session instead of
|
||||
trying to resume one the previous subprocess left in an
|
||||
unrecoverable state.
|
||||
|
||||
Kept in its own method so the policy can evolve (e.g. also clear
|
||||
on MessageParseError) without touching the retry loop. Logs at
|
||||
INFO when a session was actually cleared; silent when there was
|
||||
nothing to reset.
|
||||
"""
|
||||
exc_name = type(exc).__name__
|
||||
# Conservative: reset only on subprocess-level failures. Pure
|
||||
# rate-limit / capacity errors don't leave the session in a bad
|
||||
# state — keep the session_id so the resumed turn preserves
|
||||
# conversational continuity.
|
||||
is_subprocess_error = (
|
||||
exc_name in ("ProcessError", "CLIConnectionError")
|
||||
or getattr(exc, "exit_code", None) is not None
|
||||
or "exit code" in str(exc).lower()
|
||||
)
|
||||
if not is_subprocess_error:
|
||||
return
|
||||
if self._session_id is None:
|
||||
return
|
||||
logger.info(
|
||||
"SDK session reset after %s: clearing session_id so the next "
|
||||
"attempt starts fresh (fixes #75 session contamination)",
|
||||
exc_name,
|
||||
)
|
||||
self._session_id = None
|
||||
|
||||
async def _execute_locked(self, user_input: str) -> str:
|
||||
"""Body of execute() that runs under the run lock.
|
||||
|
||||
Retries transient errors (rate limits, capacity, exit-code-1) up to
|
||||
_MAX_RETRIES times with exponential backoff (5s, 10s, 20s).
|
||||
"""
|
||||
# Keep a clean copy of the user's actual message for the memory record,
|
||||
# BEFORE any delegation or memory injection.
|
||||
original_input = user_input
|
||||
await set_current_task(self.heartbeat, brief_summary(user_input))
|
||||
logger.debug("SDK execute [claude-code]: %s", user_input[:200])
|
||||
|
||||
prompt = self._prepare_prompt(user_input)
|
||||
prompt = await self._inject_memories_if_first_turn(prompt)
|
||||
|
||||
response_text: str = ""
|
||||
try:
|
||||
for attempt in range(_MAX_RETRIES):
|
||||
options = self._build_options()
|
||||
try:
|
||||
result = await self._run_query(prompt=prompt, options=options)
|
||||
if result.session_id:
|
||||
self._session_id = result.session_id
|
||||
response_text = result.text
|
||||
break # success
|
||||
except Exception as exc:
|
||||
formatted = _format_process_error(exc)
|
||||
# #75: CLI subprocess crashes leave our _session_id
|
||||
# referencing a session the next subprocess can't
|
||||
# resume. Without this reset the next attempt would
|
||||
# crash identically even when the underlying cause
|
||||
# was transient, cascading into "crashed once →
|
||||
# crashes forever until container restart." Clear
|
||||
# the session_id so the next attempt (retry or
|
||||
# next user turn) starts fresh.
|
||||
self._reset_session_after_error(exc)
|
||||
if attempt < _MAX_RETRIES - 1 and self._is_retryable(exc):
|
||||
delay = _BASE_RETRY_DELAY_S * (2 ** attempt)
|
||||
logger.warning(
|
||||
"SDK agent [claude-code] transient error (attempt %d/%d), "
|
||||
"retrying in %ds: %s",
|
||||
attempt + 1, _MAX_RETRIES, delay, formatted,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
# Non-retryable or exhausted retries. Log exit_code +
|
||||
# stderr explicitly (fixes #66) so operators don't have
|
||||
# to reproduce the crash manually to find out why the
|
||||
# subprocess died.
|
||||
logger.error("SDK agent error [claude-code]: %s", formatted)
|
||||
logger.exception("SDK agent error [claude-code] — full traceback follows")
|
||||
response_text = sanitize_agent_error(exc)
|
||||
break
|
||||
finally:
|
||||
await set_current_task(self.heartbeat, "")
|
||||
await commit_memory(
|
||||
f"Conversation: {original_input[:MEMORY_CONTENT_MAX_CHARS]}"
|
||||
)
|
||||
|
||||
return response_text or _NO_RESPONSE_MSG
|
||||
|
||||
async def cancel(self, context: RequestContext, event_queue: EventQueue):
|
||||
"""Cooperatively cancel the currently running turn.
|
||||
|
||||
cancel() targets whatever turn is in flight *right now*, not the
|
||||
specific turn the caller may have been looking at when they sent
|
||||
the cancel request. If turn A has finished and turn B is already
|
||||
running under the run lock by the time cancel arrives, turn B is
|
||||
the one that gets aborted. This matches how a "stop" button in a
|
||||
chat UI typically behaves (stop whatever is running) and is a
|
||||
conscious trade-off against per-turn bookkeeping.
|
||||
|
||||
Implementation: the SDK's query() is an async generator; calling
|
||||
aclose() raises GeneratorExit inside the running turn and unwinds
|
||||
cleanly. We read `self._active_stream` into a local BEFORE calling
|
||||
aclose so the reference can't be reassigned by another turn
|
||||
mid-cancel. Best-effort — if no stream is active (cancel arrived
|
||||
between turns, or the stream has no aclose), this is a no-op.
|
||||
"""
|
||||
stream = self._active_stream
|
||||
if stream is None:
|
||||
return
|
||||
aclose = getattr(stream, "aclose", None)
|
||||
if aclose is None:
|
||||
return
|
||||
try:
|
||||
await aclose()
|
||||
except Exception:
|
||||
logger.exception("SDK cancel: aclose() raised")
|
||||
456
molecule_runtime/cli_executor.py
Normal file
456
molecule_runtime/cli_executor.py
Normal file
@ -0,0 +1,456 @@
|
||||
"""CLI-based agent executor for A2A protocol.
|
||||
|
||||
Supports CLI agents that accept a prompt and output a response:
|
||||
- OpenAI Codex: codex --print -p "..."
|
||||
- Ollama: ollama run <model> "..."
|
||||
- Custom: any command that reads stdin or accepts -p
|
||||
|
||||
NOTE: the `claude-code` runtime no longer routes here. It uses
|
||||
ClaudeSDKExecutor (see claude_sdk_executor.py) which wraps the
|
||||
claude-agent-sdk Python package. This executor is reserved for CLI-only
|
||||
runtimes that don't yet have a programmatic SDK integration.
|
||||
|
||||
The runtime is selected via config.yaml:
|
||||
runtime: codex | ollama | custom
|
||||
runtime_config:
|
||||
command: "codex" # for custom
|
||||
args: ["--extra-flag"] # additional CLI args
|
||||
auth_token_env: "OPENAI_API_KEY"
|
||||
auth_token_file: ".auth-token"
|
||||
timeout: 300
|
||||
model: "sonnet"
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import atexit
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shlex
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from a2a.server.agent_execution import AgentExecutor, RequestContext
|
||||
from a2a.server.events import EventQueue
|
||||
from a2a.utils import new_agent_text_message
|
||||
|
||||
from config import RuntimeConfig
|
||||
from executor_helpers import (
|
||||
CONFIG_MOUNT,
|
||||
MEMORY_CONTENT_MAX_CHARS,
|
||||
WORKSPACE_MOUNT,
|
||||
brief_summary,
|
||||
classify_subprocess_error,
|
||||
commit_memory,
|
||||
extract_message_text,
|
||||
get_a2a_instructions,
|
||||
get_mcp_server_path,
|
||||
get_system_prompt,
|
||||
read_delegation_results,
|
||||
recall_memories,
|
||||
sanitize_agent_error,
|
||||
set_current_task,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Built-in runtime presets.
|
||||
# The `claude-code` runtime uses ClaudeSDKExecutor (claude_sdk_executor.py)
|
||||
# and intentionally has no entry here.
|
||||
RUNTIME_PRESETS: dict[str, dict] = {
|
||||
"codex": {
|
||||
"command": "codex",
|
||||
"base_args": ["--print", "--dangerously-skip-permissions"],
|
||||
"prompt_flag": "-p",
|
||||
"model_flag": "--model",
|
||||
"system_prompt_flag": "--system-prompt",
|
||||
"auth_pattern": "env", # uses OPENAI_API_KEY env var
|
||||
"default_auth_env": "OPENAI_API_KEY",
|
||||
"default_auth_file": "",
|
||||
},
|
||||
"ollama": {
|
||||
"command": "ollama",
|
||||
"base_args": ["run"],
|
||||
"prompt_flag": None, # prompt is positional
|
||||
"model_flag": None, # model is positional after "run"
|
||||
"system_prompt_flag": "--system",
|
||||
"auth_pattern": None, # no auth needed
|
||||
"default_auth_env": "",
|
||||
"default_auth_file": "",
|
||||
},
|
||||
# Gemini CLI (github.com/google-gemini/gemini-cli, Apache 2.0).
|
||||
# Auth via GEMINI_API_KEY env var; MCP is wired via ~/.gemini/settings.json
|
||||
# (not --mcp-config) — the adapter's setup() handles that step.
|
||||
# System prompt is seeded into GEMINI.md (equivalent of CLAUDE.md).
|
||||
"gemini-cli": {
|
||||
"command": "gemini",
|
||||
"base_args": ["--yolo"], # auto-approve all tool calls (non-interactive)
|
||||
"prompt_flag": "-p",
|
||||
"model_flag": "--model",
|
||||
"system_prompt_flag": None, # GEMINI.md used instead; seeded by adapter.setup()
|
||||
"auth_pattern": "env", # GEMINI_API_KEY; also enables A2A MCP instructions
|
||||
"default_auth_env": "GEMINI_API_KEY",
|
||||
"default_auth_file": "",
|
||||
"mcp_via_settings": True, # MCP injected into ~/.gemini/settings.json, not --mcp-config
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class CLIAgentExecutor(AgentExecutor):
|
||||
"""Executes agent tasks by invoking a CLI tool.
|
||||
|
||||
Works with any CLI agent that accepts a prompt and outputs text.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
runtime: str,
|
||||
runtime_config: RuntimeConfig,
|
||||
system_prompt: str | None = None,
|
||||
config_path: str = "/configs",
|
||||
heartbeat: "HeartbeatLoop | None" = None,
|
||||
):
|
||||
if runtime == "claude-code":
|
||||
# Defensive — the adapter should never construct a CLI executor
|
||||
# for claude-code. Fail loud rather than silently falling back.
|
||||
raise ValueError(
|
||||
"claude-code runtime is served by ClaudeSDKExecutor, not "
|
||||
"CLIAgentExecutor. Check adapters/claude_code/adapter.py."
|
||||
)
|
||||
self.runtime = runtime
|
||||
self.config = runtime_config
|
||||
self.system_prompt = system_prompt
|
||||
self.config_path = config_path
|
||||
self._heartbeat = heartbeat
|
||||
|
||||
# Resolve preset or use custom
|
||||
if runtime in RUNTIME_PRESETS:
|
||||
self.preset = RUNTIME_PRESETS[runtime]
|
||||
elif runtime == "custom":
|
||||
self.preset = {
|
||||
"command": runtime_config.command,
|
||||
"base_args": [], # args go in config.args, appended at end
|
||||
"prompt_flag": "-p",
|
||||
"model_flag": None,
|
||||
"system_prompt_flag": None,
|
||||
"auth_pattern": None,
|
||||
"default_auth_env": "",
|
||||
"default_auth_file": "",
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unknown runtime: {runtime}. Use: {', '.join(RUNTIME_PRESETS.keys())}, custom")
|
||||
|
||||
# Resolve auth token
|
||||
self._auth_token = self._resolve_auth_token()
|
||||
self._auth_helper_path: str | None = None
|
||||
self._temp_files: list[str] = [] # Track temp files for cleanup
|
||||
|
||||
if self._auth_token and self.preset.get("auth_pattern") == "apiKeyHelper":
|
||||
self._auth_helper_path = self._create_auth_helper(self._auth_token)
|
||||
|
||||
# Create MCP config once (reuse across invocations)
|
||||
self._mcp_config_path: str | None = None
|
||||
if self.preset.get("auth_pattern") in ("apiKeyHelper", "env"):
|
||||
mcp_config = json.dumps({
|
||||
"mcpServers": {
|
||||
"a2a": {"command": sys.executable, "args": [get_mcp_server_path()]}
|
||||
}
|
||||
})
|
||||
fd, self._mcp_config_path = tempfile.mkstemp(suffix=".json", prefix="a2a-mcp-")
|
||||
self._temp_files.append(self._mcp_config_path) # Track immediately
|
||||
os.close(fd)
|
||||
with open(self._mcp_config_path, "w") as f:
|
||||
f.write(mcp_config)
|
||||
|
||||
# Register cleanup for reliable temp file removal (atexit is more reliable than __del__)
|
||||
atexit.register(self._cleanup_temp_files)
|
||||
|
||||
# Verify command exists
|
||||
cmd = self.config.command or self.preset["command"]
|
||||
if not shutil.which(cmd):
|
||||
logger.warning(f"CLI command '{cmd}' not found in PATH")
|
||||
|
||||
def _resolve_auth_token(self) -> str | None:
|
||||
"""Resolve auth token from env var or file.
|
||||
|
||||
Resolution order:
|
||||
1. required_env — first entry that exists in the environment
|
||||
2. auth_token_env (deprecated) — explicit env var name
|
||||
3. Preset default_auth_env — adapter-declared fallback
|
||||
4. auth_token_file (deprecated) — file on disk
|
||||
5. Preset default_auth_file — adapter-declared file fallback
|
||||
"""
|
||||
# 1. New path: required_env (first match wins)
|
||||
for env_name in (self.config.required_env or []):
|
||||
token = os.environ.get(env_name)
|
||||
if token:
|
||||
return token
|
||||
|
||||
# 2. Legacy: explicit env var from config
|
||||
env_name = self.config.auth_token_env or self.preset.get("default_auth_env", "")
|
||||
if env_name:
|
||||
token = os.environ.get(env_name)
|
||||
if token:
|
||||
return token
|
||||
|
||||
# 3. Legacy: token file from config
|
||||
file_name = self.config.auth_token_file or self.preset.get("default_auth_file", "")
|
||||
if file_name:
|
||||
token_path = Path(self.config_path) / file_name
|
||||
if token_path.exists():
|
||||
return token_path.read_text().strip()
|
||||
|
||||
return None
|
||||
|
||||
def _create_auth_helper(self, token: str) -> str:
|
||||
"""Create a shell script that outputs the auth token (for apiKeyHelper pattern)."""
|
||||
fd, helper_path = tempfile.mkstemp(suffix=".sh", prefix="agent-auth-")
|
||||
self._temp_files.append(helper_path) # Track immediately before any exception can leak
|
||||
os.close(fd)
|
||||
with open(helper_path, "w") as f:
|
||||
f.write(f"#!/bin/sh\necho {shlex.quote(token)}\n")
|
||||
os.chmod(helper_path, 0o700)
|
||||
return helper_path
|
||||
|
||||
def _build_command(self, message: str) -> list[str]:
|
||||
"""Build the full CLI command from preset + config + message."""
|
||||
cmd = self.config.command or self.preset["command"]
|
||||
args = list(self.preset.get("base_args", []))
|
||||
|
||||
# Model
|
||||
model = self.config.model or None
|
||||
model_flag = self.preset.get("model_flag")
|
||||
if model and model_flag:
|
||||
args.extend([model_flag, model])
|
||||
elif model and self.runtime == "ollama":
|
||||
# Ollama: model is positional after "run"
|
||||
args.append(model)
|
||||
|
||||
# System prompt (+ A2A instructions). The remaining CLI runtimes don't
|
||||
# support session resume, so we inject the system prompt on every call.
|
||||
system_prompt = get_system_prompt(self.config_path, fallback=self.system_prompt) or ""
|
||||
mcp_capable = self.preset.get("auth_pattern") in ("apiKeyHelper", "env")
|
||||
a2a_instructions = get_a2a_instructions(mcp=mcp_capable)
|
||||
if a2a_instructions:
|
||||
system_prompt = (
|
||||
f"{system_prompt}\n\n{a2a_instructions}" if system_prompt else a2a_instructions
|
||||
)
|
||||
system_flag = self.preset.get("system_prompt_flag")
|
||||
if system_prompt and system_flag:
|
||||
args.extend([system_flag, system_prompt])
|
||||
|
||||
# Auth (apiKeyHelper pattern — reserved for future CLI runtimes)
|
||||
if self._auth_helper_path and self.preset.get("auth_pattern") == "apiKeyHelper":
|
||||
settings = json.dumps({"apiKeyHelper": self._auth_helper_path})
|
||||
args.extend(["--settings", settings])
|
||||
|
||||
# A2A MCP server — inject for MCP-compatible runtimes (created once in __init__).
|
||||
# Runtimes that declare `mcp_via_settings: True` (e.g. gemini-cli) wire MCP
|
||||
# through their own settings file (adapter.setup()) instead of --mcp-config.
|
||||
if self._mcp_config_path and not self.preset.get("mcp_via_settings"):
|
||||
args.extend(["--mcp-config", self._mcp_config_path])
|
||||
|
||||
# Extra args from config (before prompt so flags are parsed correctly)
|
||||
args.extend(self.config.args)
|
||||
|
||||
# Prompt (must be last — some CLIs treat final arg as the prompt)
|
||||
prompt_flag = self.preset.get("prompt_flag")
|
||||
if prompt_flag:
|
||||
args.extend([prompt_flag, message])
|
||||
else:
|
||||
# Positional prompt (ollama)
|
||||
args.append(message)
|
||||
|
||||
return [cmd] + args
|
||||
|
||||
async def execute(self, context: RequestContext, event_queue: EventQueue):
|
||||
"""Execute a task by invoking the CLI agent."""
|
||||
user_input = extract_message_text(context.message)
|
||||
if not user_input:
|
||||
await event_queue.enqueue_event(
|
||||
new_agent_text_message("Error: message contained no text content.")
|
||||
)
|
||||
return
|
||||
|
||||
# Keep a clean copy of the user's actual message for memory BEFORE any
|
||||
# delegation or memory injection happens.
|
||||
original_input = user_input
|
||||
|
||||
# Show current task on canvas — extract a brief one-line summary
|
||||
await set_current_task(self._heartbeat, brief_summary(user_input))
|
||||
|
||||
logger.debug("CLI execute [%s]: %s", self.runtime, user_input[:200])
|
||||
|
||||
# Inject delegation results that arrived since last message
|
||||
delegation_context = read_delegation_results()
|
||||
if delegation_context:
|
||||
user_input = f"[Delegation results received while you were idle]\n{delegation_context}\n\n[New message]\n{user_input}"
|
||||
|
||||
# Auto-recall: inject prior memories into every prompt. (The CLI
|
||||
# runtimes don't keep a session, so there's no "first turn" concept.)
|
||||
memories = await recall_memories()
|
||||
if memories:
|
||||
user_input = f"[Prior context from memory]\n{memories}\n\n{user_input}"
|
||||
|
||||
try:
|
||||
await self._run_cli(user_input, event_queue)
|
||||
finally:
|
||||
await set_current_task(self._heartbeat, "")
|
||||
# Auto-commit: save the original user request (not the memory-injected version)
|
||||
await commit_memory(
|
||||
f"Conversation: {original_input[:MEMORY_CONTENT_MAX_CHARS]}"
|
||||
)
|
||||
|
||||
async def _run_cli(self, user_input: str, event_queue: EventQueue):
|
||||
"""Run the CLI subprocess and enqueue the result."""
|
||||
cmd = self._build_command(user_input)
|
||||
timeout = self.config.timeout or None # None = no timeout (wait until agent finishes)
|
||||
max_retries = 3
|
||||
base_delay = 5 # seconds
|
||||
|
||||
# Build env — pass through auth env var if using env pattern
|
||||
env = dict(os.environ)
|
||||
if self._auth_token and self.preset.get("auth_pattern") == "env":
|
||||
# Use first required_env entry, or fall back to legacy auth_token_env
|
||||
auth_env = (self.config.required_env or [None])[0] if self.config.required_env else None
|
||||
auth_env = auth_env or self.config.auth_token_env or self.preset.get("default_auth_env", "")
|
||||
if auth_env:
|
||||
env[auth_env] = self._auth_token
|
||||
|
||||
for attempt in range(max_retries):
|
||||
proc = None
|
||||
try:
|
||||
# Run in /workspace if it exists and has content (cloned repo),
|
||||
# otherwise /configs (agent config files)
|
||||
cwd = (
|
||||
WORKSPACE_MOUNT
|
||||
if os.path.isdir(WORKSPACE_MOUNT) and os.listdir(WORKSPACE_MOUNT)
|
||||
else CONFIG_MOUNT
|
||||
)
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
env=env,
|
||||
cwd=cwd,
|
||||
)
|
||||
if timeout:
|
||||
stdout, stderr = await asyncio.wait_for(
|
||||
proc.communicate(), timeout=timeout
|
||||
)
|
||||
else:
|
||||
stdout, stderr = await proc.communicate()
|
||||
|
||||
stdout_text = stdout.decode().strip()
|
||||
stderr_text = stderr.decode().strip()
|
||||
|
||||
if proc.returncode != 0:
|
||||
logger.error("CLI agent [%s] exit=%d stdout=%s stderr=%s",
|
||||
self.runtime, proc.returncode,
|
||||
stdout_text[:200] if stdout_text else "(empty)",
|
||||
stderr_text[:500] if stderr_text else "(empty)")
|
||||
|
||||
if proc.returncode == 0 or stdout_text:
|
||||
# Success, or non-zero exit but produced output (some CLIs exit 1 with valid output)
|
||||
result = stdout_text
|
||||
if result:
|
||||
await event_queue.enqueue_event(
|
||||
new_agent_text_message(result)
|
||||
)
|
||||
return
|
||||
else:
|
||||
# Empty response — likely rate limited, retry with backoff
|
||||
if attempt < max_retries - 1:
|
||||
delay = base_delay * (2 ** attempt)
|
||||
logger.warning("CLI agent [%s] returned empty (attempt %d/%d), retrying in %ds",
|
||||
self.runtime, attempt + 1, max_retries, delay)
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
await event_queue.enqueue_event(
|
||||
new_agent_text_message("(no response generated after retries)")
|
||||
)
|
||||
return
|
||||
else:
|
||||
error_msg = stderr_text or f"Exit code {proc.returncode}"
|
||||
# Classify once — used both for retry policy and the
|
||||
# sanitized user-facing error message.
|
||||
category = classify_subprocess_error(error_msg, proc.returncode)
|
||||
if category in ("rate_limited", "session_error", "auth_failed") \
|
||||
and attempt < max_retries - 1:
|
||||
delay = base_delay * (2 ** attempt)
|
||||
logger.warning(
|
||||
"CLI agent [%s] %s (attempt %d/%d), retrying in %ds",
|
||||
self.runtime, category, attempt + 1, max_retries, delay,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
|
||||
# Log the full stderr (may contain paths/tokens); surface
|
||||
# only the sanitized category to the user.
|
||||
logger.error("CLI agent error [%s]: %s", self.runtime, error_msg[:500])
|
||||
await event_queue.enqueue_event(
|
||||
new_agent_text_message(sanitize_agent_error(category=category))
|
||||
)
|
||||
return
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("CLI agent timeout [%s] after %ds", self.runtime, timeout)
|
||||
if proc:
|
||||
# Kill and reap the process to prevent zombies
|
||||
try:
|
||||
proc.kill()
|
||||
except ProcessLookupError:
|
||||
pass # already exited
|
||||
except Exception as kill_err:
|
||||
logger.warning("CLI kill error: %s", kill_err)
|
||||
# Always await wait() to reap zombie, even if kill failed
|
||||
try:
|
||||
await asyncio.wait_for(proc.wait(), timeout=5)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("CLI agent: proc.wait() also timed out — possible zombie")
|
||||
except Exception as wait_err:
|
||||
logger.warning("CLI wait error: %s", wait_err)
|
||||
await event_queue.enqueue_event(
|
||||
new_agent_text_message(sanitize_agent_error(category="timeout"))
|
||||
)
|
||||
return
|
||||
except Exception as exc:
|
||||
logger.exception("CLI agent exception [%s]", self.runtime)
|
||||
await event_queue.enqueue_event(
|
||||
new_agent_text_message(sanitize_agent_error(exc))
|
||||
)
|
||||
return
|
||||
|
||||
def _cleanup_temp_files(self): # pragma: no cover
|
||||
"""Clean up temp files. Called via atexit for reliable cleanup."""
|
||||
for f in self._temp_files:
|
||||
try:
|
||||
os.unlink(f)
|
||||
except OSError:
|
||||
pass
|
||||
if self._auth_helper_path:
|
||||
try:
|
||||
os.unlink(self._auth_helper_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def __del__(self): # pragma: no cover
|
||||
"""Clean up temp files (fallback — prefer atexit-registered _cleanup_temp_files)."""
|
||||
for f in getattr(self, "_temp_files", []):
|
||||
try:
|
||||
os.unlink(f)
|
||||
except OSError:
|
||||
pass
|
||||
if getattr(self, "_auth_helper_path", None):
|
||||
try:
|
||||
os.unlink(self._auth_helper_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
async def cancel(self, context: RequestContext, event_queue: EventQueue): # pragma: no cover
|
||||
"""Cancel a running task."""
|
||||
pass
|
||||
349
molecule_runtime/config.py
Normal file
349
molecule_runtime/config.py
Normal file
@ -0,0 +1,349 @@
|
||||
"""Load workspace configuration from config.yaml."""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
@dataclass
|
||||
class RBACConfig:
|
||||
"""Role-based access control settings for this workspace.
|
||||
|
||||
``roles`` declares what this workspace is *allowed* to do. Each role
|
||||
name maps to a set of permitted actions. Built-in roles are defined in
|
||||
``tools/audit.ROLE_PERMISSIONS``; custom roles can be added via
|
||||
``allowed_actions``.
|
||||
|
||||
Built-in roles
|
||||
--------------
|
||||
admin All actions (delegate, approve, memory.read, memory.write)
|
||||
operator Same as admin — standard agent role (default)
|
||||
read-only memory.read only
|
||||
no-delegation approve + memory.read + memory.write
|
||||
no-approval delegate + memory.read + memory.write
|
||||
memory-readonly memory.read only
|
||||
|
||||
Example config.yaml snippet::
|
||||
|
||||
rbac:
|
||||
roles:
|
||||
- operator
|
||||
allowed_actions:
|
||||
analyst:
|
||||
- memory.read
|
||||
- memory.write
|
||||
"""
|
||||
|
||||
roles: list[str] = field(default_factory=lambda: ["operator"])
|
||||
"""List of role names granted to this workspace."""
|
||||
|
||||
allowed_actions: dict[str, list[str]] = field(default_factory=dict)
|
||||
"""Custom role → [action, ...] overrides. Takes precedence over built-ins."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class HITLConfig:
|
||||
"""Human-In-The-Loop settings loaded from the ``hitl:`` block in config.yaml.
|
||||
|
||||
Example config.yaml snippet::
|
||||
|
||||
hitl:
|
||||
channels:
|
||||
- type: dashboard # always active
|
||||
- type: slack
|
||||
webhook_url: https://hooks.slack.com/services/…
|
||||
- type: email
|
||||
smtp_host: smtp.example.com
|
||||
from: alerts@example.com
|
||||
to: ops@example.com
|
||||
default_timeout: 300 # seconds
|
||||
bypass_roles: [admin]
|
||||
"""
|
||||
channels: list[dict] = field(default_factory=lambda: [{"type": "dashboard"}])
|
||||
default_timeout: float = 300.0
|
||||
bypass_roles: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DelegationConfig:
|
||||
retry_attempts: int = 3
|
||||
retry_delay: float = 5.0
|
||||
timeout: float = 120.0
|
||||
escalate: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class A2AConfig:
|
||||
port: int = 8000
|
||||
streaming: bool = True
|
||||
push_notifications: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class SandboxConfig:
|
||||
backend: str = "subprocess" # subprocess | docker
|
||||
memory_limit: str = "256m"
|
||||
timeout: int = 30
|
||||
|
||||
@dataclass
|
||||
class RuntimeConfig:
|
||||
"""Configuration for CLI-based agent runtimes (claude-code, codex, ollama, custom)."""
|
||||
command: str = "" # e.g. "claude", "codex", "ollama" (model goes in model field)
|
||||
args: list[str] = field(default_factory=list) # additional CLI args
|
||||
required_env: list[str] = field(default_factory=list) # env vars required to run (e.g. ["CLAUDE_CODE_OAUTH_TOKEN"])
|
||||
timeout: int = 0 # seconds (0 = no timeout — agents wait until done)
|
||||
model: str = "" # model override for the CLI
|
||||
# Deprecated — use required_env + secrets API instead. Kept for backward compat.
|
||||
auth_token_env: str = ""
|
||||
auth_token_file: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class GovernanceConfig:
|
||||
"""Microsoft Agent Governance Toolkit integration settings.
|
||||
|
||||
When ``enabled`` is True, Molecule AI's RBAC and audit trail are bridged
|
||||
to the Agent Governance Toolkit (agent-os-kernel) for policy evaluation.
|
||||
|
||||
``toolkit`` is reserved for future extensibility — only ``"microsoft"``
|
||||
is supported today.
|
||||
|
||||
``policy_mode`` controls enforcement:
|
||||
strict RBAC *and* toolkit policy must both allow — strictest mode
|
||||
permissive RBAC must allow; toolkit denials are logged but not enforced
|
||||
audit RBAC only; toolkit evaluated and logged but never blocks
|
||||
|
||||
``policy_file`` path to a Rego (.rego), YAML (.yaml/.yml), or Cedar
|
||||
(.cedar) policy file, loaded into the PolicyEvaluator at startup.
|
||||
|
||||
``blocked_patterns`` is a list of regex patterns that the toolkit will
|
||||
always deny regardless of roles or policy.
|
||||
"""
|
||||
|
||||
enabled: bool = False
|
||||
toolkit: str = "microsoft"
|
||||
policy_endpoint: str = ""
|
||||
policy_mode: str = "audit" # strict | permissive | audit
|
||||
policy_file: str = ""
|
||||
blocked_patterns: list[str] = field(default_factory=list)
|
||||
max_tool_calls_per_task: int = 50
|
||||
|
||||
|
||||
@dataclass
|
||||
class SecurityScanConfig:
|
||||
"""Skill dependency security scanning settings.
|
||||
|
||||
``mode`` controls what happens when critical/high CVEs are found:
|
||||
|
||||
block — raise ``SkillSecurityError``; the skill is NOT loaded.
|
||||
warn — emit a WARNING + audit event; the skill is loaded anyway (default).
|
||||
off — skip scanning entirely (air-gapped or CI environments).
|
||||
|
||||
Scanners tried in order: Snyk CLI (requires ``SNYK_TOKEN``), then
|
||||
pip-audit. If neither is available the scan is silently skipped.
|
||||
|
||||
Example config.yaml snippet::
|
||||
|
||||
security_scan: warn # shorthand string form
|
||||
# or verbose form:
|
||||
security_scan:
|
||||
mode: block
|
||||
"""
|
||||
|
||||
mode: str = "warn"
|
||||
"""One of: block | warn | off."""
|
||||
|
||||
fail_open_if_no_scanner: bool = True
|
||||
"""When True (default), silently skip scanning if no scanner (snyk/pip-audit)
|
||||
is in PATH. When False and mode='block', raise SkillSecurityError so that
|
||||
operators who require a CVE gate know the gate is absent. Closes #268."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComplianceConfig:
|
||||
"""OWASP Top 10 for Agentic Applications compliance settings.
|
||||
|
||||
Set ``mode: owasp_agentic`` to enable all checks. When ``mode`` is
|
||||
empty or absent the compliance layer is a complete no-op.
|
||||
|
||||
Example config.yaml snippet::
|
||||
|
||||
compliance:
|
||||
mode: owasp_agentic
|
||||
prompt_injection: block # detect | block (default: detect)
|
||||
max_tool_calls_per_task: 30
|
||||
max_task_duration_seconds: 180
|
||||
"""
|
||||
|
||||
mode: str = ""
|
||||
"""Enable compliance mode. Set to ``owasp_agentic`` to activate."""
|
||||
|
||||
prompt_injection: str = "detect"
|
||||
"""``detect`` logs injection attempts; ``block`` raises PromptInjectionError."""
|
||||
|
||||
max_tool_calls_per_task: int = 50
|
||||
"""Maximum number of tool invocations per task before ExcessiveAgencyError."""
|
||||
|
||||
max_task_duration_seconds: int = 300
|
||||
"""Maximum wall-clock seconds per task before ExcessiveAgencyError."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkspaceConfig:
|
||||
name: str = "Workspace"
|
||||
description: str = ""
|
||||
version: str = "1.0.0"
|
||||
tier: int = 1
|
||||
model: str = "anthropic:claude-sonnet-4-6"
|
||||
runtime: str = "langgraph" # langgraph | claude-code | codex | ollama | custom
|
||||
runtime_config: RuntimeConfig = field(default_factory=RuntimeConfig)
|
||||
initial_prompt: str = ""
|
||||
"""Auto-sent as the first A2A message after startup. Default empty = no auto-message.
|
||||
Can be an inline string or a file reference (initial_prompt_file in yaml)."""
|
||||
idle_prompt: str = ""
|
||||
"""Auto-sent every `idle_interval_seconds` while the workspace has no active
|
||||
task (heartbeat.active_tasks == 0). Default empty = no idle loop. This is
|
||||
the reflection-on-completion / backlog-pull pattern from the Hermes/Letta
|
||||
playbook: the workspace self-wakes when idle, runs a lightweight reflection
|
||||
prompt, and either picks up queued work or stops. Cost scales with useful
|
||||
activity (the prompt returns quickly if there's nothing to do). Can be
|
||||
inline or a file reference via `idle_prompt_file`."""
|
||||
idle_interval_seconds: int = 600
|
||||
"""How often the idle loop checks in (seconds). Default 600 (10 min).
|
||||
Ignored when idle_prompt is empty."""
|
||||
skills: list[str] = field(default_factory=list)
|
||||
plugins: list[str] = field(default_factory=list) # installed plugin names
|
||||
tools: list[str] = field(default_factory=list)
|
||||
prompt_files: list[str] = field(default_factory=list)
|
||||
shared_context: list[str] = field(default_factory=list)
|
||||
a2a: A2AConfig = field(default_factory=A2AConfig)
|
||||
delegation: DelegationConfig = field(default_factory=DelegationConfig)
|
||||
sandbox: SandboxConfig = field(default_factory=SandboxConfig)
|
||||
rbac: RBACConfig = field(default_factory=RBACConfig)
|
||||
hitl: HITLConfig = field(default_factory=HITLConfig)
|
||||
governance: GovernanceConfig = field(default_factory=GovernanceConfig)
|
||||
security_scan: SecurityScanConfig = field(default_factory=SecurityScanConfig)
|
||||
compliance: ComplianceConfig = field(default_factory=ComplianceConfig)
|
||||
sub_workspaces: list[dict] = field(default_factory=list)
|
||||
|
||||
|
||||
def load_config(config_path: Optional[str] = None) -> WorkspaceConfig:
|
||||
"""Load config from WORKSPACE_CONFIG_PATH or the given path."""
|
||||
if config_path is None:
|
||||
config_path = os.environ.get("WORKSPACE_CONFIG_PATH", "/configs")
|
||||
|
||||
config_file = Path(config_path) / "config.yaml"
|
||||
if not config_file.exists():
|
||||
raise FileNotFoundError(f"Config file not found: {config_file}")
|
||||
|
||||
with open(config_file) as f:
|
||||
raw = yaml.safe_load(f) or {}
|
||||
|
||||
# Override model from env if provided
|
||||
model = os.environ.get("MODEL_PROVIDER", raw.get("model", "anthropic:claude-sonnet-4-6"))
|
||||
|
||||
runtime = raw.get("runtime", "langgraph")
|
||||
runtime_raw = raw.get("runtime_config", {})
|
||||
|
||||
a2a_raw = raw.get("a2a", {})
|
||||
delegation_raw = raw.get("delegation", {})
|
||||
sandbox_raw = raw.get("sandbox", {})
|
||||
rbac_raw = raw.get("rbac", {})
|
||||
hitl_raw = raw.get("hitl", {})
|
||||
governance_raw = raw.get("governance", {})
|
||||
# security_scan accepts both shorthand string ("warn") and dict ({"mode": "warn"})
|
||||
_ss_raw = raw.get("security_scan", {})
|
||||
security_scan_raw = _ss_raw if isinstance(_ss_raw, dict) else {"mode": str(_ss_raw)}
|
||||
compliance_raw = raw.get("compliance", {})
|
||||
|
||||
# Resolve initial_prompt: inline string or file reference
|
||||
initial_prompt = raw.get("initial_prompt", "")
|
||||
initial_prompt_file = raw.get("initial_prompt_file", "")
|
||||
if not initial_prompt and initial_prompt_file:
|
||||
prompt_path = Path(config_path) / initial_prompt_file
|
||||
if prompt_path.exists():
|
||||
initial_prompt = prompt_path.read_text().strip()
|
||||
|
||||
# Resolve idle_prompt: same pattern as initial_prompt
|
||||
idle_prompt = raw.get("idle_prompt", "")
|
||||
idle_prompt_file = raw.get("idle_prompt_file", "")
|
||||
if not idle_prompt and idle_prompt_file:
|
||||
idle_path = Path(config_path) / idle_prompt_file
|
||||
if idle_path.exists():
|
||||
idle_prompt = idle_path.read_text().strip()
|
||||
idle_interval_seconds = int(raw.get("idle_interval_seconds", 600))
|
||||
|
||||
return WorkspaceConfig(
|
||||
name=raw.get("name", "Workspace"),
|
||||
description=raw.get("description", ""),
|
||||
version=raw.get("version", "1.0.0"),
|
||||
tier=int(raw.get("tier", 1)) if str(raw.get("tier", 1)).isdigit() else 1,
|
||||
model=model,
|
||||
runtime=runtime,
|
||||
initial_prompt=initial_prompt,
|
||||
idle_prompt=idle_prompt,
|
||||
idle_interval_seconds=idle_interval_seconds,
|
||||
runtime_config=RuntimeConfig(
|
||||
command=runtime_raw.get("command", ""),
|
||||
args=runtime_raw.get("args", []),
|
||||
required_env=runtime_raw.get("required_env", []),
|
||||
timeout=runtime_raw.get("timeout", 0),
|
||||
model=runtime_raw.get("model", ""),
|
||||
# Deprecated fields — kept for backward compat
|
||||
auth_token_env=runtime_raw.get("auth_token_env", ""),
|
||||
auth_token_file=runtime_raw.get("auth_token_file", ""),
|
||||
),
|
||||
skills=raw.get("skills", []),
|
||||
plugins=raw.get("plugins", []),
|
||||
tools=raw.get("tools", []),
|
||||
prompt_files=raw.get("prompt_files", []),
|
||||
shared_context=raw.get("shared_context", []),
|
||||
a2a=A2AConfig(
|
||||
port=a2a_raw.get("port", 8000),
|
||||
streaming=a2a_raw.get("streaming", True),
|
||||
push_notifications=a2a_raw.get("push_notifications", True),
|
||||
),
|
||||
delegation=DelegationConfig(
|
||||
retry_attempts=delegation_raw.get("retry_attempts", 3),
|
||||
retry_delay=delegation_raw.get("retry_delay", 5.0),
|
||||
timeout=delegation_raw.get("timeout", 120.0),
|
||||
escalate=delegation_raw.get("escalate", True),
|
||||
),
|
||||
sandbox=SandboxConfig(
|
||||
backend=sandbox_raw.get("backend", "subprocess"),
|
||||
memory_limit=sandbox_raw.get("memory_limit", "256m"),
|
||||
timeout=sandbox_raw.get("timeout", 30),
|
||||
),
|
||||
rbac=RBACConfig(
|
||||
roles=rbac_raw.get("roles", ["operator"]),
|
||||
allowed_actions=rbac_raw.get("allowed_actions", {}),
|
||||
),
|
||||
hitl=HITLConfig(
|
||||
channels=hitl_raw.get("channels", [{"type": "dashboard"}]),
|
||||
default_timeout=float(hitl_raw.get("default_timeout", 300)),
|
||||
bypass_roles=hitl_raw.get("bypass_roles", []),
|
||||
),
|
||||
governance=GovernanceConfig(
|
||||
enabled=governance_raw.get("enabled", False),
|
||||
toolkit=governance_raw.get("toolkit", "microsoft"),
|
||||
policy_endpoint=governance_raw.get("policy_endpoint", ""),
|
||||
policy_mode=governance_raw.get("policy_mode", "audit"),
|
||||
policy_file=governance_raw.get("policy_file", ""),
|
||||
blocked_patterns=governance_raw.get("blocked_patterns", []),
|
||||
max_tool_calls_per_task=governance_raw.get("max_tool_calls_per_task", 50),
|
||||
),
|
||||
security_scan=SecurityScanConfig(
|
||||
mode=security_scan_raw.get("mode", "warn"),
|
||||
fail_open_if_no_scanner=security_scan_raw.get("fail_open_if_no_scanner", True),
|
||||
),
|
||||
compliance=ComplianceConfig(
|
||||
mode=compliance_raw.get("mode", ""),
|
||||
prompt_injection=compliance_raw.get("prompt_injection", "detect"),
|
||||
max_tool_calls_per_task=int(compliance_raw.get("max_tool_calls_per_task", 50)),
|
||||
max_task_duration_seconds=int(compliance_raw.get("max_task_duration_seconds", 300)),
|
||||
),
|
||||
sub_workspaces=raw.get("sub_workspaces", []),
|
||||
)
|
||||
131
molecule_runtime/consolidation.py
Normal file
131
molecule_runtime/consolidation.py
Normal file
@ -0,0 +1,131 @@
|
||||
"""Memory consolidation loop.
|
||||
|
||||
When an agent is idle (no active tasks for a configurable period),
|
||||
the consolidation loop wakes up and summarizes noisy local memory
|
||||
entries into dense, high-value knowledge facts.
|
||||
|
||||
Similar to human sleep consolidation — raw scratchpad entries get
|
||||
compressed into reusable knowledge.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
|
||||
import httpx
|
||||
|
||||
from platform_auth import auth_headers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PLATFORM_URL = os.environ.get("PLATFORM_URL", "http://platform:8080")
|
||||
WORKSPACE_ID = os.environ.get("WORKSPACE_ID", "")
|
||||
CONSOLIDATION_INTERVAL = float(os.environ.get("CONSOLIDATION_INTERVAL", "300")) # 5 min
|
||||
CONSOLIDATION_THRESHOLD = int(os.environ.get("CONSOLIDATION_THRESHOLD", "10")) # min memories before consolidating
|
||||
|
||||
|
||||
class ConsolidationLoop:
|
||||
"""Background loop that consolidates local memories when idle."""
|
||||
|
||||
def __init__(self, agent=None):
|
||||
self.agent = agent
|
||||
self._running = False
|
||||
|
||||
async def start(self):
|
||||
"""Start the consolidation loop."""
|
||||
self._running = True
|
||||
logger.info("Memory consolidation loop started (interval=%ss, threshold=%d)",
|
||||
CONSOLIDATION_INTERVAL, CONSOLIDATION_THRESHOLD)
|
||||
|
||||
while self._running:
|
||||
await asyncio.sleep(CONSOLIDATION_INTERVAL)
|
||||
|
||||
if not self._running:
|
||||
break
|
||||
|
||||
try:
|
||||
await self._consolidate()
|
||||
except Exception as e:
|
||||
logger.warning("Consolidation error: %s", e)
|
||||
|
||||
async def _consolidate(self):
|
||||
"""Check if consolidation is needed and run it."""
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
# Fetch local memories
|
||||
resp = await client.get(
|
||||
f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/memories",
|
||||
params={"scope": "LOCAL"},
|
||||
headers=auth_headers(),
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return
|
||||
|
||||
memories = resp.json()
|
||||
if len(memories) < CONSOLIDATION_THRESHOLD:
|
||||
return
|
||||
|
||||
logger.info("Consolidating %d local memories", len(memories))
|
||||
|
||||
# Build a summary of all local memories
|
||||
contents = [m["content"] for m in memories]
|
||||
summary_prompt = (
|
||||
"Summarize the following workspace memories into 3-5 key facts. "
|
||||
"Each fact should be a single, clear sentence capturing the most "
|
||||
"important and reusable knowledge:\n\n"
|
||||
+ "\n".join(f"- {c}" for c in contents)
|
||||
)
|
||||
|
||||
# Use the agent to generate the summary if available
|
||||
summary = ""
|
||||
if self.agent:
|
||||
try:
|
||||
result = await self.agent.ainvoke(
|
||||
{"messages": [("user", summary_prompt)]},
|
||||
config={"configurable": {"thread_id": "consolidation"}},
|
||||
)
|
||||
messages = result.get("messages", [])
|
||||
summary = ""
|
||||
for msg in reversed(messages):
|
||||
content = getattr(msg, "content", "")
|
||||
if isinstance(content, str) and content.strip():
|
||||
msg_type = getattr(msg, "type", "")
|
||||
if msg_type != "human":
|
||||
summary = content
|
||||
break
|
||||
|
||||
if summary:
|
||||
# Store consolidated summary as a TEAM memory — only delete originals if POST succeeds
|
||||
resp = await client.post(
|
||||
f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/memories",
|
||||
json={"content": f"[Consolidated] {summary}", "scope": "TEAM"},
|
||||
headers=auth_headers(),
|
||||
)
|
||||
if resp.status_code in (200, 201):
|
||||
# Safe to delete originals — consolidated version is saved
|
||||
for m in memories:
|
||||
await client.delete(
|
||||
f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/memories/{m['id']}",
|
||||
headers=auth_headers(),
|
||||
)
|
||||
logger.info("Consolidated %d memories into team knowledge", len(memories))
|
||||
else:
|
||||
logger.warning("Consolidation POST failed (status %d) — keeping originals", resp.status_code)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"CONSOLIDATION: Agent summarization failed (rate limit? model error?): %s. "
|
||||
"Falling back to simple concatenation.", e
|
||||
)
|
||||
# Fall through to concatenation below
|
||||
|
||||
# Fallback: concatenate without agent summarization
|
||||
if not (self.agent and summary):
|
||||
combined = " | ".join(contents[:20])
|
||||
await client.post(
|
||||
f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/memories",
|
||||
json={"content": f"[Consolidated] {combined}", "scope": "TEAM"},
|
||||
headers=auth_headers(),
|
||||
)
|
||||
logger.info("Consolidated %d memories via concatenation fallback", len(memories))
|
||||
|
||||
def stop(self):
|
||||
self._running = False
|
||||
136
molecule_runtime/coordinator.py
Normal file
136
molecule_runtime/coordinator.py
Normal file
@ -0,0 +1,136 @@
|
||||
"""Coordinator pattern for team workspaces.
|
||||
|
||||
When a workspace is expanded into a team, the parent agent becomes a
|
||||
coordinator that routes incoming tasks to the appropriate child workspace
|
||||
based on the task content and children's capabilities.
|
||||
|
||||
The coordinator:
|
||||
1. Fetches its children's Agent Cards (skills, capabilities)
|
||||
2. Analyzes each incoming task to determine which child is best suited
|
||||
3. Delegates to the chosen child via the delegation tool
|
||||
4. Aggregates responses if a task requires multiple children
|
||||
5. Falls back to handling the task itself if no child is appropriate
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import httpx
|
||||
from langchain_core.tools import tool
|
||||
from adapters.shared_runtime import build_peer_section
|
||||
from policies.routing import build_team_routing_payload
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PLATFORM_URL = os.environ.get("PLATFORM_URL", "http://platform:8080")
|
||||
WORKSPACE_ID = os.environ.get("WORKSPACE_ID", "")
|
||||
|
||||
|
||||
async def get_parent_context() -> list[dict]:
|
||||
"""Fetch shared context files from this workspace's parent.
|
||||
|
||||
Returns a list of {"path": str, "content": str} dicts.
|
||||
Returns empty list if no parent, parent unreachable, or no shared context.
|
||||
"""
|
||||
parent_id = os.environ.get("PARENT_ID", "")
|
||||
if not parent_id:
|
||||
return []
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.get(
|
||||
f"{PLATFORM_URL}/workspaces/{parent_id}/shared-context",
|
||||
headers={"X-Workspace-ID": WORKSPACE_ID},
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
return resp.json()
|
||||
except Exception as e:
|
||||
logger.warning("Failed to fetch parent context: %s", e)
|
||||
return []
|
||||
|
||||
|
||||
async def get_children() -> list[dict]:
|
||||
"""Fetch this workspace's children from the platform."""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.get(
|
||||
f"{PLATFORM_URL}/registry/{WORKSPACE_ID}/peers",
|
||||
headers={"X-Workspace-ID": WORKSPACE_ID},
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
peers = resp.json()
|
||||
# Filter to only children (parent_id == our ID)
|
||||
return [p for p in peers if p.get("parent_id") == WORKSPACE_ID]
|
||||
except Exception as e:
|
||||
logger.warning("Failed to fetch children: %s", e)
|
||||
return []
|
||||
|
||||
|
||||
def build_children_description(children: list[dict]) -> str:
|
||||
"""Build a description of children's capabilities for the coordinator prompt."""
|
||||
if not children:
|
||||
return ""
|
||||
|
||||
team_section = build_peer_section(
|
||||
children,
|
||||
heading="## Your Team (sub-workspaces you coordinate)",
|
||||
instruction=(
|
||||
"Use the `delegate_to_workspace` tool to send tasks to the chosen member. "
|
||||
"Only delegate to members listed above."
|
||||
),
|
||||
)
|
||||
|
||||
return "\n".join(
|
||||
[
|
||||
team_section,
|
||||
"",
|
||||
"### Coordination Rules — MANDATORY",
|
||||
"1. You are a COORDINATOR. Your ONLY job is to delegate and synthesize. NEVER do the work yourself.",
|
||||
"2. For EVERY task, use `delegate_to_workspace` to send it to the appropriate team member(s). "
|
||||
"Do this BEFORE writing any analysis, code, or research yourself.",
|
||||
"3. If a task spans multiple members, delegate to ALL of them in parallel and aggregate results.",
|
||||
"4. If ALL members are offline/paused, tell the caller which members are unavailable. "
|
||||
"Do NOT attempt the work yourself — you lack the specialist context.",
|
||||
"5. If a delegation FAILS (error, timeout): try another member first. "
|
||||
"Only provide your own brief summary if NO member can respond. Never forward raw errors.",
|
||||
"6. Your response should be a SYNTHESIS of your team's work, not your own analysis.",
|
||||
"7. Always respond in the same language the caller uses.",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
async def route_task_to_team(
|
||||
task: str,
|
||||
preferred_member_id: str = "",
|
||||
) -> dict:
|
||||
"""Route a task to the most appropriate team member.
|
||||
|
||||
As the team coordinator, analyze the task and delegate to the best-suited
|
||||
child workspace. If preferred_member_id is provided, delegate directly to
|
||||
that member.
|
||||
|
||||
Args:
|
||||
task: The task description to route.
|
||||
preferred_member_id: Optional — directly delegate to this member.
|
||||
"""
|
||||
from builtin_tools.delegation import delegate_to_workspace as delegate
|
||||
|
||||
children = await get_children()
|
||||
decision = build_team_routing_payload(
|
||||
children,
|
||||
task=task,
|
||||
preferred_member_id=preferred_member_id,
|
||||
)
|
||||
|
||||
if decision.get("action") == "delegate_to_preferred_member":
|
||||
# Async delegation — returns immediately with task_id
|
||||
result = await delegate.ainvoke(
|
||||
{
|
||||
"workspace_id": decision["preferred_member_id"],
|
||||
"task": task,
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
return decision
|
||||
96
molecule_runtime/events.py
Normal file
96
molecule_runtime/events.py
Normal file
@ -0,0 +1,96 @@
|
||||
"""WebSocket subscriber for platform events.
|
||||
|
||||
Subscribes to the platform WebSocket with X-Workspace-ID header
|
||||
so the workspace only receives events about reachable peers.
|
||||
Triggers system prompt rebuild on relevant peer changes.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Events that should trigger a system prompt rebuild
|
||||
REBUILD_EVENTS = {
|
||||
"WORKSPACE_ONLINE",
|
||||
"WORKSPACE_OFFLINE",
|
||||
"WORKSPACE_EXPANDED",
|
||||
"WORKSPACE_COLLAPSED",
|
||||
"WORKSPACE_REMOVED",
|
||||
"AGENT_CARD_UPDATED",
|
||||
}
|
||||
|
||||
|
||||
class PlatformEventSubscriber:
|
||||
"""Subscribes to platform WebSocket for peer events."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
platform_url: str,
|
||||
workspace_id: str,
|
||||
on_peer_change=None,
|
||||
):
|
||||
self.ws_url = platform_url.replace("http://", "ws://").replace("https://", "wss://") + "/ws"
|
||||
self.workspace_id = workspace_id
|
||||
self.on_peer_change = on_peer_change
|
||||
self._running = False
|
||||
self._reconnect_delay = 1.0
|
||||
|
||||
async def start(self):
|
||||
"""Connect to platform WebSocket with exponential backoff reconnect."""
|
||||
self._running = True
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
await self._connect()
|
||||
except Exception as e:
|
||||
if not self._running:
|
||||
break
|
||||
logger.warning("WebSocket disconnected: %s. Reconnecting in %.0fs...", e, self._reconnect_delay)
|
||||
await asyncio.sleep(self._reconnect_delay)
|
||||
self._reconnect_delay = min(self._reconnect_delay * 2, 30.0)
|
||||
|
||||
async def _connect(self):
|
||||
"""Establish WebSocket connection and process events."""
|
||||
try:
|
||||
import websockets
|
||||
except ImportError:
|
||||
logger.warning("websockets package not installed, skipping event subscription")
|
||||
self._running = False
|
||||
return
|
||||
|
||||
# Fix D (Cycle 5): include bearer token in WebSocket upgrade so the
|
||||
# server's new auth check can validate this agent connection.
|
||||
# Graceful fallback for workspaces that have no token yet.
|
||||
headers = {"X-Workspace-ID": self.workspace_id}
|
||||
try:
|
||||
from platform_auth import auth_headers as _auth_headers
|
||||
headers.update(_auth_headers())
|
||||
except Exception:
|
||||
pass # No token available — connect unauthenticated (grandfathered)
|
||||
logger.info("Connecting to platform WebSocket: %s", self.ws_url)
|
||||
|
||||
async with websockets.connect(self.ws_url, additional_headers=headers) as ws:
|
||||
self._reconnect_delay = 1.0 # Reset on successful connect
|
||||
logger.info("Platform WebSocket connected")
|
||||
|
||||
async for message in ws:
|
||||
try:
|
||||
event = json.loads(message)
|
||||
event_type = event.get("event", "")
|
||||
|
||||
if event_type in REBUILD_EVENTS:
|
||||
logger.info("Peer event: %s for workspace %s",
|
||||
event_type, event.get("workspace_id", ""))
|
||||
if self.on_peer_change:
|
||||
await self.on_peer_change(event)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning("Error processing event: %s", e)
|
||||
|
||||
def stop(self):
|
||||
self._running = False
|
||||
389
molecule_runtime/executor_helpers.py
Normal file
389
molecule_runtime/executor_helpers.py
Normal file
@ -0,0 +1,389 @@
|
||||
"""Shared helpers for AgentExecutor implementations.
|
||||
|
||||
Used by both CLIAgentExecutor (codex, ollama) and ClaudeSDKExecutor (claude-code).
|
||||
Provides:
|
||||
- Memory recall/commit (HTTP to platform /memories endpoints)
|
||||
- Delegation results consumption (atomic file rename)
|
||||
- Current task heartbeat updates
|
||||
- System prompt loading from /configs
|
||||
- A2A instructions text for system prompt injection (MCP and CLI variants)
|
||||
- Brief task summary extraction (markdown-aware)
|
||||
- Error message sanitization (exception classes and subprocess categories)
|
||||
- Shared workspace path constants and the MCP server path resolver
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import httpx
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from heartbeat import HeartbeatLoop
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ========================================================================
|
||||
# Constants — workspace container layout
|
||||
# ========================================================================
|
||||
|
||||
WORKSPACE_MOUNT = "/workspace"
|
||||
CONFIG_MOUNT = "/configs"
|
||||
DEFAULT_MCP_SERVER_PATH = "/app/a2a_mcp_server.py"
|
||||
DEFAULT_DELEGATION_RESULTS_FILE = "/tmp/delegation_results.jsonl"
|
||||
PLATFORM_HTTP_TIMEOUT_S = 5.0
|
||||
MEMORY_RECALL_LIMIT = 10
|
||||
MEMORY_CONTENT_MAX_CHARS = 200
|
||||
BRIEF_SUMMARY_MAX_LEN = 80
|
||||
|
||||
|
||||
def get_mcp_server_path() -> str:
|
||||
"""Return the path to the stdio MCP server script.
|
||||
|
||||
Overridable via A2A_MCP_SERVER_PATH for tests and non-default layouts.
|
||||
"""
|
||||
return os.environ.get("A2A_MCP_SERVER_PATH", DEFAULT_MCP_SERVER_PATH)
|
||||
|
||||
|
||||
# ========================================================================
|
||||
# HTTP client (shared, lazily initialised)
|
||||
# ========================================================================
|
||||
|
||||
_http_client: httpx.AsyncClient | None = None
|
||||
|
||||
|
||||
def get_http_client() -> httpx.AsyncClient:
|
||||
"""Lazy-init a shared httpx client for platform API calls."""
|
||||
global _http_client
|
||||
if _http_client is None or _http_client.is_closed:
|
||||
_http_client = httpx.AsyncClient(timeout=PLATFORM_HTTP_TIMEOUT_S)
|
||||
return _http_client
|
||||
|
||||
|
||||
def reset_http_client_for_tests() -> None:
|
||||
"""Test helper — drop the shared client so the next call rebuilds it.
|
||||
|
||||
Not for production use. Exposed so tests can guarantee a clean slate
|
||||
between cases without touching module internals.
|
||||
"""
|
||||
global _http_client
|
||||
_http_client = None
|
||||
|
||||
|
||||
# ========================================================================
|
||||
# Memory recall + commit
|
||||
# ========================================================================
|
||||
|
||||
async def recall_memories() -> str:
|
||||
"""Recall recent memories from the platform API.
|
||||
|
||||
Returns a newline-joined bullet list of up to MEMORY_RECALL_LIMIT most recent
|
||||
memories, or empty string when the platform is unreachable / not configured
|
||||
/ returns a non-200 / returns an unexpected payload shape.
|
||||
"""
|
||||
workspace_id = os.environ.get("WORKSPACE_ID", "")
|
||||
platform_url = os.environ.get("PLATFORM_URL", "")
|
||||
if not workspace_id or not platform_url:
|
||||
return ""
|
||||
# Fix E (Cycle 5): send auth headers so the WorkspaceAuth middleware
|
||||
# (Fix A) allows access once the workspace has a live token on file.
|
||||
try:
|
||||
from platform_auth import auth_headers as _platform_auth
|
||||
_auth = _platform_auth()
|
||||
except Exception:
|
||||
_auth = {}
|
||||
try:
|
||||
resp = await get_http_client().get(
|
||||
f"{platform_url}/workspaces/{workspace_id}/memories",
|
||||
headers=_auth,
|
||||
)
|
||||
if not 200 <= resp.status_code < 300:
|
||||
logger.debug(
|
||||
"recall_memories: non-2xx response %s from platform",
|
||||
resp.status_code,
|
||||
)
|
||||
return ""
|
||||
data = resp.json()
|
||||
except Exception as exc:
|
||||
logger.debug("recall_memories: request failed: %s", exc)
|
||||
return ""
|
||||
if not isinstance(data, list) or not data:
|
||||
return ""
|
||||
lines = [
|
||||
f"- [{m.get('scope', '?')}] {m.get('content', '')}"
|
||||
for m in data[-MEMORY_RECALL_LIMIT:]
|
||||
]
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
async def commit_memory(content: str) -> None:
|
||||
"""Save a memory to the platform API. Best-effort, no error propagation."""
|
||||
workspace_id = os.environ.get("WORKSPACE_ID", "")
|
||||
platform_url = os.environ.get("PLATFORM_URL", "")
|
||||
if not workspace_id or not platform_url or not content:
|
||||
return
|
||||
# Fix E (Cycle 5): include auth header so WorkspaceAuth middleware allows access.
|
||||
try:
|
||||
from platform_auth import auth_headers as _platform_auth
|
||||
_auth = _platform_auth()
|
||||
except Exception:
|
||||
_auth = {}
|
||||
try:
|
||||
await get_http_client().post(
|
||||
f"{platform_url}/workspaces/{workspace_id}/memories",
|
||||
json={"content": content, "scope": "LOCAL"},
|
||||
headers=_auth,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug("commit_memory: request failed: %s", exc)
|
||||
|
||||
|
||||
# ========================================================================
|
||||
# Delegation results — written by heartbeat loop, consumed atomically
|
||||
# ========================================================================
|
||||
|
||||
def read_delegation_results() -> str:
|
||||
"""Read and consume delegation results written by the heartbeat loop.
|
||||
|
||||
Uses atomic rename to prevent races with the heartbeat writer.
|
||||
Returns formatted text suitable for prompt injection, or empty string.
|
||||
"""
|
||||
results_file = Path(
|
||||
os.environ.get("DELEGATION_RESULTS_FILE", DEFAULT_DELEGATION_RESULTS_FILE)
|
||||
)
|
||||
if not results_file.exists():
|
||||
return ""
|
||||
consumed = results_file.with_suffix(".consumed")
|
||||
try:
|
||||
results_file.rename(consumed)
|
||||
except OSError:
|
||||
return "" # File disappeared between exists() and rename()
|
||||
try:
|
||||
raw = consumed.read_text(encoding="utf-8", errors="replace")
|
||||
except OSError:
|
||||
return ""
|
||||
finally:
|
||||
consumed.unlink(missing_ok=True)
|
||||
|
||||
parts: list[str] = []
|
||||
for line in raw.strip().split("\n"):
|
||||
if not line.strip():
|
||||
continue
|
||||
try:
|
||||
record = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
status = record.get("status", "?")
|
||||
summary = record.get("summary", "")
|
||||
preview = record.get("response_preview", "")
|
||||
parts.append(f"- [{status}] {summary}")
|
||||
if preview:
|
||||
parts.append(f" Response: {preview[:200]}")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
# ========================================================================
|
||||
# Current task heartbeat update
|
||||
# ========================================================================
|
||||
|
||||
async def set_current_task(heartbeat: "HeartbeatLoop | None", task: str) -> None:
|
||||
"""Update current task on heartbeat and push immediately via platform API."""
|
||||
if heartbeat is not None:
|
||||
heartbeat.current_task = task
|
||||
heartbeat.active_tasks = 1 if task else 0
|
||||
workspace_id = os.environ.get("WORKSPACE_ID", "")
|
||||
platform_url = os.environ.get("PLATFORM_URL", "")
|
||||
if not (workspace_id and platform_url):
|
||||
return
|
||||
try:
|
||||
try:
|
||||
from platform_auth import auth_headers as _auth
|
||||
_headers = _auth()
|
||||
except Exception:
|
||||
_headers = {}
|
||||
await get_http_client().post(
|
||||
f"{platform_url}/registry/heartbeat",
|
||||
json={
|
||||
"workspace_id": workspace_id,
|
||||
"current_task": task,
|
||||
"active_tasks": 1 if task else 0,
|
||||
"error_rate": 0,
|
||||
"sample_error": "",
|
||||
"uptime_seconds": 0,
|
||||
},
|
||||
headers=_headers,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug("set_current_task: heartbeat push failed: %s", exc)
|
||||
|
||||
|
||||
# ========================================================================
|
||||
# System prompt loading
|
||||
# ========================================================================
|
||||
|
||||
def get_system_prompt(config_path: str, fallback: str | None = None) -> str | None:
|
||||
"""Read system-prompt.md from the config dir each call (supports hot-reload).
|
||||
|
||||
Falls back to the provided string if the file doesn't exist.
|
||||
"""
|
||||
prompt_file = Path(config_path) / "system-prompt.md"
|
||||
if prompt_file.exists():
|
||||
return prompt_file.read_text(encoding="utf-8", errors="replace").strip()
|
||||
return fallback
|
||||
|
||||
|
||||
_A2A_INSTRUCTIONS_MCP = """## Inter-Agent Communication
|
||||
You have MCP tools for communicating with other workspaces:
|
||||
- list_peers: discover available peer workspaces (name, ID, status, role)
|
||||
- delegate_task: send a task and WAIT for the response (for quick tasks)
|
||||
- delegate_task_async: send a task and return immediately with a task_id (for long tasks)
|
||||
- check_task_status: poll an async task's status and get results when done
|
||||
- get_workspace_info: get your own workspace info
|
||||
|
||||
For quick questions, use delegate_task (synchronous).
|
||||
For long-running work (building pages, running audits), use delegate_task_async + check_task_status.
|
||||
Always use list_peers first to discover available workspace IDs.
|
||||
Access control is enforced — you can only reach siblings and parent/children.
|
||||
|
||||
PROACTIVE MESSAGING: Use send_message_to_user to push messages to the user's chat at ANY time:
|
||||
- Acknowledge tasks immediately: "Got it, delegating to the team now..."
|
||||
- Send progress updates during long work: "Research Lead finished, waiting on Dev Lead..."
|
||||
- Deliver follow-up results: "All teams reported back. Here's the synthesis: ..."
|
||||
This lets you respond quickly ("I'll work on this") and come back later with results.
|
||||
|
||||
If delegate_task returns a DELEGATION FAILED message, do NOT forward the raw error to the user.
|
||||
Instead: (1) try delegating to a different peer, (2) handle the task yourself, or
|
||||
(3) tell the user which peer is unavailable and provide your own best answer."""
|
||||
|
||||
|
||||
_A2A_INSTRUCTIONS_CLI = """## Inter-Agent Communication
|
||||
You can delegate tasks to other workspaces using the a2a command:
|
||||
python3 /app/a2a_cli.py peers # List available peers
|
||||
python3 /app/a2a_cli.py delegate <workspace_id> <task> # Sync: wait for response
|
||||
python3 /app/a2a_cli.py delegate --async <workspace_id> <task> # Async: return task_id
|
||||
python3 /app/a2a_cli.py status <workspace_id> <task_id> # Check async task
|
||||
python3 /app/a2a_cli.py info # Your workspace info
|
||||
|
||||
For quick questions, use sync delegate. For long tasks, use --async + status.
|
||||
Only delegate to peers listed by the peers command (access control enforced)."""
|
||||
|
||||
|
||||
def get_a2a_instructions(mcp: bool = True) -> str:
|
||||
"""Return inter-agent communication instructions for system-prompt injection.
|
||||
|
||||
Pass `mcp=True` (default) for MCP-capable runtimes (Claude Code via SDK,
|
||||
Codex). Pass `mcp=False` for CLI-only runtimes (Ollama, custom) that have
|
||||
to call a2a_cli.py as a subprocess.
|
||||
"""
|
||||
return _A2A_INSTRUCTIONS_MCP if mcp else _A2A_INSTRUCTIONS_CLI
|
||||
|
||||
|
||||
# ========================================================================
|
||||
# Misc text helpers
|
||||
# ========================================================================
|
||||
|
||||
_MARKDOWN_FENCE = "```"
|
||||
_MARKDOWN_HR = "---"
|
||||
|
||||
|
||||
_BRIEF_SUMMARY_MIN_LEN = 4 # 1 char + 3-char ellipsis
|
||||
|
||||
|
||||
def brief_summary(text: str, max_len: int = BRIEF_SUMMARY_MAX_LEN) -> str:
|
||||
"""Extract a one-line task summary for the canvas card display.
|
||||
|
||||
Strips markdown headers (#, ##, ###), bold/italic markers (**, __),
|
||||
and skips code fences and horizontal rules. Returns the first meaningful
|
||||
line, truncated with an ellipsis when it exceeds `max_len`.
|
||||
|
||||
`max_len` is clamped to at least 4 (one real character plus a 3-char
|
||||
ellipsis) so degenerate callers can't produce negative slice indices.
|
||||
"""
|
||||
max_len = max(max_len, _BRIEF_SUMMARY_MIN_LEN)
|
||||
for raw_line in text.split("\n"):
|
||||
line = raw_line.strip()
|
||||
while line.startswith("#"):
|
||||
line = line[1:]
|
||||
line = line.strip()
|
||||
if not line or line.startswith(_MARKDOWN_FENCE) or line == _MARKDOWN_HR:
|
||||
continue
|
||||
line = line.replace("**", "").replace("__", "")
|
||||
if len(line) > max_len:
|
||||
return line[: max_len - 3] + "..."
|
||||
return line
|
||||
return text[:max_len]
|
||||
|
||||
|
||||
def extract_message_text(message: Any) -> str:
|
||||
"""Extract text from an A2A message (handles both .text and .root.text patterns)."""
|
||||
parts = getattr(message, "parts", None) or []
|
||||
text_parts: list[str] = []
|
||||
for part in parts:
|
||||
text = getattr(part, "text", None)
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
continue
|
||||
root = getattr(part, "root", None)
|
||||
if root is not None:
|
||||
root_text = getattr(root, "text", None)
|
||||
if root_text:
|
||||
text_parts.append(root_text)
|
||||
return " ".join(text_parts).strip()
|
||||
|
||||
|
||||
# Word-boundary patterns for subprocess stderr classification. Using word
|
||||
# boundaries avoids false positives like "author" matching "auth" or
|
||||
# "generate" matching "rate".
|
||||
_RATE_LIMIT_RE = re.compile(r"\brate\b|\b429\b|\boverloaded\b", re.IGNORECASE)
|
||||
_AUTH_RE = re.compile(r"\bauth(?:entication|orization)?\b|\bapi[_-]?key\b", re.IGNORECASE)
|
||||
_SESSION_RE = re.compile(r"\bsession\b|\bno conversation found\b", re.IGNORECASE)
|
||||
|
||||
|
||||
def classify_subprocess_error(stderr_text: str, exit_code: int | None) -> str:
|
||||
"""Map a subprocess stderr blob to a short, user-safe category tag.
|
||||
|
||||
The full stderr goes to the workspace logs via `logger.error`; only the
|
||||
category is surfaced to the user to avoid leaking tokens, internal paths,
|
||||
or stack traces in the chat UI. Used with `sanitize_agent_error` to
|
||||
produce a user-facing message for subprocess failures.
|
||||
"""
|
||||
if _RATE_LIMIT_RE.search(stderr_text):
|
||||
return "rate_limited"
|
||||
if _AUTH_RE.search(stderr_text):
|
||||
return "auth_failed"
|
||||
if _SESSION_RE.search(stderr_text):
|
||||
return "session_error"
|
||||
if exit_code is not None and exit_code != 0:
|
||||
return f"exit_{exit_code}"
|
||||
return "subprocess_error"
|
||||
|
||||
|
||||
def sanitize_agent_error(
|
||||
exc: BaseException | None = None,
|
||||
category: str | None = None,
|
||||
) -> str:
|
||||
"""Render an agent-side failure into a user-safe error message.
|
||||
|
||||
Either pass an exception (class name is used as the tag) or an explicit
|
||||
category string (e.g. from `classify_subprocess_error`). If both are
|
||||
given, `category` wins. If neither, the tag defaults to "unknown".
|
||||
|
||||
The message body is deliberately dropped — exception messages and
|
||||
subprocess stderr frequently leak stack traces, paths, tokens, and
|
||||
API keys. Full detail is available in the workspace logs via
|
||||
`logger.exception()` / `logger.error()`.
|
||||
"""
|
||||
if category:
|
||||
tag = category
|
||||
elif exc is not None:
|
||||
tag = type(exc).__name__
|
||||
else:
|
||||
tag = "unknown"
|
||||
return f"Agent error ({tag}) — see workspace logs for details."
|
||||
291
molecule_runtime/heartbeat.py
Normal file
291
molecule_runtime/heartbeat.py
Normal file
@ -0,0 +1,291 @@
|
||||
"""Heartbeat loop — alive signal + delegation status checker.
|
||||
|
||||
Every 30 seconds:
|
||||
1. Send heartbeat to platform (alive signal with current_task, error_rate)
|
||||
2. Check pending delegations — any results back?
|
||||
3. Store completed delegation results for the agent to pick up
|
||||
|
||||
Resilient: recreates HTTP client on failure, auto-restarts on crash.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
|
||||
from platform_auth import auth_headers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
HEARTBEAT_INTERVAL = 30 # seconds
|
||||
MAX_CONSECUTIVE_FAILURES = 10
|
||||
MAX_SEEN_DELEGATION_IDS = 200
|
||||
SELF_MESSAGE_COOLDOWN = 60 # seconds — minimum between self-messages to prevent loops
|
||||
# Shared path — also used by cli_executor._read_delegation_results()
|
||||
DELEGATION_RESULTS_FILE = os.environ.get("DELEGATION_RESULTS_FILE", "/tmp/delegation_results.jsonl")
|
||||
|
||||
|
||||
class HeartbeatLoop:
|
||||
def __init__(self, platform_url: str, workspace_id: str):
|
||||
self.platform_url = platform_url
|
||||
self.workspace_id = workspace_id
|
||||
self.start_time = time.time()
|
||||
self.error_count = 0
|
||||
self.request_count = 0
|
||||
self.active_tasks = 0
|
||||
self.current_task = ""
|
||||
self.sample_error = ""
|
||||
self._task = None
|
||||
self._consecutive_failures = 0
|
||||
self._seen_delegation_ids: set[str] = set()
|
||||
self._last_self_message_time = 0.0
|
||||
self._parent_name: str | None = None # Cached after first lookup
|
||||
|
||||
@property
|
||||
def error_rate(self) -> float:
|
||||
if self.request_count == 0:
|
||||
return 0.0
|
||||
return self.error_count / self.request_count
|
||||
|
||||
def record_error(self, error: str):
|
||||
self.error_count += 1
|
||||
self.request_count += 1
|
||||
self.sample_error = error
|
||||
|
||||
def record_success(self):
|
||||
self.request_count += 1
|
||||
|
||||
def start(self):
|
||||
self._task = asyncio.create_task(self._loop())
|
||||
self._task.add_done_callback(self._on_done)
|
||||
|
||||
def _on_done(self, task):
|
||||
if not task.cancelled() and task.exception():
|
||||
logger.error("Heartbeat loop died: %s — restarting", task.exception())
|
||||
self._task = asyncio.create_task(self._loop())
|
||||
self._task.add_done_callback(self._on_done)
|
||||
|
||||
async def stop(self):
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _loop(self):
|
||||
while True:
|
||||
client = None
|
||||
try:
|
||||
client = httpx.AsyncClient(timeout=10.0)
|
||||
while True:
|
||||
# 1. Send heartbeat (Phase 30.1: include auth header if token known)
|
||||
try:
|
||||
await client.post(
|
||||
f"{self.platform_url}/registry/heartbeat",
|
||||
json={
|
||||
"workspace_id": self.workspace_id,
|
||||
"error_rate": self.error_rate,
|
||||
"sample_error": self.sample_error,
|
||||
"active_tasks": self.active_tasks,
|
||||
"current_task": self.current_task,
|
||||
"uptime_seconds": int(time.time() - self.start_time),
|
||||
},
|
||||
headers=auth_headers(),
|
||||
)
|
||||
self.error_count = 0
|
||||
self.request_count = 0
|
||||
self._consecutive_failures = 0
|
||||
except Exception as e:
|
||||
self._consecutive_failures += 1
|
||||
if self._consecutive_failures <= 3 or self._consecutive_failures % MAX_CONSECUTIVE_FAILURES == 0:
|
||||
logger.warning("Heartbeat failed (%d consecutive): %s", self._consecutive_failures, e)
|
||||
if self._consecutive_failures >= MAX_CONSECUTIVE_FAILURES:
|
||||
logger.info("Heartbeat: recreating HTTP client after %d failures", self._consecutive_failures)
|
||||
try:
|
||||
await client.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
break
|
||||
|
||||
# 2. Check delegation status
|
||||
try:
|
||||
await self._check_delegations(client)
|
||||
except Exception as e:
|
||||
logger.debug("Delegation check failed: %s", e)
|
||||
|
||||
await asyncio.sleep(HEARTBEAT_INTERVAL)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Heartbeat loop error: %s — retrying in 30s", e)
|
||||
await asyncio.sleep(HEARTBEAT_INTERVAL)
|
||||
finally:
|
||||
if client:
|
||||
try:
|
||||
await client.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _check_delegations(self, client: httpx.AsyncClient):
|
||||
"""Check for completed delegations and store results for the agent."""
|
||||
try:
|
||||
resp = await client.get(
|
||||
f"{self.platform_url}/workspaces/{self.workspace_id}/delegations",
|
||||
headers=auth_headers(),
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return
|
||||
|
||||
delegations = resp.json()
|
||||
if not isinstance(delegations, list):
|
||||
return
|
||||
|
||||
new_results = []
|
||||
for d in delegations:
|
||||
did = d.get("delegation_id", "")
|
||||
status = d.get("status", "")
|
||||
|
||||
if not did or did in self._seen_delegation_ids:
|
||||
continue
|
||||
|
||||
if status in ("completed", "failed"):
|
||||
# Fix B (Cycle 5): validate source_id before accepting delegation
|
||||
# results. Only process delegations that THIS workspace created
|
||||
# (source_id == self.workspace_id). Attacker-crafted delegation
|
||||
# records with a foreign source_id cannot inject instructions.
|
||||
source_id = d.get("source_id", "")
|
||||
if source_id != self.workspace_id:
|
||||
logger.warning(
|
||||
"Heartbeat: skipping delegation %s — source_id %r does not "
|
||||
"match this workspace %r; possible injection attempt",
|
||||
did, source_id, self.workspace_id,
|
||||
)
|
||||
self._seen_delegation_ids.add(did) # mark seen so we don't warn again
|
||||
continue
|
||||
|
||||
self._seen_delegation_ids.add(did)
|
||||
new_results.append({
|
||||
"delegation_id": did,
|
||||
"target_id": d.get("target_id", ""),
|
||||
"source_id": source_id,
|
||||
"status": status,
|
||||
"summary": d.get("summary", ""),
|
||||
"response_preview": d.get("response_preview", ""),
|
||||
"error": d.get("error", ""),
|
||||
"timestamp": time.time(),
|
||||
})
|
||||
|
||||
# Evict old seen IDs if over limit
|
||||
if len(self._seen_delegation_ids) > MAX_SEEN_DELEGATION_IDS:
|
||||
# Keep most recent half
|
||||
self._seen_delegation_ids = set(list(self._seen_delegation_ids)[MAX_SEEN_DELEGATION_IDS // 2:])
|
||||
|
||||
if new_results:
|
||||
# Append to results file for context injection on next message
|
||||
with open(DELEGATION_RESULTS_FILE, "a") as f:
|
||||
for r in new_results:
|
||||
f.write(json.dumps(r) + "\n")
|
||||
logger.info("Heartbeat: %d new delegation results — triggering self-message", len(new_results))
|
||||
|
||||
# Build a summary message for the agent.
|
||||
# Fix B (Cycle 5): do NOT embed raw response_preview text in
|
||||
# user-role A2A messages — that is the prompt-injection vector.
|
||||
# Instead reference only the delegation ID and status; the agent
|
||||
# reads full content from DELEGATION_RESULTS_FILE which was
|
||||
# written above from trusted platform data.
|
||||
summary_lines = []
|
||||
for r in new_results:
|
||||
line = f"- [{r['status']}] Delegation {r['delegation_id'][:8]}: {r['summary'][:80]}"
|
||||
if r.get("error"):
|
||||
line += f"\n Error: {r['error'][:100]}"
|
||||
summary_lines.append(line)
|
||||
|
||||
# Look up parent workspace (cached after first call)
|
||||
if self._parent_name is None:
|
||||
try:
|
||||
parent_resp = await client.get(
|
||||
f"{self.platform_url}/workspaces/{self.workspace_id}",
|
||||
headers=auth_headers(),
|
||||
)
|
||||
if parent_resp.status_code == 200:
|
||||
parent_id = parent_resp.json().get("parent_id", "")
|
||||
if parent_id:
|
||||
parent_info = await client.get(
|
||||
f"{self.platform_url}/workspaces/{parent_id}",
|
||||
headers=auth_headers(),
|
||||
)
|
||||
if parent_info.status_code == 200:
|
||||
self._parent_name = parent_info.json().get("name", "")
|
||||
if self._parent_name is None:
|
||||
self._parent_name = "" # No parent — cache empty
|
||||
except Exception:
|
||||
pass # Will retry next cycle
|
||||
parent_name = self._parent_name or ""
|
||||
|
||||
report_instruction = ""
|
||||
if parent_name:
|
||||
report_instruction = (
|
||||
f"\n\nIMPORTANT: Report these results back to your parent '{parent_name}' "
|
||||
f"by delegating a summary to them. Use delegate_task or delegate_task_async "
|
||||
f"with a concise status report. Also use send_message_to_user to notify the user."
|
||||
)
|
||||
else:
|
||||
report_instruction = (
|
||||
"\n\nReport results using send_message_to_user to notify the user."
|
||||
)
|
||||
|
||||
trigger_msg = (
|
||||
"Delegation results are ready. Review them and take appropriate action:\n"
|
||||
+ "\n".join(summary_lines)
|
||||
+ report_instruction
|
||||
)
|
||||
|
||||
# Send A2A self-message to wake the agent.
|
||||
# Minimum 60s between self-messages to avoid spam, but always send
|
||||
# when there are genuinely NEW results to process.
|
||||
now = time.time()
|
||||
if now - self._last_self_message_time < SELF_MESSAGE_COOLDOWN:
|
||||
logger.debug("Heartbeat: self-message cooldown (60s), will retry next cycle")
|
||||
else:
|
||||
self._last_self_message_time = now
|
||||
try:
|
||||
await client.post(
|
||||
f"{self.platform_url}/workspaces/{self.workspace_id}/a2a",
|
||||
json={
|
||||
"method": "message/send",
|
||||
"params": {
|
||||
"message": {
|
||||
"role": "user",
|
||||
"parts": [{"type": "text", "text": trigger_msg}],
|
||||
},
|
||||
},
|
||||
},
|
||||
headers=auth_headers(),
|
||||
timeout=120.0,
|
||||
)
|
||||
logger.info("Heartbeat: self-message sent to process delegation results")
|
||||
except Exception as e:
|
||||
logger.warning("Heartbeat: failed to send self-message: %s", e)
|
||||
|
||||
# Also push notification to user via canvas
|
||||
for r in new_results:
|
||||
try:
|
||||
msg = f"Delegation {r['status']}: {r['summary'][:100]}"
|
||||
if r.get("response_preview"):
|
||||
msg += f"\nResult: {r['response_preview'][:200]}"
|
||||
await client.post(
|
||||
f"{self.platform_url}/workspaces/{self.workspace_id}/notify",
|
||||
json={"message": msg, "type": "delegation_result"},
|
||||
headers=auth_headers(),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
logger.debug("Delegation check error: %s", e)
|
||||
51
molecule_runtime/initial_prompt.py
Normal file
51
molecule_runtime/initial_prompt.py
Normal file
@ -0,0 +1,51 @@
|
||||
"""Helpers for the workspace's one-shot initial_prompt.
|
||||
|
||||
Kept as a standalone module (no heavy imports like uvicorn) so the marker
|
||||
logic is unit-testable without standing up the full workspace runtime.
|
||||
|
||||
Background: the workspace runtime supports an `initial_prompt` that runs once
|
||||
on first boot (clone the repo, set git hooks, read CLAUDE.md, commit_memory).
|
||||
A marker file `.initial_prompt_done` prevents the prompt from re-running on
|
||||
subsequent boots.
|
||||
|
||||
Prior behaviour wrote the marker AFTER the prompt completed successfully. If
|
||||
the prompt crashed mid-execution (e.g. ProcessError from a stale Claude
|
||||
session), the marker was never written; every subsequent container boot
|
||||
replayed the same failing prompt, cascading into "every message crashes until
|
||||
an operator intervenes." See GitHub issue #71.
|
||||
|
||||
Fix (2026-04-12): write the marker BEFORE firing the prompt. If the prompt
|
||||
fails, operators re-send it manually via chat — cheap and available — instead
|
||||
of trapping the workspace in a crash loop.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
|
||||
def resolve_initial_prompt_marker(config_path: str) -> str:
|
||||
"""Return the path where the `.initial_prompt_done` marker should live.
|
||||
|
||||
Prefers ``<config_path>/.initial_prompt_done`` when the directory is
|
||||
writable; falls back to ``/workspace/.initial_prompt_done`` for containers
|
||||
where ``/configs`` is read-only.
|
||||
"""
|
||||
if os.access(config_path, os.W_OK):
|
||||
return os.path.join(config_path, ".initial_prompt_done")
|
||||
return "/workspace/.initial_prompt_done"
|
||||
|
||||
|
||||
def mark_initial_prompt_attempted(marker_path: str) -> bool:
|
||||
"""Write the marker best-effort. Return True on success, False on I/O error.
|
||||
|
||||
Called BEFORE the initial-prompt self-message is sent. If the attempt
|
||||
later fails, the marker is still present — so the next container boot
|
||||
does NOT replay the same failing prompt. Operators retry manually via
|
||||
the chat interface instead of relying on auto-replay.
|
||||
"""
|
||||
try:
|
||||
with open(marker_path, "w") as f:
|
||||
f.write("attempted")
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
||||
556
molecule_runtime/main.py
Normal file
556
molecule_runtime/main.py
Normal file
@ -0,0 +1,556 @@
|
||||
"""Workspace runtime entry point.
|
||||
|
||||
Loads config -> discovers adapter -> setup -> create executor -> wrap in A2A -> register -> heartbeat.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
|
||||
# When running as the installed `molecule-runtime` console script, the flat
|
||||
# module names (config, heartbeat, adapters, etc.) need to resolve to this
|
||||
# package's submodules. We inject the package directory onto sys.path so that
|
||||
# `from config import ...` resolves to `molecule_runtime/config.py`.
|
||||
_PKG_DIR = os.path.dirname(__file__)
|
||||
if _PKG_DIR not in sys.path:
|
||||
sys.path.insert(0, _PKG_DIR)
|
||||
|
||||
import httpx
|
||||
import uvicorn
|
||||
from a2a.server.apps import A2AStarletteApplication
|
||||
from a2a.server.request_handlers import DefaultRequestHandler
|
||||
from a2a.server.tasks import InMemoryTaskStore
|
||||
from a2a.types import AgentCard, AgentCapabilities, AgentSkill
|
||||
|
||||
from adapters import get_adapter, AdapterConfig
|
||||
from config import load_config
|
||||
from heartbeat import HeartbeatLoop
|
||||
from preflight import run_preflight, render_preflight_report
|
||||
from builtin_tools.awareness_client import get_awareness_config
|
||||
import uuid as _uuid
|
||||
|
||||
from builtin_tools.telemetry import setup_telemetry, make_trace_middleware
|
||||
from policies.namespaces import resolve_awareness_namespace
|
||||
|
||||
|
||||
from initial_prompt import (
|
||||
mark_initial_prompt_attempted,
|
||||
resolve_initial_prompt_marker,
|
||||
)
|
||||
from platform_auth import auth_headers
|
||||
|
||||
|
||||
def get_machine_ip() -> str: # pragma: no cover
|
||||
"""Get the machine's IP for A2A discovery."""
|
||||
try:
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
s.connect(("8.8.8.8", 80))
|
||||
ip = s.getsockname()[0]
|
||||
s.close()
|
||||
return ip
|
||||
except Exception:
|
||||
return "127.0.0.1"
|
||||
|
||||
|
||||
# Re-exported from transcript_auth for the inline /transcript handler.
|
||||
# Separate module keeps the security-critical gate import-light + unit-testable.
|
||||
from transcript_auth import transcript_authorized as _transcript_authorized
|
||||
|
||||
|
||||
async def main(): # pragma: no cover
|
||||
workspace_id = os.environ.get("WORKSPACE_ID", "workspace-default")
|
||||
config_path = os.environ.get("WORKSPACE_CONFIG_PATH", "/configs")
|
||||
platform_url = os.environ.get("PLATFORM_URL", "http://platform:8080")
|
||||
awareness_config = get_awareness_config()
|
||||
|
||||
# 0. Initialise OpenTelemetry (no-op if packages not installed)
|
||||
setup_telemetry(service_name=workspace_id)
|
||||
|
||||
# 1. Load config
|
||||
config = load_config(config_path)
|
||||
port = config.a2a.port
|
||||
preflight = run_preflight(config, config_path)
|
||||
render_preflight_report(preflight)
|
||||
if not preflight.ok:
|
||||
raise SystemExit(1)
|
||||
if awareness_config:
|
||||
awareness_namespace = resolve_awareness_namespace(
|
||||
workspace_id,
|
||||
awareness_config.get("namespace", ""),
|
||||
)
|
||||
print(f"Awareness enabled for namespace: {awareness_namespace}")
|
||||
|
||||
# 1.5 Initialise governance adapter (no-op if disabled or package absent)
|
||||
from builtin_tools.governance import initialize_governance
|
||||
if config.governance.enabled:
|
||||
await initialize_governance(config.governance)
|
||||
print(f"Governance: Microsoft Agent Governance Toolkit enabled (mode={config.governance.policy_mode})")
|
||||
else:
|
||||
print("Governance: disabled (set governance.enabled: true in config.yaml to activate)")
|
||||
|
||||
# 2. Create heartbeat (passed to adapter for task tracking)
|
||||
heartbeat = HeartbeatLoop(platform_url, workspace_id)
|
||||
|
||||
# 3. Get adapter for this runtime
|
||||
runtime = config.runtime or "langgraph"
|
||||
adapter_cls = get_adapter(runtime) # Raises KeyError if unknown — no silent fallback
|
||||
|
||||
adapter = adapter_cls()
|
||||
print(f"Runtime: {runtime} ({adapter.display_name()})")
|
||||
|
||||
# 4. Build adapter config
|
||||
adapter_config = AdapterConfig(
|
||||
model=config.model,
|
||||
system_prompt=None, # Adapter builds its own prompt
|
||||
tools=config.skills, # Skill names from config.yaml
|
||||
runtime_config=vars(config.runtime_config) if config.runtime_config else {},
|
||||
config_path=config_path,
|
||||
workspace_id=workspace_id,
|
||||
prompt_files=config.prompt_files,
|
||||
a2a_port=port,
|
||||
heartbeat=heartbeat,
|
||||
)
|
||||
|
||||
# 5. Setup adapter and create executor
|
||||
# If setup fails, ensure heartbeat is stopped to prevent resource leak
|
||||
try:
|
||||
await adapter.setup(adapter_config)
|
||||
executor = await adapter.create_executor(adapter_config)
|
||||
except Exception:
|
||||
# heartbeat hasn't started yet but may have async tasks pending
|
||||
if hasattr(heartbeat, "stop"):
|
||||
try:
|
||||
await heartbeat.stop()
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
|
||||
# 5.5. Initialise Temporal durable execution wrapper (optional)
|
||||
# Connects to TEMPORAL_HOST (default: localhost:7233) and starts a
|
||||
# co-located Temporal worker as a background asyncio task.
|
||||
# No-op with a warning log if Temporal is unreachable or temporalio
|
||||
# is not installed — all tasks fall back to direct execution transparently.
|
||||
from builtin_tools.temporal_workflow import create_wrapper as _create_temporal_wrapper
|
||||
temporal_wrapper = _create_temporal_wrapper()
|
||||
await temporal_wrapper.start()
|
||||
|
||||
# Get loaded skills for agent card (adapter may have populated them)
|
||||
loaded_skills = getattr(adapter, "loaded_skills", [])
|
||||
|
||||
# 6. Build Agent Card
|
||||
machine_ip = os.environ.get("HOSTNAME", get_machine_ip())
|
||||
workspace_url = f"http://{machine_ip}:{port}"
|
||||
|
||||
agent_card = AgentCard(
|
||||
name=config.name,
|
||||
description=config.description or config.name,
|
||||
version=config.version,
|
||||
url=workspace_url,
|
||||
capabilities=AgentCapabilities(
|
||||
streaming=config.a2a.streaming,
|
||||
pushNotifications=config.a2a.push_notifications,
|
||||
stateTransitionHistory=True,
|
||||
),
|
||||
skills=[
|
||||
AgentSkill(
|
||||
id=skill.metadata.id,
|
||||
name=skill.metadata.name,
|
||||
description=skill.metadata.description,
|
||||
tags=skill.metadata.tags,
|
||||
examples=skill.metadata.examples,
|
||||
)
|
||||
for skill in loaded_skills
|
||||
],
|
||||
defaultInputModes=["text/plain", "application/json"],
|
||||
defaultOutputModes=["text/plain", "application/json"],
|
||||
)
|
||||
|
||||
# 7. Wrap in A2A.
|
||||
#
|
||||
# Regression fix (#204): PR #198 tried to wire push_config_store +
|
||||
# push_sender to satisfy #175 (push notification capability), but
|
||||
# PushNotificationSender is an abstract base class in the a2a-sdk and
|
||||
# can't be instantiated directly. Passing it crashed main.py on startup
|
||||
# with `TypeError: Can't instantiate abstract class`. Dropped back to
|
||||
# DefaultRequestHandler's own defaults — pushNotifications capability
|
||||
# in the AgentCard below is still advertised via AgentCapabilities so
|
||||
# clients know we COULD do pushes; actually implementing them requires
|
||||
# a concrete sender subclass, tracked as a Phase-H follow-up to #175.
|
||||
handler = DefaultRequestHandler(
|
||||
agent_executor=executor,
|
||||
task_store=InMemoryTaskStore(),
|
||||
)
|
||||
|
||||
app = A2AStarletteApplication(
|
||||
agent_card=agent_card,
|
||||
http_handler=handler,
|
||||
)
|
||||
|
||||
# 8. Register with platform
|
||||
agent_card_dict = {
|
||||
"name": config.name,
|
||||
"description": config.description,
|
||||
"version": config.version,
|
||||
"url": workspace_url,
|
||||
"skills": [
|
||||
{
|
||||
"id": s.metadata.id,
|
||||
"name": s.metadata.name,
|
||||
"description": s.metadata.description,
|
||||
"tags": s.metadata.tags,
|
||||
}
|
||||
for s in loaded_skills
|
||||
],
|
||||
"capabilities": {
|
||||
"streaming": config.a2a.streaming,
|
||||
"pushNotifications": config.a2a.push_notifications,
|
||||
},
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
try:
|
||||
resp = await client.post(
|
||||
f"{platform_url}/registry/register",
|
||||
json={
|
||||
"id": workspace_id,
|
||||
"url": workspace_url,
|
||||
"agent_card": agent_card_dict,
|
||||
},
|
||||
headers=auth_headers(),
|
||||
)
|
||||
print(f"Registered with platform: {resp.status_code}")
|
||||
# Phase 30.1 — capture the auth token issued at first register.
|
||||
# The platform only mints one on first register per workspace,
|
||||
# so a subsequent restart gets an empty auth_token and we
|
||||
# keep using the on-disk copy from the original issuance.
|
||||
if resp.status_code == 200:
|
||||
try:
|
||||
body = resp.json()
|
||||
tok = body.get("auth_token")
|
||||
if tok:
|
||||
from platform_auth import save_token
|
||||
save_token(tok)
|
||||
print(f"Saved workspace auth token (prefix={tok[:8]}…)")
|
||||
except Exception as parse_exc:
|
||||
print(f"Warning: couldn't parse register response for token: {parse_exc}")
|
||||
except Exception as e:
|
||||
print(f"Warning: failed to register with platform: {e}")
|
||||
|
||||
# 9. Start heartbeat
|
||||
heartbeat.start()
|
||||
|
||||
# 9b. Start skills hot-reload watcher (background task)
|
||||
# When a skill file changes the watcher reloads the skill module and calls
|
||||
# back into the adapter so the next A2A request uses the updated tools.
|
||||
if config.skills:
|
||||
try:
|
||||
from skill_loader.watcher import SkillsWatcher
|
||||
|
||||
def _on_skill_reload(updated_skill):
|
||||
"""Rebuild the LangGraph agent when a skill changes in-place."""
|
||||
if not hasattr(adapter, "loaded_skills"):
|
||||
return
|
||||
# Replace the matching skill in the adapter's skill list
|
||||
adapter.loaded_skills = [
|
||||
updated_skill if s.metadata.id == updated_skill.metadata.id else s
|
||||
for s in adapter.loaded_skills
|
||||
]
|
||||
# Rebuild the agent's tool list from updated skills
|
||||
if hasattr(adapter, "all_tools") and hasattr(adapter, "system_prompt"):
|
||||
from builtin_tools.approval import request_approval
|
||||
from builtin_tools.delegation import delegate_to_workspace
|
||||
from builtin_tools.memory import commit_memory, search_memory
|
||||
from builtin_tools.sandbox import run_code
|
||||
base_tools = [delegate_to_workspace, request_approval,
|
||||
commit_memory, search_memory, run_code]
|
||||
skill_tools = []
|
||||
for sk in adapter.loaded_skills:
|
||||
skill_tools.extend(sk.tools)
|
||||
adapter.all_tools = base_tools + skill_tools
|
||||
# Rebuild compiled agent so next ainvoke picks up new tools
|
||||
try:
|
||||
from agent import create_agent
|
||||
new_agent = create_agent(
|
||||
config.model, adapter.all_tools, adapter.system_prompt
|
||||
)
|
||||
executor.agent = new_agent
|
||||
print(f"Skills hot-reload: '{updated_skill.metadata.id}' reloaded — "
|
||||
f"{len(updated_skill.tools)} tool(s)")
|
||||
except Exception as rebuild_err:
|
||||
print(f"Skills hot-reload: agent rebuild failed: {rebuild_err}")
|
||||
|
||||
skills_watcher = SkillsWatcher(
|
||||
config_path=config_path,
|
||||
skill_names=config.skills,
|
||||
on_reload=_on_skill_reload,
|
||||
)
|
||||
asyncio.create_task(skills_watcher.start())
|
||||
print(f"Skills hot-reload enabled for: {config.skills}")
|
||||
except Exception as e:
|
||||
print(f"Warning: skills watcher could not start: {e}")
|
||||
|
||||
# 10. Run A2A server
|
||||
print(f"Workspace {workspace_id} starting on port {port}")
|
||||
# Wrap the ASGI app with W3C TraceContext extraction middleware so incoming
|
||||
# A2A HTTP requests propagate their trace context into _incoming_trace_context.
|
||||
starlette_app = app.build()
|
||||
|
||||
# Add /transcript route — exposes the most-recent agent session log
|
||||
# (claude-code reads ~/.claude/projects/<cwd>/<session>.jsonl). Other
|
||||
# runtimes return supported:false.
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.routing import Route
|
||||
|
||||
async def _transcript_handler(request):
|
||||
# Require workspace bearer token — the same token issued at registration
|
||||
# and stored in /configs/.auth_token. Any container on molecule-monorepo-net
|
||||
# could otherwise read the full session log. Closes #287.
|
||||
#
|
||||
# #328: fail CLOSED when the token file is unavailable. get_token()
|
||||
# returns None during the bootstrap window (first register hasn't
|
||||
# completed), if /configs/.auth_token was deleted, or on OSError.
|
||||
# The old `if expected:` guard treated all three cases as "skip
|
||||
# auth" — an unauthenticated container on the same Docker network
|
||||
# could read the entire session log during that window. Deny
|
||||
# instead. The platform's TranscriptHandler acquires the token
|
||||
# during registration, so once the bootstrap completes it always
|
||||
# has a valid credential to present.
|
||||
from platform_auth import get_token
|
||||
if not _transcript_authorized(get_token(), request.headers.get("Authorization", "")):
|
||||
return JSONResponse({"error": "unauthorized"}, status_code=401)
|
||||
try:
|
||||
since = int(request.query_params.get("since", "0"))
|
||||
limit = int(request.query_params.get("limit", "100"))
|
||||
except (TypeError, ValueError):
|
||||
return JSONResponse({"error": "since and limit must be integers"}, status_code=400)
|
||||
result = await adapter.transcript_lines(since=since, limit=limit)
|
||||
return JSONResponse(result)
|
||||
|
||||
starlette_app.add_route("/transcript", _transcript_handler, methods=["GET"])
|
||||
|
||||
built_app = make_trace_middleware(starlette_app)
|
||||
|
||||
server_config = uvicorn.Config(
|
||||
built_app,
|
||||
host="0.0.0.0",
|
||||
port=port,
|
||||
log_level="info",
|
||||
)
|
||||
server = uvicorn.Server(server_config)
|
||||
|
||||
# 10b. Schedule initial_prompt self-message after server is ready.
|
||||
# Only runs on first boot — creates a marker file to prevent re-execution on restart.
|
||||
initial_prompt_task = None
|
||||
initial_prompt_marker = resolve_initial_prompt_marker(config_path)
|
||||
if config.initial_prompt and not os.path.exists(initial_prompt_marker):
|
||||
# Write the marker UP FRONT (#71): if the prompt later crashes or
|
||||
# times out, we do NOT replay on next boot — that created a
|
||||
# ProcessError cascade where every message kept crashing. Operators
|
||||
# can always re-send via chat. Log loudly if the marker write
|
||||
# fails so the situation is visible.
|
||||
if not mark_initial_prompt_attempted(initial_prompt_marker):
|
||||
print(
|
||||
f"Initial prompt: WARNING — could not write marker at "
|
||||
f"{initial_prompt_marker}; this boot may replay if it crashes.",
|
||||
flush=True,
|
||||
)
|
||||
async def _send_initial_prompt():
|
||||
"""Wait for server to be ready, then send initial_prompt as self-message."""
|
||||
# Wait for the A2A server to accept connections
|
||||
ready = False
|
||||
for attempt in range(30):
|
||||
await asyncio.sleep(1)
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
resp = await client.get(f"http://127.0.0.1:{port}/.well-known/agent.json")
|
||||
if resp.status_code == 200:
|
||||
ready = True
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not ready:
|
||||
print("Initial prompt: server not ready after 30s, skipping", flush=True)
|
||||
return
|
||||
|
||||
# Send initial prompt through the platform A2A proxy (not directly to self).
|
||||
# The proxy logs an a2a_receive with source_id=NULL (canvas-style),
|
||||
# broadcasts A2A_RESPONSE via WebSocket so the chat shows both the
|
||||
# prompt (as user message) and the response (as agent message).
|
||||
# Uses urllib in a thread to avoid asyncio/httpx streaming hangs.
|
||||
import json as _json
|
||||
import urllib.request
|
||||
|
||||
def _do_send_sync():
|
||||
import time as _time
|
||||
payload = _json.dumps({
|
||||
"method": "message/send",
|
||||
"params": {
|
||||
"message": {
|
||||
"role": "user",
|
||||
"messageId": f"initial-{_uuid.uuid4().hex[:8]}",
|
||||
"parts": [{"kind": "text", "text": config.initial_prompt}],
|
||||
},
|
||||
},
|
||||
}).encode()
|
||||
|
||||
# #220: include platform bearer token so the request isn't
|
||||
# silently rejected once any workspace has a live token on
|
||||
# file. Without this, initial_prompt 401s in multi-tenant
|
||||
# mode exactly like /registry/register did in #215.
|
||||
headers = {"Content-Type": "application/json", **auth_headers()}
|
||||
|
||||
# Retry with backoff — the platform proxy may not be able to
|
||||
# reach us yet (container networking takes a moment to settle).
|
||||
max_retries = 5
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
req = urllib.request.Request(
|
||||
f"{platform_url}/workspaces/{workspace_id}/a2a",
|
||||
data=payload,
|
||||
headers=headers,
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=600) as resp:
|
||||
resp.read()
|
||||
print(f"Initial prompt: completed (status={resp.status})", flush=True)
|
||||
break
|
||||
except Exception as e:
|
||||
if attempt < max_retries - 1:
|
||||
delay = 2 ** attempt # 1, 2, 4, 8, 16 seconds
|
||||
print(f"Initial prompt: attempt {attempt + 1} failed ({e}), retrying in {delay}s...", flush=True)
|
||||
_time.sleep(delay)
|
||||
else:
|
||||
print(f"Initial prompt: failed after {max_retries} attempts — {e}", flush=True)
|
||||
return
|
||||
|
||||
# Marker was already written up front (#71). Nothing to do here.
|
||||
|
||||
print("Initial prompt: sending via platform proxy...", flush=True)
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_in_executor(None, _do_send_sync)
|
||||
|
||||
initial_prompt_task = asyncio.create_task(_send_initial_prompt())
|
||||
|
||||
# 10c. Idle loop — reflection-on-completion / backlog-pull pattern.
|
||||
# Fires config.idle_prompt every config.idle_interval_seconds while the
|
||||
# workspace has no active task. This turns every role from "waits for cron"
|
||||
# into "self-wakes when idle" — the Hermes/Letta shape from today's
|
||||
# multi-framework survey (see docs/ecosystem-watch.md). Cost collapses to
|
||||
# event-driven in practice: the idle check is local (no LLM call, just
|
||||
# heartbeat.active_tasks==0), and the prompt only fires when there's
|
||||
# actually nothing to do. Gated on idle_prompt being non-empty so existing
|
||||
# workspaces upgrade opt-in — set idle_prompt in org.yaml defaults or
|
||||
# per-workspace to enable.
|
||||
idle_loop_task = None
|
||||
if config.idle_prompt:
|
||||
# Idle-fire HTTP timeout. Kept tight relative to the fire cadence so a
|
||||
# hung platform doesn't accumulate dangling requests — a fire that
|
||||
# takes longer than the idle interval itself is almost certainly stuck.
|
||||
IDLE_FIRE_TIMEOUT_SECONDS = max(60, min(300, config.idle_interval_seconds))
|
||||
# Initial settle delay — never longer than 60s so cold-start races
|
||||
# don't stall the first fire, and never shorter than the configured
|
||||
# interval (short intervals shouldn't fire instantly on boot either).
|
||||
IDLE_INITIAL_SETTLE_SECONDS = min(config.idle_interval_seconds, 60)
|
||||
|
||||
async def _run_idle_loop():
|
||||
"""Self-sends config.idle_prompt periodically when the workspace is idle."""
|
||||
await asyncio.sleep(IDLE_INITIAL_SETTLE_SECONDS)
|
||||
|
||||
import json as _json
|
||||
from urllib import request as _urlreq, error as _urlerr
|
||||
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(config.idle_interval_seconds)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
|
||||
# Local idle check — no platform API call, no LLM call.
|
||||
# heartbeat.active_tasks == 0 means no in-flight work.
|
||||
if heartbeat.active_tasks > 0:
|
||||
continue
|
||||
|
||||
# Self-post the idle prompt via the platform A2A proxy (same
|
||||
# path as initial_prompt). The agent's own concurrency control
|
||||
# rejects if the workspace becomes busy between this check and
|
||||
# the post — that's the expected safety valve.
|
||||
payload = _json.dumps({
|
||||
"method": "message/send",
|
||||
"params": {
|
||||
"message": {
|
||||
"role": "user",
|
||||
"messageId": f"idle-{_uuid.uuid4().hex[:8]}",
|
||||
"parts": [{"kind": "text", "text": config.idle_prompt}],
|
||||
},
|
||||
},
|
||||
}).encode()
|
||||
|
||||
def _post_sync():
|
||||
# Returns (status_code, error_type) so the caller logs the
|
||||
# actual outcome instead of a bare "post failed" line.
|
||||
# #220: include auth_headers() on every idle fire. Without
|
||||
# this, the idle loop 401s in multi-tenant mode.
|
||||
headers = {"Content-Type": "application/json", **auth_headers()}
|
||||
try:
|
||||
req = _urlreq.Request(
|
||||
f"{platform_url}/workspaces/{workspace_id}/a2a",
|
||||
data=payload,
|
||||
headers=headers,
|
||||
)
|
||||
with _urlreq.urlopen(req, timeout=IDLE_FIRE_TIMEOUT_SECONDS) as resp:
|
||||
resp.read()
|
||||
return resp.status, None
|
||||
except _urlerr.HTTPError as e:
|
||||
return e.code, type(e).__name__
|
||||
except _urlerr.URLError as e:
|
||||
return None, f"URLError: {e.reason}"
|
||||
except Exception as e: # pragma: no cover — catch-all safety net
|
||||
return None, type(e).__name__
|
||||
|
||||
print(
|
||||
f"Idle loop: firing (active_tasks=0, interval={config.idle_interval_seconds}s, "
|
||||
f"timeout={IDLE_FIRE_TIMEOUT_SECONDS}s)",
|
||||
flush=True,
|
||||
)
|
||||
loop_ref = asyncio.get_running_loop()
|
||||
|
||||
def _log_result(future):
|
||||
try:
|
||||
status, err = future.result()
|
||||
if err:
|
||||
print(
|
||||
f"Idle loop: post failed — status={status} err={err}",
|
||||
flush=True,
|
||||
)
|
||||
else:
|
||||
print(f"Idle loop: post ok status={status}", flush=True)
|
||||
except Exception as e: # pragma: no cover
|
||||
print(f"Idle loop: executor callback crashed — {e}", flush=True)
|
||||
|
||||
fut = loop_ref.run_in_executor(None, _post_sync)
|
||||
fut.add_done_callback(_log_result)
|
||||
|
||||
idle_loop_task = asyncio.create_task(_run_idle_loop())
|
||||
|
||||
try:
|
||||
await server.serve()
|
||||
finally:
|
||||
# Cancel initial prompt if still running
|
||||
if initial_prompt_task and not initial_prompt_task.done():
|
||||
initial_prompt_task.cancel()
|
||||
# Cancel idle loop if running
|
||||
if idle_loop_task and not idle_loop_task.done():
|
||||
idle_loop_task.cancel()
|
||||
# Gracefully stop the Temporal worker background task on shutdown
|
||||
await temporal_wrapper.stop()
|
||||
|
||||
|
||||
def main_sync(): # pragma: no cover
|
||||
"""Synchronous entry point for the molecule-runtime console script."""
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
if __name__ == "__main__": # pragma: no cover
|
||||
main_sync()
|
||||
72
molecule_runtime/molecule_ai_status.py
Normal file
72
molecule_runtime/molecule_ai_status.py
Normal file
@ -0,0 +1,72 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Update workspace task status on the canvas.
|
||||
|
||||
Usage (from any script, cron job, or shell inside the container):
|
||||
|
||||
# Set current task (shows on canvas card)
|
||||
python3 /app/molecule_ai_status.py "Running weekly SEO audit..."
|
||||
|
||||
# Clear task (removes banner from canvas)
|
||||
python3 /app/molecule_ai_status.py ""
|
||||
|
||||
# Or use the shell alias:
|
||||
molecule-monorepo-status "Analyzing competitor data..."
|
||||
molecule-monorepo-status ""
|
||||
|
||||
The status appears as an amber banner on the workspace card in the canvas,
|
||||
visible to the project owner in real-time.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import httpx
|
||||
|
||||
WORKSPACE_ID = os.environ.get("WORKSPACE_ID", "")
|
||||
PLATFORM_URL = os.environ.get("PLATFORM_URL", "http://platform:8080")
|
||||
|
||||
|
||||
def set_status(task: str):
|
||||
"""Push current_task to platform via heartbeat."""
|
||||
try:
|
||||
try:
|
||||
from platform_auth import auth_headers as _auth
|
||||
_headers = _auth()
|
||||
except Exception:
|
||||
_headers = {}
|
||||
httpx.post(
|
||||
f"{PLATFORM_URL}/registry/heartbeat",
|
||||
json={
|
||||
"workspace_id": WORKSPACE_ID,
|
||||
"current_task": task,
|
||||
"active_tasks": 1 if task else 0,
|
||||
"error_rate": 0,
|
||||
"sample_error": "",
|
||||
"uptime_seconds": 0,
|
||||
},
|
||||
headers=_headers,
|
||||
timeout=5.0,
|
||||
)
|
||||
if task:
|
||||
# Also log as activity for traceability
|
||||
httpx.post(
|
||||
f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/activity",
|
||||
json={
|
||||
"activity_type": "task_update",
|
||||
"source_id": WORKSPACE_ID,
|
||||
"summary": task,
|
||||
"status": "ok",
|
||||
},
|
||||
timeout=5.0,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"molecule-monorepo-status: failed to update: {e}", file=sys.stderr)
|
||||
|
||||
|
||||
if __name__ == "__main__": # pragma: no cover
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: molecule-monorepo-status 'task description'")
|
||||
print(" molecule-monorepo-status '' # clear")
|
||||
sys.exit(1)
|
||||
|
||||
set_status(sys.argv[1])
|
||||
105
molecule_runtime/platform_auth.py
Normal file
105
molecule_runtime/platform_auth.py
Normal file
@ -0,0 +1,105 @@
|
||||
"""Workspace auth-token store (Phase 30.1).
|
||||
|
||||
Single source of truth for this workspace's authentication token. The
|
||||
token is issued by the platform on the first successful
|
||||
``POST /registry/register`` call and travels with every subsequent
|
||||
heartbeat / update-card / (later) secrets-pull / A2A request.
|
||||
|
||||
The token is persisted to ``<configs>/.auth_token`` so it survives
|
||||
restarts — we only expect to receive it once from the platform, since
|
||||
``/registry/register`` no-ops token issuance for workspaces that already
|
||||
have one on file.
|
||||
|
||||
Storage:
|
||||
${CONFIGS_DIR}/.auth_token # 0600, one line, no trailing newline
|
||||
|
||||
Callers interact with three functions:
|
||||
:func:`get_token` — returns the cached token or None
|
||||
:func:`save_token` — persists a freshly-issued token
|
||||
:func:`auth_headers`— builds the Authorization header dict for httpx
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# In-process cache so we don't hit disk on every heartbeat. The heartbeat
|
||||
# loop fires on a short interval and reading a tiny file 10x per minute
|
||||
# is wasteful. The file is the durable copy; this var is the hot path.
|
||||
_cached_token: str | None = None
|
||||
|
||||
|
||||
def _token_file() -> Path:
|
||||
"""Path to the on-disk token file. Respects CONFIGS_DIR, falls back
|
||||
to /configs for the default container layout."""
|
||||
return Path(os.environ.get("CONFIGS_DIR", "/configs")) / ".auth_token"
|
||||
|
||||
|
||||
def get_token() -> str | None:
|
||||
"""Return the cached token, reading it from disk on first call."""
|
||||
global _cached_token
|
||||
if _cached_token is not None:
|
||||
return _cached_token
|
||||
path = _token_file()
|
||||
if not path.exists():
|
||||
return None
|
||||
try:
|
||||
tok = path.read_text().strip()
|
||||
except OSError as exc:
|
||||
logger.warning("platform_auth: failed to read %s: %s", path, exc)
|
||||
return None
|
||||
if not tok:
|
||||
return None
|
||||
_cached_token = tok
|
||||
return tok
|
||||
|
||||
|
||||
def save_token(token: str) -> None:
|
||||
"""Persist a newly-issued token. Creates the file with 0600 mode atomically.
|
||||
|
||||
Uses ``os.open(O_CREAT, 0o600)`` so the file is never world-readable,
|
||||
even transiently. The previous ``write_text()`` + ``chmod()`` approach
|
||||
had a TOCTOU window where a concurrent reader could access the token
|
||||
between the two syscalls (M4 — flagged in security audit cycle 10).
|
||||
|
||||
Idempotent — if an identical token is already on disk we skip the
|
||||
write so we don't churn the file's mtime or trigger spurious
|
||||
filesystem watchers."""
|
||||
global _cached_token
|
||||
token = token.strip()
|
||||
if not token:
|
||||
raise ValueError("platform_auth: refusing to save empty token")
|
||||
if get_token() == token:
|
||||
return
|
||||
path = _token_file()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
# O_CREAT | O_WRONLY | O_TRUNC with mode=0o600 atomically creates (or
|
||||
# truncates) the file with restricted permissions in a single syscall,
|
||||
# eliminating the TOCTOU window.
|
||||
fd = os.open(str(path), os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
|
||||
try:
|
||||
os.write(fd, token.encode())
|
||||
finally:
|
||||
os.close(fd)
|
||||
_cached_token = token
|
||||
|
||||
|
||||
def auth_headers() -> dict[str, str]:
|
||||
"""Return a header dict to merge into httpx calls. Empty if no token
|
||||
is available yet — callers send the request as-is and the platform's
|
||||
heartbeat handler grandfathers pre-token workspaces through until
|
||||
their next /registry/register issues one."""
|
||||
tok = get_token()
|
||||
if not tok:
|
||||
return {}
|
||||
return {"Authorization": f"Bearer {tok}"}
|
||||
|
||||
|
||||
def clear_cache() -> None:
|
||||
"""Reset the in-memory cache. Used by tests that write fresh token
|
||||
files between cases."""
|
||||
global _cached_token
|
||||
_cached_token = None
|
||||
154
molecule_runtime/plugins.py
Normal file
154
molecule_runtime/plugins.py
Normal file
@ -0,0 +1,154 @@
|
||||
"""Plugin system for loading per-workspace and shared plugins.
|
||||
|
||||
Plugins provide skills, rules, and prompt fragments to agent workspaces.
|
||||
Each plugin is a directory containing:
|
||||
- plugin.yaml — manifest (name, version, description, skills, rules)
|
||||
- rules/*.md — always-on guidelines injected into every prompt
|
||||
- skills/ — skill directories with SKILL.md + tools/*.py
|
||||
- *.md — prompt fragments (excluding README, CHANGELOG, etc.)
|
||||
|
||||
Loading priority:
|
||||
1. Per-workspace: /configs/plugins/<name>/ (installed via API)
|
||||
2. Shared fallback: /plugins/<name>/ (legacy bind mount)
|
||||
Deduplication by name — per-workspace wins.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import yaml
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
WORKSPACE_PLUGINS_DIR = "/configs/plugins"
|
||||
SHARED_PLUGINS_DIR = os.environ.get("PLUGINS_DIR", "/plugins")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PluginManifest:
|
||||
name: str = ""
|
||||
version: str = "0.0.0"
|
||||
description: str = ""
|
||||
author: str = ""
|
||||
tags: list[str] = field(default_factory=list)
|
||||
skills: list[str] = field(default_factory=list)
|
||||
rules: list[str] = field(default_factory=list)
|
||||
prompt_fragments: list[str] = field(default_factory=list)
|
||||
adapters: dict = field(default_factory=dict)
|
||||
runtimes: list[str] = field(default_factory=list) # declared supported runtimes
|
||||
|
||||
|
||||
@dataclass
|
||||
class Plugin:
|
||||
name: str
|
||||
path: str
|
||||
manifest: PluginManifest = field(default_factory=PluginManifest)
|
||||
rules: list[str] = field(default_factory=list) # rule content strings
|
||||
prompt_fragments: list[str] = field(default_factory=list) # extra prompt content
|
||||
skills_dir: str = "" # path to skills/ inside plugin
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadedPlugins:
|
||||
rules: list[str] = field(default_factory=list)
|
||||
prompt_fragments: list[str] = field(default_factory=list)
|
||||
skill_dirs: list[str] = field(default_factory=list) # dirs to scan for extra skills
|
||||
plugin_names: list[str] = field(default_factory=list)
|
||||
plugins: list[Plugin] = field(default_factory=list)
|
||||
|
||||
|
||||
def load_plugin_manifest(plugin_path: str) -> PluginManifest:
|
||||
"""Parse plugin.yaml from a plugin directory. Returns empty manifest if not found."""
|
||||
manifest_file = os.path.join(plugin_path, "plugin.yaml")
|
||||
if not os.path.isfile(manifest_file):
|
||||
return PluginManifest(name=os.path.basename(plugin_path))
|
||||
try:
|
||||
with open(manifest_file) as f:
|
||||
raw = yaml.safe_load(f) or {}
|
||||
return PluginManifest(
|
||||
name=raw.get("name", os.path.basename(plugin_path)),
|
||||
version=raw.get("version", "0.0.0"),
|
||||
description=raw.get("description", ""),
|
||||
author=raw.get("author", ""),
|
||||
tags=raw.get("tags", []),
|
||||
skills=raw.get("skills", []),
|
||||
rules=raw.get("rules", []),
|
||||
prompt_fragments=raw.get("prompt_fragments", []),
|
||||
adapters=raw.get("adapters", {}),
|
||||
runtimes=raw.get("runtimes", []),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to parse plugin manifest %s: %s", manifest_file, e)
|
||||
return PluginManifest(name=os.path.basename(plugin_path))
|
||||
|
||||
|
||||
def _load_single_plugin(plugin_path: str) -> Plugin:
|
||||
"""Load a single plugin from a directory."""
|
||||
name = os.path.basename(plugin_path)
|
||||
manifest = load_plugin_manifest(plugin_path)
|
||||
plugin = Plugin(name=name, path=plugin_path, manifest=manifest)
|
||||
|
||||
# Load rules
|
||||
rules_dir = os.path.join(plugin_path, "rules")
|
||||
if os.path.isdir(rules_dir):
|
||||
for rule_file in sorted(os.listdir(rules_dir)):
|
||||
if rule_file.endswith(".md"):
|
||||
content = Path(os.path.join(rules_dir, rule_file)).read_text().strip()
|
||||
if content:
|
||||
plugin.rules.append(content)
|
||||
logger.info("Plugin %s: loaded rule %s", name, rule_file)
|
||||
|
||||
# Load prompt fragments (any .md in root of plugin)
|
||||
skip = {"readme.md", "changelog.md", "license.md", "contributing.md", "plugin.yaml"}
|
||||
for f in sorted(os.listdir(plugin_path)):
|
||||
if f.endswith(".md") and f.lower() not in skip and os.path.isfile(os.path.join(plugin_path, f)):
|
||||
content = Path(os.path.join(plugin_path, f)).read_text().strip()
|
||||
if content:
|
||||
plugin.prompt_fragments.append(content)
|
||||
logger.info("Plugin %s: loaded prompt fragment %s", name, f)
|
||||
|
||||
# Register skills directory
|
||||
skills_dir = os.path.join(plugin_path, "skills")
|
||||
if os.path.isdir(skills_dir):
|
||||
plugin.skills_dir = skills_dir
|
||||
skill_count = len([d for d in os.listdir(skills_dir) if os.path.isdir(os.path.join(skills_dir, d))])
|
||||
logger.info("Plugin %s: found %d skills", name, skill_count)
|
||||
|
||||
return plugin
|
||||
|
||||
|
||||
def load_plugins(
|
||||
workspace_plugins_dir: str | None = None,
|
||||
shared_plugins_dir: str | None = None,
|
||||
) -> LoadedPlugins:
|
||||
"""Scan per-workspace plugins first, then shared plugins. Deduplicate by name."""
|
||||
ws_dir = workspace_plugins_dir or WORKSPACE_PLUGINS_DIR
|
||||
shared_dir = shared_plugins_dir or SHARED_PLUGINS_DIR
|
||||
result = LoadedPlugins()
|
||||
seen_names: set[str] = set()
|
||||
|
||||
# Scan both dirs: per-workspace first (higher priority)
|
||||
for base_dir in [ws_dir, shared_dir]:
|
||||
if not os.path.isdir(base_dir):
|
||||
continue
|
||||
for entry in sorted(os.listdir(base_dir)):
|
||||
plugin_path = os.path.join(base_dir, entry)
|
||||
if not os.path.isdir(plugin_path) or entry in seen_names:
|
||||
continue
|
||||
|
||||
plugin = _load_single_plugin(plugin_path)
|
||||
seen_names.add(entry)
|
||||
|
||||
result.rules.extend(plugin.rules)
|
||||
result.prompt_fragments.extend(plugin.prompt_fragments)
|
||||
if plugin.skills_dir:
|
||||
result.skill_dirs.append(plugin.skills_dir)
|
||||
result.plugin_names.append(entry)
|
||||
result.plugins.append(plugin)
|
||||
|
||||
if result.plugin_names:
|
||||
logger.info("Loaded %d plugins: %s", len(result.plugin_names), ", ".join(result.plugin_names))
|
||||
|
||||
return result
|
||||
135
molecule_runtime/plugins_registry/__init__.py
Normal file
135
molecule_runtime/plugins_registry/__init__.py
Normal file
@ -0,0 +1,135 @@
|
||||
"""Per-runtime plugin adaptor registry with hybrid resolution.
|
||||
|
||||
Resolution order for ``(plugin_name, runtime)``:
|
||||
|
||||
1. Platform registry → ``workspace-template/plugins_registry/<plugin>/<runtime>.py``
|
||||
2. Plugin-shipped → ``<plugin_root>/adapters/<runtime>.py``
|
||||
3. Raw filesystem → :class:`RawDropAdaptor` (warns, drops files only)
|
||||
|
||||
Path #1 wins so the platform can override or hot-fix a third-party adaptor
|
||||
without forking the upstream plugin repo. Path #2 is the SDK contract: a
|
||||
single GitHub repo ships its own adaptors and is installable on day one.
|
||||
Path #3 is the escape hatch — power users can still bring unsupported
|
||||
plugins onto a workspace, they just don't get tools wired up.
|
||||
|
||||
A registered adaptor module must expose either:
|
||||
- ``Adaptor`` class implementing :class:`PluginAdaptor`, OR
|
||||
- ``def get_adaptor(plugin_name, runtime) -> PluginAdaptor``
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from .protocol import InstallContext, InstallResult, PluginAdaptor
|
||||
from .raw_drop import RawDropAdaptor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Where the platform-curated registry lives. Resolved relative to this file
|
||||
# so it works regardless of CWD or how workspace-template is installed.
|
||||
_REGISTRY_ROOT = Path(__file__).parent
|
||||
|
||||
__all__ = [
|
||||
"InstallContext",
|
||||
"InstallResult",
|
||||
"PluginAdaptor",
|
||||
"RawDropAdaptor",
|
||||
"resolve",
|
||||
"AdaptorSource",
|
||||
]
|
||||
|
||||
|
||||
class AdaptorSource:
|
||||
REGISTRY = "registry"
|
||||
PLUGIN = "plugin"
|
||||
RAW_DROP = "raw_drop"
|
||||
|
||||
|
||||
def _load_module_from_path(module_name: str, path: Path):
|
||||
"""Import a Python file by absolute path. Returns the module or None on failure."""
|
||||
spec = importlib.util.spec_from_file_location(module_name, path)
|
||||
if spec is None or spec.loader is None:
|
||||
return None
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
try:
|
||||
spec.loader.exec_module(module)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to load adaptor module %s: %s", path, exc)
|
||||
return None
|
||||
return module
|
||||
|
||||
|
||||
def _instantiate(module, plugin_name: str, runtime: str) -> Optional[PluginAdaptor]:
|
||||
"""Build a PluginAdaptor from an adaptor module.
|
||||
|
||||
Two conventions are supported so plugin authors can pick whichever fits:
|
||||
a class named ``Adaptor`` (zero-arg constructor or ``(plugin_name, runtime)``),
|
||||
or a factory function ``get_adaptor(plugin_name, runtime)``.
|
||||
"""
|
||||
factory = getattr(module, "get_adaptor", None)
|
||||
if callable(factory):
|
||||
try:
|
||||
return factory(plugin_name, runtime)
|
||||
except Exception as exc:
|
||||
logger.warning("get_adaptor() failed for %s/%s: %s", plugin_name, runtime, exc)
|
||||
return None
|
||||
|
||||
cls = getattr(module, "Adaptor", None)
|
||||
if cls is None:
|
||||
return None
|
||||
try:
|
||||
try:
|
||||
return cls(plugin_name, runtime)
|
||||
except TypeError:
|
||||
return cls()
|
||||
except Exception as exc:
|
||||
logger.warning("Adaptor() construction failed for %s/%s: %s", plugin_name, runtime, exc)
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_registry(plugin_name: str, runtime: str) -> Optional[PluginAdaptor]:
|
||||
path = _REGISTRY_ROOT / plugin_name / f"{runtime}.py"
|
||||
if not path.is_file():
|
||||
return None
|
||||
module = _load_module_from_path(f"plugins_registry.{plugin_name}.{runtime}", path)
|
||||
if module is None:
|
||||
return None
|
||||
return _instantiate(module, plugin_name, runtime)
|
||||
|
||||
|
||||
def _resolve_plugin_shipped(plugin_root: Path, plugin_name: str, runtime: str) -> Optional[PluginAdaptor]:
|
||||
path = plugin_root / "adapters" / f"{runtime}.py"
|
||||
if not path.is_file():
|
||||
return None
|
||||
module = _load_module_from_path(f"_plugin_adaptor.{plugin_name}.{runtime}", path)
|
||||
if module is None:
|
||||
return None
|
||||
return _instantiate(module, plugin_name, runtime)
|
||||
|
||||
|
||||
def resolve(
|
||||
plugin_name: str,
|
||||
runtime: str,
|
||||
plugin_root: Path,
|
||||
) -> tuple[PluginAdaptor, str]:
|
||||
"""Resolve the adaptor for ``(plugin_name, runtime)``.
|
||||
|
||||
Returns ``(adaptor, source)`` where ``source`` is one of
|
||||
:class:`AdaptorSource` (``"registry"``, ``"plugin"``, ``"raw_drop"``).
|
||||
Always returns an adaptor — the raw-drop fallback ensures plugin installs
|
||||
never hard-fail on missing adaptors; instead the warning is surfaced via
|
||||
:class:`InstallResult.warnings`.
|
||||
"""
|
||||
adaptor = _resolve_registry(plugin_name, runtime)
|
||||
if adaptor is not None:
|
||||
return adaptor, AdaptorSource.REGISTRY
|
||||
|
||||
adaptor = _resolve_plugin_shipped(plugin_root, plugin_name, runtime)
|
||||
if adaptor is not None:
|
||||
return adaptor, AdaptorSource.PLUGIN
|
||||
|
||||
return RawDropAdaptor(plugin_name, runtime), AdaptorSource.RAW_DROP
|
||||
327
molecule_runtime/plugins_registry/builtins.py
Normal file
327
molecule_runtime/plugins_registry/builtins.py
Normal file
@ -0,0 +1,327 @@
|
||||
"""Built-in plugin adaptors — one per agent shape.
|
||||
|
||||
The adapter layer is our extensibility surface. Each agent "shape" (form
|
||||
of installable capability) gets its own named sub-type adapter. A plugin
|
||||
picks which sub-type to use by importing it as ``Adaptor`` in its
|
||||
per-runtime file:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# plugins/<name>/adapters/claude_code.py
|
||||
from plugins_registry.builtins import AgentskillsAdaptor as Adaptor
|
||||
|
||||
Shape taxonomy (one class per shape; add more as the ecosystem evolves):
|
||||
|
||||
* :class:`AgentskillsAdaptor` — skills in the `agentskills.io
|
||||
<https://agentskills.io>`_ format (``SKILL.md`` + ``scripts/`` +
|
||||
``references/`` + ``assets/``), plus Molecule AI's optional ``rules/`` and
|
||||
root-level prompt fragments at the plugin level. Works on every runtime
|
||||
we support (the spec's filesystem layout makes activation trivial on
|
||||
Claude Code, our adapter code does the equivalent on DeepAgents /
|
||||
LangGraph / etc.). **This is the default and covers the common case.**
|
||||
|
||||
Planned as the ecosystem matures (none are implemented yet — rule of
|
||||
three: promote a class here only after 3+ plugins ship the same custom
|
||||
shape via their own ``adapters/<runtime>.py``):
|
||||
|
||||
* ``MCPServerAdaptor`` — install a plugin as an MCP server *(TODO)*
|
||||
* ``DeepAgentsSubagentAdaptor`` — register a DeepAgents sub-agent
|
||||
(runtime-locked to deepagents) *(TODO)*
|
||||
* ``LangGraphSubgraphAdaptor`` — install a LangGraph sub-graph *(TODO)*
|
||||
* ``RAGPipelineAdaptor`` — wire a retriever + index *(TODO)*
|
||||
* ``SwarmAdaptor`` — bind an OpenAI-swarm / AutoGen-swarm *(TODO)*
|
||||
* ``WebhookAdaptor`` — register an event handler *(TODO)*
|
||||
|
||||
Plugins whose shape doesn't match any built-in ship their own adapter
|
||||
class in ``plugins/<name>/adapters/<runtime>.py`` — full Python, no
|
||||
constraint. When 3+ plugins ship the same custom pattern, we promote
|
||||
the class into this module.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
from .protocol import SKILLS_SUBDIR, InstallContext, InstallResult
|
||||
|
||||
# Files at the plugin root that are never treated as prompt fragments,
|
||||
# even if they're markdown. Module-level so tests and other adapters can
|
||||
# import the set rather than re-declaring it.
|
||||
SKIP_ROOT_MD = frozenset({"readme.md", "changelog.md", "license.md", "contributing.md"})
|
||||
|
||||
|
||||
def _read_md_files(directory: Path) -> list[tuple[str, str]]:
|
||||
"""Return [(filename, content)] for all *.md files in directory, sorted."""
|
||||
if not directory.is_dir():
|
||||
return []
|
||||
out: list[tuple[str, str]] = []
|
||||
for p in sorted(directory.iterdir()):
|
||||
if p.is_file() and p.suffix == ".md":
|
||||
out.append((p.name, p.read_text().strip()))
|
||||
return out
|
||||
|
||||
|
||||
class AgentskillsAdaptor:
|
||||
"""Sub-type adaptor for `agentskills.io <https://agentskills.io>`_-format skills.
|
||||
|
||||
This is the default adapter for the "skills + rules" shape — the most
|
||||
common pattern. A plugin using this adapter ships:
|
||||
|
||||
* ``skills/<name>/SKILL.md`` (+ optional ``scripts/``, ``references/``,
|
||||
``assets/``) — each skill is a spec-compliant agentskills unit,
|
||||
portable to Claude Code, Cursor, Codex, and ~35 other skill-compatible
|
||||
tools without modification.
|
||||
* ``rules/*.md`` (optional, Molecule AI extension) — always-on prose that
|
||||
gets appended to the runtime's memory file (CLAUDE.md).
|
||||
* Root-level ``*.md`` (optional) — prompt fragments, also appended to
|
||||
memory.
|
||||
|
||||
On ``install()``:
|
||||
1. Rules → append to ``/configs/<memory_filename>``, wrapped in a
|
||||
``# Plugin: <name>`` marker for idempotent re-install.
|
||||
2. Prompt fragments (``*.md`` at plugin root, excl. README/CHANGELOG/etc.)
|
||||
→ same treatment.
|
||||
3. Skills (``skills/<skill_name>/``) → copied to
|
||||
``/configs/skills/<skill_name>/``. Runtimes with native agentskills
|
||||
activation (Claude Code) pick them up automatically; other runtimes'
|
||||
loaders scan the same path.
|
||||
|
||||
Uninstall reverses the file copies and strips the rule/fragment block by
|
||||
marker (best-effort — if the user edited CLAUDE.md manually, only the
|
||||
marker line itself is removed).
|
||||
|
||||
For shapes other than agentskills (MCP server, DeepAgents sub-agent,
|
||||
LangGraph sub-graph, RAG pipeline, swarm, webhook handler, etc.), see
|
||||
the module docstring for the planned sibling adapters, or ship a custom
|
||||
adapter class in the plugin's ``adapters/<runtime>.py``.
|
||||
"""
|
||||
|
||||
def __init__(self, plugin_name: str, runtime: str) -> None:
|
||||
self.plugin_name = plugin_name
|
||||
self.runtime = runtime
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# install
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def install(self, ctx: InstallContext) -> InstallResult:
|
||||
result = InstallResult(
|
||||
plugin_name=self.plugin_name,
|
||||
runtime=self.runtime,
|
||||
source="plugin", # overridden by registry caller if source==registry
|
||||
)
|
||||
|
||||
# 1. Rules — append to memory file.
|
||||
rules = _read_md_files(ctx.plugin_root / "rules")
|
||||
# 2. Prompt fragments — any *.md at plugin root except skip list.
|
||||
root_fragments: list[tuple[str, str]] = []
|
||||
if ctx.plugin_root.is_dir():
|
||||
for p in sorted(ctx.plugin_root.iterdir()):
|
||||
if p.is_file() and p.suffix == ".md" and p.name.lower() not in SKIP_ROOT_MD:
|
||||
content = p.read_text().strip()
|
||||
if content:
|
||||
root_fragments.append((p.name, content))
|
||||
|
||||
memory_blocks: list[str] = []
|
||||
for filename, content in rules:
|
||||
memory_blocks.append(f"# Plugin: {self.plugin_name} / rule: {filename}\n\n{content}")
|
||||
for filename, content in root_fragments:
|
||||
memory_blocks.append(f"# Plugin: {self.plugin_name} / fragment: {filename}\n\n{content}")
|
||||
|
||||
if memory_blocks:
|
||||
joined = "\n\n".join(memory_blocks)
|
||||
ctx.append_to_memory(ctx.memory_filename, joined)
|
||||
ctx.logger.info(
|
||||
"%s: injected %d rule+fragment block(s) into %s",
|
||||
self.plugin_name, len(memory_blocks), ctx.memory_filename,
|
||||
)
|
||||
|
||||
# 3. Skills — copy each skill dir to /configs/skills/.
|
||||
src_skills_dir = ctx.plugin_root / "skills"
|
||||
if src_skills_dir.is_dir():
|
||||
dst_skills_root = ctx.configs_dir / SKILLS_SUBDIR
|
||||
dst_skills_root.mkdir(parents=True, exist_ok=True)
|
||||
copied = 0
|
||||
for entry in sorted(src_skills_dir.iterdir()):
|
||||
if not entry.is_dir():
|
||||
continue
|
||||
dst = dst_skills_root / entry.name
|
||||
if dst.exists():
|
||||
ctx.logger.debug("%s: skill %s already present, skipping", self.plugin_name, entry.name)
|
||||
continue
|
||||
shutil.copytree(entry, dst)
|
||||
copied += 1
|
||||
for p in dst.rglob("*"):
|
||||
if p.is_file():
|
||||
result.files_written.append(str(p.relative_to(ctx.configs_dir)))
|
||||
if copied:
|
||||
ctx.logger.info("%s: copied %d skill dir(s) to %s", self.plugin_name, copied, dst_skills_root)
|
||||
|
||||
# 4. Setup script — run setup.sh if present (for npm/pip dependencies).
|
||||
# Mirrors sdk/python/molecule_plugin/builtins.py — must stay in sync
|
||||
# (drift guard: tests/test_plugins_builtins_drift.py).
|
||||
setup_script = ctx.plugin_root / "setup.sh"
|
||||
if setup_script.is_file():
|
||||
ctx.logger.info("%s: running setup.sh", self.plugin_name)
|
||||
try:
|
||||
proc = subprocess.run(
|
||||
["bash", str(setup_script)],
|
||||
capture_output=True, text=True, timeout=120,
|
||||
cwd=str(ctx.plugin_root),
|
||||
env={**os.environ, "CONFIGS_DIR": str(ctx.configs_dir)},
|
||||
)
|
||||
if proc.returncode == 0:
|
||||
ctx.logger.info("%s: setup.sh completed successfully", self.plugin_name)
|
||||
else:
|
||||
result.warnings.append(f"setup.sh exited {proc.returncode}: {proc.stderr[:200]}")
|
||||
ctx.logger.warning("%s: setup.sh failed: %s", self.plugin_name, proc.stderr[:200])
|
||||
except subprocess.TimeoutExpired:
|
||||
result.warnings.append("setup.sh timed out (120s)")
|
||||
ctx.logger.warning("%s: setup.sh timed out", self.plugin_name)
|
||||
|
||||
# 5. Hooks — copy hooks/* into <configs>/.claude/hooks/ (Claude Code-
|
||||
# style harness hooks). No-op when the plugin doesn't ship any.
|
||||
# 6. Commands — copy commands/*.md into <configs>/.claude/commands/.
|
||||
# 7. settings-fragment.json — merge into <configs>/.claude/settings.json,
|
||||
# rewriting ${CLAUDE_DIR} to the absolute install path. Existing
|
||||
# user hooks are preserved (deep-merge by event).
|
||||
_install_claude_layer(ctx, result, self.plugin_name)
|
||||
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# uninstall
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def uninstall(self, ctx: InstallContext) -> None:
|
||||
# Remove copied skill dirs.
|
||||
src_skills_dir = ctx.plugin_root / "skills"
|
||||
if src_skills_dir.is_dir():
|
||||
for entry in src_skills_dir.iterdir():
|
||||
dst = ctx.configs_dir / SKILLS_SUBDIR / entry.name
|
||||
if dst.exists() and dst.is_dir():
|
||||
shutil.rmtree(dst)
|
||||
ctx.logger.info("%s: removed %s", self.plugin_name, dst)
|
||||
|
||||
# Best-effort strip of our markers from CLAUDE.md. Users can always
|
||||
# edit manually; we only guarantee the injected block's first line
|
||||
# is removed so re-install re-adds cleanly.
|
||||
memory_path = ctx.configs_dir / ctx.memory_filename
|
||||
if not memory_path.exists():
|
||||
return
|
||||
text = memory_path.read_text()
|
||||
prefix = f"# Plugin: {self.plugin_name} / "
|
||||
lines = text.splitlines(keepends=True)
|
||||
kept = [line for line in lines if not line.startswith(prefix)]
|
||||
if len(kept) != len(lines):
|
||||
memory_path.write_text("".join(kept))
|
||||
ctx.logger.info("%s: stripped markers from %s", self.plugin_name, ctx.memory_filename)
|
||||
|
||||
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Claude Code layer — hooks, slash commands, settings.json fragments.
|
||||
# Promoted from the molecule-guardrails plugin so any plugin can ship
|
||||
# these by dropping the right files; no custom adapter needed.
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
def _install_claude_layer(ctx: InstallContext, result: InstallResult, plugin_name: str) -> None:
|
||||
claude_dir = ctx.configs_dir / ".claude"
|
||||
claude_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
_copy_dir_files(
|
||||
ctx.plugin_root / "hooks",
|
||||
claude_dir / "hooks",
|
||||
result,
|
||||
executable_suffix=".sh",
|
||||
)
|
||||
_copy_dir_files(
|
||||
ctx.plugin_root / "commands",
|
||||
claude_dir / "commands",
|
||||
result,
|
||||
only_suffix=".md",
|
||||
)
|
||||
_merge_settings_fragment(ctx, claude_dir, result, plugin_name)
|
||||
|
||||
|
||||
def _copy_dir_files(
|
||||
src: Path,
|
||||
dst: Path,
|
||||
result: InstallResult,
|
||||
executable_suffix: str | None = None,
|
||||
only_suffix: str | None = None,
|
||||
) -> None:
|
||||
if not src.is_dir():
|
||||
return
|
||||
dst.mkdir(parents=True, exist_ok=True)
|
||||
for f in src.iterdir():
|
||||
if not f.is_file():
|
||||
continue
|
||||
if only_suffix and f.suffix != only_suffix:
|
||||
# When copying hooks, allow .py companion files alongside .sh
|
||||
if not (executable_suffix and f.suffix == ".py"):
|
||||
continue
|
||||
target = dst / f.name
|
||||
shutil.copy2(f, target)
|
||||
if executable_suffix and f.suffix == executable_suffix:
|
||||
target.chmod(0o755)
|
||||
result.files_written.append(str(target.relative_to(target.parents[2])))
|
||||
|
||||
|
||||
def _merge_settings_fragment(
|
||||
ctx: InstallContext,
|
||||
claude_dir: Path,
|
||||
result: InstallResult,
|
||||
plugin_name: str,
|
||||
) -> None:
|
||||
fragment_path = ctx.plugin_root / "settings-fragment.json"
|
||||
if not fragment_path.is_file():
|
||||
return
|
||||
try:
|
||||
fragment = json.loads(fragment_path.read_text())
|
||||
except Exception as e:
|
||||
result.warnings.append(f"settings-fragment.json invalid: {e}")
|
||||
return
|
||||
|
||||
settings_path = claude_dir / "settings.json"
|
||||
if settings_path.is_file():
|
||||
try:
|
||||
existing = json.loads(settings_path.read_text())
|
||||
except Exception:
|
||||
existing = {}
|
||||
else:
|
||||
existing = {}
|
||||
|
||||
rewritten = _rewrite_hook_paths(fragment, claude_dir)
|
||||
merged = _deep_merge_hooks(existing, rewritten)
|
||||
settings_path.write_text(json.dumps(merged, indent=2) + "\n")
|
||||
result.files_written.append(str(settings_path.relative_to(ctx.configs_dir)))
|
||||
ctx.logger.info("%s: merged hook config into %s", plugin_name, settings_path)
|
||||
|
||||
|
||||
def _rewrite_hook_paths(fragment: dict, claude_dir: Path) -> dict:
|
||||
out = json.loads(json.dumps(fragment)) # deep copy via roundtrip
|
||||
for handlers in out.get("hooks", {}).values():
|
||||
for handler in handlers:
|
||||
for h in handler.get("hooks", []):
|
||||
cmd = h.get("command", "")
|
||||
h["command"] = cmd.replace("${CLAUDE_DIR}", str(claude_dir))
|
||||
return out
|
||||
|
||||
|
||||
def _deep_merge_hooks(existing: dict, fragment: dict) -> dict:
|
||||
out = dict(existing)
|
||||
out.setdefault("hooks", {})
|
||||
for event, handlers in fragment.get("hooks", {}).items():
|
||||
out["hooks"].setdefault(event, [])
|
||||
out["hooks"][event].extend(handlers)
|
||||
for key, val in fragment.items():
|
||||
if key == "hooks":
|
||||
continue
|
||||
out.setdefault(key, val)
|
||||
return out
|
||||
104
molecule_runtime/plugins_registry/protocol.py
Normal file
104
molecule_runtime/plugins_registry/protocol.py
Normal file
@ -0,0 +1,104 @@
|
||||
"""Protocol + context types for per-runtime plugin adaptors.
|
||||
|
||||
Each plugin ships (or has registered for it) a per-runtime adaptor implementing
|
||||
``PluginAdaptor``. The platform resolves the adaptor for ``(plugin_name, runtime)``
|
||||
via :func:`plugins_registry.resolve` and calls ``install(ctx)`` to wire the
|
||||
plugin into a workspace.
|
||||
|
||||
The :class:`InstallContext` deliberately gives adaptors ONLY the hooks they
|
||||
need (``register_tool``, ``register_subagent``, ``append_to_memory``) — it
|
||||
does not leak runtime internals. This keeps adaptors thin and lets the
|
||||
workspace runtime adapter (claude_code, deepagents, …) own its own state.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Protocol, runtime_checkable
|
||||
|
||||
|
||||
# Default filename for the runtime's long-lived memory file. Claude Code
|
||||
# and DeepAgents both read CLAUDE.md natively; other runtimes override via
|
||||
# BaseAdapter.memory_filename() and that value flows through
|
||||
# InstallContext.memory_filename so adaptors don't hardcode the name.
|
||||
DEFAULT_MEMORY_FILENAME = "CLAUDE.md"
|
||||
|
||||
# Subdirectory under /configs where skills get installed.
|
||||
SKILLS_SUBDIR = "skills"
|
||||
|
||||
|
||||
@dataclass
|
||||
class InstallContext:
|
||||
"""Hooks + state passed to every PluginAdaptor.install() call.
|
||||
|
||||
Adaptors should treat unknown verbs as no-ops on runtimes that don't
|
||||
support them (e.g. ``register_subagent`` is a no-op on Claude Code).
|
||||
"""
|
||||
|
||||
configs_dir: Path
|
||||
"""Workspace's /configs directory (where CLAUDE.md, plugins/, skills/ live)."""
|
||||
|
||||
workspace_id: str
|
||||
"""Workspace UUID — useful for per-workspace state or logging."""
|
||||
|
||||
runtime: str
|
||||
"""Runtime identifier (``claude_code``, ``deepagents``, …)."""
|
||||
|
||||
plugin_root: Path
|
||||
"""Path to the plugin's directory (where plugin.yaml + content lives)."""
|
||||
|
||||
memory_filename: str = DEFAULT_MEMORY_FILENAME
|
||||
"""Runtime's long-lived memory file (populated from
|
||||
:meth:`BaseAdapter.memory_filename`). Adaptors pass this to
|
||||
:attr:`append_to_memory` instead of hardcoding a filename so runtimes
|
||||
with non-standard memory files (e.g. ``AGENTS.md``) work unchanged."""
|
||||
|
||||
register_tool: Callable[[str, Callable[..., Any]], None] = field(
|
||||
default=lambda name, fn: None
|
||||
)
|
||||
"""Register a callable as a runtime tool. No-op on runtimes without a
|
||||
dynamic tool registry — those runtimes pick tools up at startup via
|
||||
filesystem scan instead."""
|
||||
|
||||
register_subagent: Callable[[str, dict[str, Any]], None] = field(
|
||||
default=lambda name, spec: None
|
||||
)
|
||||
"""Register a sub-agent specification (DeepAgents-only). No-op elsewhere."""
|
||||
|
||||
append_to_memory: Callable[[str, str], None] = field(
|
||||
default=lambda filename, content: None
|
||||
)
|
||||
"""Append text to a runtime memory file (e.g. CLAUDE.md). The default
|
||||
no-op lets adaptors run in test harnesses that don't have a real
|
||||
workspace filesystem."""
|
||||
|
||||
logger: logging.Logger = field(default_factory=lambda: logging.getLogger(__name__))
|
||||
|
||||
|
||||
@dataclass
|
||||
class InstallResult:
|
||||
"""Outcome of a PluginAdaptor.install() call."""
|
||||
|
||||
plugin_name: str
|
||||
runtime: str
|
||||
source: str # "registry" | "plugin" | "raw_drop"
|
||||
files_written: list[str] = field(default_factory=list)
|
||||
tools_registered: list[str] = field(default_factory=list)
|
||||
subagents_registered: list[str] = field(default_factory=list)
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class PluginAdaptor(Protocol):
|
||||
"""Contract every per-runtime adaptor must implement."""
|
||||
|
||||
plugin_name: str
|
||||
runtime: str
|
||||
|
||||
async def install(self, ctx: InstallContext) -> InstallResult:
|
||||
...
|
||||
|
||||
async def uninstall(self, ctx: InstallContext) -> None:
|
||||
...
|
||||
71
molecule_runtime/plugins_registry/raw_drop.py
Normal file
71
molecule_runtime/plugins_registry/raw_drop.py
Normal file
@ -0,0 +1,71 @@
|
||||
"""Fallback adaptor used when no per-runtime adaptor is found.
|
||||
|
||||
Behaviour: copy the plugin's content into ``/configs/plugins/<name>/`` so a
|
||||
user can still inspect or hand-wire it, then surface a warning that no tools
|
||||
or sub-agents were registered.
|
||||
|
||||
This preserves the "power users can drop raw files" escape hatch without
|
||||
silently breaking — the warning is propagated up via :class:`InstallResult`
|
||||
so the API can surface it to the user.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
|
||||
from .protocol import InstallContext, InstallResult, PluginAdaptor
|
||||
|
||||
|
||||
class RawDropAdaptor:
|
||||
"""Filesystem-only fallback. Implements :class:`PluginAdaptor`."""
|
||||
|
||||
def __init__(self, plugin_name: str, runtime: str) -> None:
|
||||
self.plugin_name = plugin_name
|
||||
self.runtime = runtime
|
||||
|
||||
async def install(self, ctx: InstallContext) -> InstallResult:
|
||||
dst = ctx.configs_dir / "plugins" / self.plugin_name
|
||||
files_written: list[str] = []
|
||||
|
||||
if ctx.plugin_root.exists() and ctx.plugin_root.is_dir():
|
||||
dst.parent.mkdir(parents=True, exist_ok=True)
|
||||
if dst.exists():
|
||||
# Idempotent — leave existing copy alone.
|
||||
ctx.logger.info(
|
||||
"raw_drop: %s already present at %s, skipping copy",
|
||||
self.plugin_name, dst,
|
||||
)
|
||||
else:
|
||||
shutil.copytree(ctx.plugin_root, dst)
|
||||
for p in dst.rglob("*"):
|
||||
if p.is_file():
|
||||
files_written.append(str(p.relative_to(ctx.configs_dir)))
|
||||
ctx.logger.info(
|
||||
"raw_drop: copied %s → %s (%d files)",
|
||||
self.plugin_name, dst, len(files_written),
|
||||
)
|
||||
|
||||
warning = (
|
||||
f"plugin '{self.plugin_name}' has no adaptor for runtime "
|
||||
f"'{self.runtime}' — files dropped at /configs/plugins/{self.plugin_name} "
|
||||
f"but no tools/sub-agents were wired in"
|
||||
)
|
||||
ctx.logger.warning(warning)
|
||||
|
||||
return InstallResult(
|
||||
plugin_name=self.plugin_name,
|
||||
runtime=self.runtime,
|
||||
source="raw_drop",
|
||||
files_written=files_written,
|
||||
warnings=[warning],
|
||||
)
|
||||
|
||||
async def uninstall(self, ctx: InstallContext) -> None:
|
||||
dst = ctx.configs_dir / "plugins" / self.plugin_name
|
||||
if dst.exists():
|
||||
shutil.rmtree(dst)
|
||||
ctx.logger.info("raw_drop: removed %s", dst)
|
||||
|
||||
|
||||
# Static check: RawDropAdaptor satisfies PluginAdaptor.
|
||||
_: PluginAdaptor = RawDropAdaptor("_", "_")
|
||||
11
molecule_runtime/policies/__init__.py
Normal file
11
molecule_runtime/policies/__init__.py
Normal file
@ -0,0 +1,11 @@
|
||||
"""Policy helpers for routing and execution decisions."""
|
||||
|
||||
from .namespaces import resolve_awareness_namespace, workspace_awareness_namespace
|
||||
from .routing import build_team_routing_payload, summarize_children
|
||||
|
||||
__all__ = [
|
||||
"build_team_routing_payload",
|
||||
"resolve_awareness_namespace",
|
||||
"summarize_children",
|
||||
"workspace_awareness_namespace",
|
||||
]
|
||||
18
molecule_runtime/policies/namespaces.py
Normal file
18
molecule_runtime/policies/namespaces.py
Normal file
@ -0,0 +1,18 @@
|
||||
"""Canonical namespace helpers for workspace-scoped resources."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def workspace_awareness_namespace(workspace_id: str) -> str:
|
||||
"""Return the default awareness namespace for a workspace."""
|
||||
workspace_id = workspace_id.strip()
|
||||
return f"workspace:{workspace_id}" if workspace_id else "workspace:unknown"
|
||||
|
||||
|
||||
def resolve_awareness_namespace(
|
||||
workspace_id: str,
|
||||
configured_namespace: str | None = None,
|
||||
) -> str:
|
||||
"""Return the configured namespace, or the workspace default when unset."""
|
||||
namespace = (configured_namespace or "").strip()
|
||||
return namespace or workspace_awareness_namespace(workspace_id)
|
||||
98
molecule_runtime/policies/routing.py
Normal file
98
molecule_runtime/policies/routing.py
Normal file
@ -0,0 +1,98 @@
|
||||
"""Explicit routing policy for coordinator workspaces."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _load_agent_card(agent_card: Any) -> dict[str, Any]:
|
||||
if isinstance(agent_card, str):
|
||||
try:
|
||||
loaded = json.loads(agent_card)
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
return loaded if isinstance(loaded, dict) else {}
|
||||
return agent_card if isinstance(agent_card, dict) else {}
|
||||
|
||||
|
||||
def summarize_children(children: list[dict]) -> list[dict[str, Any]]:
|
||||
"""Return the minimal child summary needed for routing and prompts."""
|
||||
members: list[dict[str, Any]] = []
|
||||
for child in children:
|
||||
card = _load_agent_card(child.get("agent_card", {}))
|
||||
members.append(
|
||||
{
|
||||
"id": child.get("id"),
|
||||
"name": child.get("name"),
|
||||
"status": child.get("status"),
|
||||
"skills": [
|
||||
s.get("name", s.get("id", ""))
|
||||
for s in card.get("skills", [])
|
||||
if isinstance(s, dict)
|
||||
],
|
||||
}
|
||||
)
|
||||
return members
|
||||
|
||||
|
||||
def build_team_routing_payload(
|
||||
children: list[dict],
|
||||
task: str,
|
||||
preferred_member_id: str = "",
|
||||
) -> dict[str, Any]:
|
||||
"""Return the deterministic routing payload for coordinator tasks."""
|
||||
if preferred_member_id:
|
||||
return {
|
||||
"success": True,
|
||||
"action": "delegate_to_preferred_member",
|
||||
"preferred_member_id": preferred_member_id,
|
||||
"task": task,
|
||||
}
|
||||
|
||||
members = summarize_children(children)
|
||||
if not members:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "No team members available. Handle this task yourself.",
|
||||
"task": task,
|
||||
"members": [],
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"action": "choose_member",
|
||||
"message": (
|
||||
f"You have {len(members)} team members. "
|
||||
"Choose the best one for this task and call delegate_to_workspace with their ID."
|
||||
),
|
||||
"task": task,
|
||||
"members": members,
|
||||
}
|
||||
|
||||
|
||||
def decide_team_route(
|
||||
children: list[dict],
|
||||
*,
|
||||
task: str,
|
||||
preferred_member_id: str = "",
|
||||
) -> dict[str, Any]:
|
||||
"""Compatibility wrapper for older callers."""
|
||||
return build_team_routing_payload(
|
||||
children,
|
||||
task=task,
|
||||
preferred_member_id=preferred_member_id,
|
||||
)
|
||||
|
||||
|
||||
def build_team_route_decision(
|
||||
children: list[dict],
|
||||
task: str,
|
||||
preferred_member_id: str = "",
|
||||
) -> dict[str, Any]:
|
||||
"""Compatibility wrapper for tests and older imports."""
|
||||
return build_team_routing_payload(
|
||||
children,
|
||||
task=task,
|
||||
preferred_member_id=preferred_member_id,
|
||||
)
|
||||
143
molecule_runtime/preflight.py
Normal file
143
molecule_runtime/preflight.py
Normal file
@ -0,0 +1,143 @@
|
||||
"""Startup preflight checks for workspace runtime configs."""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from config import WorkspaceConfig
|
||||
|
||||
SUPPORTED_RUNTIMES = {
|
||||
"langgraph",
|
||||
"claude-code",
|
||||
"codex",
|
||||
"ollama",
|
||||
"custom",
|
||||
"crewai",
|
||||
"autogen",
|
||||
"deepagents",
|
||||
"openclaw",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreflightIssue:
|
||||
severity: str
|
||||
title: str
|
||||
detail: str
|
||||
fix: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreflightReport:
|
||||
warnings: list[PreflightIssue] = field(default_factory=list)
|
||||
failures: list[PreflightIssue] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def ok(self) -> bool:
|
||||
return not self.failures
|
||||
|
||||
|
||||
def run_preflight(config: WorkspaceConfig, config_path: str) -> PreflightReport:
|
||||
"""Check the workspace config for obvious startup blockers."""
|
||||
report = PreflightReport()
|
||||
config_dir = Path(config_path)
|
||||
|
||||
if config.runtime not in SUPPORTED_RUNTIMES:
|
||||
report.failures.append(
|
||||
PreflightIssue(
|
||||
severity="fail",
|
||||
title="Runtime",
|
||||
detail=f"Unsupported runtime '{config.runtime}'",
|
||||
fix="Choose one of the supported runtimes or install the matching adapter.",
|
||||
)
|
||||
)
|
||||
|
||||
if not 1 <= int(config.a2a.port) <= 65535:
|
||||
report.failures.append(
|
||||
PreflightIssue(
|
||||
severity="fail",
|
||||
title="A2A port",
|
||||
detail=f"Invalid A2A port: {config.a2a.port}",
|
||||
fix="Set a2a.port to a value between 1 and 65535.",
|
||||
)
|
||||
)
|
||||
|
||||
# Check required environment variables (e.g. CLAUDE_CODE_OAUTH_TOKEN, OPENAI_API_KEY).
|
||||
# These are declared per-runtime in config.yaml and injected via the secrets API.
|
||||
required_env = getattr(config.runtime_config, "required_env", []) or []
|
||||
for env_var in required_env:
|
||||
if not os.environ.get(env_var):
|
||||
report.failures.append(
|
||||
PreflightIssue(
|
||||
severity="fail",
|
||||
title="Required env",
|
||||
detail=f"Missing required environment variable: {env_var}",
|
||||
fix=f"Set {env_var} via the secrets API (global or workspace-level).",
|
||||
)
|
||||
)
|
||||
|
||||
# Backward compat: if legacy auth_token_file is set, warn but don't block
|
||||
# if the token is available via required_env or auth_token_env.
|
||||
token_file = getattr(config.runtime_config, "auth_token_file", "")
|
||||
if token_file:
|
||||
token_path = config_dir / token_file
|
||||
if not token_path.exists():
|
||||
token_env = getattr(config.runtime_config, "auth_token_env", "")
|
||||
env_has_token = bool(token_env and os.environ.get(token_env))
|
||||
# Also check if any required_env is set (covers the new path)
|
||||
if not env_has_token and required_env:
|
||||
env_has_token = all(os.environ.get(e) for e in required_env)
|
||||
|
||||
if not env_has_token:
|
||||
report.failures.append(
|
||||
PreflightIssue(
|
||||
severity="fail",
|
||||
title="Auth token",
|
||||
detail=f"Missing auth token file: {token_file}",
|
||||
fix="Remove auth_token_file and use required_env + secrets API instead.",
|
||||
)
|
||||
)
|
||||
|
||||
prompt_files = config.prompt_files or ["system-prompt.md"]
|
||||
for prompt_file in prompt_files:
|
||||
prompt_path = config_dir / prompt_file
|
||||
if not prompt_path.exists():
|
||||
report.warnings.append(
|
||||
PreflightIssue(
|
||||
severity="warn",
|
||||
title="Prompt file",
|
||||
detail=f"Missing prompt file: {prompt_file}",
|
||||
fix="Add the file or remove it from prompt_files.",
|
||||
)
|
||||
)
|
||||
|
||||
skills_dir = config_dir / "skills"
|
||||
for skill_name in config.skills:
|
||||
skill_path = skills_dir / skill_name / "SKILL.md"
|
||||
if not skill_path.exists():
|
||||
report.warnings.append(
|
||||
PreflightIssue(
|
||||
severity="warn",
|
||||
title="Skill",
|
||||
detail=f"Missing skill package: {skill_name}",
|
||||
fix="Restore the skill folder or remove it from config.yaml.",
|
||||
)
|
||||
)
|
||||
|
||||
return report
|
||||
|
||||
|
||||
def render_preflight_report(report: PreflightReport) -> None:
|
||||
"""Print a concise startup report."""
|
||||
if not report.warnings and not report.failures:
|
||||
return
|
||||
|
||||
print("Preflight checks:")
|
||||
for issue in report.failures:
|
||||
print(f"[FAIL] {issue.title}: {issue.detail}")
|
||||
if issue.fix:
|
||||
print(f" Fix: {issue.fix}")
|
||||
for issue in report.warnings:
|
||||
print(f"[WARN] {issue.title}: {issue.detail}")
|
||||
if issue.fix:
|
||||
print(f" Fix: {issue.fix}")
|
||||
132
molecule_runtime/prompt.py
Normal file
132
molecule_runtime/prompt.py
Normal file
@ -0,0 +1,132 @@
|
||||
"""Build the system prompt for the workspace agent."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from skill_loader.loader import LoadedSkill
|
||||
from adapters.shared_runtime import build_peer_section
|
||||
|
||||
DEFAULT_MEMORY_SNAPSHOT_FILES = ("MEMORY.md", "USER.md")
|
||||
|
||||
|
||||
async def get_peer_capabilities(platform_url: str, workspace_id: str) -> list[dict]:
|
||||
"""Fetch peer workspace capabilities from the platform."""
|
||||
try:
|
||||
import httpx
|
||||
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
resp = await client.get(
|
||||
f"{platform_url}/registry/{workspace_id}/peers",
|
||||
headers={"X-Workspace-ID": workspace_id},
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
return resp.json()
|
||||
except Exception as e:
|
||||
print(f"Warning: could not fetch peers: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def build_system_prompt(
|
||||
config_path: str,
|
||||
workspace_id: str,
|
||||
loaded_skills: list[LoadedSkill],
|
||||
peers: list[dict],
|
||||
prompt_files: list[str] | None = None,
|
||||
plugin_rules: list[str] | None = None,
|
||||
plugin_prompts: list[str] | None = None,
|
||||
parent_context: list[dict] | None = None,
|
||||
) -> str:
|
||||
"""Build the complete system prompt.
|
||||
|
||||
Loads prompt files in order from config_path. If prompt_files is specified
|
||||
in config.yaml, those files are loaded in order. Otherwise falls back to
|
||||
system-prompt.md for backwards compatibility.
|
||||
If MEMORY.md or USER.md exist alongside the config, they are appended as a
|
||||
frozen memory snapshot without needing to list them explicitly.
|
||||
|
||||
This allows different agent frameworks to use their own file structures:
|
||||
- OpenClaw: SOUL.md, BOOTSTRAP.md, AGENTS.md, HEARTBEAT.md, TOOLS.md, USER.md
|
||||
- Claude Code: CLAUDE.md
|
||||
- Default: system-prompt.md
|
||||
"""
|
||||
parts = []
|
||||
|
||||
# Load prompt files in order
|
||||
files_to_load = list(prompt_files or [])
|
||||
if not files_to_load:
|
||||
# Backwards compatible: fall back to system-prompt.md
|
||||
files_to_load = ["system-prompt.md"]
|
||||
|
||||
seen_files = set(files_to_load)
|
||||
|
||||
for filename in files_to_load:
|
||||
file_path = Path(config_path) / filename
|
||||
if file_path.exists():
|
||||
content = file_path.read_text().strip()
|
||||
if content:
|
||||
parts.append(content)
|
||||
else:
|
||||
print(f"Warning: prompt file not found: {file_path}")
|
||||
|
||||
# Hermes-style memory snapshot files: load automatically when present.
|
||||
# These stay as thin markdown files so the runtime does not need a new storage layer.
|
||||
for filename in DEFAULT_MEMORY_SNAPSHOT_FILES:
|
||||
if filename in seen_files:
|
||||
continue
|
||||
file_path = Path(config_path) / filename
|
||||
if file_path.exists():
|
||||
content = file_path.read_text().strip()
|
||||
if content:
|
||||
parts.append(content)
|
||||
|
||||
# Inject parent's shared context (if this workspace is a child)
|
||||
if parent_context:
|
||||
parts.append("\n## Parent Context\n")
|
||||
parts.append("The following context was shared by your parent workspace:\n")
|
||||
for ctx_file in parent_context:
|
||||
path = ctx_file.get("path", "unknown")
|
||||
content = ctx_file.get("content", "")
|
||||
if content.strip():
|
||||
parts.append(f"### {path}")
|
||||
parts.append(content.strip())
|
||||
parts.append("")
|
||||
|
||||
# Inject plugin rules (always-on guidelines from ECC, Superpowers, etc.)
|
||||
if plugin_rules:
|
||||
parts.append("\n## Platform Rules\n")
|
||||
for rule in plugin_rules:
|
||||
parts.append(rule)
|
||||
parts.append("")
|
||||
|
||||
# Inject plugin prompt fragments
|
||||
if plugin_prompts:
|
||||
parts.append("\n## Platform Guidelines\n")
|
||||
for fragment in plugin_prompts:
|
||||
parts.append(fragment)
|
||||
parts.append("")
|
||||
|
||||
# Add skill instructions
|
||||
if loaded_skills:
|
||||
parts.append("\n## Your Skills\n")
|
||||
for skill in loaded_skills:
|
||||
parts.append(f"### {skill.metadata.name}")
|
||||
if skill.metadata.description:
|
||||
parts.append(skill.metadata.description)
|
||||
parts.append(skill.instructions)
|
||||
parts.append("")
|
||||
|
||||
# Add peer capabilities with a single shared renderer.
|
||||
peer_section = build_peer_section(peers)
|
||||
if peer_section:
|
||||
parts.append(peer_section)
|
||||
|
||||
# Add delegation failure handling
|
||||
parts.append("""
|
||||
## Handling delegation failures
|
||||
If a delegation fails:
|
||||
1. Check if the task is blocking — if not, continue other work
|
||||
2. Retry transient failures (connection errors) after 30 seconds
|
||||
3. For persistent failures, report to the caller with context
|
||||
4. Never silently drop a failed task
|
||||
""")
|
||||
|
||||
return "\n".join(parts)
|
||||
0
molecule_runtime/skill_loader/__init__.py
Normal file
0
molecule_runtime/skill_loader/__init__.py
Normal file
191
molecule_runtime/skill_loader/loader.py
Normal file
191
molecule_runtime/skill_loader/loader.py
Normal file
@ -0,0 +1,191 @@
|
||||
"""Load skill packages from the workspace config directory."""
|
||||
|
||||
import importlib.util
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from builtin_tools.security_scan import SkillSecurityError, scan_skill_dependencies
|
||||
_SECURITY_SCAN_AVAILABLE = True
|
||||
except ImportError: # lightweight test environments without tools/ on sys.path
|
||||
_SECURITY_SCAN_AVAILABLE = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillMetadata:
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
tags: list[str] = field(default_factory=list)
|
||||
examples: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadedSkill:
|
||||
metadata: SkillMetadata
|
||||
instructions: str
|
||||
tools: list[Any] = field(default_factory=list)
|
||||
|
||||
|
||||
def parse_skill_frontmatter(skill_md_path: Path) -> tuple[dict, str]:
|
||||
"""Parse YAML frontmatter from a SKILL.md file.
|
||||
|
||||
Runtime-side: tolerant of malformed frontmatter (returns ``({}, body)``
|
||||
so the skill loads with empty metadata rather than crashing the
|
||||
workspace at startup). The SDK's :func:`molecule_plugin.parse_skill_md`
|
||||
is the authoring-time strict validator that surfaces the same errors.
|
||||
Keep behaviour aligned: if you change acceptance rules here, mirror
|
||||
them in the SDK's parser.
|
||||
"""
|
||||
content = skill_md_path.read_text()
|
||||
|
||||
if not content.startswith("---"):
|
||||
return {}, content
|
||||
|
||||
parts = content.split("---", 2)
|
||||
if len(parts) < 3:
|
||||
return {}, content
|
||||
|
||||
try:
|
||||
frontmatter = yaml.safe_load(parts[1]) or {}
|
||||
except yaml.YAMLError:
|
||||
logger.warning("SKILL.md at %s has malformed frontmatter; loading with empty metadata", skill_md_path)
|
||||
frontmatter = {}
|
||||
if not isinstance(frontmatter, dict):
|
||||
logger.warning("SKILL.md at %s frontmatter is not a mapping; ignoring", skill_md_path)
|
||||
frontmatter = {}
|
||||
|
||||
body = parts[2].strip()
|
||||
return frontmatter, body
|
||||
|
||||
|
||||
def load_skill_tools(scripts_dir: Path) -> list[Any]:
|
||||
"""Dynamically load tool functions from a skill's scripts/ directory.
|
||||
|
||||
Follows the agentskills.io spec layout: each skill's executable code
|
||||
lives under ``scripts/``. Returns an empty list if the directory
|
||||
doesn't exist.
|
||||
"""
|
||||
tools = []
|
||||
if not scripts_dir.exists():
|
||||
return tools
|
||||
|
||||
# Import langchain only when we actually have scripts to process.
|
||||
# Keeps test environments (and empty skills) from needing langchain.
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
# Sensitive env vars that must not be readable by skill scripts.
|
||||
# Fix C (Cycle 5): scrub before exec_module() so a malicious skill cannot
|
||||
# exfiltrate credentials even if it somehow bypasses the POST /plugins
|
||||
# auth gate (defence in depth).
|
||||
_SCRUB_KEYS = (
|
||||
"CLAUDE_CODE_OAUTH_TOKEN",
|
||||
"ANTHROPIC_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"WORKSPACE_AUTH_TOKEN",
|
||||
"GITHUB_TOKEN",
|
||||
"GH_TOKEN",
|
||||
)
|
||||
|
||||
for py_file in sorted(scripts_dir.glob("*.py")):
|
||||
if py_file.name.startswith("_"):
|
||||
continue
|
||||
|
||||
# Verify the script is actually inside the expected scripts directory
|
||||
# (path traversal guard — glob shouldn't produce outside paths, but
|
||||
# belt-and-suspenders for symlink attacks).
|
||||
try:
|
||||
py_file.resolve().relative_to(scripts_dir.resolve())
|
||||
except ValueError:
|
||||
logger.warning("skill_loader: rejecting script outside scripts_dir: %s", py_file)
|
||||
continue
|
||||
|
||||
module_name = f"skill_tool_{py_file.stem}"
|
||||
spec = importlib.util.spec_from_file_location(module_name, py_file)
|
||||
if spec is None or spec.loader is None:
|
||||
continue
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
|
||||
# Temporarily remove sensitive env vars before running skill code.
|
||||
_saved_env = {k: os.environ.pop(k) for k in _SCRUB_KEYS if k in os.environ}
|
||||
try:
|
||||
spec.loader.exec_module(module)
|
||||
finally:
|
||||
# Always restore so the rest of the agent process retains them.
|
||||
os.environ.update(_saved_env)
|
||||
|
||||
# Look for functions decorated with @tool (BaseTool instances)
|
||||
for attr_name in dir(module):
|
||||
attr = getattr(module, attr_name)
|
||||
if isinstance(attr, BaseTool):
|
||||
tools.append(attr)
|
||||
|
||||
return tools
|
||||
|
||||
|
||||
def load_skills(config_path: str, skill_names: list[str]) -> list[LoadedSkill]:
|
||||
"""Load all skills specified in the config."""
|
||||
skills_dir = Path(config_path) / "skills"
|
||||
loaded = []
|
||||
|
||||
# Resolve security scan mode once before the loop
|
||||
scan_mode = "warn"
|
||||
fail_open_if_no_scanner = True # safe default matches security_scan.py default
|
||||
if _SECURITY_SCAN_AVAILABLE:
|
||||
try:
|
||||
from config import load_config
|
||||
_cfg = load_config(config_path)
|
||||
scan_mode = _cfg.security_scan.mode
|
||||
fail_open_if_no_scanner = _cfg.security_scan.fail_open_if_no_scanner
|
||||
except Exception:
|
||||
pass # use defaults — never block on config error
|
||||
|
||||
for skill_name in skill_names:
|
||||
skill_path = skills_dir / skill_name
|
||||
skill_md = skill_path / "SKILL.md"
|
||||
|
||||
if not skill_md.exists():
|
||||
logger.warning("SKILL.md not found for %s, skipping", skill_name)
|
||||
continue
|
||||
|
||||
# --- Security scan before loading any code from the skill ------------
|
||||
if _SECURITY_SCAN_AVAILABLE and scan_mode != "off":
|
||||
try:
|
||||
scan_skill_dependencies(
|
||||
skill_name, skill_path, scan_mode,
|
||||
fail_open_if_no_scanner=fail_open_if_no_scanner,
|
||||
)
|
||||
except SkillSecurityError as exc:
|
||||
logger.warning("Skipping skill '%s': blocked by security scan — %s", skill_name, exc)
|
||||
continue
|
||||
|
||||
frontmatter, instructions = parse_skill_frontmatter(skill_md)
|
||||
|
||||
metadata = SkillMetadata(
|
||||
id=skill_name,
|
||||
name=frontmatter.get("name", skill_name),
|
||||
description=frontmatter.get("description", ""),
|
||||
tags=frontmatter.get("tags", []),
|
||||
examples=frontmatter.get("examples", []),
|
||||
)
|
||||
|
||||
# Executables live under scripts/ per the agentskills.io spec.
|
||||
tools = load_skill_tools(skill_path / "scripts")
|
||||
|
||||
loaded.append(LoadedSkill(
|
||||
metadata=metadata,
|
||||
instructions=instructions,
|
||||
tools=tools,
|
||||
))
|
||||
|
||||
return loaded
|
||||
227
molecule_runtime/skill_loader/watcher.py
Normal file
227
molecule_runtime/skill_loader/watcher.py
Normal file
@ -0,0 +1,227 @@
|
||||
"""Skills hot-reload watcher.
|
||||
|
||||
Monitors the workspace's ``skills/`` directory for file changes and reloads
|
||||
affected skill modules in-place — no coordinator restart required.
|
||||
|
||||
Architecture
|
||||
------------
|
||||
``SkillsWatcher`` runs as a background asyncio task alongside the agent. It
|
||||
polls the skill directories every ``POLL_INTERVAL`` seconds (default 3 s),
|
||||
computes SHA-256 hashes of every file, and fires ``_reload_skill()`` when any
|
||||
file inside a skill's folder changes.
|
||||
|
||||
``_reload_skill()`` calls ``load_skills()`` from ``skills.loader`` for the
|
||||
changed skill and passes the fresh ``LoadedSkill`` to every registered
|
||||
``on_reload`` callback. Adapters register a callback that rebuilds the
|
||||
LangGraph agent with the updated tool set, so the change takes effect on
|
||||
the very next incoming A2A task — zero downtime.
|
||||
|
||||
Audit event
|
||||
-----------
|
||||
Every successful reload emits::
|
||||
|
||||
event_type : "skill_reload"
|
||||
action : "reload"
|
||||
resource : "<skill_name>"
|
||||
outcome : "success" | "failure"
|
||||
changed_files : [list of relative paths that triggered the reload]
|
||||
|
||||
Usage::
|
||||
|
||||
watcher = SkillsWatcher(
|
||||
config_path="/configs",
|
||||
skill_names=["web_search", "code_review"],
|
||||
on_reload=lambda skill: rebuild_agent_with_skill(skill),
|
||||
)
|
||||
asyncio.create_task(watcher.start())
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
POLL_INTERVAL = 3.0 # seconds between filesystem polls
|
||||
DEBOUNCE_SECS = 1.5 # wait for writes to settle before reloading
|
||||
|
||||
|
||||
class SkillsWatcher:
|
||||
"""Watches skill directories and reloads changed skills without restarting.
|
||||
|
||||
Args:
|
||||
config_path: Path to the workspace config directory (contains ``skills/``).
|
||||
skill_names: List of skill IDs to watch (subfolder names under ``skills/``).
|
||||
on_reload: Async or sync callable invoked with a fresh ``LoadedSkill``
|
||||
every time a skill is reloaded. May be called concurrently
|
||||
for multiple skills if several change at once.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config_path: str,
|
||||
skill_names: list[str],
|
||||
on_reload: Callable | None = None,
|
||||
) -> None:
|
||||
self.config_path = config_path
|
||||
self.skill_names = list(skill_names)
|
||||
self.on_reload = on_reload
|
||||
self._hashes: dict[str, str] = {} # rel_path → sha256 hex
|
||||
self._running = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public interface
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the poll loop in the current event loop. Runs until ``stop()``."""
|
||||
self._running = True
|
||||
self._hashes = self._scan()
|
||||
logger.info(
|
||||
"SkillsWatcher: monitoring %d skill(s) in %s",
|
||||
len(self.skill_names), self.config_path,
|
||||
)
|
||||
|
||||
while self._running:
|
||||
await asyncio.sleep(POLL_INTERVAL)
|
||||
await self._tick()
|
||||
|
||||
def stop(self) -> None:
|
||||
self._running = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _skills_root(self) -> Path:
|
||||
return Path(self.config_path) / "skills"
|
||||
|
||||
def _hash_file(self, path: Path) -> str:
|
||||
try:
|
||||
# H1: SHA-256 replaces MD5 for file-integrity change detection.
|
||||
return hashlib.sha256(path.read_bytes()).hexdigest()
|
||||
except OSError:
|
||||
return ""
|
||||
|
||||
def _scan(self) -> dict[str, str]:
|
||||
"""Return {relative_path: sha256} for every file in watched skill dirs."""
|
||||
hashes: dict[str, str] = {}
|
||||
root = self._skills_root()
|
||||
for skill_name in self.skill_names:
|
||||
skill_dir = root / skill_name
|
||||
if not skill_dir.is_dir():
|
||||
continue
|
||||
for fpath in skill_dir.rglob("*"):
|
||||
if fpath.is_file() and not fpath.name.startswith("."):
|
||||
rel = str(fpath.relative_to(root))
|
||||
hashes[rel] = self._hash_file(fpath)
|
||||
return hashes
|
||||
|
||||
def _changed_skills(self, new_hashes: dict[str, str]) -> dict[str, list[str]]:
|
||||
"""Return {skill_name: [changed_file, …]} for skills with file changes."""
|
||||
changed: dict[str, list[str]] = {}
|
||||
|
||||
all_paths = set(new_hashes) | set(self._hashes)
|
||||
for rel_path in all_paths:
|
||||
old = self._hashes.get(rel_path, "")
|
||||
new = new_hashes.get(rel_path, "")
|
||||
if old != new:
|
||||
# rel_path is like "web_search/SKILL.md" or "web_search/tools/foo.py"
|
||||
skill_name = rel_path.split("/")[0]
|
||||
if skill_name in self.skill_names:
|
||||
changed.setdefault(skill_name, []).append(rel_path)
|
||||
|
||||
return changed
|
||||
|
||||
async def _tick(self) -> None:
|
||||
"""One poll cycle: detect changes, debounce, reload."""
|
||||
new_hashes = self._scan()
|
||||
changed = self._changed_skills(new_hashes)
|
||||
|
||||
if not changed:
|
||||
return
|
||||
|
||||
logger.info("SkillsWatcher: changes detected in %s", list(changed.keys()))
|
||||
await asyncio.sleep(DEBOUNCE_SECS)
|
||||
|
||||
# Re-scan after debounce to absorb any writes still in-flight
|
||||
new_hashes = self._scan()
|
||||
changed = self._changed_skills(new_hashes)
|
||||
|
||||
self._hashes = new_hashes # commit new baseline
|
||||
|
||||
for skill_name, files in changed.items():
|
||||
await self._reload_skill(skill_name, files)
|
||||
|
||||
async def _reload_skill(self, skill_name: str, changed_files: list[str]) -> None:
|
||||
"""Reload *skill_name*'s modules and notify the callback."""
|
||||
logger.info("SkillsWatcher: reloading skill '%s' (changed: %s)", skill_name, changed_files)
|
||||
|
||||
# Evict stale module entries so importlib loads fresh copies
|
||||
stale = [k for k in sys.modules if k.startswith(f"skill_tool_")]
|
||||
for key in stale:
|
||||
del sys.modules[key]
|
||||
|
||||
try:
|
||||
from skill_loader.loader import load_skills
|
||||
loaded = load_skills(self.config_path, [skill_name])
|
||||
|
||||
if loaded:
|
||||
skill = loaded[0]
|
||||
logger.info(
|
||||
"SkillsWatcher: skill '%s' reloaded — %d tool(s)",
|
||||
skill_name, len(skill.tools),
|
||||
)
|
||||
|
||||
# Audit event
|
||||
try:
|
||||
from builtin_tools.audit import log_event
|
||||
log_event(
|
||||
event_type="skill_reload",
|
||||
action="reload",
|
||||
resource=skill_name,
|
||||
outcome="success",
|
||||
changed_files=changed_files,
|
||||
tool_count=len(skill.tools),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Notify adapter callback
|
||||
if self.on_reload is not None:
|
||||
try:
|
||||
result = self.on_reload(skill)
|
||||
if asyncio.iscoroutine(result):
|
||||
await result
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"SkillsWatcher: on_reload callback failed for '%s': %s",
|
||||
skill_name, exc,
|
||||
)
|
||||
else:
|
||||
logger.warning("SkillsWatcher: no LoadedSkill returned for '%s'", skill_name)
|
||||
self._audit_failure(skill_name, changed_files, "no_skill_returned")
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("SkillsWatcher: reload failed for '%s': %s", skill_name, exc)
|
||||
self._audit_failure(skill_name, changed_files, str(exc))
|
||||
|
||||
@staticmethod
|
||||
def _audit_failure(skill_name: str, changed_files: list[str], error: str) -> None:
|
||||
try:
|
||||
from builtin_tools.audit import log_event
|
||||
log_event(
|
||||
event_type="skill_reload",
|
||||
action="reload",
|
||||
resource=skill_name,
|
||||
outcome="failure",
|
||||
changed_files=changed_files,
|
||||
error=error,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
30
molecule_runtime/transcript_auth.py
Normal file
30
molecule_runtime/transcript_auth.py
Normal file
@ -0,0 +1,30 @@
|
||||
"""Auth gate for the /transcript Starlette route.
|
||||
|
||||
Extracted from main.py so the security-critical logic is unit-testable
|
||||
without standing up the full uvicorn/a2a/httpx import stack.
|
||||
|
||||
#328: the route must fail CLOSED when the expected token is unavailable
|
||||
(bootstrap window, missing file, OSError). The previous implementation
|
||||
treated a missing token as "skip auth entirely" — any container on the
|
||||
same Docker network could read the session log during provisioning.
|
||||
"""
|
||||
|
||||
|
||||
def transcript_authorized(expected_token: str | None, auth_header: str) -> bool:
|
||||
"""Return True iff /transcript should serve the request.
|
||||
|
||||
Args:
|
||||
expected_token: the workspace's registered bearer token, or None
|
||||
if `/configs/.auth_token` is absent / unreadable.
|
||||
auth_header: raw value of the Authorization request header.
|
||||
|
||||
Behavior:
|
||||
- None/empty expected → fail closed (401). This is the #328 fix;
|
||||
a missing token file is an auth failure, not a bypass.
|
||||
- Non-empty expected: strict equality check against "Bearer <tok>".
|
||||
Bearer prefix is case-sensitive (matches the platform's
|
||||
wsauth.BearerTokenFromHeader contract).
|
||||
"""
|
||||
if not expected_token:
|
||||
return False
|
||||
return auth_header == f"Bearer {expected_token}"
|
||||
120
molecule_runtime/watcher.py
Normal file
120
molecule_runtime/watcher.py
Normal file
@ -0,0 +1,120 @@
|
||||
"""File watcher for hot-reloading skills and config changes.
|
||||
|
||||
Monitors the config directory for file changes and triggers
|
||||
agent rebuild + Agent Card update broadcast.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEBOUNCE_SECONDS = 2.0
|
||||
POLL_INTERVAL = 3.0 # seconds between filesystem checks
|
||||
|
||||
|
||||
class ConfigWatcher:
|
||||
"""Watches the config directory for changes and triggers reload callbacks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config_path: str,
|
||||
platform_url: str,
|
||||
workspace_id: str,
|
||||
on_reload=None,
|
||||
):
|
||||
self.config_path = config_path
|
||||
self.platform_url = platform_url
|
||||
self.workspace_id = workspace_id
|
||||
self.on_reload = on_reload
|
||||
self._file_hashes: dict[str, str] = {}
|
||||
self._running = False
|
||||
|
||||
def _hash_file(self, path: str) -> str:
|
||||
try:
|
||||
# H1: SHA-256 replaces MD5 for file-integrity change detection.
|
||||
# MD5 is collision-prone; using SHA-256 prevents a crafted config
|
||||
# file from producing the same hash as a benign one, which would
|
||||
# silently suppress the hot-reload callback.
|
||||
return hashlib.sha256(Path(path).read_bytes()).hexdigest()
|
||||
except (OSError, IOError):
|
||||
return ""
|
||||
|
||||
def _scan_hashes(self) -> dict[str, str]:
|
||||
"""Scan all files in config directory and return hash map."""
|
||||
hashes = {}
|
||||
for root, _, files in os.walk(self.config_path):
|
||||
for fname in files:
|
||||
if fname.startswith("."):
|
||||
continue
|
||||
fpath = os.path.join(root, fname)
|
||||
rel = os.path.relpath(fpath, self.config_path)
|
||||
hashes[rel] = self._hash_file(fpath)
|
||||
return hashes
|
||||
|
||||
def _detect_changes(self) -> list[str]:
|
||||
"""Compare current state with cached hashes, return changed files."""
|
||||
current = self._scan_hashes()
|
||||
changed = []
|
||||
|
||||
for path, h in current.items():
|
||||
if path not in self._file_hashes or self._file_hashes[path] != h:
|
||||
changed.append(path)
|
||||
|
||||
for path in self._file_hashes:
|
||||
if path not in current:
|
||||
changed.append(path)
|
||||
|
||||
self._file_hashes = current
|
||||
return changed
|
||||
|
||||
async def _notify_platform(self, agent_card: dict):
|
||||
"""Push updated Agent Card to the platform."""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
await client.post(
|
||||
f"{self.platform_url}/registry/update-card",
|
||||
json={
|
||||
"workspace_id": self.workspace_id,
|
||||
"agent_card": agent_card,
|
||||
},
|
||||
)
|
||||
logger.info("Agent Card updated via platform")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to update Agent Card: %s", e)
|
||||
|
||||
async def start(self):
|
||||
"""Start watching for changes in a background loop."""
|
||||
self._running = True
|
||||
self._file_hashes = self._scan_hashes()
|
||||
logger.info("Config watcher started for %s", self.config_path)
|
||||
|
||||
while self._running:
|
||||
await asyncio.sleep(POLL_INTERVAL)
|
||||
|
||||
changed = self._detect_changes()
|
||||
if not changed:
|
||||
continue
|
||||
|
||||
logger.info("Config changes detected: %s", changed)
|
||||
|
||||
# Debounce — wait for writes to settle
|
||||
await asyncio.sleep(DEBOUNCE_SECONDS)
|
||||
|
||||
# Re-scan after debounce (more changes may have occurred)
|
||||
self._detect_changes()
|
||||
|
||||
# Trigger reload callback
|
||||
if self.on_reload:
|
||||
try:
|
||||
await self.on_reload()
|
||||
except Exception as e:
|
||||
logger.error("Reload callback failed: %s", e)
|
||||
|
||||
def stop(self):
|
||||
self._running = False
|
||||
35
pyproject.toml
Normal file
35
pyproject.toml
Normal file
@ -0,0 +1,35 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=68.0", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "molecule-ai-workspace-runtime"
|
||||
version = "0.1.0"
|
||||
description = "Molecule AI workspace runtime — shared infrastructure for all agent adapters"
|
||||
requires-python = ">=3.11"
|
||||
license = {text = "BSL-1.1"}
|
||||
readme = "README.md"
|
||||
# Don't pin heavy deps — each adapter adds its own
|
||||
dependencies = [
|
||||
"a2a-sdk[http-server]>=0.3.25",
|
||||
"httpx>=0.27.0",
|
||||
"uvicorn>=0.30.0",
|
||||
"starlette>=0.38.0",
|
||||
"websockets>=12.0",
|
||||
"pyyaml>=6.0",
|
||||
"langchain-core>=0.3.0",
|
||||
"opentelemetry-api>=1.24.0",
|
||||
"opentelemetry-sdk>=1.24.0",
|
||||
"opentelemetry-exporter-otlp-proto-http>=1.24.0",
|
||||
"temporalio>=1.7.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
molecule-runtime = "molecule_runtime.main:main_sync"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["."]
|
||||
include = ["molecule_runtime*"]
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
"molecule_runtime" = ["py.typed"]
|
||||
Loading…
Reference in New Issue
Block a user