Merge pull request #13 from Molecule-AI/gap-03-fix
feat(sdk): GAP-03 conftest, GAP-05 retry backoff, KI-002 idempotency key
This commit is contained in:
commit
a3203a8a9e
29
.github/workflows/ci.yml
vendored
Normal file
29
.github/workflows/ci.yml
vendored
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
name: Test
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
pull_request:
|
||||||
|
branches: [main]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ['3.11', '3.12', '3.13']
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: pip install -e ".[test]"
|
||||||
|
|
||||||
|
- name: Run tests
|
||||||
|
run: python -m pytest tests/
|
||||||
|
|
||||||
|
- name: Lint
|
||||||
|
run: pip install ruff && ruff check molecule_agent/ molecule_plugin/
|
||||||
@ -34,6 +34,7 @@ Design notes:
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from .a2a_server import A2AServer
|
||||||
from .client import (
|
from .client import (
|
||||||
PeerInfo,
|
PeerInfo,
|
||||||
RemoteAgentClient,
|
RemoteAgentClient,
|
||||||
@ -46,6 +47,7 @@ from .client import (
|
|||||||
from .__main__ import compute_plugin_sha256
|
from .__main__ import compute_plugin_sha256
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"A2AServer",
|
||||||
"RemoteAgentClient",
|
"RemoteAgentClient",
|
||||||
"WorkspaceState",
|
"WorkspaceState",
|
||||||
"PeerInfo",
|
"PeerInfo",
|
||||||
|
|||||||
229
molecule_agent/a2a_server.py
Normal file
229
molecule_agent/a2a_server.py
Normal file
@ -0,0 +1,229 @@
|
|||||||
|
"""A2A server for inbound agent calls.
|
||||||
|
|
||||||
|
Bundled alongside :class:`molecule_agent.client.RemoteAgentClient` to
|
||||||
|
enable remote agents to receive A2A calls from the platform without
|
||||||
|
requiring the agent author to provision their own HTTP endpoint.
|
||||||
|
|
||||||
|
Phase 30.8b contract — the server exposes ``POST /a2a/inbound`` which
|
||||||
|
the platform's ingress proxy calls when it needs to push work to a
|
||||||
|
registered remote agent.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
from molecule_agent import RemoteAgentClient, A2AServer
|
||||||
|
|
||||||
|
client = RemoteAgentClient(workspace_id="...", platform_url="...")
|
||||||
|
server = A2AServer(
|
||||||
|
agent_id=client.workspace_id,
|
||||||
|
inbound_url="https://my-agent.example.com/a2a/inbound",
|
||||||
|
message_handler=my_handler,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start server in background thread, then register with platform.
|
||||||
|
server.start_in_background()
|
||||||
|
client.reported_url = server.inbound_url # platform reaches this URL
|
||||||
|
token = client.register()
|
||||||
|
|
||||||
|
# Heartbeat loop now reports a real URL instead of "remote://no-inbound".
|
||||||
|
client.run_heartbeat_loop()
|
||||||
|
|
||||||
|
# Shutdown the server when the agent exits.
|
||||||
|
server.stop()
|
||||||
|
|
||||||
|
The ``message_handler`` signature is::
|
||||||
|
|
||||||
|
async def my_handler(request: dict) -> dict:
|
||||||
|
'''Return an A2A-formatted response dict.'''
|
||||||
|
...
|
||||||
|
|
||||||
|
Handlers are invoked on the server's internal thread pool.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||||
|
from typing import Any, Callable, Awaitable
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Module-level HTTPServer instance so the handler can access server state.
|
||||||
|
_server: HTTPServer | None = None
|
||||||
|
_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Handler
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class _A2AHandler(BaseHTTPRequestHandler):
|
||||||
|
"""Handles ``POST /a2a/inbound`` requests.
|
||||||
|
|
||||||
|
The request body is a JSON A2A task dispatch dict::
|
||||||
|
|
||||||
|
{
|
||||||
|
"task_id": "...",
|
||||||
|
"sender": "...",
|
||||||
|
"message": "...",
|
||||||
|
"idempotency_key": "...",
|
||||||
|
}
|
||||||
|
|
||||||
|
The ``message_handler`` ( supplied at construction) is called with the
|
||||||
|
parsed dict and its return value is written as a JSON response::
|
||||||
|
|
||||||
|
200 {"status": "ok", "result": <handler-result>}
|
||||||
|
400 {"error": "bad request: ..."}
|
||||||
|
500 {"error": "internal error: ..."}
|
||||||
|
"""
|
||||||
|
|
||||||
|
protocol_version = "HTTP/1.1"
|
||||||
|
|
||||||
|
def log_message(self, format: str, *args: Any) -> None:
|
||||||
|
"""Suppress default stderr noise; use structured logging instead."""
|
||||||
|
logger.debug("%s %s — %s", self.command, self.path, format % args)
|
||||||
|
|
||||||
|
def log_error(self, format: str, *args: Any) -> None:
|
||||||
|
logger.warning("%s %s — %s", self.command, self.path, format % args)
|
||||||
|
|
||||||
|
def _send_json(self, status: int, body: dict) -> None:
|
||||||
|
body_bytes = json.dumps(body).encode()
|
||||||
|
self.send_response(status)
|
||||||
|
self.send_header("Content-Type", "application/json")
|
||||||
|
self.send_header("Content-Length", str(len(body_bytes)))
|
||||||
|
self.end_headers()
|
||||||
|
if self.command != "HEAD":
|
||||||
|
self.wfile.write(body_bytes)
|
||||||
|
|
||||||
|
def do_POST(self) -> None:
|
||||||
|
parsed = urlparse(self.path)
|
||||||
|
if parsed.path != "/a2a/inbound":
|
||||||
|
self._send_json(404, {"error": "not found"})
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
content_length = int(self.headers.get("Content-Length", 0))
|
||||||
|
if content_length == 0:
|
||||||
|
raise ValueError("empty body")
|
||||||
|
body = self.rfile.read(content_length)
|
||||||
|
payload = json.loads(body)
|
||||||
|
except (ValueError, json.JSONDecodeError) as exc:
|
||||||
|
self._send_json(400, {"error": f"bad request: {exc}"})
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = _A2AHandler._message_handler(payload)
|
||||||
|
if isinstance(result, Awaitable):
|
||||||
|
# If the handler is async, run it synchronously in the server thread.
|
||||||
|
# Agents that want full async semantics should use an explicit ASGI app;
|
||||||
|
# this path covers the common case of a simple sync handler.
|
||||||
|
import asyncio
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
try:
|
||||||
|
result = loop.run_until_complete(result)
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
self._send_json(200, {"status": "ok", "result": result})
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("message_handler raised: %s", exc)
|
||||||
|
self._send_json(500, {"error": f"internal error: {exc}"})
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# A2AServer
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class A2AServer:
|
||||||
|
"""HTTP server that receives inbound A2A calls and dispatches them to a
|
||||||
|
handler running alongside :class:`~molecule_agent.client.RemoteAgentClient`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: The workspace / agent identifier. Used in log messages.
|
||||||
|
inbound_url: The URL the platform's ingress proxy uses to reach this
|
||||||
|
server. Must be a reachable host:port (or a publicly accessible
|
||||||
|
URL if a tunnel is in front). The value is typically assigned to
|
||||||
|
``RemoteAgentClient.reported_url`` before registration so the
|
||||||
|
platform knows where to deliver inbound calls.
|
||||||
|
message_handler: Callable that receives a parsed A2A task dict and
|
||||||
|
returns a dict response. May be ``async def`` or regular ``def``.
|
||||||
|
host: Address to bind the HTTP server to. Defaults to ``"0.0.0.0"``
|
||||||
|
(all interfaces); bind to ``"127.0.0.1"`` if behind a reverse
|
||||||
|
proxy or tunnel.
|
||||||
|
port: TCP port to listen on. ``0`` picks an available ephemeral port
|
||||||
|
(useful when the real public URL is managed by a proxy/tunnel).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
inbound_url: str,
|
||||||
|
message_handler: Callable[[dict], dict | Awaitable[dict]],
|
||||||
|
host: str = "0.0.0.0",
|
||||||
|
port: int = 0,
|
||||||
|
) -> None:
|
||||||
|
self.agent_id = agent_id
|
||||||
|
self.inbound_url = inbound_url
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
self._handler = message_handler
|
||||||
|
self._server: HTTPServer | None = None
|
||||||
|
self._thread: threading.Thread | None = None
|
||||||
|
self._stop_event = threading.Event()
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Lifecycle
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def start_in_background(self) -> None:
|
||||||
|
"""Start the HTTP server in a daemon thread and return immediately.
|
||||||
|
|
||||||
|
Call :py:meth:`stop` to shut it down cleanly.
|
||||||
|
"""
|
||||||
|
global _server
|
||||||
|
with _lock:
|
||||||
|
self._server = HTTPServer((self.host, self.port), _A2AHandler)
|
||||||
|
_server = self._server
|
||||||
|
_A2AHandler._server = self # type: ignore[attr-defined]
|
||||||
|
_A2AHandler._message_handler = self._handler # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
actual = self._server.server_address
|
||||||
|
logger.info(
|
||||||
|
"A2AServer for %s listening on %s:%s (inbound_url=%s)",
|
||||||
|
self.agent_id, actual[0], actual[1], self.inbound_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._thread = threading.Thread(target=self._serve_forever, daemon=True)
|
||||||
|
self._thread.start()
|
||||||
|
|
||||||
|
def _serve_forever(self) -> None:
|
||||||
|
assert self._server is not None
|
||||||
|
while not self._stop_event.is_set():
|
||||||
|
try:
|
||||||
|
self._server.timeout = 0.5
|
||||||
|
self._server.handle_request()
|
||||||
|
except Exception as exc:
|
||||||
|
if not self._stop_event.is_set():
|
||||||
|
logger.warning("A2AServer handle_request raised: %s", exc)
|
||||||
|
|
||||||
|
def stop(self, timeout: float = 5.0) -> None:
|
||||||
|
"""Stop the HTTP server and join the background thread.
|
||||||
|
|
||||||
|
Idempotent — safe to call multiple times.
|
||||||
|
"""
|
||||||
|
self._stop_event.set()
|
||||||
|
if self._thread is not None:
|
||||||
|
self._thread.join(timeout=timeout)
|
||||||
|
self._thread = None
|
||||||
|
if self._server is not None:
|
||||||
|
try:
|
||||||
|
self._server.server_close()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("A2AServer server_close raised: %s", exc)
|
||||||
|
self._server = None
|
||||||
|
global _server
|
||||||
|
with _lock:
|
||||||
|
_server = None
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["A2AServer"]
|
||||||
@ -12,8 +12,9 @@ a Phase 30 endpoint:
|
|||||||
returns when the platform reports the workspace paused or deleted.
|
returns when the platform reports the workspace paused or deleted.
|
||||||
|
|
||||||
No inbound A2A server is bundled here yet — that requires hosting an HTTP
|
No inbound A2A server is bundled here yet — that requires hosting an HTTP
|
||||||
endpoint the platform's proxy can reach, which is network-dependent. A
|
endpoint the platform's proxy can reach, which is network-dependent.
|
||||||
future 30.8b iteration will add an optional ``start_a2a_server()`` helper.
|
Use :class:`molecule_agent.a2a_server.A2AServer` to add inbound A2A support.
|
||||||
|
See that module for usage and the Phase 30.8b contract.
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@ -24,7 +25,6 @@ import logging
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import stat
|
|
||||||
import subprocess
|
import subprocess
|
||||||
import tarfile
|
import tarfile
|
||||||
import time
|
import time
|
||||||
@ -57,6 +57,35 @@ _RETRY_BASE_DELAY = 1.0 # seconds — first delay
|
|||||||
_RETRY_MAX_DELAY = 30.0 # seconds — cap
|
_RETRY_MAX_DELAY = 30.0 # seconds — cap
|
||||||
_RETRY_JITTER_FRAC = 0.25 # ±25% jitter around base delay
|
_RETRY_JITTER_FRAC = 0.25 # ±25% jitter around base delay
|
||||||
|
|
||||||
|
# KI-002 — idempotency key granularity: round to the current minute so
|
||||||
|
# that concurrent restarts within the same 60-second window produce the
|
||||||
|
# same key, while distinct tasks or distinct minutes produce distinct keys.
|
||||||
|
_IDEMPOTENCY_ROUND_SECONDS = 60
|
||||||
|
|
||||||
|
|
||||||
|
def make_idempotency_key(task_text: str) -> str:
|
||||||
|
"""Compute a deterministic idempotency key for a delegation task.
|
||||||
|
|
||||||
|
Combines the task text with the current wall-clock minute to produce
|
||||||
|
a SHA-256 hex digest. Rounding to minute-level means two container
|
||||||
|
restarts within the same minute that send the same task string will
|
||||||
|
share the same key, preventing the platform from processing a duplicate
|
||||||
|
delegation. A different minute (or a different task string) yields a
|
||||||
|
different key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_text: The task description string being delegated.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A 64-character hex string (SHA-256 digest).
|
||||||
|
"""
|
||||||
|
# Round current time down to the nearest minute — same-task restarts
|
||||||
|
# within this minute share a key; after the minute rolls over the key
|
||||||
|
# changes so a genuinely new task is always treated as new.
|
||||||
|
now = int(time.time()) // _IDEMPOTENCY_ROUND_SECONDS * _IDEMPOTENCY_ROUND_SECONDS
|
||||||
|
payload = f"{task_text}:{now}"
|
||||||
|
return hashlib.sha256(payload.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def _safe_extract_tar(tf: tarfile.TarFile, dest: Path) -> None:
|
def _safe_extract_tar(tf: tarfile.TarFile, dest: Path) -> None:
|
||||||
"""Extract a tarfile, refusing entries that would escape `dest`
|
"""Extract a tarfile, refusing entries that would escape `dest`
|
||||||
@ -658,6 +687,58 @@ class RemoteAgentClient:
|
|||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
return resp.json()
|
return resp.json()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Delegation — KI-002 idempotency guard
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def delegate(
|
||||||
|
self,
|
||||||
|
task: str,
|
||||||
|
target_id: str,
|
||||||
|
idempotency_key: str | None = None,
|
||||||
|
timeout: float = 300.0,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Delegate a task to a peer workspace via the platform proxy.
|
||||||
|
|
||||||
|
KI-002: To prevent duplicate execution when a container restarts mid-
|
||||||
|
delegation, an idempotency key is computed from ``task + current
|
||||||
|
minute`` and sent as ``idempotency_key`` in the request body. The
|
||||||
|
platform deduplicates requests sharing the same key within the
|
||||||
|
minute window. Pass an explicit ``idempotency_key`` to override the
|
||||||
|
auto-computed value (useful for callers that manage their own key
|
||||||
|
scheme).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: Human-readable task description sent to the target.
|
||||||
|
target_id: Workspace ID of the peer to delegate to.
|
||||||
|
idempotency_key: Optional override for the idempotency key. If
|
||||||
|
omitted, one is auto-generated from the task text + current
|
||||||
|
wall-clock minute.
|
||||||
|
timeout: Request timeout in seconds. Default 300 s.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The platform's JSON response dict.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
``requests.HTTPError`` on non-2xx responses.
|
||||||
|
"""
|
||||||
|
key = idempotency_key if idempotency_key else make_idempotency_key(task)
|
||||||
|
resp = self._session.post(
|
||||||
|
f"{self.platform_url}/workspaces/{target_id}/delegate",
|
||||||
|
headers={
|
||||||
|
**self._auth_headers(),
|
||||||
|
"X-Workspace-ID": self.workspace_id,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"task": task,
|
||||||
|
"idempotency_key": key,
|
||||||
|
},
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Plugin install (Phase 30.3)
|
# Plugin install (Phase 30.3)
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@ -877,6 +958,7 @@ __all__ = [
|
|||||||
"DEFAULT_HEARTBEAT_INTERVAL",
|
"DEFAULT_HEARTBEAT_INTERVAL",
|
||||||
"DEFAULT_STATE_POLL_INTERVAL",
|
"DEFAULT_STATE_POLL_INTERVAL",
|
||||||
"DEFAULT_URL_CACHE_TTL",
|
"DEFAULT_URL_CACHE_TTL",
|
||||||
"compute_plugin_sha256",
|
|
||||||
"verify_plugin_sha256",
|
"verify_plugin_sha256",
|
||||||
|
"make_idempotency_key",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -14,7 +14,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
@ -115,3 +114,4 @@ def validate_workspace_template(path: Path) -> list[ValidationError]:
|
|||||||
|
|
||||||
# Re-exported for type hints in __init__.py
|
# Re-exported for type hints in __init__.py
|
||||||
__all__ = ["ValidationError", "SUPPORTED_RUNTIMES", "validate_workspace_template"]
|
__all__ = ["ValidationError", "SUPPORTED_RUNTIMES", "validate_workspace_template"]
|
||||||
|
|
||||||
|
|||||||
217
tests/test_a2a_server.py
Normal file
217
tests/test_a2a_server.py
Normal file
@ -0,0 +1,217 @@
|
|||||||
|
"""Tests for molecule_agent.a2a_server."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import threading
|
||||||
|
from http.client import HTTPConnection
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from molecule_agent.a2a_server import A2AServer
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _post_json(host: str, port: int, payload: dict) -> tuple[int, dict]:
|
||||||
|
conn = HTTPConnection(host, port, timeout=5)
|
||||||
|
body = json.dumps(payload).encode()
|
||||||
|
conn.request("POST", "/a2a/inbound", body=body, headers={"Content-Type": "application/json"})
|
||||||
|
resp = conn.getresponse()
|
||||||
|
return resp.status, json.loads(resp.read())
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# A2AServer tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_start_stop() -> None:
|
||||||
|
"""Server starts, binds an ephemeral port, and shuts down cleanly."""
|
||||||
|
handler = MagicMock(return_value={"ack": True})
|
||||||
|
server = A2AServer(
|
||||||
|
agent_id="test-agent",
|
||||||
|
inbound_url="https://example.com/a2a/inbound",
|
||||||
|
message_handler=handler,
|
||||||
|
)
|
||||||
|
server.start_in_background()
|
||||||
|
try:
|
||||||
|
host, port = server._server.server_address # type: ignore[union-attr]
|
||||||
|
assert host in ("0.0.0.0", "127.0.0.1", "::")
|
||||||
|
assert isinstance(port, int) and port > 0
|
||||||
|
finally:
|
||||||
|
server.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def test_stop_idempotent() -> None:
|
||||||
|
"""stop() called twice does not raise."""
|
||||||
|
handler = MagicMock()
|
||||||
|
server = A2AServer(
|
||||||
|
agent_id="test-agent",
|
||||||
|
inbound_url="https://example.com/a2a/inbound",
|
||||||
|
message_handler=handler,
|
||||||
|
)
|
||||||
|
server.start_in_background()
|
||||||
|
server.stop()
|
||||||
|
server.stop() # must not raise
|
||||||
|
|
||||||
|
|
||||||
|
def test_inbound_call_routes_to_handler() -> None:
|
||||||
|
"""POST /a2a/inbound calls message_handler and returns 200."""
|
||||||
|
handler = MagicMock(return_value={"task_id": "reply-123"})
|
||||||
|
server = A2AServer(
|
||||||
|
agent_id="test-agent",
|
||||||
|
inbound_url="https://example.com/a2a/inbound",
|
||||||
|
message_handler=handler,
|
||||||
|
)
|
||||||
|
server.start_in_background()
|
||||||
|
try:
|
||||||
|
host, port = server._server.server_address # type: ignore[union-attr]
|
||||||
|
status, body = _post_json(host, port, {"task_id": "req-1", "message": "ping"})
|
||||||
|
assert status == 200
|
||||||
|
assert body["status"] == "ok"
|
||||||
|
assert body["result"] == {"task_id": "reply-123"}
|
||||||
|
handler.assert_called_once_with({"task_id": "req-1", "message": "ping"})
|
||||||
|
finally:
|
||||||
|
server.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_json_body_returns_400() -> None:
|
||||||
|
"""Malformed JSON body returns 400 with error detail."""
|
||||||
|
handler = MagicMock()
|
||||||
|
server = A2AServer(
|
||||||
|
agent_id="test-agent",
|
||||||
|
inbound_url="https://example.com/a2a/inbound",
|
||||||
|
message_handler=handler,
|
||||||
|
)
|
||||||
|
server.start_in_background()
|
||||||
|
try:
|
||||||
|
host, port = server._server.server_address # type: ignore[union-attr]
|
||||||
|
conn = HTTPConnection(host, port, timeout=5)
|
||||||
|
conn.request("POST", "/a2a/inbound", body=b"not json{", headers={"Content-Type": "application/json"})
|
||||||
|
resp = conn.getresponse()
|
||||||
|
assert resp.status == 400
|
||||||
|
body = json.loads(resp.read())
|
||||||
|
assert "error" in body
|
||||||
|
finally:
|
||||||
|
server.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_body_returns_400() -> None:
|
||||||
|
"""Empty body returns 400."""
|
||||||
|
handler = MagicMock()
|
||||||
|
server = A2AServer(
|
||||||
|
agent_id="test-agent",
|
||||||
|
inbound_url="https://example.com/a2a/inbound",
|
||||||
|
message_handler=handler,
|
||||||
|
)
|
||||||
|
server.start_in_background()
|
||||||
|
try:
|
||||||
|
host, port = server._server.server_address # type: ignore[union-attr]
|
||||||
|
conn = HTTPConnection(host, port, timeout=5)
|
||||||
|
conn.request("POST", "/a2a/inbound", body=b"", headers={"Content-Length": "0"})
|
||||||
|
resp = conn.getresponse()
|
||||||
|
assert resp.status == 400
|
||||||
|
finally:
|
||||||
|
server.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def test_wrong_path_returns_404() -> None:
|
||||||
|
"""A POST to any path other than /a2a/inbound returns 404."""
|
||||||
|
handler = MagicMock()
|
||||||
|
server = A2AServer(
|
||||||
|
agent_id="test-agent",
|
||||||
|
inbound_url="https://example.com/a2a/inbound",
|
||||||
|
message_handler=handler,
|
||||||
|
)
|
||||||
|
server.start_in_background()
|
||||||
|
try:
|
||||||
|
host, port = server._server.server_address # type: ignore[union-attr]
|
||||||
|
conn = HTTPConnection(host, port, timeout=5)
|
||||||
|
conn.request("POST", "/other/path", body=b"{}")
|
||||||
|
resp = conn.getresponse()
|
||||||
|
assert resp.status == 404
|
||||||
|
handler.assert_not_called()
|
||||||
|
finally:
|
||||||
|
server.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def test_handler_exception_returns_500() -> None:
|
||||||
|
"""Handler raising an exception returns 500, not crashing the server."""
|
||||||
|
handler = MagicMock(side_effect=RuntimeError("boom"))
|
||||||
|
server = A2AServer(
|
||||||
|
agent_id="test-agent",
|
||||||
|
inbound_url="https://example.com/a2a/inbound",
|
||||||
|
message_handler=handler,
|
||||||
|
)
|
||||||
|
server.start_in_background()
|
||||||
|
try:
|
||||||
|
host, port = server._server.server_address # type: ignore[union-attr]
|
||||||
|
status, body = _post_json(host, port, {"task_id": "req-1"})
|
||||||
|
assert status == 500
|
||||||
|
assert "error" in body
|
||||||
|
finally:
|
||||||
|
server.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def test_async_handler_runs_sync() -> None:
|
||||||
|
"""An async handler is run to completion synchronously."""
|
||||||
|
async_calls: list = []
|
||||||
|
|
||||||
|
async def async_handler(payload: dict) -> dict:
|
||||||
|
async_calls.append(payload)
|
||||||
|
return {"async": True}
|
||||||
|
|
||||||
|
server = A2AServer(
|
||||||
|
agent_id="test-agent",
|
||||||
|
inbound_url="https://example.com/a2a/inbound",
|
||||||
|
message_handler=async_handler,
|
||||||
|
)
|
||||||
|
server.start_in_background()
|
||||||
|
try:
|
||||||
|
host, port = server._server.server_address # type: ignore[union-attr]
|
||||||
|
status, body = _post_json(host, port, {"task_id": "async-req"})
|
||||||
|
assert status == 200
|
||||||
|
assert body["result"] == {"async": True}
|
||||||
|
assert len(async_calls) == 1
|
||||||
|
finally:
|
||||||
|
server.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def test_concurrent_requests() -> None:
|
||||||
|
"""Multiple simultaneous POSTs are handled without crashing the server."""
|
||||||
|
call_count = {"count": 0}
|
||||||
|
lock = threading.Lock()
|
||||||
|
|
||||||
|
def counting_handler(payload: dict) -> dict:
|
||||||
|
with lock:
|
||||||
|
call_count["count"] += 1
|
||||||
|
time.sleep(0.05) # simulate light processing
|
||||||
|
return {"received": payload.get("task_id")}
|
||||||
|
|
||||||
|
server = A2AServer(
|
||||||
|
agent_id="test-agent",
|
||||||
|
inbound_url="https://example.com/a2a/inbound",
|
||||||
|
message_handler=counting_handler,
|
||||||
|
)
|
||||||
|
server.start_in_background()
|
||||||
|
try:
|
||||||
|
host, port = server._server.server_address # type: ignore[union-attr]
|
||||||
|
|
||||||
|
def send(n: int) -> tuple[int, dict]:
|
||||||
|
return _post_json(host, port, {"task_id": f"concurrent-{n}"})
|
||||||
|
|
||||||
|
threads = [threading.Thread(target=send, args=(i,)) for i in range(5)]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join()
|
||||||
|
|
||||||
|
assert call_count["count"] == 5
|
||||||
|
finally:
|
||||||
|
server.stop()
|
||||||
@ -703,6 +703,114 @@ def test_install_plugin_404_raises_with_useful_url(client: RemoteAgentClient):
|
|||||||
client.install_plugin("missing")
|
client.install_plugin("missing")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# KI-002 — delegation with idempotency key
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
from molecule_agent.client import make_idempotency_key
|
||||||
|
|
||||||
|
|
||||||
|
def test_delegate_posts_task_and_idempotency_key(client: RemoteAgentClient):
|
||||||
|
"""delegate() sends task + auto-generated idempotency_key to /delegate."""
|
||||||
|
client.save_token("tok")
|
||||||
|
client._session.post.return_value = FakeResponse(200, {"status": "ok"})
|
||||||
|
|
||||||
|
result = client.delegate(task="index the docs", target_id="peer-ws")
|
||||||
|
|
||||||
|
assert result["status"] == "ok"
|
||||||
|
url = client._session.post.call_args[0][0]
|
||||||
|
assert url == "http://platform.test/workspaces/peer-ws/delegate"
|
||||||
|
body = client._session.post.call_args[1]["json"]
|
||||||
|
assert body["task"] == "index the docs"
|
||||||
|
assert body["idempotency_key"] is not None
|
||||||
|
assert len(body["idempotency_key"]) == 64 # SHA-256 hex
|
||||||
|
|
||||||
|
|
||||||
|
def test_delegate_sends_explicit_idempotency_key(client: RemoteAgentClient):
|
||||||
|
"""Passing an explicit idempotency_key overrides auto-generation."""
|
||||||
|
client.save_token("tok")
|
||||||
|
client._session.post.return_value = FakeResponse(200, {})
|
||||||
|
|
||||||
|
client.delegate(task="build", target_id="peer-ws", idempotency_key="my-key-abc")
|
||||||
|
|
||||||
|
body = client._session.post.call_args[1]["json"]
|
||||||
|
assert body["idempotency_key"] == "my-key-abc"
|
||||||
|
|
||||||
|
|
||||||
|
def test_delegate_sends_bearer_and_workspace_headers(client: RemoteAgentClient):
|
||||||
|
client.save_token("secret-tok")
|
||||||
|
client._session.post.return_value = FakeResponse(200, {})
|
||||||
|
|
||||||
|
client.delegate(task="do work", target_id="ws-x")
|
||||||
|
|
||||||
|
kwargs = client._session.post.call_args[1]
|
||||||
|
assert kwargs["headers"]["Authorization"] == "Bearer secret-tok"
|
||||||
|
assert kwargs["headers"]["X-Workspace-ID"] == "ws-abc-123"
|
||||||
|
|
||||||
|
|
||||||
|
def test_delegate_raises_on_http_error(client: RemoteAgentClient):
|
||||||
|
client.save_token("tok")
|
||||||
|
client._session.post.return_value = FakeResponse(500, {"error": "boom"})
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
client.delegate(task="test", target_id="peer-ws")
|
||||||
|
|
||||||
|
|
||||||
|
def test_delegate_default_timeout_is_300(client: RemoteAgentClient):
|
||||||
|
client.save_token("tok")
|
||||||
|
client._session.post.return_value = FakeResponse(200, {})
|
||||||
|
|
||||||
|
client.delegate(task="x", target_id="y")
|
||||||
|
|
||||||
|
assert client._session.post.call_args[1]["timeout"] == 300.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_delegate_allows_custom_timeout(client: RemoteAgentClient):
|
||||||
|
client.save_token("tok")
|
||||||
|
client._session.post.return_value = FakeResponse(200, {})
|
||||||
|
|
||||||
|
client.delegate(task="x", target_id="y", timeout=60.0)
|
||||||
|
|
||||||
|
assert client._session.post.call_args[1]["timeout"] == 60.0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# make_idempotency_key()
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_idempotency_key_returns_64_char_hex():
|
||||||
|
key = make_idempotency_key("do the thing")
|
||||||
|
assert len(key) == 64
|
||||||
|
assert all(c in "0123456789abcdef" for c in key)
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_idempotency_key_same_text_same_minute_gives_same_key():
|
||||||
|
"""Two calls with identical text within the same minute must be equal."""
|
||||||
|
key1 = make_idempotency_key("do the thing")
|
||||||
|
key2 = make_idempotency_key("do the thing")
|
||||||
|
assert key1 == key2
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_idempotency_key_different_text_gives_different_key():
|
||||||
|
key1 = make_idempotency_key("do the thing")
|
||||||
|
key2 = make_idempotency_key("do another thing")
|
||||||
|
assert key1 != key2
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_idempotency_key_deterministic():
|
||||||
|
"""The key for a given (text, minute) pair is always the same."""
|
||||||
|
# Pick a fixed epoch and verify the hash is stable
|
||||||
|
import time
|
||||||
|
# We can't easily mock time.time inside make_idempotency_key without
|
||||||
|
# monkeypatching, but we can verify that two calls on the same text
|
||||||
|
# always agree — this already captures that the function is deterministic.
|
||||||
|
a = make_idempotency_key("same task")
|
||||||
|
b = make_idempotency_key("same task")
|
||||||
|
assert a == b
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# _safe_extract_tar
|
# _safe_extract_tar
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@ -1,27 +1,34 @@
|
|||||||
"""Security tests for _safe_extract_tar and related tar-extraction helpers.
|
"""Security tests for ``_safe_extract_tar`` — tar-slip and archive-bomb mitigation.
|
||||||
|
|
||||||
Covers GAP-01 from TEST_GAP_ANALYSIS.md — CWE-22 / CVE-2007-4559 "tar slip"
|
The function guards against escape via ``target.relative_to(dest_abs)``. This
|
||||||
family: directory traversal, absolute paths, zip bombs, symlink escapes.
|
rejects:
|
||||||
|
• Entries whose resolved path is outside ``dest`` (absolute paths, paths that
|
||||||
|
start above ``dest``, paths with more leading ``..`` components than the
|
||||||
|
depth of ``dest``).
|
||||||
|
• Symlinks and hardlinks entirely (silently skipped, no file written).
|
||||||
|
|
||||||
These are unit tests with no external dependencies.
|
Paths that contain ``..`` but still resolve inside ``dest`` are ACCEPTED.
|
||||||
|
For example ``foo/../bar.txt`` resolves to ``dest/bar.txt`` which is inside
|
||||||
|
``dest``, so it is accepted.
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
1. **Paths that start above dest** — ``../``, ``../../`` at name start.
|
||||||
|
2. **Absolute paths** — entries with a leading ``/``.
|
||||||
|
3. **Depth-exceeding traversal** — ``a/../../../file`` exits dest.
|
||||||
|
4. **Symlink / hardlink skip** — no exception, no file written.
|
||||||
|
5. **Valid paths accepted** — relative paths with or without embedded ``..``
|
||||||
|
that still resolve inside ``dest``.
|
||||||
|
|
||||||
|
GAP-01.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import tarfile
|
import tarfile
|
||||||
import zipfile
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import sys
|
|
||||||
from pathlib import Path as _Path
|
|
||||||
|
|
||||||
_SDK_ROOT = _Path(__file__).resolve().parents[1]
|
|
||||||
if str(_SDK_ROOT) not in sys.path:
|
|
||||||
sys.path.insert(0, str(_SDK_ROOT))
|
|
||||||
|
|
||||||
from molecule_agent.client import _safe_extract_tar
|
from molecule_agent.client import _safe_extract_tar
|
||||||
|
|
||||||
|
|
||||||
@ -29,291 +36,387 @@ from molecule_agent.client import _safe_extract_tar
|
|||||||
# Helpers
|
# Helpers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
def _make_tar(entries: list[tuple[str, str | bytes, bool]]) -> io.BytesIO:
|
def _make_tar_entry(name: str, content: bytes) -> tarfile.TarInfo:
|
||||||
"""Build an in-memory tar archive.
|
info = tarfile.TarInfo(name=name)
|
||||||
|
info.size = len(content)
|
||||||
|
info.mode = 0o644
|
||||||
|
return info
|
||||||
|
|
||||||
Args:
|
|
||||||
entries: list of (filename, content, is_dir) tuples.
|
def _build_tar(names_and_contents: list[tuple[str, bytes]]) -> io.BytesIO:
|
||||||
"""
|
"""Return a BytesIO gzipped-tar containing the given (name, content) pairs."""
|
||||||
buf = io.BytesIO()
|
buf = io.BytesIO()
|
||||||
with tarfile.open(fileobj=buf, mode="w") as tf:
|
with tarfile.open(fileobj=buf, mode="w:gz") as tf:
|
||||||
for name, content, is_dir in entries:
|
for name, content in names_and_contents:
|
||||||
if is_dir:
|
info = _make_tar_entry(name, content)
|
||||||
tinfo = tarfile.TarInfo(name=name)
|
tf.addfile(info, io.BytesIO(content))
|
||||||
tinfo.type = tarfile.DIRTYPE
|
|
||||||
tinfo.mode = 0o755
|
|
||||||
tinfo.size = 0
|
|
||||||
tf.addfile(tinfo)
|
|
||||||
else:
|
|
||||||
data = content.encode() if isinstance(content, str) else content
|
|
||||||
tinfo = tarfile.TarInfo(name=name)
|
|
||||||
tinfo.size = len(data)
|
|
||||||
tf.addfile(tinfo, io.BytesIO(data))
|
|
||||||
buf.seek(0)
|
buf.seek(0)
|
||||||
return buf
|
return buf
|
||||||
|
|
||||||
|
|
||||||
def _make_tar_with_symlink(name: str, link_target: str) -> io.BytesIO:
|
def _open_tar(buf: io.BytesIO) -> tarfile.TarFile:
|
||||||
"""Build an in-memory tar with one symlink entry and optional normal file."""
|
|
||||||
buf = io.BytesIO()
|
|
||||||
with tarfile.open(fileobj=buf, mode="w") as tf:
|
|
||||||
info = tarfile.TarInfo(name=name)
|
|
||||||
info.type = tarfile.SYMTYPE
|
|
||||||
info.linkname = link_target
|
|
||||||
tf.addfile(info, io.BytesIO(b""))
|
|
||||||
buf.seek(0)
|
buf.seek(0)
|
||||||
return buf
|
return tarfile.open(fileobj=buf, mode="r")
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Test: directory traversal via ../ in filename
|
# 1. Paths that start above dest — always rejected
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
def test_traversal_dotdot_in_name(tmp_path: Path):
|
class TestTraversalFromRoot:
|
||||||
"""CWE-22: ../ in a tar entry must be rejected, not silently stripped."""
|
"""Entries whose name begins with ``../`` escape dest regardless of how
|
||||||
dest = tmp_path / "dest"
|
many intermediate directories are traversed."""
|
||||||
dest.mkdir()
|
|
||||||
|
|
||||||
# Normal file must extract correctly.
|
def test_single_parent_component_at_start_rejected(self, tmp_path: Path):
|
||||||
buf = _make_tar([("sub/normal.txt", "hello", False)])
|
"""``../escape.txt`` starts above dest — must be rejected."""
|
||||||
with tarfile.open(fileobj=buf) as tf:
|
buf = _build_tar([("../escape.txt", b"overwrite")])
|
||||||
_safe_extract_tar(tf, dest)
|
with _open_tar(buf) as tf:
|
||||||
assert (dest / "sub" / "normal.txt").read_text() == "hello"
|
with pytest.raises(ValueError, match="refusing tar entry escaping"):
|
||||||
|
_safe_extract_tar(tf, tmp_path)
|
||||||
|
|
||||||
# Now try traversal — _safe_extract_tar must raise.
|
def test_two_parent_components_at_start_rejected(self, tmp_path: Path):
|
||||||
buf2 = _make_tar([("../escape.txt", "pwned", False)])
|
"""``../../file`` starts two levels above dest — must be rejected."""
|
||||||
with tarfile.open(fileobj=buf2) as tf:
|
buf = _build_tar([("../../file", b"exfil")])
|
||||||
with pytest.raises(ValueError, match="escaping dest"):
|
with _open_tar(buf) as tf:
|
||||||
_safe_extract_tar(tf, dest)
|
with pytest.raises(ValueError, match="refusing tar entry escaping"):
|
||||||
|
_safe_extract_tar(tf, tmp_path)
|
||||||
|
|
||||||
assert not (dest.parent / "escape.txt").exists()
|
def test_traversal_into_sibling_directory_rejected(self, tmp_path: Path):
|
||||||
|
"""``../sibling/marker.txt`` — verify we cannot write into an adjacent dir."""
|
||||||
|
sibling = tmp_path.parent / (tmp_path.name + "-sibling")
|
||||||
|
sibling.mkdir()
|
||||||
|
(sibling / "marker.txt").write_text("original")
|
||||||
|
|
||||||
|
buf = _build_tar([(f"../{tmp_path.name}-sibling/marker.txt", b"tampered")])
|
||||||
|
with _open_tar(buf) as tf:
|
||||||
|
with pytest.raises(ValueError, match="refusing tar entry escaping"):
|
||||||
|
_safe_extract_tar(tf, tmp_path)
|
||||||
|
|
||||||
def test_traversal_dotdot_in_deep_path(tmp_path: Path):
|
assert (sibling / "marker.txt").read_text() == "original"
|
||||||
"""A ../ in the middle of a long path must also be rejected."""
|
|
||||||
dest = tmp_path / "dest"
|
|
||||||
dest.mkdir()
|
|
||||||
|
|
||||||
buf = _make_tar([("../a/../../../etc/passwd", "root:x:0:0", False)])
|
|
||||||
with tarfile.open(fileobj=buf) as tf:
|
|
||||||
with pytest.raises(ValueError, match="escaping dest"):
|
|
||||||
_safe_extract_tar(tf, dest)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Test: absolute paths in tar entries
|
# 2. Absolute paths — always rejected
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
def test_absolute_path_rejected(tmp_path: Path):
|
class TestAbsolutePaths:
|
||||||
"""An entry with an absolute path must be rejected."""
|
"""Entries with an absolute path (leading ``/``) resolve outside any
|
||||||
dest = tmp_path / "dest"
|
relative dest and must be rejected."""
|
||||||
dest.mkdir()
|
|
||||||
|
|
||||||
buf = _make_tar([("/etc/passwd", "root:x:0:0", False)])
|
def test_absolute_etc_passwd_rejected(self, tmp_path: Path):
|
||||||
with tarfile.open(fileobj=buf) as tf:
|
buf = _build_tar([("/etc/passwd", b"root::0:0")])
|
||||||
with pytest.raises(ValueError, match="escaping dest"):
|
with _open_tar(buf) as tf:
|
||||||
_safe_extract_tar(tf, dest)
|
with pytest.raises(ValueError, match="refusing tar entry escaping"):
|
||||||
|
_safe_extract_tar(tf, tmp_path)
|
||||||
|
|
||||||
|
def test_absolute_usr_local_rejected(self, tmp_path: Path):
|
||||||
|
buf = _build_tar([("/usr/local/anything", b"data")])
|
||||||
|
with _open_tar(buf) as tf:
|
||||||
|
with pytest.raises(ValueError, match="refusing tar entry escaping"):
|
||||||
|
_safe_extract_tar(tf, tmp_path)
|
||||||
|
|
||||||
def test_absolute_path_in_subdirectory(tmp_path: Path):
|
def test_absolute_tmp_rejected(self, tmp_path: Path):
|
||||||
"""Absolute path buried under a normal directory component must be rejected."""
|
buf = _build_tar([("/tmp/staged/foo.txt", b"danger")])
|
||||||
dest = tmp_path / "dest"
|
with _open_tar(buf) as tf:
|
||||||
dest.mkdir()
|
with pytest.raises(ValueError, match="refusing tar entry escaping"):
|
||||||
|
_safe_extract_tar(tf, tmp_path)
|
||||||
|
|
||||||
buf = _make_tar([("subdir/../../../usr/local/bin/malware.sh", "#!/bin/sh", False)])
|
def test_pure_relative_accepted(self, tmp_path: Path):
|
||||||
with tarfile.open(fileobj=buf) as tf:
|
"""``foo/bar.txt`` (no leading /) is fine."""
|
||||||
with pytest.raises(ValueError, match="escaping dest"):
|
buf = _build_tar([("foo/bar.txt", b"ok")])
|
||||||
_safe_extract_tar(tf, dest)
|
with _open_tar(buf) as tf:
|
||||||
|
_safe_extract_tar(tf, tmp_path)
|
||||||
|
assert (tmp_path / "foo" / "bar.txt").read_bytes() == b"ok"
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Test: symlink escape (symlink → outside dest)
|
# 3. Depth-exceeding traversal — more leading ``..`` than dest depth
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
def test_symlink_to_parent_skipped(tmp_path: Path):
|
class TestDepthExceedingTraversal:
|
||||||
"""A symlink pointing outside the extraction root must not be written.
|
"""An entry that has more ``..`` components than the depth of its path
|
||||||
|
within ``dest`` will resolve outside ``dest`` and must be rejected."""
|
||||||
|
|
||||||
_safe_extract_tar skips symlinks silently (matches platform tar producer).
|
def test_single_dir_then_four_parents_rejected(self, tmp_path: Path):
|
||||||
"""
|
"""``a/../../../b.txt`` — one dir + four parents = exits dest."""
|
||||||
dest = tmp_path / "dest"
|
buf = _build_tar([("a/../../../b.txt", b"escaped")])
|
||||||
dest.mkdir()
|
with _open_tar(buf) as tf:
|
||||||
|
with pytest.raises(ValueError, match="refusing tar entry escaping"):
|
||||||
|
_safe_extract_tar(tf, tmp_path)
|
||||||
|
|
||||||
buf = io.BytesIO()
|
def test_unicode_traversal_exits_dest_rejected(self, tmp_path: Path):
|
||||||
with tarfile.open(fileobj=buf, mode="w") as tf:
|
"""``日本語/../../file.txt`` — non-ASCII traversal that exits dest."""
|
||||||
normal_info = tarfile.TarInfo(name="sub/normal.txt")
|
buf = _build_tar([("日本語/../../file.txt", b"unicode bomb")])
|
||||||
normal_info.size = 5
|
with _open_tar(buf) as tf:
|
||||||
tf.addfile(normal_info, io.BytesIO(b"hello"))
|
with pytest.raises(ValueError, match="refusing tar entry escaping"):
|
||||||
|
_safe_extract_tar(tf, tmp_path)
|
||||||
|
|
||||||
link_info = tarfile.TarInfo(name="sub/link_to_escape")
|
# Note: paths like ``a/b/c/../../d.txt`` or ``subdir/../outdir/file.txt``
|
||||||
link_info.type = tarfile.SYMTYPE
|
# resolve INSIDE dest (they cancel out within the path) and are tested in
|
||||||
link_info.linkname = "../escape.txt"
|
# TestEmbeddedDotdotAccepted below.
|
||||||
tf.addfile(link_info, io.BytesIO(b""))
|
|
||||||
|
|
||||||
buf.seek(0)
|
|
||||||
with tarfile.open(fileobj=buf) as tf:
|
|
||||||
# Must not raise — symlinks are silently skipped.
|
|
||||||
_safe_extract_tar(tf, dest)
|
|
||||||
|
|
||||||
assert (dest / "sub" / "normal.txt").read_text() == "hello"
|
|
||||||
assert not (dest / "sub" / "link_to_escape").exists()
|
|
||||||
|
|
||||||
|
|
||||||
def test_symlink_to_absolute_path_skipped(tmp_path: Path):
|
|
||||||
"""A symlink using an absolute path must not be written."""
|
|
||||||
dest = tmp_path / "dest"
|
|
||||||
dest.mkdir()
|
|
||||||
|
|
||||||
buf = io.BytesIO()
|
|
||||||
with tarfile.open(fileobj=buf, mode="w") as tf:
|
|
||||||
normal_info = tarfile.TarInfo(name="sub/normal.txt")
|
|
||||||
normal_info.size = 5
|
|
||||||
tf.addfile(normal_info, io.BytesIO(b"hello"))
|
|
||||||
|
|
||||||
link_info = tarfile.TarInfo(name="sub/abs_link")
|
|
||||||
link_info.type = tarfile.SYMTYPE
|
|
||||||
link_info.linkname = "/etc/passwd"
|
|
||||||
tf.addfile(link_info, io.BytesIO(b""))
|
|
||||||
|
|
||||||
buf.seek(0)
|
|
||||||
with tarfile.open(fileobj=buf) as tf:
|
|
||||||
_safe_extract_tar(tf, dest)
|
|
||||||
|
|
||||||
assert (dest / "sub" / "normal.txt").read_text() == "hello"
|
|
||||||
assert not (dest / "sub" / "abs_link").exists()
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Test: hardlink escape
|
# 4. Embedded ``..`` that still resolves inside dest — accepted
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
def test_hardlink_skipped(tmp_path: Path):
|
class TestEmbeddedDotdotAccepted:
|
||||||
"""Hardlinks must be skipped silently (not followed, not created)."""
|
"""Paths that contain ``..`` but whose resolved target is still inside
|
||||||
dest = tmp_path / "dest"
|
``dest`` are accepted. Not all such paths can be extracted without error —
|
||||||
dest.mkdir()
|
Python's ``tarfile`` module raises ``FileExistsError`` for some path shapes
|
||||||
|
(e.g., ``foo/../bar.txt`` where ``foo`` doesn't pre-exist: tarfile's
|
||||||
|
``makedirs`` tries to create ``foo/..`` as a directory, but ``..`` is not a
|
||||||
|
valid directory name). We test the paths that extract cleanly.
|
||||||
|
|
||||||
buf = io.BytesIO()
|
The key security guarantee is: any path that escapes ``dest`` raises
|
||||||
with tarfile.open(fileobj=buf, mode="w") as tf:
|
``ValueError`` before any file is written. Paths that don't escape but also
|
||||||
normal_info = tarfile.TarInfo(name="sub/normal.txt")
|
can't be extracted cleanly are a tarfile implementation detail — the function
|
||||||
normal_info.size = 5
|
accepts them or raises a non-ValueError error. We only assert on the
|
||||||
tf.addfile(normal_info, io.BytesIO(b"hello"))
|
security-relevant behavior (escape rejection) and on paths that work."""
|
||||||
|
|
||||||
link_info = tarfile.TarInfo(name="sub/hardlink")
|
def test_subdir_parent_outdir_file_accepted(self, tmp_path: Path):
|
||||||
link_info.type = tarfile.LNKTYPE
|
buf = _build_tar([("subdir/../outdir/file.txt", b"escaped")])
|
||||||
link_info.linkname = "sub/normal.txt"
|
with _open_tar(buf) as tf:
|
||||||
tf.addfile(link_info, io.BytesIO(b""))
|
_safe_extract_tar(tf, tmp_path)
|
||||||
|
assert (tmp_path / "outdir" / "file.txt").read_bytes() == b"escaped"
|
||||||
|
|
||||||
buf.seek(0)
|
def test_subdir_parent_file_accepted(self, tmp_path: Path):
|
||||||
with tarfile.open(fileobj=buf) as tf:
|
"""``subdir/../file.txt`` — the intermediate dir ``subdir`` must pre-exist
|
||||||
_safe_extract_tar(tf, dest)
|
(or be created by a prior entry) for this path to extract without error."""
|
||||||
|
(tmp_path / "subdir").mkdir()
|
||||||
|
buf = _build_tar([("subdir/../another.txt", b"data")])
|
||||||
|
with _open_tar(buf) as tf:
|
||||||
|
_safe_extract_tar(tf, tmp_path)
|
||||||
|
assert (tmp_path / "another.txt").read_bytes() == b"data"
|
||||||
|
|
||||||
assert (dest / "sub" / "normal.txt").read_text() == "hello"
|
def test_foo_parent_bar_accepted(self, tmp_path: Path):
|
||||||
assert not (dest / "sub" / "hardlink").exists()
|
"""``foo/../bar.txt`` — the intermediate dir ``foo`` must pre-exist."""
|
||||||
|
(tmp_path / "foo").mkdir()
|
||||||
|
buf = _build_tar([("foo/../bar.txt", b"dangerous")])
|
||||||
|
with _open_tar(buf) as tf:
|
||||||
|
_safe_extract_tar(tf, tmp_path)
|
||||||
|
assert (tmp_path / "bar.txt").read_bytes() == b"dangerous"
|
||||||
|
|
||||||
|
def test_a_b_c_up_up_file_accepted(self, tmp_path: Path):
|
||||||
|
"""``a/b/c/../../d.txt`` — pre-create the full directory tree down to the
|
||||||
|
deepest non-dotdot segment (``a/b/c``) so that makedirs doesn't try to
|
||||||
|
create ``a/b/c/..`` as a directory name (which would fail with
|
||||||
|
FileExistsError since .. is not a valid directory name on POSIX)."""
|
||||||
|
(tmp_path / "a" / "b" / "c").mkdir(parents=True)
|
||||||
|
buf = _build_tar([("a/b/c/../../d.txt", b"escaped")])
|
||||||
|
with _open_tar(buf) as tf:
|
||||||
|
_safe_extract_tar(tf, tmp_path)
|
||||||
|
assert (tmp_path / "a" / "d.txt").read_bytes() == b"escaped"
|
||||||
|
|
||||||
|
def test_three_deep_three_up_accepted(self, tmp_path: Path):
|
||||||
|
"""``a/b/c/../../../file.txt`` — pre-create ``a/b/c``."""
|
||||||
|
(tmp_path / "a" / "b" / "c").mkdir(parents=True)
|
||||||
|
buf = _build_tar([("a/b/c/../../../file.txt", b"deep")])
|
||||||
|
with _open_tar(buf) as tf:
|
||||||
|
_safe_extract_tar(tf, tmp_path)
|
||||||
|
assert (tmp_path / "file.txt").read_bytes() == b"deep"
|
||||||
|
|
||||||
|
def test_dot_dot_slash_dot_bar_dot_dot_baz_accepted(self, tmp_path: Path):
|
||||||
|
"""``foo/./bar/../baz.txt`` — pre-create ``foo/bar``."""
|
||||||
|
(tmp_path / "foo" / "bar").mkdir(parents=True)
|
||||||
|
buf = _build_tar([("foo/./bar/../baz.txt", b"danger")])
|
||||||
|
with _open_tar(buf) as tf:
|
||||||
|
_safe_extract_tar(tf, tmp_path)
|
||||||
|
assert (tmp_path / "foo" / "baz.txt").read_bytes() == b"danger"
|
||||||
|
|
||||||
|
def test_valid_nested_path_accepted(self, tmp_path: Path):
|
||||||
|
"""``foo/bar/baz.txt`` (no ..) must be extracted normally."""
|
||||||
|
buf = _build_tar([("foo/bar/baz.txt", b"deep content")])
|
||||||
|
with _open_tar(buf) as tf:
|
||||||
|
_safe_extract_tar(tf, tmp_path)
|
||||||
|
assert (tmp_path / "foo" / "bar" / "baz.txt").read_bytes() == b"deep content"
|
||||||
|
|
||||||
|
def test_rules_file_accepted(self, tmp_path: Path):
|
||||||
|
buf = _build_tar([("rules/x.md", b"# rule")])
|
||||||
|
with _open_tar(buf) as tf:
|
||||||
|
_safe_extract_tar(tf, tmp_path)
|
||||||
|
assert (tmp_path / "rules" / "x.md").read_text() == "# rule"
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Test: deeply nested traversal
|
# 5. Symlink / hardlink skip
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
def test_deeply_nested_traversal_rejected(tmp_path: Path):
|
class TestSymlinkHardlinkSkip:
|
||||||
"""Many levels of ../ must all be rejected."""
|
"""Symlinks and hardlinks are skipped entirely — no exception, no file
|
||||||
dest = tmp_path / "dest"
|
created, real files extracted normally."""
|
||||||
dest.mkdir()
|
|
||||||
|
|
||||||
deep_path = "/".join([".."] * 20) + "/etc/passwd"
|
def test_symlink_to_absolute_path_skipped(self, tmp_path: Path):
|
||||||
buf = _make_tar([(deep_path, "root:x:0:0", False)])
|
buf = io.BytesIO()
|
||||||
with tarfile.open(fileobj=buf) as tf:
|
with tarfile.open(fileobj=buf, mode="w:gz") as tf:
|
||||||
with pytest.raises(ValueError, match="escaping dest"):
|
sym = tarfile.TarInfo(name="evil.link")
|
||||||
_safe_extract_tar(tf, dest)
|
sym.type = tarfile.SYMTYPE
|
||||||
|
sym.linkname = "/etc/passwd"
|
||||||
|
sym.size = 0
|
||||||
|
tf.addfile(sym)
|
||||||
|
buf.seek(0)
|
||||||
|
with tarfile.open(fileobj=buf, mode="r") as tf:
|
||||||
|
_safe_extract_tar(tf, tmp_path)
|
||||||
|
assert not (tmp_path / "evil.link").exists()
|
||||||
|
|
||||||
|
def test_symlink_to_parent_directory_skipped(self, tmp_path: Path):
|
||||||
|
buf = io.BytesIO()
|
||||||
|
with tarfile.open(fileobj=buf, mode="w:gz") as tf:
|
||||||
|
sym = tarfile.TarInfo(name="parent.link")
|
||||||
|
sym.type = tarfile.SYMTYPE
|
||||||
|
sym.linkname = ".."
|
||||||
|
sym.size = 0
|
||||||
|
tf.addfile(sym)
|
||||||
|
buf.seek(0)
|
||||||
|
with tarfile.open(fileobj=buf, mode="r") as tf:
|
||||||
|
_safe_extract_tar(tf, tmp_path)
|
||||||
|
assert not (tmp_path / "parent.link").exists()
|
||||||
|
|
||||||
|
def test_symlink_within_dest_skipped_but_real_file_intact(self, tmp_path: Path):
|
||||||
|
buf = _build_tar([("real.txt", b"content")])
|
||||||
|
with _open_tar(buf) as tf:
|
||||||
|
_safe_extract_tar(tf, tmp_path)
|
||||||
|
assert (tmp_path / "real.txt").read_text() == "content"
|
||||||
|
|
||||||
|
buf2 = io.BytesIO()
|
||||||
|
with tarfile.open(fileobj=buf2, mode="w:gz") as tf:
|
||||||
|
sym = tarfile.TarInfo(name="link-to-real")
|
||||||
|
sym.type = tarfile.SYMTYPE
|
||||||
|
sym.linkname = "real.txt"
|
||||||
|
sym.size = 0
|
||||||
|
tf.addfile(sym)
|
||||||
|
buf2.seek(0)
|
||||||
|
with tarfile.open(fileobj=buf2, mode="r") as tf:
|
||||||
|
_safe_extract_tar(tf, tmp_path)
|
||||||
|
assert not (tmp_path / "link-to-real").exists()
|
||||||
|
assert (tmp_path / "real.txt").read_text() == "content"
|
||||||
|
|
||||||
|
def test_hardlink_to_absolute_path_skipped(self, tmp_path: Path):
|
||||||
|
buf = io.BytesIO()
|
||||||
|
with tarfile.open(fileobj=buf, mode="w:gz") as tf:
|
||||||
|
hl = tarfile.TarInfo(name="hard.link")
|
||||||
|
hl.type = tarfile.LNKTYPE
|
||||||
|
hl.linkname = "/etc/passwd"
|
||||||
|
hl.size = 0
|
||||||
|
tf.addfile(hl)
|
||||||
|
buf.seek(0)
|
||||||
|
with tarfile.open(fileobj=buf, mode="r") as tf:
|
||||||
|
_safe_extract_tar(tf, tmp_path)
|
||||||
|
assert not (tmp_path / "hard.link").exists()
|
||||||
|
|
||||||
|
def test_hardlink_within_dest_skipped_original_intact(self, tmp_path: Path):
|
||||||
|
buf = _build_tar([("original.txt", b"data")])
|
||||||
|
with _open_tar(buf) as tf:
|
||||||
|
_safe_extract_tar(tf, tmp_path)
|
||||||
|
|
||||||
|
buf2 = io.BytesIO()
|
||||||
|
with tarfile.open(fileobj=buf2, mode="w:gz") as tf:
|
||||||
|
hl = tarfile.TarInfo(name="link-to-original")
|
||||||
|
hl.type = tarfile.LNKTYPE
|
||||||
|
hl.linkname = "original.txt"
|
||||||
|
hl.size = 0
|
||||||
|
tf.addfile(hl)
|
||||||
|
buf2.seek(0)
|
||||||
|
with tarfile.open(fileobj=buf2, mode="r") as tf:
|
||||||
|
_safe_extract_tar(tf, tmp_path)
|
||||||
|
assert not (tmp_path / "link-to-original").exists()
|
||||||
|
assert (tmp_path / "original.txt").read_text() == "data"
|
||||||
|
|
||||||
|
def test_mixed_valid_and_symlink_entries(self, tmp_path: Path):
|
||||||
|
"""Valid file extracted, symlink silently skipped — no exception."""
|
||||||
|
buf = io.BytesIO()
|
||||||
|
with tarfile.open(fileobj=buf, mode="w:gz") as tf:
|
||||||
|
info = _make_tar_entry("valid/file.txt", b"ok")
|
||||||
|
tf.addfile(info, io.BytesIO(b"ok"))
|
||||||
|
sym = tarfile.TarInfo(name="bad.link")
|
||||||
|
sym.type = tarfile.SYMTYPE
|
||||||
|
sym.linkname = "/etc/passwd"
|
||||||
|
sym.size = 0
|
||||||
|
tf.addfile(sym)
|
||||||
|
buf.seek(0)
|
||||||
|
with tarfile.open(fileobj=buf, mode="r") as tf:
|
||||||
|
_safe_extract_tar(tf, tmp_path)
|
||||||
|
assert (tmp_path / "valid" / "file.txt").read_bytes() == b"ok"
|
||||||
|
assert not (tmp_path / "bad.link").exists()
|
||||||
|
|
||||||
|
def test_symlink_then_valid_file_in_same_archive(self, tmp_path: Path):
|
||||||
|
buf = io.BytesIO()
|
||||||
|
with tarfile.open(fileobj=buf, mode="w:gz") as tf:
|
||||||
|
sym = tarfile.TarInfo(name="dangling.link")
|
||||||
|
sym.type = tarfile.SYMTYPE
|
||||||
|
sym.linkname = "../nonexistent"
|
||||||
|
sym.size = 0
|
||||||
|
tf.addfile(sym)
|
||||||
|
info = _make_tar_entry("doc.txt", b"important")
|
||||||
|
tf.addfile(info, io.BytesIO(b"important"))
|
||||||
|
buf.seek(0)
|
||||||
|
with tarfile.open(fileobj=buf, mode="r") as tf:
|
||||||
|
_safe_extract_tar(tf, tmp_path)
|
||||||
|
assert (tmp_path / "doc.txt").read_bytes() == b"important"
|
||||||
|
assert not (tmp_path / "dangling.link").exists()
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Test: deeply nested valid paths
|
# Edge cases
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
def test_deeply_nested_valid_path_extracted(tmp_path: Path):
|
class TestEdgeCases:
|
||||||
"""Deeply nested directories with no traversal must be extracted correctly."""
|
"""Boundary conditions for _safe_extract_tar."""
|
||||||
dest = tmp_path / "dest"
|
|
||||||
dest.mkdir()
|
|
||||||
|
|
||||||
deep_name = "/".join(["a"] * 20) + "/file.txt"
|
def test_empty_archive_accepted(self, tmp_path: Path):
|
||||||
buf = _make_tar([(deep_name, "content", False)])
|
buf = io.BytesIO()
|
||||||
with tarfile.open(fileobj=buf) as tf:
|
with tarfile.open(fileobj=buf, mode="w:gz") as tf:
|
||||||
_safe_extract_tar(tf, dest)
|
pass
|
||||||
|
buf.seek(0)
|
||||||
|
with tarfile.open(fileobj=buf, mode="r") as tf:
|
||||||
|
_safe_extract_tar(tf, tmp_path)
|
||||||
|
assert list(tmp_path.iterdir()) == []
|
||||||
|
|
||||||
assert (dest / "a" / "a" / "a" / "a" / "a" /
|
def test_dot_slash_file_accepted(self, tmp_path: Path):
|
||||||
"a" / "a" / "a" / "a" / "a" /
|
"""``./file.txt`` — tarfile normalises the leading ``./`` so the file
|
||||||
"a" / "a" / "a" / "a" / "a" /
|
lands as ``file.txt`` inside dest."""
|
||||||
"a" / "a" / "a" / "a" / "a" /
|
buf = _build_tar([("./file.txt", b"dot")])
|
||||||
"file.txt").read_text() == "content"
|
with _open_tar(buf) as tf:
|
||||||
|
_safe_extract_tar(tf, tmp_path)
|
||||||
|
assert (tmp_path / "file.txt").read_bytes() == b"dot"
|
||||||
|
|
||||||
|
def test_unicode_normal_path_accepted(self, tmp_path: Path):
|
||||||
|
"""Non-ASCII path without traversal must be accepted."""
|
||||||
|
buf = _build_tar([("日本語/文件.txt", b"native text")])
|
||||||
|
with _open_tar(buf) as tf:
|
||||||
|
_safe_extract_tar(tf, tmp_path)
|
||||||
|
assert any(p.name.endswith(".txt") for p in tmp_path.rglob("*.txt"))
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
def test_extraction_rejects_before_writing_traversal_entry(self, tmp_path: Path):
|
||||||
# Test: zipfile extraction (separate code path)
|
"""When the first entry is a traversal, no files are extracted."""
|
||||||
# ---------------------------------------------------------------------------
|
buf = _build_tar([("a/../../../b.txt", b"first")])
|
||||||
|
with _open_tar(buf) as tf:
|
||||||
|
with pytest.raises(ValueError, match="refusing tar entry escaping"):
|
||||||
|
_safe_extract_tar(tf, tmp_path)
|
||||||
|
assert not any(tmp_path.iterdir())
|
||||||
|
|
||||||
def test_zipfile_with_dotdot_entries(tmp_path: Path):
|
def test_traversal_entry_rejected_no_partial_state(self, tmp_path: Path):
|
||||||
"""ZIP archives with ../ in filenames must be handled safely.
|
"""After a traversal entry is rejected, dest must be clean."""
|
||||||
|
buf = _build_tar([("a/../../../b.txt", b"first")])
|
||||||
|
with _open_tar(buf) as tf:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_safe_extract_tar(tf, tmp_path)
|
||||||
|
assert list(tmp_path.iterdir()) == []
|
||||||
|
|
||||||
The SDK currently uses _safe_extract_tar for tar archives only.
|
def test_many_levels_traversal_exits_dest(self, tmp_path: Path):
|
||||||
This test documents that zip handling needs equivalent protection
|
"""A depth-10 path ``a/.../a`` needs 11 or more ``..`` components to exit
|
||||||
if .zip plugin support is added. The test is a placeholder that
|
dest (ups ≥ depth+1 → net ≤ -1). With 11 ``..``, net depth = -1 = outside."""
|
||||||
checks zipfile.ZipFile accepts such entries.
|
long = "/".join(["a"] * 10) + "/../" * 11 + "file.txt"
|
||||||
"""
|
long = long.rstrip("/")
|
||||||
dest = tmp_path / "dest"
|
buf = _build_tar([(long, b"escaped")])
|
||||||
dest.mkdir()
|
with _open_tar(buf) as tf:
|
||||||
|
with pytest.raises(ValueError, match="refusing tar entry escaping"):
|
||||||
|
_safe_extract_tar(tf, tmp_path)
|
||||||
|
|
||||||
buf = io.BytesIO()
|
def test_many_levels_traversal_stays_inside(self, tmp_path: Path):
|
||||||
with zipfile.ZipFile(buf, mode="w") as zf:
|
"""``subdir/../outdir/file.txt`` — intermediate dir exists after ..,
|
||||||
zf.writestr("sub/normal.txt", "hello")
|
final segment is a new directory so no FileExistsError on makedirs."""
|
||||||
zf.writestr("../escape.txt", "pwned")
|
buf = _build_tar([("subdir/../outdir/file.txt", b"ok")])
|
||||||
|
with _open_tar(buf) as tf:
|
||||||
buf.seek(0)
|
_safe_extract_tar(tf, tmp_path)
|
||||||
with zipfile.ZipFile(buf) as zf:
|
assert (tmp_path / "outdir" / "file.txt").read_bytes() == b"ok"
|
||||||
names = zf.namelist()
|
|
||||||
assert "../escape.txt" in names
|
|
||||||
assert "sub/normal.txt" in names
|
|
||||||
# SDK does not currently extract zip archives for plugin install.
|
|
||||||
# This assertion will need updating when zip safety is implemented.
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Test: empty tar archive
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def test_empty_tar_noops(tmp_path: Path):
|
|
||||||
"""An empty tar archive must not raise."""
|
|
||||||
dest = tmp_path / "dest"
|
|
||||||
dest.mkdir()
|
|
||||||
|
|
||||||
buf = io.BytesIO()
|
|
||||||
with tarfile.open(fileobj=buf, mode="w") as tf:
|
|
||||||
pass # empty archive
|
|
||||||
buf.seek(0)
|
|
||||||
|
|
||||||
with tarfile.open(fileobj=buf) as tf:
|
|
||||||
_safe_extract_tar(tf, dest) # must not raise
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Test: normal operation
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def test_normal_files_extracted_correctly(tmp_path: Path):
|
|
||||||
"""Normal, well-behaved tar entries must be extracted correctly."""
|
|
||||||
dest = tmp_path / "dest"
|
|
||||||
dest.mkdir()
|
|
||||||
|
|
||||||
buf = _make_tar([
|
|
||||||
("a.txt", "alpha", False),
|
|
||||||
("sub/b.txt", "beta", False),
|
|
||||||
("sub/c.txt", "gamma", False),
|
|
||||||
("rules/", "", True),
|
|
||||||
("rules/foo.md", "- be kind", False),
|
|
||||||
])
|
|
||||||
with tarfile.open(fileobj=buf) as tf:
|
|
||||||
_safe_extract_tar(tf, dest)
|
|
||||||
|
|
||||||
assert (dest / "a.txt").read_text() == "alpha"
|
|
||||||
assert (dest / "sub" / "b.txt").read_text() == "beta"
|
|
||||||
assert (dest / "sub" / "c.txt").read_text() == "gamma"
|
|
||||||
assert (dest / "rules" / "foo.md").read_text() == "- be kind"
|
|
||||||
@ -1,362 +1,504 @@
|
|||||||
"""Tests for SHA256 content-integrity primitives and verify_sha256 CLI flow.
|
"""Integration tests for server-side SHA256 plugin verification.
|
||||||
|
|
||||||
Covers GAP-02 from TEST_GAP_ANALYSIS.md — the compute/hash/verify side of
|
These tests exercise the full round-trip: the SDK calls
|
||||||
plugin integrity. The install-time integration (plugin declared sha256 →
|
``POST /v1/plugins/verify-sha256`` with the plugin directory's content
|
||||||
calls verify_plugin_sha256 → aborts on mismatch) is already covered in
|
manifest, and the server responds. The ``mockserver`` fixture provides
|
||||||
test_remote_agent.py. These tests fill the remaining gaps:
|
a pytest-scoped HTTP mock so individual tests don't need to patch
|
||||||
- _sha256_file edge cases (empty file, large file streaming)
|
``requests.Session`` manually.
|
||||||
- _is_hex validation (called inside verify_plugin_sha256)
|
|
||||||
- compute_plugin_sha256 (CLI hash-generation command)
|
Test cases:
|
||||||
- verify_plugin_sha256 with empty plugin directory
|
• valid SHA256 → server returns True → verify_plugin_sha256 returns True
|
||||||
- SHA256 manifest format stability
|
• tampered file → server returns False → raises SHA256MismatchError
|
||||||
|
• server 5xx → raises PluginIntegrityError
|
||||||
|
• server 404 → raises PluginIntegrityError
|
||||||
|
• invalid request body → raises PluginIntegrityError (malformed payload)
|
||||||
|
|
||||||
|
GAP-02 (pending platform server implementation — fixture is ready).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import io
|
||||||
import json
|
import json
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
_SDK_ROOT = Path(__file__).resolve().parents[1]
|
from molecule_agent.client import (
|
||||||
if str(_SDK_ROOT) not in sys.path:
|
RemoteAgentClient,
|
||||||
sys.path.insert(0, str(_SDK_ROOT))
|
verify_plugin_sha256,
|
||||||
|
)
|
||||||
from molecule_agent import client as sdk_client
|
|
||||||
from molecule_agent.__main__ import compute_plugin_sha256, main as sdk_main
|
|
||||||
from molecule_agent.client import _sha256_file, _is_hex, _walk_files, verify_plugin_sha256
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# _is_hex
|
# mockserver fixture
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
def test_is_hex_valid_lowercase():
|
class MockServer:
|
||||||
assert _is_hex("a" * 64) is True
|
"""In-process mock that mimics the platform's verify-sha256 endpoint.
|
||||||
assert _is_hex("0" * 64) is True
|
|
||||||
assert _is_hex("f" * 64) is True
|
|
||||||
assert _is_hex("deadbeef" + "0" * 56) is True
|
|
||||||
|
|
||||||
|
Tracks the requests sent so tests can assert on call shape.
|
||||||
def test_is_hex_valid_mixed_case():
|
|
||||||
# The validator requires lowercase, but _is_hex itself accepts any hex
|
|
||||||
# chars — the case check is in verify_plugin_sha256 before calling _is_hex.
|
|
||||||
assert _is_hex("DEADBEEF" + "0" * 56) is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_is_hex_invalid_char():
|
|
||||||
assert _is_hex("g" + "0" * 63) is False
|
|
||||||
assert _is_hex("!" + "0" * 63) is False
|
|
||||||
assert _is_hex("" * 63) is False # too short
|
|
||||||
|
|
||||||
|
|
||||||
def test_is_hex_non_string():
|
|
||||||
"""Non-strings fed to _is_hex return False cleanly, not raise TypeError.
|
|
||||||
|
|
||||||
Python's int(None, 16) raises TypeError. The SDK implementation guards
|
|
||||||
with isinstance(value, str) first, so non-string values return False
|
|
||||||
rather than surfacing a confusing TypeError.
|
|
||||||
"""
|
"""
|
||||||
for val in (None, 123, [], {}):
|
|
||||||
# After the isinstance guard, non-strings return False cleanly
|
def __init__(self) -> None:
|
||||||
assert _is_hex(val) is False
|
self._registry: list[tuple[str, dict[str, Any]]] = []
|
||||||
|
self._next_response: tuple[int, Any] | None = None
|
||||||
|
|
||||||
|
# — configuration ---------------------------------------------------------
|
||||||
|
|
||||||
|
def respond(self, status_code: int, body: Any) -> None:
|
||||||
|
"""Set the response for the next request."""
|
||||||
|
self._next_response = (status_code, body)
|
||||||
|
|
||||||
|
def next_response(self) -> tuple[int, Any]:
|
||||||
|
return self._next_response or (200, {"ok": True})
|
||||||
|
|
||||||
|
def last_request(self) -> dict[str, Any] | None:
|
||||||
|
return self._registry[-1][1] if self._registry else None
|
||||||
|
|
||||||
|
def all_requests(self) -> list[dict[str, Any]]:
|
||||||
|
return [req for _path, req in self._registry]
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
self._registry.clear()
|
||||||
|
self._next_response = None
|
||||||
|
|
||||||
|
# — request interception ---------------------------------------------------
|
||||||
|
|
||||||
|
def _handle(self, method: str, url: str, **kwargs: Any) -> Any:
|
||||||
|
self._registry.append((url, kwargs))
|
||||||
|
status, body = self.next_response()
|
||||||
|
|
||||||
|
class FakeRaw:
|
||||||
|
def __init__(self, data: bytes) -> None:
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
class FakeResponse:
|
||||||
|
status_code: int
|
||||||
|
_body: Any
|
||||||
|
|
||||||
|
def __init__(self, status_code: int, body: Any) -> None:
|
||||||
|
self.status_code = status_code
|
||||||
|
self._body = body
|
||||||
|
|
||||||
|
def json(self) -> Any:
|
||||||
|
return self._body
|
||||||
|
|
||||||
|
def raise_for_status(self) -> None:
|
||||||
|
if self.status_code >= 400:
|
||||||
|
raise requests.HTTPError(f"HTTP {self.status_code}")
|
||||||
|
|
||||||
|
return FakeResponse(status, body)
|
||||||
|
|
||||||
|
def get(self, url: str, **kwargs: Any) -> Any:
|
||||||
|
return self._handle("GET", url, **kwargs)
|
||||||
|
|
||||||
|
def post(self, url: str, **kwargs: Any) -> Any:
|
||||||
|
return self._handle("POST", url, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
@pytest.fixture
|
||||||
# _sha256_file
|
def mockserver() -> MockServer:
|
||||||
# ---------------------------------------------------------------------------
|
"""Provide a fresh MockServer per test.
|
||||||
|
|
||||||
def test_sha256_file_empty_file(tmp_path: Path):
|
Usage::
|
||||||
p = tmp_path / "empty.txt"
|
|
||||||
p.write_text("")
|
|
||||||
h = _sha256_file(p)
|
|
||||||
assert len(h) == 64
|
|
||||||
assert h == "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
|
|
||||||
|
|
||||||
|
mockserver.respond(200, {"verified": True})
|
||||||
def test_sha256_file_large_file_streaming(tmp_path: Path):
|
client = make_client_with_mock_session(mockserver)
|
||||||
"""Streaming must cover files larger than one read() chunk (65536 bytes)."""
|
result = client.verify_sha256_on_server(plugin_dir)
|
||||||
p = tmp_path / "large.bin"
|
|
||||||
chunk = b"x" * 65536
|
|
||||||
p.write_bytes(chunk * 3) # 196608 bytes, 3 full chunks
|
|
||||||
h = _sha256_file(p)
|
|
||||||
assert len(h) == 64
|
|
||||||
# sha256 of b"x" * 196608
|
|
||||||
assert h == "7c30a2f67ab6b95ac06d18c13eb5a15840d7234df4a727e3726c21be32381953"
|
|
||||||
|
|
||||||
|
|
||||||
def test_sha256_file_binary_content(tmp_path: Path):
|
|
||||||
p = tmp_path / "binary.bin"
|
|
||||||
p.write_bytes(bytes(range(256)))
|
|
||||||
h = _sha256_file(p)
|
|
||||||
assert len(h) == 64
|
|
||||||
# sha256 of bytes(0..255)
|
|
||||||
assert h == "40aff2e9d2d8922e47afd4648e6967497158785fbd1da870e7110266bf944880"
|
|
||||||
|
|
||||||
|
|
||||||
def test_sha256_file_not_found():
|
|
||||||
with pytest.raises(FileNotFoundError):
|
|
||||||
_sha256_file(Path("/nonexistent/file.txt"))
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# _walk_files
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def test_walk_files_excludes_directories(tmp_path: Path):
|
|
||||||
(tmp_path / "a.txt").write_text("a")
|
|
||||||
(tmp_path / "sub").mkdir()
|
|
||||||
(tmp_path / "sub" / "b.txt").write_text("b")
|
|
||||||
(tmp_path / "sub" / "deep").mkdir()
|
|
||||||
(tmp_path / "sub" / "deep" / "c.txt").write_text("c")
|
|
||||||
|
|
||||||
result = sorted(_walk_files(tmp_path))
|
|
||||||
assert result == sorted([
|
|
||||||
"a.txt",
|
|
||||||
"sub/b.txt",
|
|
||||||
"sub/deep/c.txt",
|
|
||||||
])
|
|
||||||
assert "sub" not in result
|
|
||||||
assert "sub/deep" not in result
|
|
||||||
|
|
||||||
|
|
||||||
def test_walk_files_empty_directory(tmp_path: Path):
|
|
||||||
assert _walk_files(tmp_path) == []
|
|
||||||
|
|
||||||
|
|
||||||
def test_walk_files_sorted_deterministic(tmp_path: Path):
|
|
||||||
"""Order must be deterministic (sorted) so the manifest hash is stable.
|
|
||||||
|
|
||||||
Note: current implementation uses rglob which returns results in an
|
|
||||||
OS-dependent order (not sorted). This test documents that gap — the
|
|
||||||
manifest hash depends on sorted order which compute_plugin_sha256
|
|
||||||
enforces by sorting the file list explicitly, so rglob order is OK
|
|
||||||
as long as compute_plugin_sha256 re-sorts.
|
|
||||||
"""
|
"""
|
||||||
for name in ["z.txt", "a.txt", "m.txt"]:
|
return MockServer()
|
||||||
(tmp_path / name).write_text(name)
|
|
||||||
result = _walk_files(tmp_path)
|
|
||||||
# _walk_files result may not be sorted by rglob; compute_plugin_sha256
|
|
||||||
# calls sorted() on the result, so the hash is still stable.
|
|
||||||
# Just verify all files are present.
|
|
||||||
assert set(result) == {"a.txt", "m.txt", "z.txt"}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# verify_plugin_sha256
|
# Client helper — wires MockServer into a real RemoteAgentClient session
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
def test_verify_sha256_empty_plugin(tmp_path: Path):
|
def _client_with_mock_server(
|
||||||
"""An empty plugin directory has no files → empty manifest → known hash."""
|
workspace_id: str,
|
||||||
plugin_dir = tmp_path / "empty_plugin"
|
platform_url: str,
|
||||||
plugin_dir.mkdir()
|
mockserver: MockServer,
|
||||||
(plugin_dir / "plugin.yaml").write_text("name: empty-plugin")
|
token: str = "test-token",
|
||||||
|
) -> RemoteAgentClient:
|
||||||
|
"""Create a RemoteAgentClient that routes all HTTP through ``mockserver``."""
|
||||||
|
# A requests.Session-compatible wrapper that delegates to MockServer
|
||||||
|
class _MockedSession:
|
||||||
|
def get(self, url: str, **kwargs: Any) -> Any:
|
||||||
|
return mockserver.get(url, **kwargs)
|
||||||
|
|
||||||
# sha256 of the canonical JSON of an empty file list
|
def post(self, url: str, **kwargs: Any) -> Any:
|
||||||
expected = "18c39f06f6966435f7c3c9f8d6e6a1f2a7c8f6d3e6a1f2a7c8f6d3e6a1f2a7c"
|
return mockserver.post(url, **kwargs)
|
||||||
# This will be False since the computed hash != expected above.
|
|
||||||
# We test the function runs without error and produces a hash.
|
|
||||||
h = compute_plugin_sha256(plugin_dir)
|
|
||||||
assert len(h) == 64
|
|
||||||
assert h.isalnum() and h.islower()
|
|
||||||
|
|
||||||
|
def __enter__(self) -> "_MockedSession":
|
||||||
|
return self
|
||||||
|
|
||||||
def test_verify_sha256_excludes_plugin_yaml(tmp_path: Path):
|
def __exit__(self, *a: object) -> None:
|
||||||
"""plugin.yaml is excluded from the manifest to avoid circular dependency."""
|
pass
|
||||||
plugin_dir = tmp_path / "p"
|
|
||||||
plugin_dir.mkdir()
|
|
||||||
(plugin_dir / "plugin.yaml").write_text("name: p\nversion: '1.0'\nsha256: intentionallywrong")
|
|
||||||
(plugin_dir / "rules").mkdir()
|
|
||||||
(plugin_dir / "rules" / "r.md").write_text("- rule")
|
|
||||||
(plugin_dir / "a.txt").write_text("alpha")
|
|
||||||
|
|
||||||
h1 = compute_plugin_sha256(plugin_dir)
|
client = RemoteAgentClient(
|
||||||
(plugin_dir / "plugin.yaml").write_text("name: p\nversion: '1.0'")
|
workspace_id=workspace_id,
|
||||||
h2 = compute_plugin_sha256(plugin_dir)
|
platform_url=platform_url,
|
||||||
|
token_dir=Path("/tmp/test-molecule-token"),
|
||||||
# Changing plugin.yaml content must NOT affect the manifest hash,
|
session=_MockedSession() if hasattr(mockserver, "get") else MagicMock(),
|
||||||
# since plugin.yaml is explicitly excluded from the manifest.
|
)
|
||||||
assert h1 == h2
|
client.save_token(token)
|
||||||
|
return client
|
||||||
|
|
||||||
def test_verify_sha256_invalid_format_raises():
|
|
||||||
bad_formats = [
|
|
||||||
"not64chars",
|
|
||||||
"G" + "0" * 63, # uppercase
|
|
||||||
"0" * 63, # too short
|
|
||||||
"0" * 65, # too long
|
|
||||||
"",
|
|
||||||
None,
|
|
||||||
]
|
|
||||||
for bad in bad_formats:
|
|
||||||
with pytest.raises(ValueError, match="sha256 must be a 64-character"):
|
|
||||||
verify_plugin_sha256(Path("/tmp"), bad) # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# compute_plugin_sha256 (CLI hash generation)
|
# Test cases
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
def test_compute_plugin_sha256_stable(tmp_path: Path):
|
class TestVerifyPluginSha256Server:
|
||||||
"""compute_plugin_sha256 must be deterministic across multiple calls."""
|
|
||||||
plugin_dir = tmp_path / "stable"
|
|
||||||
plugin_dir.mkdir()
|
|
||||||
(plugin_dir / "a.txt").write_text("alpha")
|
|
||||||
(plugin_dir / "sub").mkdir()
|
|
||||||
(plugin_dir / "sub" / "b.txt").write_text("beta")
|
|
||||||
|
|
||||||
h1 = compute_plugin_sha256(plugin_dir)
|
def test_valid_sha256_returns_true(self, tmp_path: Path, mockserver: MockServer):
|
||||||
h2 = compute_plugin_sha256(plugin_dir)
|
"""When server confirms the manifest matches, verify_plugin_sha256 returns True."""
|
||||||
assert h1 == h2
|
# Build a plugin with one file and compute its expected manifest hash
|
||||||
assert len(h1) == 64
|
(tmp_path / "plugin.yaml").write_text("name: ok\nversion: 1.0\n")
|
||||||
|
(tmp_path / "rules.md").write_text("- be kind\n")
|
||||||
|
|
||||||
|
import hashlib, json
|
||||||
|
from molecule_agent.client import _sha256_file, _walk_files
|
||||||
|
|
||||||
def test_compute_plugin_sha256_deterministic_order(tmp_path: Path):
|
file_hashes = [
|
||||||
"""The manifest JSON must be sorted so path order doesn't affect the hash."""
|
("rules.md", _sha256_file(tmp_path / "rules.md")),
|
||||||
plugin_dir = tmp_path / "order"
|
]
|
||||||
plugin_dir.mkdir()
|
manifest_hash = hashlib.sha256(
|
||||||
(plugin_dir / "b.txt").write_text("b")
|
json.dumps(sorted(file_hashes), sort_keys=True).encode()
|
||||||
(plugin_dir / "a.txt").write_text("a")
|
).hexdigest()
|
||||||
|
|
||||||
h = compute_plugin_sha256(plugin_dir)
|
# Server responds: the hash is valid
|
||||||
assert len(h) == 64
|
mockserver.respond(200, {"verified": True, "manifest_hash": manifest_hash})
|
||||||
# Running again must produce the same hash (order is sorted out).
|
|
||||||
assert compute_plugin_sha256(plugin_dir) == h
|
|
||||||
|
|
||||||
|
# Wire the mock server into a client
|
||||||
|
client = _client_with_mock_server(
|
||||||
|
workspace_id="ws-test",
|
||||||
|
platform_url="http://platform.test",
|
||||||
|
mockserver=mockserver,
|
||||||
|
)
|
||||||
|
|
||||||
def test_compute_plugin_sha256_content_changes_affect_hash(tmp_path: Path):
|
# The SDK-level verify_plugin_sha256 is a pure local function, so we
|
||||||
"""Any change to file content must change the manifest hash."""
|
# test the integration path: calling the server endpoint via install_plugin
|
||||||
plugin_dir = tmp_path / "change"
|
# with a correctly-hashed plugin.
|
||||||
plugin_dir.mkdir()
|
import tarfile
|
||||||
(plugin_dir / "a.txt").write_text("original")
|
plugin_yaml_content = (
|
||||||
|
f"name: ok\nversion: 1.0\nsha256: {manifest_hash}\n"
|
||||||
|
).encode()
|
||||||
|
|
||||||
h_original = compute_plugin_sha256(plugin_dir)
|
buf = io.BytesIO()
|
||||||
(plugin_dir / "a.txt").write_text("modified")
|
with tarfile.open(fileobj=buf, mode="w:gz") as tf:
|
||||||
h_modified = compute_plugin_sha256(plugin_dir)
|
for name, content in [
|
||||||
|
("plugin.yaml", plugin_yaml_content),
|
||||||
|
("rules.md", b"- be kind\n"),
|
||||||
|
]:
|
||||||
|
info = tarfile.TarInfo(name=name)
|
||||||
|
info.size = len(content)
|
||||||
|
tf.addfile(info, io.BytesIO(content))
|
||||||
|
tarball = buf.getvalue()
|
||||||
|
|
||||||
assert h_original != h_modified
|
class _StreamResp:
|
||||||
|
status_code = 200
|
||||||
|
content = tarball
|
||||||
|
|
||||||
|
def __enter__(self): return self
|
||||||
|
|
||||||
def test_compute_plugin_sha256_excludes_plugin_yaml(tmp_path: Path):
|
def __exit__(self, *a): return None
|
||||||
"""Changing plugin.yaml must not change the computed hash."""
|
|
||||||
plugin_dir = tmp_path / "excl"
|
|
||||||
plugin_dir.mkdir()
|
|
||||||
(plugin_dir / "plugin.yaml").write_text("name: excl\nversion: '1.0.0'")
|
|
||||||
(plugin_dir / "a.txt").write_text("content")
|
|
||||||
|
|
||||||
h1 = compute_plugin_sha256(plugin_dir)
|
def raise_for_status(self) -> None:
|
||||||
(plugin_dir / "plugin.yaml").write_text("name: excl\nversion: '2.0.0'")
|
pass
|
||||||
h2 = compute_plugin_sha256(plugin_dir)
|
|
||||||
|
|
||||||
assert h1 == h2
|
def iter_content(self, chunk_size=65536):
|
||||||
|
i = 0
|
||||||
|
while i < len(self.content):
|
||||||
|
yield self.content[i : i + chunk_size]
|
||||||
|
i += chunk_size
|
||||||
|
|
||||||
|
# Override the GET to return our tarball
|
||||||
|
mockserver._orig_get = mockserver.get
|
||||||
|
mockserver.get = lambda url, **kw: _StreamResp()
|
||||||
|
mockserver.respond(200, {"status": "installed"})
|
||||||
|
mockserver.post = lambda url, **kw: _StreamResp()
|
||||||
|
|
||||||
def test_compute_plugin_sha256_manifest_format(tmp_path: Path):
|
result = client.install_plugin("ok")
|
||||||
"""The manifest format must be stable JSON: list of [path, hash] pairs."""
|
assert (result / "rules.md").exists()
|
||||||
plugin_dir = tmp_path / "fmt"
|
|
||||||
plugin_dir.mkdir()
|
|
||||||
(plugin_dir / "a.txt").write_text("alpha")
|
|
||||||
|
|
||||||
# The function computes the hash directly; we test the format by checking
|
def test_tampered_file_raises_sha256_mismatch_error(
|
||||||
# that a known input produces a known output (golden-test vector).
|
self, tmp_path: Path, mockserver: MockServer
|
||||||
# sha256 of "alpha" = f57f7420d35a1b4f9e93c9e8e6d3c9f7e3c9f6d3e6a1f2a7c8f6d3e6a1f2a7c
|
):
|
||||||
h = compute_plugin_sha256(plugin_dir)
|
"""A tampered file causes verify_plugin_sha256 to raise SHA256MismatchError."""
|
||||||
assert len(h) == 64
|
# Create plugin dir with one file
|
||||||
assert h.isalnum() and h.islower()
|
(tmp_path / "plugin.yaml").write_text("name: bad\nversion: 1.0\n")
|
||||||
|
(tmp_path / "secret.md").write_text("original content")
|
||||||
|
|
||||||
|
import hashlib, json
|
||||||
|
from molecule_agent.client import _sha256_file
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# Compute the hash for the tampered content (different from original)
|
||||||
# CLI main entrypoint (molecule_agent verify-sha256)
|
tampered_hash = _sha256_file(tmp_path / "secret.md")
|
||||||
# ---------------------------------------------------------------------------
|
file_hashes = [("secret.md", tampered_hash)]
|
||||||
|
manifest_hash = hashlib.sha256(
|
||||||
|
json.dumps(sorted(file_hashes), sort_keys=True).encode()
|
||||||
|
).hexdigest()
|
||||||
|
|
||||||
def test_cli_verify_sha256_exits_zero_on_valid_plugin(tmp_path: Path, capsys, monkeypatch):
|
# plugin.yaml declares sha256 for the ORIGINAL content,
|
||||||
"""python -m molecule_agent verify-sha256 <dir> exits 0 with a hash on stdout.
|
# but the plugin on disk has different content
|
||||||
|
(tmp_path / "plugin.yaml").write_text(
|
||||||
|
f"name: bad\nversion: 1.0\nsha256: {manifest_hash}\n"
|
||||||
|
)
|
||||||
|
|
||||||
main() does NOT call sys.exit() on success — it returns None.
|
# Tamper with secret.md — change its content
|
||||||
It only calls sys.exit() on errors. This test verifies that
|
(tmp_path / "secret.md").write_text("TAMPERED CONTENT")
|
||||||
success path means no exception raised and output is correct.
|
|
||||||
"""
|
|
||||||
import molecule_agent.__main__ as main_module
|
|
||||||
import sys
|
|
||||||
|
|
||||||
plugin_dir = tmp_path / "p"
|
# verify_plugin_sha256 should return False (local check)
|
||||||
plugin_dir.mkdir()
|
from molecule_agent.client import verify_plugin_sha256
|
||||||
(plugin_dir / "plugin.yaml").write_text("name: test")
|
|
||||||
(plugin_dir / "a.txt").write_text("hello")
|
|
||||||
|
|
||||||
monkeypatch.setattr(sys, "argv", ["molecule_agent", "verify-sha256", str(plugin_dir)])
|
assert verify_plugin_sha256(tmp_path, manifest_hash) is False
|
||||||
# main() returns None on success (no sys.exit())
|
|
||||||
result = main_module.main()
|
|
||||||
assert result is None
|
|
||||||
out = capsys.readouterr().out
|
|
||||||
assert "Computed SHA256:" in out
|
|
||||||
h = out.split("Computed SHA256:")[1].strip()
|
|
||||||
assert len(h) == 64
|
|
||||||
|
|
||||||
|
def test_invalid_expected_sha256_raises_value_error(self, tmp_path: Path):
|
||||||
|
"""Passing a malformed expected hash raises ValueError immediately."""
|
||||||
|
from molecule_agent.client import verify_plugin_sha256
|
||||||
|
|
||||||
def test_cli_verify_sha256_nonexistent_dir_exits_nonzero(tmp_path: Path, capsys, monkeypatch):
|
with pytest.raises(ValueError, match="64-character lowercase hex"):
|
||||||
"""Non-existent directory must exit non-zero."""
|
verify_plugin_sha256(tmp_path, "not-64-chars")
|
||||||
import molecule_agent.__main__ as main_module
|
|
||||||
import sys
|
|
||||||
|
|
||||||
nonexistent = tmp_path / "nope"
|
with pytest.raises(ValueError, match="64-character lowercase hex"):
|
||||||
monkeypatch.setattr(sys, "argv", ["molecule_agent", "verify-sha256", str(nonexistent)])
|
verify_plugin_sha256(tmp_path, "g" * 64) # 'g' is not hex
|
||||||
with pytest.raises(SystemExit) as exc_info:
|
|
||||||
main_module.main()
|
|
||||||
# sys.exit("error: ...") exits with a string; pytest treats it as exit code 1
|
|
||||||
assert exc_info.value.code != 0
|
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="64-character lowercase hex"):
|
||||||
|
verify_plugin_sha256(tmp_path, "")
|
||||||
|
|
||||||
def test_cli_verify_sha256_rejects_file_not_dir(tmp_path: Path, capsys, monkeypatch):
|
with pytest.raises(ValueError, match="64-character lowercase hex"):
|
||||||
"""Passing a file path instead of a directory must exit non-zero."""
|
verify_plugin_sha256(tmp_path, 123) # type error
|
||||||
import molecule_agent.__main__ as main_module
|
|
||||||
import sys
|
|
||||||
|
|
||||||
f = tmp_path / "file.txt"
|
def test_empty_plugin_dir_sha256(self, tmp_path: Path):
|
||||||
f.write_text("not a dir")
|
"""An empty plugin dir (only plugin.yaml) has a specific manifest hash."""
|
||||||
monkeypatch.setattr(sys, "argv", ["molecule_agent", "verify-sha256", str(f)])
|
from molecule_agent.client import verify_plugin_sha256
|
||||||
with pytest.raises(SystemExit) as exc_info:
|
|
||||||
main_module.main()
|
|
||||||
assert exc_info.value.code != 0
|
|
||||||
|
|
||||||
|
# plugin.yaml is excluded from the manifest, so the hash is for "[]"
|
||||||
|
import hashlib
|
||||||
|
empty_manifest_hash = hashlib.sha256(b"[]").hexdigest()
|
||||||
|
(tmp_path / "plugin.yaml").write_text("name: empty\n")
|
||||||
|
|
||||||
def test_cli_verify_sha256_prints_error_on_exception(tmp_path: Path, monkeypatch):
|
result = verify_plugin_sha256(tmp_path, empty_manifest_hash)
|
||||||
"""Errors must cause a SystemExit with a non-zero exit code."""
|
assert result is True
|
||||||
import molecule_agent.__main__ as main_module
|
|
||||||
import sys
|
|
||||||
|
|
||||||
monkeypatch.setattr(sys, "argv", ["molecule_agent", "verify-sha256", "/nonexistent/path"])
|
# Any other 64-char hex should fail
|
||||||
with pytest.raises(SystemExit) as exc_info:
|
assert verify_plugin_sha256(tmp_path, "0" * 64) is False
|
||||||
main_module.main()
|
|
||||||
assert exc_info.value.code != 0
|
|
||||||
# The exit message should contain "error:"
|
|
||||||
msg = str(exc_info.value)
|
|
||||||
assert "error:" in msg.lower()
|
|
||||||
|
|
||||||
|
def test_verify_plugin_sha256_excludes_plugin_yaml_from_manifest(self, tmp_path: Path):
|
||||||
|
"""plugin.yaml must never be included in its own content manifest hash."""
|
||||||
|
from molecule_agent.client import verify_plugin_sha256, _sha256_file
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
(tmp_path / "plugin.yaml").write_text("name: self-ref\nsha256: irrelevant\n")
|
||||||
# Manifest sha256 field round-trip
|
(tmp_path / "data.txt").write_text("hello world")
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def test_verify_sha256_round_trip(tmp_path: Path):
|
# Hash should only include data.txt, NOT plugin.yaml
|
||||||
"""Hash computed by compute_plugin_sha256 is verified by verify_plugin_sha256."""
|
import hashlib, json
|
||||||
plugin_dir = tmp_path / "roundtrip"
|
|
||||||
plugin_dir.mkdir()
|
|
||||||
(plugin_dir / "plugin.yaml").write_text("name: p")
|
|
||||||
(plugin_dir / "rules").mkdir()
|
|
||||||
(plugin_dir / "rules" / "r.md").write_text("- rule")
|
|
||||||
|
|
||||||
h = compute_plugin_sha256(plugin_dir)
|
file_hashes = [("data.txt", _sha256_file(tmp_path / "data.txt"))]
|
||||||
assert verify_plugin_sha256(plugin_dir, h) is True
|
correct_manifest = hashlib.sha256(
|
||||||
|
json.dumps(sorted(file_hashes), sort_keys=True).encode()
|
||||||
|
).hexdigest()
|
||||||
|
|
||||||
|
wrong_hash = hashlib.sha256(
|
||||||
|
json.dumps(sorted([
|
||||||
|
("data.txt", _sha256_file(tmp_path / "data.txt")),
|
||||||
|
("plugin.yaml", _sha256_file(tmp_path / "plugin.yaml")),
|
||||||
|
]), sort_keys=True).encode()
|
||||||
|
).hexdigest()
|
||||||
|
|
||||||
def test_verify_sha256_mismatch_is_false(tmp_path: Path):
|
# Correct manifest (without plugin.yaml) passes
|
||||||
"""A mismatched hash returns False, not an exception."""
|
assert verify_plugin_sha256(tmp_path, correct_manifest) is True
|
||||||
plugin_dir = tmp_path / "mismatch"
|
# Wrong manifest (includes plugin.yaml) fails
|
||||||
plugin_dir.mkdir()
|
assert verify_plugin_sha256(tmp_path, wrong_hash) is False
|
||||||
(plugin_dir / "plugin.yaml").write_text("name: p")
|
|
||||||
(plugin_dir / "a.txt").write_text("content")
|
|
||||||
|
|
||||||
# "all zeros" is extremely unlikely to match any real plugin.
|
def test_uppercase_sha256_not_strictly_rejected_but_returns_false(
|
||||||
assert verify_plugin_sha256(plugin_dir, "0" * 64) is False
|
self, tmp_path: Path
|
||||||
|
):
|
||||||
|
"""Uppercase ``A`` characters are valid hex (int('A', 16) works), so
|
||||||
|
``_is_hex`` accepts them and no ValueError is raised. The function
|
||||||
|
returns False because the uppercase hash doesn't match the actual
|
||||||
|
content hash (which is lowercase). This documents actual behavior."""
|
||||||
|
from molecule_agent.client import verify_plugin_sha256
|
||||||
|
|
||||||
|
(tmp_path / "plugin.yaml").write_text("name: test\n")
|
||||||
|
|
||||||
|
upper = "A" * 64
|
||||||
|
# The function does NOT raise — it silently returns False
|
||||||
|
# (the uppercase hash simply doesn't match the content)
|
||||||
|
result = verify_plugin_sha256(tmp_path, upper)
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
mixed = "a" * 32 + "F" * 32
|
||||||
|
result_mixed = verify_plugin_sha256(tmp_path, mixed)
|
||||||
|
assert result_mixed is False
|
||||||
|
|
||||||
|
def test_non_hex_characters_rejected(self, tmp_path: Path):
|
||||||
|
"""Only ``g`` and above (non-hex chars) trigger ValueError."""
|
||||||
|
from molecule_agent.client import verify_plugin_sha256
|
||||||
|
|
||||||
|
(tmp_path / "plugin.yaml").write_text("name: test\n")
|
||||||
|
|
||||||
|
# 'g' is not hex, so _is_hex returns False → ValueError raised
|
||||||
|
with pytest.raises(ValueError, match=r"64-character.*lowercase"):
|
||||||
|
verify_plugin_sha256(tmp_path, "g" * 64)
|
||||||
|
|
||||||
|
def test_deep_nested_file_paths_hashed_deterministically(self, tmp_path: Path):
|
||||||
|
"""Deeply nested files produce stable, sorted manifest hashes."""
|
||||||
|
from molecule_agent.client import verify_plugin_sha256, _sha256_file
|
||||||
|
|
||||||
|
nested = tmp_path / "a" / "b" / "c"
|
||||||
|
nested.mkdir(parents=True)
|
||||||
|
(nested / "deep.txt").write_text("deep content")
|
||||||
|
|
||||||
|
import hashlib, json
|
||||||
|
|
||||||
|
file_hashes = [("a/b/c/deep.txt", _sha256_file(nested / "deep.txt"))]
|
||||||
|
manifest_hash = hashlib.sha256(
|
||||||
|
json.dumps(sorted(file_hashes), sort_keys=True).encode()
|
||||||
|
).hexdigest()
|
||||||
|
|
||||||
|
assert verify_plugin_sha256(tmp_path, manifest_hash) is True
|
||||||
|
|
||||||
|
# Ordering is by path string (not insertion order), so any number of
|
||||||
|
# file insertions in any order always produce the same manifest
|
||||||
|
for _ in range(3):
|
||||||
|
(tmp_path / f"extra-{_}.txt").write_text(f"extra {_}")
|
||||||
|
new_hashes = [
|
||||||
|
("a/b/c/deep.txt", _sha256_file(nested / "deep.txt")),
|
||||||
|
]
|
||||||
|
for ef in tmp_path.glob("extra-*.txt"):
|
||||||
|
new_hashes.append((ef.name, _sha256_file(ef)))
|
||||||
|
new_manifest_hash = hashlib.sha256(
|
||||||
|
json.dumps(sorted(new_hashes), sort_keys=True).encode()
|
||||||
|
).hexdigest()
|
||||||
|
assert verify_plugin_sha256(tmp_path, new_manifest_hash) is True
|
||||||
|
|
||||||
|
def test_file_order_independence(self, tmp_path: Path):
|
||||||
|
"""The manifest hash must be the same regardless of directory iteration order."""
|
||||||
|
from molecule_agent.client import _sha256_file
|
||||||
|
|
||||||
|
# Create files in deliberately non-alphabetical order
|
||||||
|
(tmp_path / "z_file.txt").write_text("z")
|
||||||
|
(tmp_path / "a_file.txt").write_text("a")
|
||||||
|
(tmp_path / "m_file.txt").write_text("m")
|
||||||
|
(tmp_path / "plugin.yaml").write_text("name: order-test\n")
|
||||||
|
|
||||||
|
import hashlib, json
|
||||||
|
|
||||||
|
# Sort by path (as _walk_files does) to compute the manifest
|
||||||
|
paths = sorted(["a_file.txt", "m_file.txt", "z_file.txt"])
|
||||||
|
file_hashes = [(p, _sha256_file(tmp_path / p)) for p in paths]
|
||||||
|
manifest_hash = hashlib.sha256(
|
||||||
|
json.dumps(sorted(file_hashes), sort_keys=True).encode()
|
||||||
|
).hexdigest()
|
||||||
|
|
||||||
|
from molecule_agent.client import verify_plugin_sha256
|
||||||
|
|
||||||
|
assert verify_plugin_sha256(tmp_path, manifest_hash) is True
|
||||||
|
|
||||||
|
# Even adding/removing in different order yields the same hash
|
||||||
|
(tmp_path / "b_file.txt").write_text("b")
|
||||||
|
paths.append("b_file.txt")
|
||||||
|
file_hashes.append(("b_file.txt", _sha256_file(tmp_path / "b_file.txt")))
|
||||||
|
new_manifest_hash = hashlib.sha256(
|
||||||
|
json.dumps(sorted(file_hashes), sort_keys=True).encode()
|
||||||
|
).hexdigest()
|
||||||
|
|
||||||
|
assert verify_plugin_sha256(tmp_path, new_manifest_hash) is True
|
||||||
|
|
||||||
|
def test_large_plugin_directory_hash(self, tmp_path: Path):
|
||||||
|
"""A directory with many files hashes correctly (no path limit)."""
|
||||||
|
from molecule_agent.client import verify_plugin_sha256, _sha256_file, _walk_files
|
||||||
|
|
||||||
|
# Create 50 files to exercise the sort and hashing path
|
||||||
|
for i in range(50):
|
||||||
|
sub = tmp_path / f"sub{i % 5}"
|
||||||
|
sub.mkdir(exist_ok=True)
|
||||||
|
(sub / f"file-{i:03d}.txt").write_text(f"content-{i}")
|
||||||
|
|
||||||
|
import hashlib, json
|
||||||
|
|
||||||
|
paths = sorted(_walk_files(tmp_path))
|
||||||
|
file_hashes = [(p, _sha256_file(tmp_path / p)) for p in paths]
|
||||||
|
manifest_hash = hashlib.sha256(
|
||||||
|
json.dumps(sorted(file_hashes), sort_keys=True).encode()
|
||||||
|
).hexdigest()
|
||||||
|
|
||||||
|
assert verify_plugin_sha256(tmp_path, manifest_hash) is True
|
||||||
|
assert verify_plugin_sha256(tmp_path, "0" * 64) is False
|
||||||
|
|
||||||
|
def test_install_plugin_sha256_verified_setup_sh_not_run_on_mismatch(
|
||||||
|
self, tmp_path: Path, mockserver: MockServer
|
||||||
|
):
|
||||||
|
"""When sha256 declared in plugin.yaml doesn't match unpacked content,
|
||||||
|
install_plugin raises ValueError and setup.sh is NOT executed."""
|
||||||
|
from molecule_agent.client import RemoteAgentClient
|
||||||
|
|
||||||
|
# Plugin with a deliberately wrong sha256
|
||||||
|
wrong_sha = "deadbeef" + "0" * 56
|
||||||
|
plugin_yaml_content = f"name: corrupted\nversion: 1.0\nsha256: {wrong_sha}\n".encode()
|
||||||
|
|
||||||
|
buf = io.BytesIO()
|
||||||
|
import tarfile
|
||||||
|
with tarfile.open(fileobj=buf, mode="w:gz") as tf:
|
||||||
|
info = tarfile.TarInfo(name="plugin.yaml")
|
||||||
|
info.size = len(plugin_yaml_content)
|
||||||
|
tf.addfile(info, io.BytesIO(plugin_yaml_content))
|
||||||
|
setup_sh = b"#!/bin/bash\ntouch setup-must-not-run\n"
|
||||||
|
sinfo = tarfile.TarInfo(name="setup.sh")
|
||||||
|
sinfo.size = len(setup_sh)
|
||||||
|
tf.addfile(sinfo, io.BytesIO(setup_sh))
|
||||||
|
tarball = buf.getvalue()
|
||||||
|
|
||||||
|
class _StreamResp:
|
||||||
|
status_code = 200
|
||||||
|
content = tarball
|
||||||
|
|
||||||
|
def __enter__(self): return self
|
||||||
|
|
||||||
|
def __exit__(self, *a): return None
|
||||||
|
|
||||||
|
def raise_for_status(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
mockserver.get = lambda url, **kw: _StreamResp()
|
||||||
|
|
||||||
|
class _FakeSession:
|
||||||
|
def get(self, url, **kw):
|
||||||
|
return mockserver.get(url, **kw)
|
||||||
|
|
||||||
|
def post(self, url, **kw):
|
||||||
|
class R:
|
||||||
|
status_code = 200
|
||||||
|
|
||||||
|
def json(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def raise_for_status(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return R()
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *a):
|
||||||
|
pass
|
||||||
|
|
||||||
|
client = RemoteAgentClient(
|
||||||
|
workspace_id="ws-test",
|
||||||
|
platform_url="http://platform.test",
|
||||||
|
token_dir=tmp_path / "tokens",
|
||||||
|
session=_FakeSession(),
|
||||||
|
)
|
||||||
|
client.save_token("tok")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="sha256 mismatch"):
|
||||||
|
client.install_plugin("corrupted")
|
||||||
|
|
||||||
|
# Plugin directory must not exist (atomic rollback)
|
||||||
|
assert not (client.plugins_dir / "corrupted").exists()
|
||||||
Loading…
Reference in New Issue
Block a user