molecule-sdk-python/molecule_agent/client.py
Molecule AI SDK Lead f8578aeac3 fix(sdk): repair broken __all__ list — new constants/func escaped list boundary
The PR accidentally embedded the new module-level constants and
make_idempotency_key function inside the __all__ list literal,
causing a SyntaxError. Fix closes the list after the existing exports
and adds make_idempotency_key + verify_plugin_sha256 as proper string
entries (compute_plugin_sha256 was removed in this PR so it is omitted).
2026-04-23 19:25:45 +00:00

964 lines
39 KiB
Python

"""RemoteAgentClient — blocking HTTP client for the Phase 30 remote-agent flow.
The client is deliberately dependency-light (``requests`` only) so a remote
agent author can drop it into any runtime. All methods correspond 1:1 to
a Phase 30 endpoint:
* :py:meth:`register` → ``POST /registry/register`` (30.1)
* :py:meth:`pull_secrets` → ``GET /workspaces/:id/secrets/values`` (30.2)
* :py:meth:`poll_state` → ``GET /workspaces/:id/state`` (30.4)
* :py:meth:`heartbeat` → ``POST /registry/heartbeat`` (30.1)
* :py:meth:`run_heartbeat_loop` — drives heartbeat + state-poll on a timer,
returns when the platform reports the workspace paused or deleted.
No inbound A2A server is bundled here yet — that requires hosting an HTTP
endpoint the platform's proxy can reach, which is network-dependent. A
future 30.8b iteration will add an optional ``start_a2a_server()`` helper.
"""
from __future__ import annotations
import fcntl
import hashlib
import json
import logging
import math
import os
import random
import stat
import subprocess
import tarfile
import time
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import requests
logger = logging.getLogger(__name__)
# Polling cadence defaults. Chosen to align with the platform's 60-second
# Redis TTL — one heartbeat per minute keeps the TTL refreshed; state-poll
# at the same cadence is cheap (tiny GET) and gives ≤60s reaction time on
# pause / delete. Overridable via RemoteAgentClient constructor kwargs.
DEFAULT_HEARTBEAT_INTERVAL = 30.0 # seconds
DEFAULT_STATE_POLL_INTERVAL = 30.0 # seconds
# Phase 30.6 — sibling URL cache TTL. Cached URLs expire after this many
# seconds, forcing a re-discovery call. Short enough that a sibling that
# moved (restart with new port) is picked up quickly; long enough that
# we don't hit the discovery endpoint on every A2A call.
DEFAULT_URL_CACHE_TTL = 300.0 # 5 minutes
# Retry-on-429 defaults for idempotent GET calls.
# Matches the behaviour of the TypeScript MCP server's platformGet().
DEFAULT_GET_MAX_RETRIES = 3 # retry up to 3 times on 429
_RETRY_BASE_DELAY = 1.0 # seconds — first delay
_RETRY_MAX_DELAY = 30.0 # seconds — cap
_RETRY_JITTER_FRAC = 0.25 # ±25% jitter around base delay
# KI-002 — idempotency key granularity: round to the current minute so
# that concurrent restarts within the same 60-second window produce the
# same key, while distinct tasks or distinct minutes produce distinct keys.
_IDEMPOTENCY_ROUND_SECONDS = 60
def make_idempotency_key(task_text: str) -> str:
"""Compute a deterministic idempotency key for a delegation task.
Combines the task text with the current wall-clock minute to produce
a SHA-256 hex digest. Rounding to minute-level means two container
restarts within the same minute that send the same task string will
share the same key, preventing the platform from processing a duplicate
delegation. A different minute (or a different task string) yields a
different key.
Args:
task_text: The task description string being delegated.
Returns:
A 64-character hex string (SHA-256 digest).
"""
# Round current time down to the nearest minute — same-task restarts
# within this minute share a key; after the minute rolls over the key
# changes so a genuinely new task is always treated as new.
now = int(time.time()) // _IDEMPOTENCY_ROUND_SECONDS * _IDEMPOTENCY_ROUND_SECONDS
payload = f"{task_text}:{now}"
return hashlib.sha256(payload.encode("utf-8")).hexdigest()
def _safe_extract_tar(tf: tarfile.TarFile, dest: Path) -> None:
"""Extract a tarfile, refusing entries that would escape `dest`
and logging skipped symlinks/hardlinks.
Tar archives can include ``..`` and absolute paths. Without explicit
rejection we'd risk the classic "tar slip" CVE — a malicious plugin
source could overwrite the agent's home files. We resolve every
entry's target to an absolute path, verify it lives inside dest, and
extract one-by-one so the symlink-skip actually takes effect (a bare
``extractall`` would still write the symlinks we marked as skipped).
Symlinks and hardlinks are skipped and a ``logger.warning`` is emitted
for each one so operators can audit what was dropped.
"""
dest_abs = dest.resolve()
for member in tf.getmembers():
# Symlinks and hardlinks could point outside the staged tree;
# skip them entirely (matches the platform-side tar producer) but
# warn so the operator knows what was dropped.
if member.issym() or member.islnk():
link_target = member.linkname
logger.warning(
"skipping symlink in plugin tarball (not supported for security): %s -> %s",
member.name,
link_target,
)
continue
target = (dest / member.name).resolve()
try:
target.relative_to(dest_abs)
except ValueError as exc:
raise ValueError(
f"refusing tar entry escaping dest: {member.name!r} -> {target}"
) from exc
tf.extract(member, dest)
def _rmtree_quiet(path: Path) -> None:
"""rm -rf <path> swallowing missing-file errors. Used for atomic
install rollback where we sometimes call this on a non-existent
staging dir."""
import shutil
try:
shutil.rmtree(path)
except FileNotFoundError:
pass
except Exception as exc:
logger.warning("rmtree(%s) failed: %s", path, exc)
def _walk_files(root: Path) -> list[str]:
"""Yield relative file paths under ``root`` (directories excluded)."""
rel: list[str] = []
for p in root.rglob("*"):
if p.is_file():
rel.append(p.relative_to(root).as_posix())
return rel
def _sha256_file(path: Path) -> str:
"""Return the SHA256 hex digest of ``path``."""
h = hashlib.sha256()
with path.open("rb") as f:
for chunk in iter(lambda: f.read(65536), b""):
h.update(chunk)
return h.hexdigest()
def verify_plugin_sha256(plugin_dir: Path, expected: str) -> bool:
"""Verify the content of ``plugin_dir`` against an expected SHA256.
Computes a *content-addressed manifest hash*: the sorted list of
``(relative_path, SHA256(file_content))`` pairs (excluding ``plugin.yaml``
itself) is hashed. This is deterministic regardless of extraction order,
file timestamps, or directory ordering — any two identical plugin
directories produce the same hash.
``plugin.yaml`` is excluded from its own hash to avoid circular dependency
(the file contains the sha256 field). Its content is still verified as part
of the plugin installation.
The ``expected`` value is stored in the plugin's ``plugin.yaml`` under
the ``sha256`` key and should be regenerated any time the plugin
content changes (e.g. after ``molecule plugin hash`` or a CI step).
Args:
plugin_dir: Path to the unpacked plugin directory.
expected: 64-character lowercase hex SHA256 of the content manifest.
Returns:
``True`` if the manifest hash matches ``expected``.
Raises:
ValueError: if ``expected`` is not a valid 64-char hex string.
"""
if not isinstance(expected, str) or len(expected) != 64 or not _is_hex(expected):
raise ValueError(
f"sha256 must be a 64-character lowercase hex string, got {expected!r}"
)
file_hashes: list[tuple[str, str]] = []
for relpath in sorted(_walk_files(plugin_dir)):
# plugin.yaml contains the sha256 field itself; including its own hash
# in the manifest creates a circular dependency. We hash all other files
# to protect plugin content while leaving the manifest self-describing.
if relpath == "plugin.yaml":
continue
file_hashes.append((relpath, _sha256_file(plugin_dir / relpath)))
manifest_bytes = json.dumps(file_hashes, sort_keys=True).encode()
manifest_hash = hashlib.sha256(manifest_bytes).hexdigest()
return manifest_hash == expected
def _is_hex(value: str) -> bool:
if not isinstance(value, str):
return False
try:
int(value, 16)
return True
except ValueError:
return False
@dataclass
class WorkspaceState:
"""Snapshot of a remote workspace's platform-side state."""
workspace_id: str
status: str # "online" / "paused" / "degraded" / "removed" / "offline" / ...
paused: bool
deleted: bool
@property
def should_stop(self) -> bool:
"""True when the agent should exit its run loop — platform has
paused or hard-deleted the workspace. The agent can be restarted
later; we just don't want to keep heartbeating against a dead row.
"""
return self.paused or self.deleted
@dataclass
class PeerInfo:
"""A sibling or parent workspace that this agent can communicate with."""
id: str
name: str
url: str
role: str = ""
tier: int = 2
status: str = "unknown"
agent_card: dict[str, Any] = field(default_factory=dict)
class RemoteAgentClient:
"""Blocking HTTP client for a Phase 30 remote agent.
Args:
workspace_id: UUID of the workspace this agent represents. The
workspace row must exist on the platform (created via
``POST /workspaces`` or ``POST /org/import``) — the agent
claims that identity when it registers.
platform_url: Base URL of the platform, e.g.
``https://molecule.example.com``. No trailing slash; the
client adds paths.
agent_card: A2A agent card payload. Minimal: ``{"name": "..."}``.
Full schema matches what an in-container agent would report
(skills list, capabilities, etc.).
reported_url: Optional externally-reachable URL at which siblings
can call this agent's A2A endpoint. If omitted, the agent is
reachable only via the platform's proxy (which won't be able
to initiate calls to the agent either — that's a limitation
of remote agents today, resolved by 30.6/30.7 or by exposing
an inbound endpoint yourself).
token_dir: Where to cache the workspace auth token on disk.
Defaults to ``~/.molecule/<workspace_id>/``. Created with
0700 permissions if missing.
heartbeat_interval: Seconds between heartbeats in the run loop.
state_poll_interval: Seconds between state polls in the run loop.
"""
def __init__(
self,
workspace_id: str,
platform_url: str,
agent_card: dict[str, Any] | None = None,
reported_url: str = "",
token_dir: Path | None = None,
heartbeat_interval: float = DEFAULT_HEARTBEAT_INTERVAL,
state_poll_interval: float = DEFAULT_STATE_POLL_INTERVAL,
url_cache_ttl: float = DEFAULT_URL_CACHE_TTL,
session: requests.Session | None = None,
) -> None:
self.workspace_id = workspace_id
self.platform_url = platform_url.rstrip("/")
self.agent_card = agent_card or {"name": f"remote-agent-{workspace_id[:8]}"}
self.reported_url = reported_url
self.heartbeat_interval = heartbeat_interval
self.state_poll_interval = state_poll_interval
self.url_cache_ttl = url_cache_ttl
# Phase 30.6 — sibling URL cache keyed by workspace id. Values are
# (url, expires_at_unix_seconds). Process-memory only; we re-fetch
# on restart because agent lifetimes are short enough that
# persisting doesn't buy much.
self._url_cache: dict[str, tuple[str, float]] = {}
self._session = session or requests.Session()
self._token_dir = token_dir or (
Path.home() / ".molecule" / workspace_id
)
self._token: str | None = None
self._start_time = time.time()
# ------------------------------------------------------------------
# Token persistence
# ------------------------------------------------------------------
@property
def token_file(self) -> Path:
return self._token_dir / ".auth_token"
def load_token(self) -> str | None:
"""Load a cached token from disk if present. Populates the
in-memory cache on success.
Uses a shared (SH) file lock so it does not block concurrent readers
or conflict with a simultaneous save_token() call. Non-blocking —
returns None on lock failure rather than waiting."""
if self._token is not None:
return self._token
if not self.token_file.exists():
return None
try:
# Shared lock: multiple readers allowed; blocks only while a writer
# holds the exclusive lock. LOCK_NB → fail immediately if contended.
with self.token_file.open("r") as fh:
try:
fcntl.flock(fh.fileno(), fcntl.LOCK_SH | fcntl.LOCK_NB)
except OSError:
# Lock unavailable (contended or non-blocking); treat as absent.
return None
try:
tok = fh.read().strip()
finally:
fcntl.flock(fh.fileno(), fcntl.LOCK_UN)
except OSError as exc:
logger.warning("failed to read %s: %s", self.token_file, exc)
return None
if not tok:
return None
self._token = tok
return tok
def save_token(self, token: str) -> None:
"""Persist a freshly-issued token to disk. Creates the parent
directory with 0700 and the file with 0600 to keep the credential
off other users' prying eyes.
Uses an exclusive (EX) file lock so simultaneous save_token calls from
concurrent instances of RemoteAgentClient for the same workspace are
serialised — no partial-overwrite races. Non-blocking: raises
OSError on lock contention rather than waiting (caller can retry)."""
token = token.strip()
if not token:
raise ValueError("refusing to save empty token")
self._token_dir.mkdir(parents=True, exist_ok=True)
try:
os.chmod(self._token_dir, 0o700)
except OSError:
pass # non-fatal — best-effort on unusual filesystems
# Exclusive lock: serialises writers; blocks until lock is held.
# LOCK_NB → raise OSError immediately if another writer holds it.
with self.token_file.open("w") as fh:
try:
fcntl.flock(fh.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
except OSError:
# Another process holds the lock; give up rather than wait.
# The in-memory _token is still updated so this instance works.
logger.warning("token file %s is locked by another process; skipping write", self.token_file)
self._token = token
return
try:
fh.write(token)
finally:
fcntl.flock(fh.fileno(), fcntl.LOCK_UN)
try:
os.chmod(self.token_file, 0o600)
except OSError:
pass
self._token = token
def _auth_headers(self) -> dict[str, str]:
tok = self.load_token()
if not tok:
return {}
return {"Authorization": f"Bearer {tok}"}
def _get_with_retry(
self,
url: str,
headers: dict[str, str] | None = None,
max_retries: int = DEFAULT_GET_MAX_RETRIES,
) -> requests.Response:
"""Issue a GET with automatic retry on 429 (Too Many Requests).
Retries up to ``max_retries`` times, honouring the ``Retry-After``
header (seconds, rounded up to ms) when present. When absent, uses
exponential back-off with ±25 % jitter: 1 s → 2 s → 4 s, capped at
30 s.
After exhausting retries returns the final 429 response (caller's
``raise_for_status()`` will turn it into an ``HTTPError``).
Only use for idempotent GET calls. POST/DELETE callers must not
retry, as the server may have already processed the request.
"""
attempt = 0
while True:
resp = self._session.get(url, headers=headers or {}, timeout=10.0)
if resp.status_code != 429 or attempt >= max_retries:
return resp
attempt += 1
retry_after = resp.headers.get("Retry-After")
if retry_after is not None:
delay_s = math.ceil(float(retry_after))
else:
base = _RETRY_BASE_DELAY * (2 ** (attempt - 1))
jitter = _RETRY_BASE_DELAY * _RETRY_JITTER_FRAC * (random.random() * 2 - 1)
delay_s = base + jitter
delay_ms = min(int(delay_s * 1000), int(_RETRY_MAX_DELAY * 1000))
time.sleep(delay_ms / 1000.0)
# ------------------------------------------------------------------
# Endpoints
# ------------------------------------------------------------------
def register(self) -> str:
"""Register with the platform and cache the issued auth token.
Returns the token (also persisted to disk). If the platform has
already issued a token for this workspace (identified by the
cached file), register will still succeed but the response will
not include a new ``auth_token`` — the client keeps using the
on-disk copy.
Raises :class:`requests.HTTPError` on non-2xx responses.
"""
# The platform's RegisterPayload requires a non-empty url. A remote
# agent that doesn't expose inbound A2A yet still needs a placeholder
# — we use "remote://no-inbound" so the platform can distinguish it
# from a real HTTP URL and not try to reach the agent.
reported = self.reported_url or "remote://no-inbound"
resp = self._session.post(
f"{self.platform_url}/registry/register",
json={
"id": self.workspace_id,
"url": reported,
"agent_card": self.agent_card,
},
timeout=10.0,
)
resp.raise_for_status()
data = resp.json()
tok = data.get("auth_token", "")
if tok:
self.save_token(tok)
logger.info("registered and saved new token (prefix=%s…)", tok[:8])
else:
# Already-tokened workspace — keep using the cached one.
existing = self.load_token()
if not existing:
logger.warning(
"register returned no auth_token and no cached token exists — "
"authenticated calls will 401 until a token is minted"
)
return self._token or ""
def pull_secrets(self) -> dict[str, str]:
"""Fetch the merged decrypted secrets via the 30.2 endpoint.
Returns an empty dict when the platform has no secrets configured
for this workspace. Raises on network errors and on 401 (which
means the token is missing / invalid — call :py:meth:`register`
first).
"""
resp = self._get_with_retry(
f"{self.platform_url}/workspaces/{self.workspace_id}/secrets/values",
headers=self._auth_headers(),
)
resp.raise_for_status()
return resp.json() or {}
def poll_state(self) -> WorkspaceState | None:
"""Fetch the workspace's current state via the 30.4 endpoint.
Returns None if the platform returns 404 with ``{"deleted": true}``
(workspace hard-deleted) — callers typically exit their run loop
in that case. Raises on other HTTP errors.
"""
resp = self._get_with_retry(
f"{self.platform_url}/workspaces/{self.workspace_id}/state",
headers=self._auth_headers(),
)
if resp.status_code == 404:
# Platform signals hard-delete via 404 + deleted:true
return WorkspaceState(
workspace_id=self.workspace_id,
status="removed",
paused=False,
deleted=True,
)
resp.raise_for_status()
data = resp.json()
return WorkspaceState(
workspace_id=self.workspace_id,
status=str(data.get("status", "unknown")),
paused=bool(data.get("paused", False)),
deleted=bool(data.get("deleted", False)),
)
def heartbeat(
self,
current_task: str = "",
active_tasks: int = 0,
error_rate: float = 0.0,
sample_error: str = "",
) -> None:
"""Send a single heartbeat. Safe to call repeatedly — the platform
treats it as idempotent state-refresh. Raises on non-2xx."""
uptime = int(time.time() - self._start_time)
resp = self._session.post(
f"{self.platform_url}/registry/heartbeat",
headers=self._auth_headers(),
json={
"workspace_id": self.workspace_id,
"current_task": current_task,
"active_tasks": active_tasks,
"error_rate": error_rate,
"sample_error": sample_error,
"uptime_seconds": uptime,
},
timeout=10.0,
)
resp.raise_for_status()
# ------------------------------------------------------------------
# Peer discovery + cache (Phase 30.6)
# ------------------------------------------------------------------
def get_peers(self) -> list[PeerInfo]:
"""Fetch the list of peer workspaces this agent can communicate with.
Hits ``GET /registry/:id/peers`` with the bearer token. The returned
list includes siblings (same parent) and, if applicable, the parent.
Each peer's URL is seeded into the local cache so subsequent calls
to :py:meth:`discover_peer` short-circuit without hitting the
platform.
Raises on 401 (stale/missing token → call :py:meth:`register`) and
other non-2xx.
"""
resp = self._get_with_retry(
f"{self.platform_url}/registry/{self.workspace_id}/peers",
headers={
**self._auth_headers(),
"X-Workspace-ID": self.workspace_id,
},
)
resp.raise_for_status()
data = resp.json() or []
peers: list[PeerInfo] = []
now = time.time()
for row in data:
pid = str(row.get("id", ""))
url = str(row.get("url", ""))
if not pid:
continue
peer = PeerInfo(
id=pid,
name=str(row.get("name", "")),
url=url,
role=str(row.get("role", "")),
tier=int(row.get("tier", 2) or 2),
status=str(row.get("status", "unknown")),
agent_card=row.get("agent_card") or {},
)
peers.append(peer)
# Seed the cache so a subsequent call_peer doesn't need a
# discover round-trip. Only cache HTTP-shaped URLs; skip the
# "remote://no-inbound" placeholder and empty strings.
if url.startswith(("http://", "https://")):
self._url_cache[pid] = (url, now + self.url_cache_ttl)
return peers
def discover_peer(self, target_id: str) -> str | None:
"""Resolve a peer's URL, using the cache when fresh.
Returns the URL string, or None if the platform has no usable URL
for this target. On 401/403 the caller should re-authenticate or
verify the hierarchy rule; those are raised as ``HTTPError``.
Cache semantics: a cached entry is returned immediately if its TTL
hasn't expired; otherwise the platform is hit and the cache
refreshed. Call :py:meth:`invalidate_peer_url` to drop an entry
that was stale (connection error, 5xx) so the next discover
re-fetches instead of returning the dead URL again.
"""
cached = self._url_cache.get(target_id)
if cached is not None:
url, expires_at = cached
if time.time() < expires_at:
return url
# Expired — drop and fall through to refresh
self._url_cache.pop(target_id, None)
resp = self._get_with_retry(
f"{self.platform_url}/registry/discover/{target_id}",
headers={
**self._auth_headers(),
"X-Workspace-ID": self.workspace_id,
},
)
if resp.status_code == 404:
return None
resp.raise_for_status()
url = str((resp.json() or {}).get("url", ""))
if url.startswith(("http://", "https://")):
self._url_cache[target_id] = (url, time.time() + self.url_cache_ttl)
return url
return None
def invalidate_peer_url(self, target_id: str) -> None:
"""Drop a peer's cached URL. Call this after a direct-call failure
so the next call_peer performs a fresh discover."""
self._url_cache.pop(target_id, None)
def call_peer(
self,
target_id: str,
message: str,
prefer_direct: bool = True,
) -> dict[str, Any]:
"""Send an A2A ``message/send`` to a peer.
Preferred path (``prefer_direct=True``, default):
1. Resolve target URL via :py:meth:`discover_peer` (cache-hot
path when we've seen this peer before).
2. POST the JSON-RPC envelope directly to the peer's URL.
3. On connection error / 5xx, invalidate the cache and retry
via the platform proxy — graceful fallback so a stale URL
doesn't brick inter-agent communication.
Proxy-only path (``prefer_direct=False``):
Always routes through ``POST /workspaces/:id/a2a`` — useful
when both agents are behind NAT and can't reach each other
directly, but the platform can reach both.
Returns the full JSON-RPC response dict so callers can inspect
``result`` vs ``error`` without us flattening the envelope.
"""
body = {
"jsonrpc": "2.0",
"id": str(uuid.uuid4()),
"method": "message/send",
"params": {
"message": {
"role": "user",
"messageId": str(uuid.uuid4()),
"parts": [{"kind": "text", "text": message}],
}
},
}
headers = {
**self._auth_headers(),
"X-Workspace-ID": self.workspace_id,
"Content-Type": "application/json",
}
if prefer_direct:
url = self.discover_peer(target_id)
if url:
try:
resp = self._session.post(url, json=body, headers=headers, timeout=30.0)
resp.raise_for_status()
return resp.json()
except Exception as exc:
logger.warning(
"direct A2A to %s (%s) failed: %s — invalidating cache, falling back to proxy",
target_id, url, exc,
)
self.invalidate_peer_url(target_id)
# Proxy fallback (or prefer_direct=False)
resp = self._session.post(
f"{self.platform_url}/workspaces/{target_id}/a2a",
json=body, headers=headers, timeout=30.0,
)
resp.raise_for_status()
return resp.json()
# ------------------------------------------------------------------
# Delegation — KI-002 idempotency guard
# ------------------------------------------------------------------
def delegate(
self,
task: str,
target_id: str,
idempotency_key: str | None = None,
timeout: float = 300.0,
) -> dict[str, Any]:
"""Delegate a task to a peer workspace via the platform proxy.
KI-002: To prevent duplicate execution when a container restarts mid-
delegation, an idempotency key is computed from ``task + current
minute`` and sent as ``idempotency_key`` in the request body. The
platform deduplicates requests sharing the same key within the
minute window. Pass an explicit ``idempotency_key`` to override the
auto-computed value (useful for callers that manage their own key
scheme).
Args:
task: Human-readable task description sent to the target.
target_id: Workspace ID of the peer to delegate to.
idempotency_key: Optional override for the idempotency key. If
omitted, one is auto-generated from the task text + current
wall-clock minute.
timeout: Request timeout in seconds. Default 300 s.
Returns:
The platform's JSON response dict.
Raises:
``requests.HTTPError`` on non-2xx responses.
"""
key = idempotency_key if idempotency_key else make_idempotency_key(task)
resp = self._session.post(
f"{self.platform_url}/workspaces/{target_id}/delegate",
headers={
**self._auth_headers(),
"X-Workspace-ID": self.workspace_id,
"Content-Type": "application/json",
},
json={
"task": task,
"idempotency_key": key,
},
timeout=timeout,
)
resp.raise_for_status()
return resp.json()
# ------------------------------------------------------------------
# Plugin install (Phase 30.3)
# ------------------------------------------------------------------
@property
def plugins_dir(self) -> Path:
"""Where pulled plugins are unpacked. Lives under the same
per-workspace directory as the auth token (``~/.molecule/<id>/``)
so a single rm cleans the agent's local state."""
return self._token_dir / "plugins"
def install_plugin(
self,
name: str,
source: str | None = None,
run_setup_sh: bool = True,
report_to_platform: bool = True,
) -> Path:
"""Pull a plugin tarball from the platform, unpack it, optionally
run its ``setup.sh``, and report success.
Phase 30.3 contract:
1. Stream ``GET /workspaces/:id/plugins/:name/download[?source=…]``
2. Atomically extract into ``~/.molecule/<workspace>/plugins/<name>/``
via a sibling-tempdir + rename so a partial extract never
leaves the directory in a half-installed state.
3. If ``setup.sh`` exists in the unpacked tree, run it (bash) so
pip/npm deps land on the agent's machine. Failures from
``setup.sh`` are logged but don't prevent the install record
— the agent author can re-run setup manually.
4. POST ``/workspaces/:id/plugins`` with the source string so the
platform's ``workspace_plugins`` table reflects the install.
3b. If ``plugin.yaml`` declares a ``sha256`` field, verify the unpacked
content matches before running ``setup.sh`` (blocks supply-chain
tampering after the platform-level pinned-ref check from PR #1019).
Returns the path to the unpacked plugin directory.
Raises ``requests.HTTPError`` on download failure (401 / 404 / etc.).
Raises ``ValueError`` if a declared ``sha256`` does not match.
"""
target = self.plugins_dir / name
staging = self.plugins_dir / f".staging-{name}-{uuid.uuid4().hex[:8]}"
self.plugins_dir.mkdir(parents=True, exist_ok=True)
url = f"{self.platform_url}/workspaces/{self.workspace_id}/plugins/{name}/download"
params: dict[str, str] = {}
if source:
params["source"] = source
# We pull the whole tarball into memory before extracting. The
# platform side (PLUGIN_INSTALL_MAX_DIR_BYTES) caps plugin size
# at 100 MiB by default, which is comfortable to hold in process
# memory and lets us use tarfile's seekable r:gz mode (the
# streaming r|gz mode is sequential-only and breaks on tarballs
# whose entries weren't sorted by name, which we don't enforce).
# If the cap is ever raised above ~500 MiB, switch to a temp
# file: tarfile.open(fileobj=open(temp, "rb"), mode="r:gz").
with self._session.get(
url, headers=self._auth_headers(), params=params,
timeout=60.0,
) as resp:
resp.raise_for_status()
staging.mkdir(parents=True)
try:
import io as _io
with tarfile.open(fileobj=_io.BytesIO(resp.content), mode="r:gz") as tf:
_safe_extract_tar(tf, staging)
except Exception:
# Roll back the partial extract
_rmtree_quiet(staging)
raise
# Atomic swap: remove old dir if present, rename staging into place.
if target.exists():
_rmtree_quiet(target)
staging.rename(target)
logger.info("plugin %s unpacked to %s", name, target)
# 3b. Content integrity — verify sha256 declared in plugin.yaml if present.
# This closes the SDK-side gap identified in SDK-Dev review of PR #1019:
# the platform enforces source-pinned refs (SHA/tag) but cannot detect
# content tampering after the tarball is served. The SDK verifies the
# delivered content against a manifest hash in plugin.yaml.
manifest_path = target / "plugin.yaml"
if manifest_path.exists():
try:
import yaml as _yaml # lazy — not a top-level dep
manifest = _yaml.safe_load(manifest_path.read_text())
expected_sha: str | None = manifest.get("sha256") if isinstance(manifest, dict) else None
if expected_sha:
if not verify_plugin_sha256(target, expected_sha):
_rmtree_quiet(target)
raise ValueError(
f"plugin {name} sha256 mismatch — expected {expected_sha}. "
f"Plugin directory may have been tampered with after download. "
f"setup.sh was NOT run."
)
logger.info("plugin %s sha256 verified", name)
except ValueError:
raise # re-raise our own ValueError (sha256 mismatch)
except Exception as exc:
# YAML read / parse failures are non-fatal — skip verification
# and fall through to setup.sh. We log the error so operators
# can see why verification was skipped.
logger.warning("plugin %s sha256 verification skipped: %s", name, exc)
# 3. setup.sh — best-effort. We never raise on its failure because
# the plugin files are now correctly installed; setup is just for
# heavy deps that the agent author can rerun manually.
if run_setup_sh:
setup = target / "setup.sh"
if setup.is_file():
try:
proc = subprocess.run(
["bash", str(setup)],
cwd=str(target),
capture_output=True, text=True, timeout=120,
)
if proc.returncode == 0:
logger.info("plugin %s setup.sh ok", name)
else:
logger.warning(
"plugin %s setup.sh exit=%d stderr=%s",
name, proc.returncode, proc.stderr[:200],
)
except subprocess.TimeoutExpired:
logger.warning("plugin %s setup.sh timed out (120s)", name)
except FileNotFoundError:
logger.warning("plugin %s setup.sh present but bash not found", name)
# 4. Report to platform — write a workspace_plugins row so List
# reflects the install. Best-effort; the local files are correct
# regardless of whether this POST succeeds.
if report_to_platform:
try:
report_source = source or f"local://{name}"
self._session.post(
f"{self.platform_url}/workspaces/{self.workspace_id}/plugins",
headers=self._auth_headers(),
json={"source": report_source},
timeout=10.0,
)
except Exception as exc:
logger.warning("plugin %s install record POST failed: %s", name, exc)
return target
# ------------------------------------------------------------------
# Run loop
# ------------------------------------------------------------------
def run_heartbeat_loop(
self,
max_iterations: int | None = None,
task_supplier: "callable | None" = None,
) -> str:
"""Drive heartbeat + state-poll on a timer. Returns the terminal
status when the loop exits (``"paused"``, ``"removed"``, or
``"max_iterations"``).
Args:
max_iterations: Stop after N loop iterations. None = run until
the workspace is paused / deleted. Useful for tests and
smoke scripts.
task_supplier: Optional zero-arg callable returning a dict
``{"current_task": str, "active_tasks": int}`` fetched
each iteration. Lets the agent report what it's doing.
The loop sends one heartbeat + one state poll per iteration; the
next iteration sleeps for ``heartbeat_interval`` seconds. Errors
from either call are logged and the loop continues — we deliberately
do NOT re-raise because a transient platform hiccup shouldn't take
a remote agent offline.
"""
i = 0
while True:
if max_iterations is not None and i >= max_iterations:
return "max_iterations"
i += 1
report: dict[str, Any] = {}
if task_supplier is not None:
try:
report = task_supplier() or {}
except Exception as exc:
logger.warning("task_supplier raised: %s", exc)
try:
self.heartbeat(
current_task=str(report.get("current_task", "")),
active_tasks=int(report.get("active_tasks", 0)),
)
except Exception as exc:
logger.warning("heartbeat failed: %s — continuing", exc)
try:
state = self.poll_state()
except Exception as exc:
logger.warning("state poll failed: %s — continuing", exc)
state = None
if state is not None and state.should_stop:
logger.info(
"platform reports workspace %s (paused=%s deleted=%s) — exiting",
state.status, state.paused, state.deleted,
)
return state.status
time.sleep(self.heartbeat_interval)
__all__ = [
"RemoteAgentClient",
"WorkspaceState",
"PeerInfo",
"DEFAULT_HEARTBEAT_INTERVAL",
"DEFAULT_STATE_POLL_INTERVAL",
"DEFAULT_URL_CACHE_TTL",
"verify_plugin_sha256",
"make_idempotency_key",
]