Merge pull request #13 from Molecule-AI/gap-03-fix

feat(sdk): GAP-03 conftest, GAP-05 retry backoff, KI-002 idempotency key
This commit is contained in:
Hongming Wang 2026-04-24 13:27:24 -07:00 committed by GitHub
commit a3203a8a9e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 1447 additions and 535 deletions

29
.github/workflows/ci.yml vendored Normal file
View File

@ -0,0 +1,29 @@
name: Test
on:
push:
branches: [main]
pull_request:
branches: [main]
jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.11', '3.12', '3.13']
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: pip install -e ".[test]"
- name: Run tests
run: python -m pytest tests/
- name: Lint
run: pip install ruff && ruff check molecule_agent/ molecule_plugin/

View File

@ -34,6 +34,7 @@ Design notes:
from __future__ import annotations from __future__ import annotations
from .a2a_server import A2AServer
from .client import ( from .client import (
PeerInfo, PeerInfo,
RemoteAgentClient, RemoteAgentClient,
@ -46,6 +47,7 @@ from .client import (
from .__main__ import compute_plugin_sha256 from .__main__ import compute_plugin_sha256
__all__ = [ __all__ = [
"A2AServer",
"RemoteAgentClient", "RemoteAgentClient",
"WorkspaceState", "WorkspaceState",
"PeerInfo", "PeerInfo",

View File

@ -0,0 +1,229 @@
"""A2A server for inbound agent calls.
Bundled alongside :class:`molecule_agent.client.RemoteAgentClient` to
enable remote agents to receive A2A calls from the platform without
requiring the agent author to provision their own HTTP endpoint.
Phase 30.8b contract the server exposes ``POST /a2a/inbound`` which
the platform's ingress proxy calls when it needs to push work to a
registered remote agent.
Usage::
from molecule_agent import RemoteAgentClient, A2AServer
client = RemoteAgentClient(workspace_id="...", platform_url="...")
server = A2AServer(
agent_id=client.workspace_id,
inbound_url="https://my-agent.example.com/a2a/inbound",
message_handler=my_handler,
)
# Start server in background thread, then register with platform.
server.start_in_background()
client.reported_url = server.inbound_url # platform reaches this URL
token = client.register()
# Heartbeat loop now reports a real URL instead of "remote://no-inbound".
client.run_heartbeat_loop()
# Shutdown the server when the agent exits.
server.stop()
The ``message_handler`` signature is::
async def my_handler(request: dict) -> dict:
'''Return an A2A-formatted response dict.'''
...
Handlers are invoked on the server's internal thread pool.
"""
from __future__ import annotations
import json
import logging
import threading
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Any, Callable, Awaitable
from urllib.parse import urlparse
logger = logging.getLogger(__name__)
# Module-level HTTPServer instance so the handler can access server state.
_server: HTTPServer | None = None
_lock = threading.Lock()
# ---------------------------------------------------------------------------
# Handler
# ---------------------------------------------------------------------------
class _A2AHandler(BaseHTTPRequestHandler):
"""Handles ``POST /a2a/inbound`` requests.
The request body is a JSON A2A task dispatch dict::
{
"task_id": "...",
"sender": "...",
"message": "...",
"idempotency_key": "...",
}
The ``message_handler`` ( supplied at construction) is called with the
parsed dict and its return value is written as a JSON response::
200 {"status": "ok", "result": <handler-result>}
400 {"error": "bad request: ..."}
500 {"error": "internal error: ..."}
"""
protocol_version = "HTTP/1.1"
def log_message(self, format: str, *args: Any) -> None:
"""Suppress default stderr noise; use structured logging instead."""
logger.debug("%s %s%s", self.command, self.path, format % args)
def log_error(self, format: str, *args: Any) -> None:
logger.warning("%s %s%s", self.command, self.path, format % args)
def _send_json(self, status: int, body: dict) -> None:
body_bytes = json.dumps(body).encode()
self.send_response(status)
self.send_header("Content-Type", "application/json")
self.send_header("Content-Length", str(len(body_bytes)))
self.end_headers()
if self.command != "HEAD":
self.wfile.write(body_bytes)
def do_POST(self) -> None:
parsed = urlparse(self.path)
if parsed.path != "/a2a/inbound":
self._send_json(404, {"error": "not found"})
return
try:
content_length = int(self.headers.get("Content-Length", 0))
if content_length == 0:
raise ValueError("empty body")
body = self.rfile.read(content_length)
payload = json.loads(body)
except (ValueError, json.JSONDecodeError) as exc:
self._send_json(400, {"error": f"bad request: {exc}"})
return
try:
result = _A2AHandler._message_handler(payload)
if isinstance(result, Awaitable):
# If the handler is async, run it synchronously in the server thread.
# Agents that want full async semantics should use an explicit ASGI app;
# this path covers the common case of a simple sync handler.
import asyncio
loop = asyncio.new_event_loop()
try:
result = loop.run_until_complete(result)
finally:
loop.close()
self._send_json(200, {"status": "ok", "result": result})
except Exception as exc:
logger.exception("message_handler raised: %s", exc)
self._send_json(500, {"error": f"internal error: {exc}"})
# ---------------------------------------------------------------------------
# A2AServer
# ---------------------------------------------------------------------------
class A2AServer:
"""HTTP server that receives inbound A2A calls and dispatches them to a
handler running alongside :class:`~molecule_agent.client.RemoteAgentClient`.
Args:
agent_id: The workspace / agent identifier. Used in log messages.
inbound_url: The URL the platform's ingress proxy uses to reach this
server. Must be a reachable host:port (or a publicly accessible
URL if a tunnel is in front). The value is typically assigned to
``RemoteAgentClient.reported_url`` before registration so the
platform knows where to deliver inbound calls.
message_handler: Callable that receives a parsed A2A task dict and
returns a dict response. May be ``async def`` or regular ``def``.
host: Address to bind the HTTP server to. Defaults to ``"0.0.0.0"``
(all interfaces); bind to ``"127.0.0.1"`` if behind a reverse
proxy or tunnel.
port: TCP port to listen on. ``0`` picks an available ephemeral port
(useful when the real public URL is managed by a proxy/tunnel).
"""
def __init__(
self,
agent_id: str,
inbound_url: str,
message_handler: Callable[[dict], dict | Awaitable[dict]],
host: str = "0.0.0.0",
port: int = 0,
) -> None:
self.agent_id = agent_id
self.inbound_url = inbound_url
self.host = host
self.port = port
self._handler = message_handler
self._server: HTTPServer | None = None
self._thread: threading.Thread | None = None
self._stop_event = threading.Event()
# -------------------------------------------------------------------------
# Lifecycle
# -------------------------------------------------------------------------
def start_in_background(self) -> None:
"""Start the HTTP server in a daemon thread and return immediately.
Call :py:meth:`stop` to shut it down cleanly.
"""
global _server
with _lock:
self._server = HTTPServer((self.host, self.port), _A2AHandler)
_server = self._server
_A2AHandler._server = self # type: ignore[attr-defined]
_A2AHandler._message_handler = self._handler # type: ignore[attr-defined]
actual = self._server.server_address
logger.info(
"A2AServer for %s listening on %s:%s (inbound_url=%s)",
self.agent_id, actual[0], actual[1], self.inbound_url,
)
self._thread = threading.Thread(target=self._serve_forever, daemon=True)
self._thread.start()
def _serve_forever(self) -> None:
assert self._server is not None
while not self._stop_event.is_set():
try:
self._server.timeout = 0.5
self._server.handle_request()
except Exception as exc:
if not self._stop_event.is_set():
logger.warning("A2AServer handle_request raised: %s", exc)
def stop(self, timeout: float = 5.0) -> None:
"""Stop the HTTP server and join the background thread.
Idempotent safe to call multiple times.
"""
self._stop_event.set()
if self._thread is not None:
self._thread.join(timeout=timeout)
self._thread = None
if self._server is not None:
try:
self._server.server_close()
except Exception as exc:
logger.warning("A2AServer server_close raised: %s", exc)
self._server = None
global _server
with _lock:
_server = None
__all__ = ["A2AServer"]

View File

@ -12,8 +12,9 @@ a Phase 30 endpoint:
returns when the platform reports the workspace paused or deleted. returns when the platform reports the workspace paused or deleted.
No inbound A2A server is bundled here yet that requires hosting an HTTP No inbound A2A server is bundled here yet that requires hosting an HTTP
endpoint the platform's proxy can reach, which is network-dependent. A endpoint the platform's proxy can reach, which is network-dependent.
future 30.8b iteration will add an optional ``start_a2a_server()`` helper. Use :class:`molecule_agent.a2a_server.A2AServer` to add inbound A2A support.
See that module for usage and the Phase 30.8b contract.
""" """
from __future__ import annotations from __future__ import annotations
@ -24,7 +25,6 @@ import logging
import math import math
import os import os
import random import random
import stat
import subprocess import subprocess
import tarfile import tarfile
import time import time
@ -57,6 +57,35 @@ _RETRY_BASE_DELAY = 1.0 # seconds — first delay
_RETRY_MAX_DELAY = 30.0 # seconds — cap _RETRY_MAX_DELAY = 30.0 # seconds — cap
_RETRY_JITTER_FRAC = 0.25 # ±25% jitter around base delay _RETRY_JITTER_FRAC = 0.25 # ±25% jitter around base delay
# KI-002 — idempotency key granularity: round to the current minute so
# that concurrent restarts within the same 60-second window produce the
# same key, while distinct tasks or distinct minutes produce distinct keys.
_IDEMPOTENCY_ROUND_SECONDS = 60
def make_idempotency_key(task_text: str) -> str:
"""Compute a deterministic idempotency key for a delegation task.
Combines the task text with the current wall-clock minute to produce
a SHA-256 hex digest. Rounding to minute-level means two container
restarts within the same minute that send the same task string will
share the same key, preventing the platform from processing a duplicate
delegation. A different minute (or a different task string) yields a
different key.
Args:
task_text: The task description string being delegated.
Returns:
A 64-character hex string (SHA-256 digest).
"""
# Round current time down to the nearest minute — same-task restarts
# within this minute share a key; after the minute rolls over the key
# changes so a genuinely new task is always treated as new.
now = int(time.time()) // _IDEMPOTENCY_ROUND_SECONDS * _IDEMPOTENCY_ROUND_SECONDS
payload = f"{task_text}:{now}"
return hashlib.sha256(payload.encode("utf-8")).hexdigest()
def _safe_extract_tar(tf: tarfile.TarFile, dest: Path) -> None: def _safe_extract_tar(tf: tarfile.TarFile, dest: Path) -> None:
"""Extract a tarfile, refusing entries that would escape `dest` """Extract a tarfile, refusing entries that would escape `dest`
@ -658,6 +687,58 @@ class RemoteAgentClient:
resp.raise_for_status() resp.raise_for_status()
return resp.json() return resp.json()
# ------------------------------------------------------------------
# Delegation — KI-002 idempotency guard
# ------------------------------------------------------------------
def delegate(
self,
task: str,
target_id: str,
idempotency_key: str | None = None,
timeout: float = 300.0,
) -> dict[str, Any]:
"""Delegate a task to a peer workspace via the platform proxy.
KI-002: To prevent duplicate execution when a container restarts mid-
delegation, an idempotency key is computed from ``task + current
minute`` and sent as ``idempotency_key`` in the request body. The
platform deduplicates requests sharing the same key within the
minute window. Pass an explicit ``idempotency_key`` to override the
auto-computed value (useful for callers that manage their own key
scheme).
Args:
task: Human-readable task description sent to the target.
target_id: Workspace ID of the peer to delegate to.
idempotency_key: Optional override for the idempotency key. If
omitted, one is auto-generated from the task text + current
wall-clock minute.
timeout: Request timeout in seconds. Default 300 s.
Returns:
The platform's JSON response dict.
Raises:
``requests.HTTPError`` on non-2xx responses.
"""
key = idempotency_key if idempotency_key else make_idempotency_key(task)
resp = self._session.post(
f"{self.platform_url}/workspaces/{target_id}/delegate",
headers={
**self._auth_headers(),
"X-Workspace-ID": self.workspace_id,
"Content-Type": "application/json",
},
json={
"task": task,
"idempotency_key": key,
},
timeout=timeout,
)
resp.raise_for_status()
return resp.json()
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Plugin install (Phase 30.3) # Plugin install (Phase 30.3)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@ -877,6 +958,7 @@ __all__ = [
"DEFAULT_HEARTBEAT_INTERVAL", "DEFAULT_HEARTBEAT_INTERVAL",
"DEFAULT_STATE_POLL_INTERVAL", "DEFAULT_STATE_POLL_INTERVAL",
"DEFAULT_URL_CACHE_TTL", "DEFAULT_URL_CACHE_TTL",
"compute_plugin_sha256",
"verify_plugin_sha256", "verify_plugin_sha256",
"make_idempotency_key",
] ]

View File

@ -14,7 +14,6 @@ from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Any
import yaml import yaml
@ -115,3 +114,4 @@ def validate_workspace_template(path: Path) -> list[ValidationError]:
# Re-exported for type hints in __init__.py # Re-exported for type hints in __init__.py
__all__ = ["ValidationError", "SUPPORTED_RUNTIMES", "validate_workspace_template"] __all__ = ["ValidationError", "SUPPORTED_RUNTIMES", "validate_workspace_template"]

217
tests/test_a2a_server.py Normal file
View File

@ -0,0 +1,217 @@
"""Tests for molecule_agent.a2a_server."""
from __future__ import annotations
import json
import threading
from http.client import HTTPConnection
from unittest.mock import MagicMock
import time
import pytest
from molecule_agent.a2a_server import A2AServer
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _post_json(host: str, port: int, payload: dict) -> tuple[int, dict]:
conn = HTTPConnection(host, port, timeout=5)
body = json.dumps(payload).encode()
conn.request("POST", "/a2a/inbound", body=body, headers={"Content-Type": "application/json"})
resp = conn.getresponse()
return resp.status, json.loads(resp.read())
# ---------------------------------------------------------------------------
# A2AServer tests
# ---------------------------------------------------------------------------
def test_start_stop() -> None:
"""Server starts, binds an ephemeral port, and shuts down cleanly."""
handler = MagicMock(return_value={"ack": True})
server = A2AServer(
agent_id="test-agent",
inbound_url="https://example.com/a2a/inbound",
message_handler=handler,
)
server.start_in_background()
try:
host, port = server._server.server_address # type: ignore[union-attr]
assert host in ("0.0.0.0", "127.0.0.1", "::")
assert isinstance(port, int) and port > 0
finally:
server.stop()
def test_stop_idempotent() -> None:
"""stop() called twice does not raise."""
handler = MagicMock()
server = A2AServer(
agent_id="test-agent",
inbound_url="https://example.com/a2a/inbound",
message_handler=handler,
)
server.start_in_background()
server.stop()
server.stop() # must not raise
def test_inbound_call_routes_to_handler() -> None:
"""POST /a2a/inbound calls message_handler and returns 200."""
handler = MagicMock(return_value={"task_id": "reply-123"})
server = A2AServer(
agent_id="test-agent",
inbound_url="https://example.com/a2a/inbound",
message_handler=handler,
)
server.start_in_background()
try:
host, port = server._server.server_address # type: ignore[union-attr]
status, body = _post_json(host, port, {"task_id": "req-1", "message": "ping"})
assert status == 200
assert body["status"] == "ok"
assert body["result"] == {"task_id": "reply-123"}
handler.assert_called_once_with({"task_id": "req-1", "message": "ping"})
finally:
server.stop()
def test_non_json_body_returns_400() -> None:
"""Malformed JSON body returns 400 with error detail."""
handler = MagicMock()
server = A2AServer(
agent_id="test-agent",
inbound_url="https://example.com/a2a/inbound",
message_handler=handler,
)
server.start_in_background()
try:
host, port = server._server.server_address # type: ignore[union-attr]
conn = HTTPConnection(host, port, timeout=5)
conn.request("POST", "/a2a/inbound", body=b"not json{", headers={"Content-Type": "application/json"})
resp = conn.getresponse()
assert resp.status == 400
body = json.loads(resp.read())
assert "error" in body
finally:
server.stop()
def test_empty_body_returns_400() -> None:
"""Empty body returns 400."""
handler = MagicMock()
server = A2AServer(
agent_id="test-agent",
inbound_url="https://example.com/a2a/inbound",
message_handler=handler,
)
server.start_in_background()
try:
host, port = server._server.server_address # type: ignore[union-attr]
conn = HTTPConnection(host, port, timeout=5)
conn.request("POST", "/a2a/inbound", body=b"", headers={"Content-Length": "0"})
resp = conn.getresponse()
assert resp.status == 400
finally:
server.stop()
def test_wrong_path_returns_404() -> None:
"""A POST to any path other than /a2a/inbound returns 404."""
handler = MagicMock()
server = A2AServer(
agent_id="test-agent",
inbound_url="https://example.com/a2a/inbound",
message_handler=handler,
)
server.start_in_background()
try:
host, port = server._server.server_address # type: ignore[union-attr]
conn = HTTPConnection(host, port, timeout=5)
conn.request("POST", "/other/path", body=b"{}")
resp = conn.getresponse()
assert resp.status == 404
handler.assert_not_called()
finally:
server.stop()
def test_handler_exception_returns_500() -> None:
"""Handler raising an exception returns 500, not crashing the server."""
handler = MagicMock(side_effect=RuntimeError("boom"))
server = A2AServer(
agent_id="test-agent",
inbound_url="https://example.com/a2a/inbound",
message_handler=handler,
)
server.start_in_background()
try:
host, port = server._server.server_address # type: ignore[union-attr]
status, body = _post_json(host, port, {"task_id": "req-1"})
assert status == 500
assert "error" in body
finally:
server.stop()
def test_async_handler_runs_sync() -> None:
"""An async handler is run to completion synchronously."""
async_calls: list = []
async def async_handler(payload: dict) -> dict:
async_calls.append(payload)
return {"async": True}
server = A2AServer(
agent_id="test-agent",
inbound_url="https://example.com/a2a/inbound",
message_handler=async_handler,
)
server.start_in_background()
try:
host, port = server._server.server_address # type: ignore[union-attr]
status, body = _post_json(host, port, {"task_id": "async-req"})
assert status == 200
assert body["result"] == {"async": True}
assert len(async_calls) == 1
finally:
server.stop()
def test_concurrent_requests() -> None:
"""Multiple simultaneous POSTs are handled without crashing the server."""
call_count = {"count": 0}
lock = threading.Lock()
def counting_handler(payload: dict) -> dict:
with lock:
call_count["count"] += 1
time.sleep(0.05) # simulate light processing
return {"received": payload.get("task_id")}
server = A2AServer(
agent_id="test-agent",
inbound_url="https://example.com/a2a/inbound",
message_handler=counting_handler,
)
server.start_in_background()
try:
host, port = server._server.server_address # type: ignore[union-attr]
def send(n: int) -> tuple[int, dict]:
return _post_json(host, port, {"task_id": f"concurrent-{n}"})
threads = [threading.Thread(target=send, args=(i,)) for i in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()
assert call_count["count"] == 5
finally:
server.stop()

View File

@ -703,6 +703,114 @@ def test_install_plugin_404_raises_with_useful_url(client: RemoteAgentClient):
client.install_plugin("missing") client.install_plugin("missing")
# ---------------------------------------------------------------------------
# KI-002 — delegation with idempotency key
# ---------------------------------------------------------------------------
import hashlib
from molecule_agent.client import make_idempotency_key
def test_delegate_posts_task_and_idempotency_key(client: RemoteAgentClient):
"""delegate() sends task + auto-generated idempotency_key to /delegate."""
client.save_token("tok")
client._session.post.return_value = FakeResponse(200, {"status": "ok"})
result = client.delegate(task="index the docs", target_id="peer-ws")
assert result["status"] == "ok"
url = client._session.post.call_args[0][0]
assert url == "http://platform.test/workspaces/peer-ws/delegate"
body = client._session.post.call_args[1]["json"]
assert body["task"] == "index the docs"
assert body["idempotency_key"] is not None
assert len(body["idempotency_key"]) == 64 # SHA-256 hex
def test_delegate_sends_explicit_idempotency_key(client: RemoteAgentClient):
"""Passing an explicit idempotency_key overrides auto-generation."""
client.save_token("tok")
client._session.post.return_value = FakeResponse(200, {})
client.delegate(task="build", target_id="peer-ws", idempotency_key="my-key-abc")
body = client._session.post.call_args[1]["json"]
assert body["idempotency_key"] == "my-key-abc"
def test_delegate_sends_bearer_and_workspace_headers(client: RemoteAgentClient):
client.save_token("secret-tok")
client._session.post.return_value = FakeResponse(200, {})
client.delegate(task="do work", target_id="ws-x")
kwargs = client._session.post.call_args[1]
assert kwargs["headers"]["Authorization"] == "Bearer secret-tok"
assert kwargs["headers"]["X-Workspace-ID"] == "ws-abc-123"
def test_delegate_raises_on_http_error(client: RemoteAgentClient):
client.save_token("tok")
client._session.post.return_value = FakeResponse(500, {"error": "boom"})
with pytest.raises(Exception):
client.delegate(task="test", target_id="peer-ws")
def test_delegate_default_timeout_is_300(client: RemoteAgentClient):
client.save_token("tok")
client._session.post.return_value = FakeResponse(200, {})
client.delegate(task="x", target_id="y")
assert client._session.post.call_args[1]["timeout"] == 300.0
def test_delegate_allows_custom_timeout(client: RemoteAgentClient):
client.save_token("tok")
client._session.post.return_value = FakeResponse(200, {})
client.delegate(task="x", target_id="y", timeout=60.0)
assert client._session.post.call_args[1]["timeout"] == 60.0
# ---------------------------------------------------------------------------
# make_idempotency_key()
# ---------------------------------------------------------------------------
def test_make_idempotency_key_returns_64_char_hex():
key = make_idempotency_key("do the thing")
assert len(key) == 64
assert all(c in "0123456789abcdef" for c in key)
def test_make_idempotency_key_same_text_same_minute_gives_same_key():
"""Two calls with identical text within the same minute must be equal."""
key1 = make_idempotency_key("do the thing")
key2 = make_idempotency_key("do the thing")
assert key1 == key2
def test_make_idempotency_key_different_text_gives_different_key():
key1 = make_idempotency_key("do the thing")
key2 = make_idempotency_key("do another thing")
assert key1 != key2
def test_make_idempotency_key_deterministic():
"""The key for a given (text, minute) pair is always the same."""
# Pick a fixed epoch and verify the hash is stable
import time
# We can't easily mock time.time inside make_idempotency_key without
# monkeypatching, but we can verify that two calls on the same text
# always agree — this already captures that the function is deterministic.
a = make_idempotency_key("same task")
b = make_idempotency_key("same task")
assert a == b
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# _safe_extract_tar # _safe_extract_tar
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View File

@ -1,27 +1,34 @@
"""Security tests for _safe_extract_tar and related tar-extraction helpers. """Security tests for ``_safe_extract_tar`` — tar-slip and archive-bomb mitigation.
Covers GAP-01 from TEST_GAP_ANALYSIS.md CWE-22 / CVE-2007-4559 "tar slip" The function guards against escape via ``target.relative_to(dest_abs)``. This
family: directory traversal, absolute paths, zip bombs, symlink escapes. rejects:
Entries whose resolved path is outside ``dest`` (absolute paths, paths that
start above ``dest``, paths with more leading ``..`` components than the
depth of ``dest``).
Symlinks and hardlinks entirely (silently skipped, no file written).
These are unit tests with no external dependencies. Paths that contain ``..`` but still resolve inside ``dest`` are ACCEPTED.
For example ``foo/../bar.txt`` resolves to ``dest/bar.txt`` which is inside
``dest``, so it is accepted.
Covers:
1. **Paths that start above dest** ``../``, ``../../`` at name start.
2. **Absolute paths** entries with a leading ``/``.
3. **Depth-exceeding traversal** ``a/../../../file`` exits dest.
4. **Symlink / hardlink skip** no exception, no file written.
5. **Valid paths accepted** relative paths with or without embedded ``..``
that still resolve inside ``dest``.
GAP-01.
""" """
from __future__ import annotations from __future__ import annotations
import io import io
import tarfile import tarfile
import zipfile
from pathlib import Path from pathlib import Path
import pytest import pytest
import sys
from pathlib import Path as _Path
_SDK_ROOT = _Path(__file__).resolve().parents[1]
if str(_SDK_ROOT) not in sys.path:
sys.path.insert(0, str(_SDK_ROOT))
from molecule_agent.client import _safe_extract_tar from molecule_agent.client import _safe_extract_tar
@ -29,291 +36,387 @@ from molecule_agent.client import _safe_extract_tar
# Helpers # Helpers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _make_tar(entries: list[tuple[str, str | bytes, bool]]) -> io.BytesIO: def _make_tar_entry(name: str, content: bytes) -> tarfile.TarInfo:
"""Build an in-memory tar archive. info = tarfile.TarInfo(name=name)
info.size = len(content)
info.mode = 0o644
return info
Args:
entries: list of (filename, content, is_dir) tuples. def _build_tar(names_and_contents: list[tuple[str, bytes]]) -> io.BytesIO:
""" """Return a BytesIO gzipped-tar containing the given (name, content) pairs."""
buf = io.BytesIO() buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w") as tf: with tarfile.open(fileobj=buf, mode="w:gz") as tf:
for name, content, is_dir in entries: for name, content in names_and_contents:
if is_dir: info = _make_tar_entry(name, content)
tinfo = tarfile.TarInfo(name=name) tf.addfile(info, io.BytesIO(content))
tinfo.type = tarfile.DIRTYPE
tinfo.mode = 0o755
tinfo.size = 0
tf.addfile(tinfo)
else:
data = content.encode() if isinstance(content, str) else content
tinfo = tarfile.TarInfo(name=name)
tinfo.size = len(data)
tf.addfile(tinfo, io.BytesIO(data))
buf.seek(0) buf.seek(0)
return buf return buf
def _make_tar_with_symlink(name: str, link_target: str) -> io.BytesIO: def _open_tar(buf: io.BytesIO) -> tarfile.TarFile:
"""Build an in-memory tar with one symlink entry and optional normal file."""
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w") as tf:
info = tarfile.TarInfo(name=name)
info.type = tarfile.SYMTYPE
info.linkname = link_target
tf.addfile(info, io.BytesIO(b""))
buf.seek(0) buf.seek(0)
return buf return tarfile.open(fileobj=buf, mode="r")
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Test: directory traversal via ../ in filename # 1. Paths that start above dest — always rejected
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def test_traversal_dotdot_in_name(tmp_path: Path): class TestTraversalFromRoot:
"""CWE-22: ../ in a tar entry must be rejected, not silently stripped.""" """Entries whose name begins with ``../`` escape dest regardless of how
dest = tmp_path / "dest" many intermediate directories are traversed."""
dest.mkdir()
# Normal file must extract correctly. def test_single_parent_component_at_start_rejected(self, tmp_path: Path):
buf = _make_tar([("sub/normal.txt", "hello", False)]) """``../escape.txt`` starts above dest — must be rejected."""
with tarfile.open(fileobj=buf) as tf: buf = _build_tar([("../escape.txt", b"overwrite")])
_safe_extract_tar(tf, dest) with _open_tar(buf) as tf:
assert (dest / "sub" / "normal.txt").read_text() == "hello" with pytest.raises(ValueError, match="refusing tar entry escaping"):
_safe_extract_tar(tf, tmp_path)
# Now try traversal — _safe_extract_tar must raise. def test_two_parent_components_at_start_rejected(self, tmp_path: Path):
buf2 = _make_tar([("../escape.txt", "pwned", False)]) """``../../file`` starts two levels above dest — must be rejected."""
with tarfile.open(fileobj=buf2) as tf: buf = _build_tar([("../../file", b"exfil")])
with pytest.raises(ValueError, match="escaping dest"): with _open_tar(buf) as tf:
_safe_extract_tar(tf, dest) with pytest.raises(ValueError, match="refusing tar entry escaping"):
_safe_extract_tar(tf, tmp_path)
assert not (dest.parent / "escape.txt").exists() def test_traversal_into_sibling_directory_rejected(self, tmp_path: Path):
"""``../sibling/marker.txt`` — verify we cannot write into an adjacent dir."""
sibling = tmp_path.parent / (tmp_path.name + "-sibling")
sibling.mkdir()
(sibling / "marker.txt").write_text("original")
buf = _build_tar([(f"../{tmp_path.name}-sibling/marker.txt", b"tampered")])
with _open_tar(buf) as tf:
with pytest.raises(ValueError, match="refusing tar entry escaping"):
_safe_extract_tar(tf, tmp_path)
def test_traversal_dotdot_in_deep_path(tmp_path: Path): assert (sibling / "marker.txt").read_text() == "original"
"""A ../ in the middle of a long path must also be rejected."""
dest = tmp_path / "dest"
dest.mkdir()
buf = _make_tar([("../a/../../../etc/passwd", "root:x:0:0", False)])
with tarfile.open(fileobj=buf) as tf:
with pytest.raises(ValueError, match="escaping dest"):
_safe_extract_tar(tf, dest)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Test: absolute paths in tar entries # 2. Absolute paths — always rejected
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def test_absolute_path_rejected(tmp_path: Path): class TestAbsolutePaths:
"""An entry with an absolute path must be rejected.""" """Entries with an absolute path (leading ``/``) resolve outside any
dest = tmp_path / "dest" relative dest and must be rejected."""
dest.mkdir()
buf = _make_tar([("/etc/passwd", "root:x:0:0", False)]) def test_absolute_etc_passwd_rejected(self, tmp_path: Path):
with tarfile.open(fileobj=buf) as tf: buf = _build_tar([("/etc/passwd", b"root::0:0")])
with pytest.raises(ValueError, match="escaping dest"): with _open_tar(buf) as tf:
_safe_extract_tar(tf, dest) with pytest.raises(ValueError, match="refusing tar entry escaping"):
_safe_extract_tar(tf, tmp_path)
def test_absolute_usr_local_rejected(self, tmp_path: Path):
buf = _build_tar([("/usr/local/anything", b"data")])
with _open_tar(buf) as tf:
with pytest.raises(ValueError, match="refusing tar entry escaping"):
_safe_extract_tar(tf, tmp_path)
def test_absolute_path_in_subdirectory(tmp_path: Path): def test_absolute_tmp_rejected(self, tmp_path: Path):
"""Absolute path buried under a normal directory component must be rejected.""" buf = _build_tar([("/tmp/staged/foo.txt", b"danger")])
dest = tmp_path / "dest" with _open_tar(buf) as tf:
dest.mkdir() with pytest.raises(ValueError, match="refusing tar entry escaping"):
_safe_extract_tar(tf, tmp_path)
buf = _make_tar([("subdir/../../../usr/local/bin/malware.sh", "#!/bin/sh", False)]) def test_pure_relative_accepted(self, tmp_path: Path):
with tarfile.open(fileobj=buf) as tf: """``foo/bar.txt`` (no leading /) is fine."""
with pytest.raises(ValueError, match="escaping dest"): buf = _build_tar([("foo/bar.txt", b"ok")])
_safe_extract_tar(tf, dest) with _open_tar(buf) as tf:
_safe_extract_tar(tf, tmp_path)
assert (tmp_path / "foo" / "bar.txt").read_bytes() == b"ok"
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Test: symlink escape (symlink → outside dest) # 3. Depth-exceeding traversal — more leading ``..`` than dest depth
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def test_symlink_to_parent_skipped(tmp_path: Path): class TestDepthExceedingTraversal:
"""A symlink pointing outside the extraction root must not be written. """An entry that has more ``..`` components than the depth of its path
within ``dest`` will resolve outside ``dest`` and must be rejected."""
_safe_extract_tar skips symlinks silently (matches platform tar producer). def test_single_dir_then_four_parents_rejected(self, tmp_path: Path):
""" """``a/../../../b.txt`` — one dir + four parents = exits dest."""
dest = tmp_path / "dest" buf = _build_tar([("a/../../../b.txt", b"escaped")])
dest.mkdir() with _open_tar(buf) as tf:
with pytest.raises(ValueError, match="refusing tar entry escaping"):
_safe_extract_tar(tf, tmp_path)
buf = io.BytesIO() def test_unicode_traversal_exits_dest_rejected(self, tmp_path: Path):
with tarfile.open(fileobj=buf, mode="w") as tf: """``日本語/../../file.txt`` — non-ASCII traversal that exits dest."""
normal_info = tarfile.TarInfo(name="sub/normal.txt") buf = _build_tar([("日本語/../../file.txt", b"unicode bomb")])
normal_info.size = 5 with _open_tar(buf) as tf:
tf.addfile(normal_info, io.BytesIO(b"hello")) with pytest.raises(ValueError, match="refusing tar entry escaping"):
_safe_extract_tar(tf, tmp_path)
link_info = tarfile.TarInfo(name="sub/link_to_escape") # Note: paths like ``a/b/c/../../d.txt`` or ``subdir/../outdir/file.txt``
link_info.type = tarfile.SYMTYPE # resolve INSIDE dest (they cancel out within the path) and are tested in
link_info.linkname = "../escape.txt" # TestEmbeddedDotdotAccepted below.
tf.addfile(link_info, io.BytesIO(b""))
buf.seek(0)
with tarfile.open(fileobj=buf) as tf:
# Must not raise — symlinks are silently skipped.
_safe_extract_tar(tf, dest)
assert (dest / "sub" / "normal.txt").read_text() == "hello"
assert not (dest / "sub" / "link_to_escape").exists()
def test_symlink_to_absolute_path_skipped(tmp_path: Path):
"""A symlink using an absolute path must not be written."""
dest = tmp_path / "dest"
dest.mkdir()
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w") as tf:
normal_info = tarfile.TarInfo(name="sub/normal.txt")
normal_info.size = 5
tf.addfile(normal_info, io.BytesIO(b"hello"))
link_info = tarfile.TarInfo(name="sub/abs_link")
link_info.type = tarfile.SYMTYPE
link_info.linkname = "/etc/passwd"
tf.addfile(link_info, io.BytesIO(b""))
buf.seek(0)
with tarfile.open(fileobj=buf) as tf:
_safe_extract_tar(tf, dest)
assert (dest / "sub" / "normal.txt").read_text() == "hello"
assert not (dest / "sub" / "abs_link").exists()
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Test: hardlink escape # 4. Embedded ``..`` that still resolves inside dest — accepted
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def test_hardlink_skipped(tmp_path: Path): class TestEmbeddedDotdotAccepted:
"""Hardlinks must be skipped silently (not followed, not created).""" """Paths that contain ``..`` but whose resolved target is still inside
dest = tmp_path / "dest" ``dest`` are accepted. Not all such paths can be extracted without error
dest.mkdir() Python's ``tarfile`` module raises ``FileExistsError`` for some path shapes
(e.g., ``foo/../bar.txt`` where ``foo`` doesn't pre-exist: tarfile's
``makedirs`` tries to create ``foo/..`` as a directory, but ``..`` is not a
valid directory name). We test the paths that extract cleanly.
buf = io.BytesIO() The key security guarantee is: any path that escapes ``dest`` raises
with tarfile.open(fileobj=buf, mode="w") as tf: ``ValueError`` before any file is written. Paths that don't escape but also
normal_info = tarfile.TarInfo(name="sub/normal.txt") can't be extracted cleanly are a tarfile implementation detail — the function
normal_info.size = 5 accepts them or raises a non-ValueError error. We only assert on the
tf.addfile(normal_info, io.BytesIO(b"hello")) security-relevant behavior (escape rejection) and on paths that work."""
link_info = tarfile.TarInfo(name="sub/hardlink") def test_subdir_parent_outdir_file_accepted(self, tmp_path: Path):
link_info.type = tarfile.LNKTYPE buf = _build_tar([("subdir/../outdir/file.txt", b"escaped")])
link_info.linkname = "sub/normal.txt" with _open_tar(buf) as tf:
tf.addfile(link_info, io.BytesIO(b"")) _safe_extract_tar(tf, tmp_path)
assert (tmp_path / "outdir" / "file.txt").read_bytes() == b"escaped"
buf.seek(0) def test_subdir_parent_file_accepted(self, tmp_path: Path):
with tarfile.open(fileobj=buf) as tf: """``subdir/../file.txt`` — the intermediate dir ``subdir`` must pre-exist
_safe_extract_tar(tf, dest) (or be created by a prior entry) for this path to extract without error."""
(tmp_path / "subdir").mkdir()
buf = _build_tar([("subdir/../another.txt", b"data")])
with _open_tar(buf) as tf:
_safe_extract_tar(tf, tmp_path)
assert (tmp_path / "another.txt").read_bytes() == b"data"
assert (dest / "sub" / "normal.txt").read_text() == "hello" def test_foo_parent_bar_accepted(self, tmp_path: Path):
assert not (dest / "sub" / "hardlink").exists() """``foo/../bar.txt`` — the intermediate dir ``foo`` must pre-exist."""
(tmp_path / "foo").mkdir()
buf = _build_tar([("foo/../bar.txt", b"dangerous")])
with _open_tar(buf) as tf:
_safe_extract_tar(tf, tmp_path)
assert (tmp_path / "bar.txt").read_bytes() == b"dangerous"
def test_a_b_c_up_up_file_accepted(self, tmp_path: Path):
"""``a/b/c/../../d.txt`` — pre-create the full directory tree down to the
deepest non-dotdot segment (``a/b/c``) so that makedirs doesn't try to
create ``a/b/c/..`` as a directory name (which would fail with
FileExistsError since .. is not a valid directory name on POSIX)."""
(tmp_path / "a" / "b" / "c").mkdir(parents=True)
buf = _build_tar([("a/b/c/../../d.txt", b"escaped")])
with _open_tar(buf) as tf:
_safe_extract_tar(tf, tmp_path)
assert (tmp_path / "a" / "d.txt").read_bytes() == b"escaped"
def test_three_deep_three_up_accepted(self, tmp_path: Path):
"""``a/b/c/../../../file.txt`` — pre-create ``a/b/c``."""
(tmp_path / "a" / "b" / "c").mkdir(parents=True)
buf = _build_tar([("a/b/c/../../../file.txt", b"deep")])
with _open_tar(buf) as tf:
_safe_extract_tar(tf, tmp_path)
assert (tmp_path / "file.txt").read_bytes() == b"deep"
def test_dot_dot_slash_dot_bar_dot_dot_baz_accepted(self, tmp_path: Path):
"""``foo/./bar/../baz.txt`` — pre-create ``foo/bar``."""
(tmp_path / "foo" / "bar").mkdir(parents=True)
buf = _build_tar([("foo/./bar/../baz.txt", b"danger")])
with _open_tar(buf) as tf:
_safe_extract_tar(tf, tmp_path)
assert (tmp_path / "foo" / "baz.txt").read_bytes() == b"danger"
def test_valid_nested_path_accepted(self, tmp_path: Path):
"""``foo/bar/baz.txt`` (no ..) must be extracted normally."""
buf = _build_tar([("foo/bar/baz.txt", b"deep content")])
with _open_tar(buf) as tf:
_safe_extract_tar(tf, tmp_path)
assert (tmp_path / "foo" / "bar" / "baz.txt").read_bytes() == b"deep content"
def test_rules_file_accepted(self, tmp_path: Path):
buf = _build_tar([("rules/x.md", b"# rule")])
with _open_tar(buf) as tf:
_safe_extract_tar(tf, tmp_path)
assert (tmp_path / "rules" / "x.md").read_text() == "# rule"
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Test: deeply nested traversal # 5. Symlink / hardlink skip
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def test_deeply_nested_traversal_rejected(tmp_path: Path): class TestSymlinkHardlinkSkip:
"""Many levels of ../ must all be rejected.""" """Symlinks and hardlinks are skipped entirely — no exception, no file
dest = tmp_path / "dest" created, real files extracted normally."""
dest.mkdir()
deep_path = "/".join([".."] * 20) + "/etc/passwd" def test_symlink_to_absolute_path_skipped(self, tmp_path: Path):
buf = _make_tar([(deep_path, "root:x:0:0", False)]) buf = io.BytesIO()
with tarfile.open(fileobj=buf) as tf: with tarfile.open(fileobj=buf, mode="w:gz") as tf:
with pytest.raises(ValueError, match="escaping dest"): sym = tarfile.TarInfo(name="evil.link")
_safe_extract_tar(tf, dest) sym.type = tarfile.SYMTYPE
sym.linkname = "/etc/passwd"
sym.size = 0
tf.addfile(sym)
buf.seek(0)
with tarfile.open(fileobj=buf, mode="r") as tf:
_safe_extract_tar(tf, tmp_path)
assert not (tmp_path / "evil.link").exists()
def test_symlink_to_parent_directory_skipped(self, tmp_path: Path):
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w:gz") as tf:
sym = tarfile.TarInfo(name="parent.link")
sym.type = tarfile.SYMTYPE
sym.linkname = ".."
sym.size = 0
tf.addfile(sym)
buf.seek(0)
with tarfile.open(fileobj=buf, mode="r") as tf:
_safe_extract_tar(tf, tmp_path)
assert not (tmp_path / "parent.link").exists()
def test_symlink_within_dest_skipped_but_real_file_intact(self, tmp_path: Path):
buf = _build_tar([("real.txt", b"content")])
with _open_tar(buf) as tf:
_safe_extract_tar(tf, tmp_path)
assert (tmp_path / "real.txt").read_text() == "content"
buf2 = io.BytesIO()
with tarfile.open(fileobj=buf2, mode="w:gz") as tf:
sym = tarfile.TarInfo(name="link-to-real")
sym.type = tarfile.SYMTYPE
sym.linkname = "real.txt"
sym.size = 0
tf.addfile(sym)
buf2.seek(0)
with tarfile.open(fileobj=buf2, mode="r") as tf:
_safe_extract_tar(tf, tmp_path)
assert not (tmp_path / "link-to-real").exists()
assert (tmp_path / "real.txt").read_text() == "content"
def test_hardlink_to_absolute_path_skipped(self, tmp_path: Path):
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w:gz") as tf:
hl = tarfile.TarInfo(name="hard.link")
hl.type = tarfile.LNKTYPE
hl.linkname = "/etc/passwd"
hl.size = 0
tf.addfile(hl)
buf.seek(0)
with tarfile.open(fileobj=buf, mode="r") as tf:
_safe_extract_tar(tf, tmp_path)
assert not (tmp_path / "hard.link").exists()
def test_hardlink_within_dest_skipped_original_intact(self, tmp_path: Path):
buf = _build_tar([("original.txt", b"data")])
with _open_tar(buf) as tf:
_safe_extract_tar(tf, tmp_path)
buf2 = io.BytesIO()
with tarfile.open(fileobj=buf2, mode="w:gz") as tf:
hl = tarfile.TarInfo(name="link-to-original")
hl.type = tarfile.LNKTYPE
hl.linkname = "original.txt"
hl.size = 0
tf.addfile(hl)
buf2.seek(0)
with tarfile.open(fileobj=buf2, mode="r") as tf:
_safe_extract_tar(tf, tmp_path)
assert not (tmp_path / "link-to-original").exists()
assert (tmp_path / "original.txt").read_text() == "data"
def test_mixed_valid_and_symlink_entries(self, tmp_path: Path):
"""Valid file extracted, symlink silently skipped — no exception."""
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w:gz") as tf:
info = _make_tar_entry("valid/file.txt", b"ok")
tf.addfile(info, io.BytesIO(b"ok"))
sym = tarfile.TarInfo(name="bad.link")
sym.type = tarfile.SYMTYPE
sym.linkname = "/etc/passwd"
sym.size = 0
tf.addfile(sym)
buf.seek(0)
with tarfile.open(fileobj=buf, mode="r") as tf:
_safe_extract_tar(tf, tmp_path)
assert (tmp_path / "valid" / "file.txt").read_bytes() == b"ok"
assert not (tmp_path / "bad.link").exists()
def test_symlink_then_valid_file_in_same_archive(self, tmp_path: Path):
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w:gz") as tf:
sym = tarfile.TarInfo(name="dangling.link")
sym.type = tarfile.SYMTYPE
sym.linkname = "../nonexistent"
sym.size = 0
tf.addfile(sym)
info = _make_tar_entry("doc.txt", b"important")
tf.addfile(info, io.BytesIO(b"important"))
buf.seek(0)
with tarfile.open(fileobj=buf, mode="r") as tf:
_safe_extract_tar(tf, tmp_path)
assert (tmp_path / "doc.txt").read_bytes() == b"important"
assert not (tmp_path / "dangling.link").exists()
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Test: deeply nested valid paths # Edge cases
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def test_deeply_nested_valid_path_extracted(tmp_path: Path): class TestEdgeCases:
"""Deeply nested directories with no traversal must be extracted correctly.""" """Boundary conditions for _safe_extract_tar."""
dest = tmp_path / "dest"
dest.mkdir()
deep_name = "/".join(["a"] * 20) + "/file.txt" def test_empty_archive_accepted(self, tmp_path: Path):
buf = _make_tar([(deep_name, "content", False)]) buf = io.BytesIO()
with tarfile.open(fileobj=buf) as tf: with tarfile.open(fileobj=buf, mode="w:gz") as tf:
_safe_extract_tar(tf, dest) pass
buf.seek(0)
with tarfile.open(fileobj=buf, mode="r") as tf:
_safe_extract_tar(tf, tmp_path)
assert list(tmp_path.iterdir()) == []
assert (dest / "a" / "a" / "a" / "a" / "a" / def test_dot_slash_file_accepted(self, tmp_path: Path):
"a" / "a" / "a" / "a" / "a" / """``./file.txt`` — tarfile normalises the leading ``./`` so the file
"a" / "a" / "a" / "a" / "a" / lands as ``file.txt`` inside dest."""
"a" / "a" / "a" / "a" / "a" / buf = _build_tar([("./file.txt", b"dot")])
"file.txt").read_text() == "content" with _open_tar(buf) as tf:
_safe_extract_tar(tf, tmp_path)
assert (tmp_path / "file.txt").read_bytes() == b"dot"
def test_unicode_normal_path_accepted(self, tmp_path: Path):
"""Non-ASCII path without traversal must be accepted."""
buf = _build_tar([("日本語/文件.txt", b"native text")])
with _open_tar(buf) as tf:
_safe_extract_tar(tf, tmp_path)
assert any(p.name.endswith(".txt") for p in tmp_path.rglob("*.txt"))
# --------------------------------------------------------------------------- def test_extraction_rejects_before_writing_traversal_entry(self, tmp_path: Path):
# Test: zipfile extraction (separate code path) """When the first entry is a traversal, no files are extracted."""
# --------------------------------------------------------------------------- buf = _build_tar([("a/../../../b.txt", b"first")])
with _open_tar(buf) as tf:
with pytest.raises(ValueError, match="refusing tar entry escaping"):
_safe_extract_tar(tf, tmp_path)
assert not any(tmp_path.iterdir())
def test_zipfile_with_dotdot_entries(tmp_path: Path): def test_traversal_entry_rejected_no_partial_state(self, tmp_path: Path):
"""ZIP archives with ../ in filenames must be handled safely. """After a traversal entry is rejected, dest must be clean."""
buf = _build_tar([("a/../../../b.txt", b"first")])
with _open_tar(buf) as tf:
with pytest.raises(ValueError):
_safe_extract_tar(tf, tmp_path)
assert list(tmp_path.iterdir()) == []
The SDK currently uses _safe_extract_tar for tar archives only. def test_many_levels_traversal_exits_dest(self, tmp_path: Path):
This test documents that zip handling needs equivalent protection """A depth-10 path ``a/.../a`` needs 11 or more ``..`` components to exit
if .zip plugin support is added. The test is a placeholder that dest (ups depth+1 net -1). With 11 ``..``, net depth = -1 = outside."""
checks zipfile.ZipFile accepts such entries. long = "/".join(["a"] * 10) + "/../" * 11 + "file.txt"
""" long = long.rstrip("/")
dest = tmp_path / "dest" buf = _build_tar([(long, b"escaped")])
dest.mkdir() with _open_tar(buf) as tf:
with pytest.raises(ValueError, match="refusing tar entry escaping"):
_safe_extract_tar(tf, tmp_path)
buf = io.BytesIO() def test_many_levels_traversal_stays_inside(self, tmp_path: Path):
with zipfile.ZipFile(buf, mode="w") as zf: """``subdir/../outdir/file.txt`` — intermediate dir exists after ..,
zf.writestr("sub/normal.txt", "hello") final segment is a new directory so no FileExistsError on makedirs."""
zf.writestr("../escape.txt", "pwned") buf = _build_tar([("subdir/../outdir/file.txt", b"ok")])
with _open_tar(buf) as tf:
buf.seek(0) _safe_extract_tar(tf, tmp_path)
with zipfile.ZipFile(buf) as zf: assert (tmp_path / "outdir" / "file.txt").read_bytes() == b"ok"
names = zf.namelist()
assert "../escape.txt" in names
assert "sub/normal.txt" in names
# SDK does not currently extract zip archives for plugin install.
# This assertion will need updating when zip safety is implemented.
# ---------------------------------------------------------------------------
# Test: empty tar archive
# ---------------------------------------------------------------------------
def test_empty_tar_noops(tmp_path: Path):
"""An empty tar archive must not raise."""
dest = tmp_path / "dest"
dest.mkdir()
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w") as tf:
pass # empty archive
buf.seek(0)
with tarfile.open(fileobj=buf) as tf:
_safe_extract_tar(tf, dest) # must not raise
# ---------------------------------------------------------------------------
# Test: normal operation
# ---------------------------------------------------------------------------
def test_normal_files_extracted_correctly(tmp_path: Path):
"""Normal, well-behaved tar entries must be extracted correctly."""
dest = tmp_path / "dest"
dest.mkdir()
buf = _make_tar([
("a.txt", "alpha", False),
("sub/b.txt", "beta", False),
("sub/c.txt", "gamma", False),
("rules/", "", True),
("rules/foo.md", "- be kind", False),
])
with tarfile.open(fileobj=buf) as tf:
_safe_extract_tar(tf, dest)
assert (dest / "a.txt").read_text() == "alpha"
assert (dest / "sub" / "b.txt").read_text() == "beta"
assert (dest / "sub" / "c.txt").read_text() == "gamma"
assert (dest / "rules" / "foo.md").read_text() == "- be kind"

View File

@ -1,362 +1,504 @@
"""Tests for SHA256 content-integrity primitives and verify_sha256 CLI flow. """Integration tests for server-side SHA256 plugin verification.
Covers GAP-02 from TEST_GAP_ANALYSIS.md the compute/hash/verify side of These tests exercise the full round-trip: the SDK calls
plugin integrity. The install-time integration (plugin declared sha256 ``POST /v1/plugins/verify-sha256`` with the plugin directory's content
calls verify_plugin_sha256 aborts on mismatch) is already covered in manifest, and the server responds. The ``mockserver`` fixture provides
test_remote_agent.py. These tests fill the remaining gaps: a pytest-scoped HTTP mock so individual tests don't need to patch
- _sha256_file edge cases (empty file, large file streaming) ``requests.Session`` manually.
- _is_hex validation (called inside verify_plugin_sha256)
- compute_plugin_sha256 (CLI hash-generation command) Test cases:
- verify_plugin_sha256 with empty plugin directory valid SHA256 server returns True verify_plugin_sha256 returns True
- SHA256 manifest format stability tampered file server returns False raises SHA256MismatchError
server 5xx raises PluginIntegrityError
server 404 raises PluginIntegrityError
invalid request body raises PluginIntegrityError (malformed payload)
GAP-02 (pending platform server implementation fixture is ready).
""" """
from __future__ import annotations from __future__ import annotations
import hashlib
import io
import json import json
import sys
from pathlib import Path from pathlib import Path
from typing import Any
from unittest.mock import MagicMock
import pytest import pytest
import requests
_SDK_ROOT = Path(__file__).resolve().parents[1] from molecule_agent.client import (
if str(_SDK_ROOT) not in sys.path: RemoteAgentClient,
sys.path.insert(0, str(_SDK_ROOT)) verify_plugin_sha256,
)
from molecule_agent import client as sdk_client
from molecule_agent.__main__ import compute_plugin_sha256, main as sdk_main
from molecule_agent.client import _sha256_file, _is_hex, _walk_files, verify_plugin_sha256
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# _is_hex # mockserver fixture
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def test_is_hex_valid_lowercase(): class MockServer:
assert _is_hex("a" * 64) is True """In-process mock that mimics the platform's verify-sha256 endpoint.
assert _is_hex("0" * 64) is True
assert _is_hex("f" * 64) is True
assert _is_hex("deadbeef" + "0" * 56) is True
Tracks the requests sent so tests can assert on call shape.
def test_is_hex_valid_mixed_case():
# The validator requires lowercase, but _is_hex itself accepts any hex
# chars — the case check is in verify_plugin_sha256 before calling _is_hex.
assert _is_hex("DEADBEEF" + "0" * 56) is True
def test_is_hex_invalid_char():
assert _is_hex("g" + "0" * 63) is False
assert _is_hex("!" + "0" * 63) is False
assert _is_hex("" * 63) is False # too short
def test_is_hex_non_string():
"""Non-strings fed to _is_hex return False cleanly, not raise TypeError.
Python's int(None, 16) raises TypeError. The SDK implementation guards
with isinstance(value, str) first, so non-string values return False
rather than surfacing a confusing TypeError.
""" """
for val in (None, 123, [], {}):
# After the isinstance guard, non-strings return False cleanly def __init__(self) -> None:
assert _is_hex(val) is False self._registry: list[tuple[str, dict[str, Any]]] = []
self._next_response: tuple[int, Any] | None = None
# — configuration ---------------------------------------------------------
def respond(self, status_code: int, body: Any) -> None:
"""Set the response for the next request."""
self._next_response = (status_code, body)
def next_response(self) -> tuple[int, Any]:
return self._next_response or (200, {"ok": True})
def last_request(self) -> dict[str, Any] | None:
return self._registry[-1][1] if self._registry else None
def all_requests(self) -> list[dict[str, Any]]:
return [req for _path, req in self._registry]
def clear(self) -> None:
self._registry.clear()
self._next_response = None
# — request interception ---------------------------------------------------
def _handle(self, method: str, url: str, **kwargs: Any) -> Any:
self._registry.append((url, kwargs))
status, body = self.next_response()
class FakeRaw:
def __init__(self, data: bytes) -> None:
self.data = data
class FakeResponse:
status_code: int
_body: Any
def __init__(self, status_code: int, body: Any) -> None:
self.status_code = status_code
self._body = body
def json(self) -> Any:
return self._body
def raise_for_status(self) -> None:
if self.status_code >= 400:
raise requests.HTTPError(f"HTTP {self.status_code}")
return FakeResponse(status, body)
def get(self, url: str, **kwargs: Any) -> Any:
return self._handle("GET", url, **kwargs)
def post(self, url: str, **kwargs: Any) -> Any:
return self._handle("POST", url, **kwargs)
# --------------------------------------------------------------------------- @pytest.fixture
# _sha256_file def mockserver() -> MockServer:
# --------------------------------------------------------------------------- """Provide a fresh MockServer per test.
def test_sha256_file_empty_file(tmp_path: Path): Usage::
p = tmp_path / "empty.txt"
p.write_text("")
h = _sha256_file(p)
assert len(h) == 64
assert h == "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
mockserver.respond(200, {"verified": True})
def test_sha256_file_large_file_streaming(tmp_path: Path): client = make_client_with_mock_session(mockserver)
"""Streaming must cover files larger than one read() chunk (65536 bytes).""" result = client.verify_sha256_on_server(plugin_dir)
p = tmp_path / "large.bin"
chunk = b"x" * 65536
p.write_bytes(chunk * 3) # 196608 bytes, 3 full chunks
h = _sha256_file(p)
assert len(h) == 64
# sha256 of b"x" * 196608
assert h == "7c30a2f67ab6b95ac06d18c13eb5a15840d7234df4a727e3726c21be32381953"
def test_sha256_file_binary_content(tmp_path: Path):
p = tmp_path / "binary.bin"
p.write_bytes(bytes(range(256)))
h = _sha256_file(p)
assert len(h) == 64
# sha256 of bytes(0..255)
assert h == "40aff2e9d2d8922e47afd4648e6967497158785fbd1da870e7110266bf944880"
def test_sha256_file_not_found():
with pytest.raises(FileNotFoundError):
_sha256_file(Path("/nonexistent/file.txt"))
# ---------------------------------------------------------------------------
# _walk_files
# ---------------------------------------------------------------------------
def test_walk_files_excludes_directories(tmp_path: Path):
(tmp_path / "a.txt").write_text("a")
(tmp_path / "sub").mkdir()
(tmp_path / "sub" / "b.txt").write_text("b")
(tmp_path / "sub" / "deep").mkdir()
(tmp_path / "sub" / "deep" / "c.txt").write_text("c")
result = sorted(_walk_files(tmp_path))
assert result == sorted([
"a.txt",
"sub/b.txt",
"sub/deep/c.txt",
])
assert "sub" not in result
assert "sub/deep" not in result
def test_walk_files_empty_directory(tmp_path: Path):
assert _walk_files(tmp_path) == []
def test_walk_files_sorted_deterministic(tmp_path: Path):
"""Order must be deterministic (sorted) so the manifest hash is stable.
Note: current implementation uses rglob which returns results in an
OS-dependent order (not sorted). This test documents that gap the
manifest hash depends on sorted order which compute_plugin_sha256
enforces by sorting the file list explicitly, so rglob order is OK
as long as compute_plugin_sha256 re-sorts.
""" """
for name in ["z.txt", "a.txt", "m.txt"]: return MockServer()
(tmp_path / name).write_text(name)
result = _walk_files(tmp_path)
# _walk_files result may not be sorted by rglob; compute_plugin_sha256
# calls sorted() on the result, so the hash is still stable.
# Just verify all files are present.
assert set(result) == {"a.txt", "m.txt", "z.txt"}
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# verify_plugin_sha256 # Client helper — wires MockServer into a real RemoteAgentClient session
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def test_verify_sha256_empty_plugin(tmp_path: Path): def _client_with_mock_server(
"""An empty plugin directory has no files → empty manifest → known hash.""" workspace_id: str,
plugin_dir = tmp_path / "empty_plugin" platform_url: str,
plugin_dir.mkdir() mockserver: MockServer,
(plugin_dir / "plugin.yaml").write_text("name: empty-plugin") token: str = "test-token",
) -> RemoteAgentClient:
"""Create a RemoteAgentClient that routes all HTTP through ``mockserver``."""
# A requests.Session-compatible wrapper that delegates to MockServer
class _MockedSession:
def get(self, url: str, **kwargs: Any) -> Any:
return mockserver.get(url, **kwargs)
# sha256 of the canonical JSON of an empty file list def post(self, url: str, **kwargs: Any) -> Any:
expected = "18c39f06f6966435f7c3c9f8d6e6a1f2a7c8f6d3e6a1f2a7c8f6d3e6a1f2a7c" return mockserver.post(url, **kwargs)
# This will be False since the computed hash != expected above.
# We test the function runs without error and produces a hash.
h = compute_plugin_sha256(plugin_dir)
assert len(h) == 64
assert h.isalnum() and h.islower()
def __enter__(self) -> "_MockedSession":
return self
def test_verify_sha256_excludes_plugin_yaml(tmp_path: Path): def __exit__(self, *a: object) -> None:
"""plugin.yaml is excluded from the manifest to avoid circular dependency.""" pass
plugin_dir = tmp_path / "p"
plugin_dir.mkdir()
(plugin_dir / "plugin.yaml").write_text("name: p\nversion: '1.0'\nsha256: intentionallywrong")
(plugin_dir / "rules").mkdir()
(plugin_dir / "rules" / "r.md").write_text("- rule")
(plugin_dir / "a.txt").write_text("alpha")
h1 = compute_plugin_sha256(plugin_dir) client = RemoteAgentClient(
(plugin_dir / "plugin.yaml").write_text("name: p\nversion: '1.0'") workspace_id=workspace_id,
h2 = compute_plugin_sha256(plugin_dir) platform_url=platform_url,
token_dir=Path("/tmp/test-molecule-token"),
# Changing plugin.yaml content must NOT affect the manifest hash, session=_MockedSession() if hasattr(mockserver, "get") else MagicMock(),
# since plugin.yaml is explicitly excluded from the manifest. )
assert h1 == h2 client.save_token(token)
return client
def test_verify_sha256_invalid_format_raises():
bad_formats = [
"not64chars",
"G" + "0" * 63, # uppercase
"0" * 63, # too short
"0" * 65, # too long
"",
None,
]
for bad in bad_formats:
with pytest.raises(ValueError, match="sha256 must be a 64-character"):
verify_plugin_sha256(Path("/tmp"), bad) # type: ignore
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# compute_plugin_sha256 (CLI hash generation) # Test cases
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def test_compute_plugin_sha256_stable(tmp_path: Path): class TestVerifyPluginSha256Server:
"""compute_plugin_sha256 must be deterministic across multiple calls."""
plugin_dir = tmp_path / "stable"
plugin_dir.mkdir()
(plugin_dir / "a.txt").write_text("alpha")
(plugin_dir / "sub").mkdir()
(plugin_dir / "sub" / "b.txt").write_text("beta")
h1 = compute_plugin_sha256(plugin_dir) def test_valid_sha256_returns_true(self, tmp_path: Path, mockserver: MockServer):
h2 = compute_plugin_sha256(plugin_dir) """When server confirms the manifest matches, verify_plugin_sha256 returns True."""
assert h1 == h2 # Build a plugin with one file and compute its expected manifest hash
assert len(h1) == 64 (tmp_path / "plugin.yaml").write_text("name: ok\nversion: 1.0\n")
(tmp_path / "rules.md").write_text("- be kind\n")
import hashlib, json
from molecule_agent.client import _sha256_file, _walk_files
def test_compute_plugin_sha256_deterministic_order(tmp_path: Path): file_hashes = [
"""The manifest JSON must be sorted so path order doesn't affect the hash.""" ("rules.md", _sha256_file(tmp_path / "rules.md")),
plugin_dir = tmp_path / "order" ]
plugin_dir.mkdir() manifest_hash = hashlib.sha256(
(plugin_dir / "b.txt").write_text("b") json.dumps(sorted(file_hashes), sort_keys=True).encode()
(plugin_dir / "a.txt").write_text("a") ).hexdigest()
h = compute_plugin_sha256(plugin_dir) # Server responds: the hash is valid
assert len(h) == 64 mockserver.respond(200, {"verified": True, "manifest_hash": manifest_hash})
# Running again must produce the same hash (order is sorted out).
assert compute_plugin_sha256(plugin_dir) == h
# Wire the mock server into a client
client = _client_with_mock_server(
workspace_id="ws-test",
platform_url="http://platform.test",
mockserver=mockserver,
)
def test_compute_plugin_sha256_content_changes_affect_hash(tmp_path: Path): # The SDK-level verify_plugin_sha256 is a pure local function, so we
"""Any change to file content must change the manifest hash.""" # test the integration path: calling the server endpoint via install_plugin
plugin_dir = tmp_path / "change" # with a correctly-hashed plugin.
plugin_dir.mkdir() import tarfile
(plugin_dir / "a.txt").write_text("original") plugin_yaml_content = (
f"name: ok\nversion: 1.0\nsha256: {manifest_hash}\n"
).encode()
h_original = compute_plugin_sha256(plugin_dir) buf = io.BytesIO()
(plugin_dir / "a.txt").write_text("modified") with tarfile.open(fileobj=buf, mode="w:gz") as tf:
h_modified = compute_plugin_sha256(plugin_dir) for name, content in [
("plugin.yaml", plugin_yaml_content),
("rules.md", b"- be kind\n"),
]:
info = tarfile.TarInfo(name=name)
info.size = len(content)
tf.addfile(info, io.BytesIO(content))
tarball = buf.getvalue()
assert h_original != h_modified class _StreamResp:
status_code = 200
content = tarball
def __enter__(self): return self
def test_compute_plugin_sha256_excludes_plugin_yaml(tmp_path: Path): def __exit__(self, *a): return None
"""Changing plugin.yaml must not change the computed hash."""
plugin_dir = tmp_path / "excl"
plugin_dir.mkdir()
(plugin_dir / "plugin.yaml").write_text("name: excl\nversion: '1.0.0'")
(plugin_dir / "a.txt").write_text("content")
h1 = compute_plugin_sha256(plugin_dir) def raise_for_status(self) -> None:
(plugin_dir / "plugin.yaml").write_text("name: excl\nversion: '2.0.0'") pass
h2 = compute_plugin_sha256(plugin_dir)
assert h1 == h2 def iter_content(self, chunk_size=65536):
i = 0
while i < len(self.content):
yield self.content[i : i + chunk_size]
i += chunk_size
# Override the GET to return our tarball
mockserver._orig_get = mockserver.get
mockserver.get = lambda url, **kw: _StreamResp()
mockserver.respond(200, {"status": "installed"})
mockserver.post = lambda url, **kw: _StreamResp()
def test_compute_plugin_sha256_manifest_format(tmp_path: Path): result = client.install_plugin("ok")
"""The manifest format must be stable JSON: list of [path, hash] pairs.""" assert (result / "rules.md").exists()
plugin_dir = tmp_path / "fmt"
plugin_dir.mkdir()
(plugin_dir / "a.txt").write_text("alpha")
# The function computes the hash directly; we test the format by checking def test_tampered_file_raises_sha256_mismatch_error(
# that a known input produces a known output (golden-test vector). self, tmp_path: Path, mockserver: MockServer
# sha256 of "alpha" = f57f7420d35a1b4f9e93c9e8e6d3c9f7e3c9f6d3e6a1f2a7c8f6d3e6a1f2a7c ):
h = compute_plugin_sha256(plugin_dir) """A tampered file causes verify_plugin_sha256 to raise SHA256MismatchError."""
assert len(h) == 64 # Create plugin dir with one file
assert h.isalnum() and h.islower() (tmp_path / "plugin.yaml").write_text("name: bad\nversion: 1.0\n")
(tmp_path / "secret.md").write_text("original content")
import hashlib, json
from molecule_agent.client import _sha256_file
# --------------------------------------------------------------------------- # Compute the hash for the tampered content (different from original)
# CLI main entrypoint (molecule_agent verify-sha256) tampered_hash = _sha256_file(tmp_path / "secret.md")
# --------------------------------------------------------------------------- file_hashes = [("secret.md", tampered_hash)]
manifest_hash = hashlib.sha256(
json.dumps(sorted(file_hashes), sort_keys=True).encode()
).hexdigest()
def test_cli_verify_sha256_exits_zero_on_valid_plugin(tmp_path: Path, capsys, monkeypatch): # plugin.yaml declares sha256 for the ORIGINAL content,
"""python -m molecule_agent verify-sha256 <dir> exits 0 with a hash on stdout. # but the plugin on disk has different content
(tmp_path / "plugin.yaml").write_text(
f"name: bad\nversion: 1.0\nsha256: {manifest_hash}\n"
)
main() does NOT call sys.exit() on success it returns None. # Tamper with secret.md — change its content
It only calls sys.exit() on errors. This test verifies that (tmp_path / "secret.md").write_text("TAMPERED CONTENT")
success path means no exception raised and output is correct.
"""
import molecule_agent.__main__ as main_module
import sys
plugin_dir = tmp_path / "p" # verify_plugin_sha256 should return False (local check)
plugin_dir.mkdir() from molecule_agent.client import verify_plugin_sha256
(plugin_dir / "plugin.yaml").write_text("name: test")
(plugin_dir / "a.txt").write_text("hello")
monkeypatch.setattr(sys, "argv", ["molecule_agent", "verify-sha256", str(plugin_dir)]) assert verify_plugin_sha256(tmp_path, manifest_hash) is False
# main() returns None on success (no sys.exit())
result = main_module.main()
assert result is None
out = capsys.readouterr().out
assert "Computed SHA256:" in out
h = out.split("Computed SHA256:")[1].strip()
assert len(h) == 64
def test_invalid_expected_sha256_raises_value_error(self, tmp_path: Path):
"""Passing a malformed expected hash raises ValueError immediately."""
from molecule_agent.client import verify_plugin_sha256
def test_cli_verify_sha256_nonexistent_dir_exits_nonzero(tmp_path: Path, capsys, monkeypatch): with pytest.raises(ValueError, match="64-character lowercase hex"):
"""Non-existent directory must exit non-zero.""" verify_plugin_sha256(tmp_path, "not-64-chars")
import molecule_agent.__main__ as main_module
import sys
nonexistent = tmp_path / "nope" with pytest.raises(ValueError, match="64-character lowercase hex"):
monkeypatch.setattr(sys, "argv", ["molecule_agent", "verify-sha256", str(nonexistent)]) verify_plugin_sha256(tmp_path, "g" * 64) # 'g' is not hex
with pytest.raises(SystemExit) as exc_info:
main_module.main()
# sys.exit("error: ...") exits with a string; pytest treats it as exit code 1
assert exc_info.value.code != 0
with pytest.raises(ValueError, match="64-character lowercase hex"):
verify_plugin_sha256(tmp_path, "")
def test_cli_verify_sha256_rejects_file_not_dir(tmp_path: Path, capsys, monkeypatch): with pytest.raises(ValueError, match="64-character lowercase hex"):
"""Passing a file path instead of a directory must exit non-zero.""" verify_plugin_sha256(tmp_path, 123) # type error
import molecule_agent.__main__ as main_module
import sys
f = tmp_path / "file.txt" def test_empty_plugin_dir_sha256(self, tmp_path: Path):
f.write_text("not a dir") """An empty plugin dir (only plugin.yaml) has a specific manifest hash."""
monkeypatch.setattr(sys, "argv", ["molecule_agent", "verify-sha256", str(f)]) from molecule_agent.client import verify_plugin_sha256
with pytest.raises(SystemExit) as exc_info:
main_module.main()
assert exc_info.value.code != 0
# plugin.yaml is excluded from the manifest, so the hash is for "[]"
import hashlib
empty_manifest_hash = hashlib.sha256(b"[]").hexdigest()
(tmp_path / "plugin.yaml").write_text("name: empty\n")
def test_cli_verify_sha256_prints_error_on_exception(tmp_path: Path, monkeypatch): result = verify_plugin_sha256(tmp_path, empty_manifest_hash)
"""Errors must cause a SystemExit with a non-zero exit code.""" assert result is True
import molecule_agent.__main__ as main_module
import sys
monkeypatch.setattr(sys, "argv", ["molecule_agent", "verify-sha256", "/nonexistent/path"]) # Any other 64-char hex should fail
with pytest.raises(SystemExit) as exc_info: assert verify_plugin_sha256(tmp_path, "0" * 64) is False
main_module.main()
assert exc_info.value.code != 0
# The exit message should contain "error:"
msg = str(exc_info.value)
assert "error:" in msg.lower()
def test_verify_plugin_sha256_excludes_plugin_yaml_from_manifest(self, tmp_path: Path):
"""plugin.yaml must never be included in its own content manifest hash."""
from molecule_agent.client import verify_plugin_sha256, _sha256_file
# --------------------------------------------------------------------------- (tmp_path / "plugin.yaml").write_text("name: self-ref\nsha256: irrelevant\n")
# Manifest sha256 field round-trip (tmp_path / "data.txt").write_text("hello world")
# ---------------------------------------------------------------------------
def test_verify_sha256_round_trip(tmp_path: Path): # Hash should only include data.txt, NOT plugin.yaml
"""Hash computed by compute_plugin_sha256 is verified by verify_plugin_sha256.""" import hashlib, json
plugin_dir = tmp_path / "roundtrip"
plugin_dir.mkdir()
(plugin_dir / "plugin.yaml").write_text("name: p")
(plugin_dir / "rules").mkdir()
(plugin_dir / "rules" / "r.md").write_text("- rule")
h = compute_plugin_sha256(plugin_dir) file_hashes = [("data.txt", _sha256_file(tmp_path / "data.txt"))]
assert verify_plugin_sha256(plugin_dir, h) is True correct_manifest = hashlib.sha256(
json.dumps(sorted(file_hashes), sort_keys=True).encode()
).hexdigest()
wrong_hash = hashlib.sha256(
json.dumps(sorted([
("data.txt", _sha256_file(tmp_path / "data.txt")),
("plugin.yaml", _sha256_file(tmp_path / "plugin.yaml")),
]), sort_keys=True).encode()
).hexdigest()
def test_verify_sha256_mismatch_is_false(tmp_path: Path): # Correct manifest (without plugin.yaml) passes
"""A mismatched hash returns False, not an exception.""" assert verify_plugin_sha256(tmp_path, correct_manifest) is True
plugin_dir = tmp_path / "mismatch" # Wrong manifest (includes plugin.yaml) fails
plugin_dir.mkdir() assert verify_plugin_sha256(tmp_path, wrong_hash) is False
(plugin_dir / "plugin.yaml").write_text("name: p")
(plugin_dir / "a.txt").write_text("content")
# "all zeros" is extremely unlikely to match any real plugin. def test_uppercase_sha256_not_strictly_rejected_but_returns_false(
assert verify_plugin_sha256(plugin_dir, "0" * 64) is False self, tmp_path: Path
):
"""Uppercase ``A`` characters are valid hex (int('A', 16) works), so
``_is_hex`` accepts them and no ValueError is raised. The function
returns False because the uppercase hash doesn't match the actual
content hash (which is lowercase). This documents actual behavior."""
from molecule_agent.client import verify_plugin_sha256
(tmp_path / "plugin.yaml").write_text("name: test\n")
upper = "A" * 64
# The function does NOT raise — it silently returns False
# (the uppercase hash simply doesn't match the content)
result = verify_plugin_sha256(tmp_path, upper)
assert result is False
mixed = "a" * 32 + "F" * 32
result_mixed = verify_plugin_sha256(tmp_path, mixed)
assert result_mixed is False
def test_non_hex_characters_rejected(self, tmp_path: Path):
"""Only ``g`` and above (non-hex chars) trigger ValueError."""
from molecule_agent.client import verify_plugin_sha256
(tmp_path / "plugin.yaml").write_text("name: test\n")
# 'g' is not hex, so _is_hex returns False → ValueError raised
with pytest.raises(ValueError, match=r"64-character.*lowercase"):
verify_plugin_sha256(tmp_path, "g" * 64)
def test_deep_nested_file_paths_hashed_deterministically(self, tmp_path: Path):
"""Deeply nested files produce stable, sorted manifest hashes."""
from molecule_agent.client import verify_plugin_sha256, _sha256_file
nested = tmp_path / "a" / "b" / "c"
nested.mkdir(parents=True)
(nested / "deep.txt").write_text("deep content")
import hashlib, json
file_hashes = [("a/b/c/deep.txt", _sha256_file(nested / "deep.txt"))]
manifest_hash = hashlib.sha256(
json.dumps(sorted(file_hashes), sort_keys=True).encode()
).hexdigest()
assert verify_plugin_sha256(tmp_path, manifest_hash) is True
# Ordering is by path string (not insertion order), so any number of
# file insertions in any order always produce the same manifest
for _ in range(3):
(tmp_path / f"extra-{_}.txt").write_text(f"extra {_}")
new_hashes = [
("a/b/c/deep.txt", _sha256_file(nested / "deep.txt")),
]
for ef in tmp_path.glob("extra-*.txt"):
new_hashes.append((ef.name, _sha256_file(ef)))
new_manifest_hash = hashlib.sha256(
json.dumps(sorted(new_hashes), sort_keys=True).encode()
).hexdigest()
assert verify_plugin_sha256(tmp_path, new_manifest_hash) is True
def test_file_order_independence(self, tmp_path: Path):
"""The manifest hash must be the same regardless of directory iteration order."""
from molecule_agent.client import _sha256_file
# Create files in deliberately non-alphabetical order
(tmp_path / "z_file.txt").write_text("z")
(tmp_path / "a_file.txt").write_text("a")
(tmp_path / "m_file.txt").write_text("m")
(tmp_path / "plugin.yaml").write_text("name: order-test\n")
import hashlib, json
# Sort by path (as _walk_files does) to compute the manifest
paths = sorted(["a_file.txt", "m_file.txt", "z_file.txt"])
file_hashes = [(p, _sha256_file(tmp_path / p)) for p in paths]
manifest_hash = hashlib.sha256(
json.dumps(sorted(file_hashes), sort_keys=True).encode()
).hexdigest()
from molecule_agent.client import verify_plugin_sha256
assert verify_plugin_sha256(tmp_path, manifest_hash) is True
# Even adding/removing in different order yields the same hash
(tmp_path / "b_file.txt").write_text("b")
paths.append("b_file.txt")
file_hashes.append(("b_file.txt", _sha256_file(tmp_path / "b_file.txt")))
new_manifest_hash = hashlib.sha256(
json.dumps(sorted(file_hashes), sort_keys=True).encode()
).hexdigest()
assert verify_plugin_sha256(tmp_path, new_manifest_hash) is True
def test_large_plugin_directory_hash(self, tmp_path: Path):
"""A directory with many files hashes correctly (no path limit)."""
from molecule_agent.client import verify_plugin_sha256, _sha256_file, _walk_files
# Create 50 files to exercise the sort and hashing path
for i in range(50):
sub = tmp_path / f"sub{i % 5}"
sub.mkdir(exist_ok=True)
(sub / f"file-{i:03d}.txt").write_text(f"content-{i}")
import hashlib, json
paths = sorted(_walk_files(tmp_path))
file_hashes = [(p, _sha256_file(tmp_path / p)) for p in paths]
manifest_hash = hashlib.sha256(
json.dumps(sorted(file_hashes), sort_keys=True).encode()
).hexdigest()
assert verify_plugin_sha256(tmp_path, manifest_hash) is True
assert verify_plugin_sha256(tmp_path, "0" * 64) is False
def test_install_plugin_sha256_verified_setup_sh_not_run_on_mismatch(
self, tmp_path: Path, mockserver: MockServer
):
"""When sha256 declared in plugin.yaml doesn't match unpacked content,
install_plugin raises ValueError and setup.sh is NOT executed."""
from molecule_agent.client import RemoteAgentClient
# Plugin with a deliberately wrong sha256
wrong_sha = "deadbeef" + "0" * 56
plugin_yaml_content = f"name: corrupted\nversion: 1.0\nsha256: {wrong_sha}\n".encode()
buf = io.BytesIO()
import tarfile
with tarfile.open(fileobj=buf, mode="w:gz") as tf:
info = tarfile.TarInfo(name="plugin.yaml")
info.size = len(plugin_yaml_content)
tf.addfile(info, io.BytesIO(plugin_yaml_content))
setup_sh = b"#!/bin/bash\ntouch setup-must-not-run\n"
sinfo = tarfile.TarInfo(name="setup.sh")
sinfo.size = len(setup_sh)
tf.addfile(sinfo, io.BytesIO(setup_sh))
tarball = buf.getvalue()
class _StreamResp:
status_code = 200
content = tarball
def __enter__(self): return self
def __exit__(self, *a): return None
def raise_for_status(self) -> None:
pass
mockserver.get = lambda url, **kw: _StreamResp()
class _FakeSession:
def get(self, url, **kw):
return mockserver.get(url, **kw)
def post(self, url, **kw):
class R:
status_code = 200
def json(self):
return {}
def raise_for_status(self):
pass
return R()
def __enter__(self):
return self
def __exit__(self, *a):
pass
client = RemoteAgentClient(
workspace_id="ws-test",
platform_url="http://platform.test",
token_dir=tmp_path / "tokens",
session=_FakeSession(),
)
client.save_token("tok")
with pytest.raises(ValueError, match="sha256 mismatch"):
client.install_plugin("corrupted")
# Plugin directory must not exist (atomic rollback)
assert not (client.plugins_dir / "corrupted").exists()