molecule-sdk-python/tests/test_remote_agent.py
Molecule AI SDK-Dev d55b2b951c fix(sdk): resolve KI-003 — log warning for skipped symlinks in _safe_extract_tar
_symlink entries in plugin tarballs are skipped (security posture, correct) but
now emit a logger.warning so operators can audit what was dropped:

  "skipping symlink in plugin tarball (not supported for security): <name> -> <target>"

Added test_safe_extract_logs_warning_for_skipped_symlink asserting the warning
is present in caplog records at WARNING level.  All 211 tests pass (+1 new).

known-issues.md updated.
2026-04-21 22:03:13 +00:00

967 lines
35 KiB
Python

"""Tests for the molecule_agent Phase 30.8 remote-agent client.
The client is pure HTTP — we mock the network via ``requests_mock``-style
monkey-patching of ``requests.Session.get`` / ``.post`` instead of pulling
in a third-party mock library.
"""
from __future__ import annotations
import stat
import time
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock
import pytest
from molecule_agent import PeerInfo, RemoteAgentClient, WorkspaceState
# ---------------------------------------------------------------------------
# FakeResponse / FakeSession — minimal stand-ins for requests
# ---------------------------------------------------------------------------
class FakeResponse:
def __init__(self, status_code: int = 200, json_body: Any = None, text: str = ""):
self.status_code = status_code
self._json = json_body
self.text = text
def json(self) -> Any:
return self._json
def raise_for_status(self) -> None:
if self.status_code >= 400:
import requests
raise requests.HTTPError(f"HTTP {self.status_code}")
@pytest.fixture
def tmp_token_dir(tmp_path: Path) -> Path:
return tmp_path / "molecule-token-cache"
@pytest.fixture
def client(tmp_token_dir: Path) -> RemoteAgentClient:
session = MagicMock()
return RemoteAgentClient(
workspace_id="ws-abc-123",
platform_url="http://platform.test",
agent_card={"name": "test-agent"},
token_dir=tmp_token_dir,
session=session,
)
# ---------------------------------------------------------------------------
# Token persistence
# ---------------------------------------------------------------------------
def test_save_and_load_token_roundtrip(client: RemoteAgentClient, tmp_token_dir: Path):
client.save_token("secret-token-abc")
assert client.token_file.exists()
# File must be 0600 so other local users can't read the credential.
mode = stat.S_IMODE(client.token_file.stat().st_mode)
assert mode == 0o600, f"expected 0600, got 0o{mode:o}"
assert client.load_token() == "secret-token-abc"
def test_save_empty_token_rejected(client: RemoteAgentClient):
with pytest.raises(ValueError):
client.save_token("")
with pytest.raises(ValueError):
client.save_token(" ")
def test_load_token_returns_none_when_absent(client: RemoteAgentClient):
assert client.load_token() is None
def test_load_token_returns_none_when_file_empty(client: RemoteAgentClient, tmp_token_dir: Path):
tmp_token_dir.mkdir(parents=True, exist_ok=True)
(tmp_token_dir / ".auth_token").write_text("")
assert client.load_token() is None
def test_token_dir_default_is_under_home(tmp_path: Path):
# Just verifies the default path shape — we don't want to actually
# write to $HOME during tests.
c = RemoteAgentClient(
workspace_id="ws-xyz",
platform_url="http://p",
)
assert "ws-xyz" in str(c.token_file)
assert ".molecule" in str(c.token_file)
# ---------------------------------------------------------------------------
# register()
# ---------------------------------------------------------------------------
def test_register_saves_token_when_issued(client: RemoteAgentClient):
client._session.post.return_value = FakeResponse(
200, {"status": "registered", "auth_token": "fresh-token-xyz"}
)
tok = client.register()
assert tok == "fresh-token-xyz"
assert client.load_token() == "fresh-token-xyz"
# Verify call shape
url, kwargs = client._session.post.call_args[0][0], client._session.post.call_args[1]
assert url == "http://platform.test/registry/register"
assert kwargs["json"]["id"] == "ws-abc-123"
assert kwargs["json"]["agent_card"] == {"name": "test-agent"}
def test_register_keeps_cached_token_when_platform_omits(client: RemoteAgentClient):
# Simulate re-register of an already-tokened workspace: platform returns
# no auth_token, SDK must keep using the cached one.
client.save_token("cached-from-earlier")
client._session.post.return_value = FakeResponse(200, {"status": "registered"})
tok = client.register()
assert tok == "cached-from-earlier"
def test_register_http_error_propagates(client: RemoteAgentClient):
client._session.post.return_value = FakeResponse(500, {"error": "boom"})
with pytest.raises(Exception):
client.register()
# ---------------------------------------------------------------------------
# pull_secrets()
# ---------------------------------------------------------------------------
def test_pull_secrets_sends_bearer_token(client: RemoteAgentClient):
client.save_token("tok-for-secrets")
client._session.get.return_value = FakeResponse(200, {"API_KEY": "v1", "DB_URL": "v2"})
out = client.pull_secrets()
assert out == {"API_KEY": "v1", "DB_URL": "v2"}
url, kwargs = client._session.get.call_args[0][0], client._session.get.call_args[1]
assert url == "http://platform.test/workspaces/ws-abc-123/secrets/values"
assert kwargs["headers"]["Authorization"] == "Bearer tok-for-secrets"
def test_pull_secrets_empty_body_yields_empty_dict(client: RemoteAgentClient):
client.save_token("t")
client._session.get.return_value = FakeResponse(200, None)
assert client.pull_secrets() == {}
def test_pull_secrets_401_raises(client: RemoteAgentClient):
client.save_token("t")
client._session.get.return_value = FakeResponse(401, {"error": "missing token"})
with pytest.raises(Exception):
client.pull_secrets()
# ---------------------------------------------------------------------------
# poll_state()
# ---------------------------------------------------------------------------
def test_poll_state_returns_normal_state(client: RemoteAgentClient):
client.save_token("t")
client._session.get.return_value = FakeResponse(
200, {"workspace_id": "ws-abc-123", "status": "online", "paused": False, "deleted": False}
)
state = client.poll_state()
assert state is not None
assert state.status == "online"
assert state.paused is False
assert state.deleted is False
assert state.should_stop is False
def test_poll_state_detects_paused(client: RemoteAgentClient):
client.save_token("t")
client._session.get.return_value = FakeResponse(
200, {"workspace_id": "ws-abc-123", "status": "paused", "paused": True, "deleted": False}
)
state = client.poll_state()
assert state.should_stop is True
def test_poll_state_404_means_deleted(client: RemoteAgentClient):
client.save_token("t")
client._session.get.return_value = FakeResponse(404, {"deleted": True})
state = client.poll_state()
assert state is not None
assert state.deleted is True
assert state.should_stop is True
def test_poll_state_server_error_raises(client: RemoteAgentClient):
client.save_token("t")
client._session.get.return_value = FakeResponse(500, {"error": "boom"})
with pytest.raises(Exception):
client.poll_state()
# ---------------------------------------------------------------------------
# heartbeat()
# ---------------------------------------------------------------------------
def test_heartbeat_sends_full_payload(client: RemoteAgentClient):
client.save_token("t")
client._session.post.return_value = FakeResponse(200, {"status": "ok"})
client.heartbeat(current_task="indexing", active_tasks=1, error_rate=0.1, sample_error="err")
url = client._session.post.call_args[0][0]
kwargs = client._session.post.call_args[1]
assert url == "http://platform.test/registry/heartbeat"
body = kwargs["json"]
assert body["workspace_id"] == "ws-abc-123"
assert body["current_task"] == "indexing"
assert body["active_tasks"] == 1
assert body["error_rate"] == 0.1
assert body["sample_error"] == "err"
assert "uptime_seconds" in body
assert kwargs["headers"]["Authorization"] == "Bearer t"
# ---------------------------------------------------------------------------
# run_heartbeat_loop()
# ---------------------------------------------------------------------------
def test_run_loop_exits_on_max_iterations(client: RemoteAgentClient, monkeypatch):
# Stub sleep so the test doesn't actually wait
import molecule_agent.client as mod
monkeypatch.setattr(mod.time, "sleep", lambda s: None)
client.save_token("t")
client._session.post.return_value = FakeResponse(200, {"status": "ok"})
client._session.get.return_value = FakeResponse(
200, {"status": "online", "paused": False, "deleted": False}
)
terminal = client.run_heartbeat_loop(max_iterations=3)
assert terminal == "max_iterations"
# 3 heartbeats + 3 state polls
assert client._session.post.call_count == 3
assert client._session.get.call_count == 3
def test_run_loop_exits_on_paused(client: RemoteAgentClient, monkeypatch):
import molecule_agent.client as mod
monkeypatch.setattr(mod.time, "sleep", lambda s: None)
client.save_token("t")
client._session.post.return_value = FakeResponse(200, {"status": "ok"})
# First iteration: online. Second: paused.
responses = [
FakeResponse(200, {"status": "online", "paused": False, "deleted": False}),
FakeResponse(200, {"status": "paused", "paused": True, "deleted": False}),
]
client._session.get.side_effect = responses
terminal = client.run_heartbeat_loop(max_iterations=10)
assert terminal == "paused"
assert client._session.post.call_count == 2
assert client._session.get.call_count == 2
def test_run_loop_exits_on_deleted_404(client: RemoteAgentClient, monkeypatch):
import molecule_agent.client as mod
monkeypatch.setattr(mod.time, "sleep", lambda s: None)
client.save_token("t")
client._session.post.return_value = FakeResponse(200, {"status": "ok"})
client._session.get.return_value = FakeResponse(404, {"deleted": True})
terminal = client.run_heartbeat_loop(max_iterations=10)
assert terminal == "removed"
assert client._session.get.call_count == 1
def test_run_loop_continues_through_transient_errors(client: RemoteAgentClient, monkeypatch):
"""Network hiccups must log-and-continue, never crash the loop."""
import molecule_agent.client as mod
monkeypatch.setattr(mod.time, "sleep", lambda s: None)
client.save_token("t")
# Heartbeat fails on iter 1, succeeds on iter 2
client._session.post.side_effect = [
ConnectionError("flaky net"),
FakeResponse(200, {"status": "ok"}),
]
# State poll returns online both times
client._session.get.return_value = FakeResponse(
200, {"status": "online", "paused": False, "deleted": False}
)
terminal = client.run_heartbeat_loop(max_iterations=2)
assert terminal == "max_iterations"
# Both iterations completed despite the first post failing
assert client._session.post.call_count == 2
def test_run_loop_task_supplier_reported(client: RemoteAgentClient, monkeypatch):
import molecule_agent.client as mod
monkeypatch.setattr(mod.time, "sleep", lambda s: None)
client.save_token("t")
client._session.post.return_value = FakeResponse(200, {"status": "ok"})
client._session.get.return_value = FakeResponse(
200, {"status": "online", "paused": False, "deleted": False}
)
reports = [{"current_task": "step-1", "active_tasks": 1}]
client.run_heartbeat_loop(max_iterations=1, task_supplier=lambda: reports[0])
body = client._session.post.call_args[1]["json"]
assert body["current_task"] == "step-1"
assert body["active_tasks"] == 1
# ---------------------------------------------------------------------------
# WorkspaceState dataclass
# ---------------------------------------------------------------------------
def test_workspace_state_should_stop_semantics():
assert WorkspaceState("w", "online", False, False).should_stop is False
assert WorkspaceState("w", "degraded", False, False).should_stop is False
assert WorkspaceState("w", "paused", True, False).should_stop is True
assert WorkspaceState("w", "removed", False, True).should_stop is True
# ---------------------------------------------------------------------------
# Phase 30.6 — sibling URL cache + call_peer
# ---------------------------------------------------------------------------
def test_get_peers_seeds_cache(client: RemoteAgentClient):
client.save_token("t")
client._session.get.return_value = FakeResponse(200, [
{"id": "sibling-1", "name": "Research", "url": "http://10.0.0.5:8000", "role": "researcher", "tier": 2, "status": "online"},
{"id": "sibling-2", "name": "Dev", "url": "http://10.0.0.6:8000", "role": "developer", "tier": 2, "status": "online"},
])
peers = client.get_peers()
assert len(peers) == 2
assert peers[0].id == "sibling-1"
assert peers[0].name == "Research"
assert peers[0].url == "http://10.0.0.5:8000"
# Cache seeded for both
assert client._url_cache["sibling-1"][0] == "http://10.0.0.5:8000"
assert client._url_cache["sibling-2"][0] == "http://10.0.0.6:8000"
# Request included bearer + X-Workspace-ID
headers = client._session.get.call_args[1]["headers"]
assert headers["Authorization"] == "Bearer t"
assert headers["X-Workspace-ID"] == "ws-abc-123"
def test_get_peers_skips_non_http_urls_in_cache(client: RemoteAgentClient):
"""Cache seed only accepts http(s); the 'remote://no-inbound' placeholder
for remote agents without inbound servers must not poison the cache."""
client.save_token("t")
client._session.get.return_value = FakeResponse(200, [
{"id": "sib-remote", "name": "Remote", "url": "remote://no-inbound"},
{"id": "sib-http", "name": "HTTP", "url": "http://192.168.1.7:8000"},
])
client.get_peers()
assert "sib-remote" not in client._url_cache
assert "sib-http" in client._url_cache
def test_discover_peer_cache_hit(client: RemoteAgentClient):
client._url_cache["sib-x"] = ("http://cached.url:8000", time.time() + 60)
url = client.discover_peer("sib-x")
assert url == "http://cached.url:8000"
# No network call
client._session.get.assert_not_called()
def test_discover_peer_cache_miss_hits_platform(client: RemoteAgentClient):
client.save_token("t")
client._session.get.return_value = FakeResponse(
200, {"id": "sib-y", "url": "http://fresh.url:8000", "name": "Y"}
)
url = client.discover_peer("sib-y")
assert url == "http://fresh.url:8000"
assert client._url_cache["sib-y"][0] == "http://fresh.url:8000"
# Request used discover endpoint
called_url = client._session.get.call_args[0][0]
assert "/registry/discover/sib-y" in called_url
def test_discover_peer_expired_cache_refreshes(client: RemoteAgentClient, monkeypatch):
# Cache entry already expired
client._url_cache["sib-stale"] = ("http://stale.url", time.time() - 10)
client.save_token("t")
client._session.get.return_value = FakeResponse(
200, {"url": "http://fresh.url:9000"}
)
url = client.discover_peer("sib-stale")
assert url == "http://fresh.url:9000"
# Cache replaced with fresh entry
assert client._url_cache["sib-stale"][0] == "http://fresh.url:9000"
def test_discover_peer_404_returns_none(client: RemoteAgentClient):
client.save_token("t")
client._session.get.return_value = FakeResponse(404, {"error": "not found"})
assert client.discover_peer("missing") is None
def test_invalidate_peer_url_drops_cache_entry(client: RemoteAgentClient):
client._url_cache["sib-x"] = ("http://x", time.time() + 100)
client.invalidate_peer_url("sib-x")
assert "sib-x" not in client._url_cache
# Idempotent — second call is safe
client.invalidate_peer_url("sib-x")
def test_call_peer_direct_path_on_cache_hit(client: RemoteAgentClient):
client.save_token("t")
client._url_cache["sib"] = ("http://direct.peer:8000", time.time() + 60)
client._session.post.return_value = FakeResponse(
200, {"jsonrpc": "2.0", "id": "x", "result": {"ok": True}}
)
out = client.call_peer("sib", "hello sibling")
assert out["result"]["ok"] is True
# Exactly ONE post: direct to the cached URL, not through proxy
assert client._session.post.call_count == 1
called_url = client._session.post.call_args[0][0]
assert called_url == "http://direct.peer:8000"
body = client._session.post.call_args[1]["json"]
assert body["method"] == "message/send"
assert body["params"]["message"]["parts"][0]["text"] == "hello sibling"
headers = client._session.post.call_args[1]["headers"]
assert headers["X-Workspace-ID"] == "ws-abc-123"
def test_call_peer_falls_back_to_proxy_on_direct_error(client: RemoteAgentClient):
client.save_token("t")
client._url_cache["sib"] = ("http://dead.peer:8000", time.time() + 60)
# First post (direct): connection error. Second post (proxy): success.
client._session.post.side_effect = [
ConnectionError("unreachable"),
FakeResponse(200, {"jsonrpc": "2.0", "result": {"via": "proxy"}}),
]
out = client.call_peer("sib", "hello")
assert out["result"]["via"] == "proxy"
assert client._session.post.call_count == 2
# Direct URL was invalidated so next call re-discovers
assert "sib" not in client._url_cache
# Second call went to /workspaces/sib/a2a
proxy_url = client._session.post.call_args_list[1][0][0]
assert "/workspaces/sib/a2a" in proxy_url
def test_call_peer_proxy_only_when_prefer_direct_false(client: RemoteAgentClient):
client.save_token("t")
client._url_cache["sib"] = ("http://direct.peer:8000", time.time() + 60)
client._session.post.return_value = FakeResponse(
200, {"jsonrpc": "2.0", "result": {"via": "proxy-only"}}
)
client.call_peer("sib", "hello", prefer_direct=False)
# Exactly one post — went straight to proxy despite cache hit
assert client._session.post.call_count == 1
assert "/workspaces/sib/a2a" in client._session.post.call_args[0][0]
def test_call_peer_no_cached_url_uses_discover_then_direct(client: RemoteAgentClient):
"""Fresh call: no cache entry → discover via GET, then direct POST to the
returned URL. Tests the full discover-then-call sequence in one shot."""
client.save_token("t")
# discover returns a URL
client._session.get.return_value = FakeResponse(
200, {"url": "http://newly-discovered:9000"}
)
# direct post succeeds
client._session.post.return_value = FakeResponse(
200, {"jsonrpc": "2.0", "result": {"ok": True}}
)
out = client.call_peer("new-sib", "hi")
assert out["result"]["ok"] is True
assert client._url_cache["new-sib"][0] == "http://newly-discovered:9000"
called_url = client._session.post.call_args[0][0]
assert called_url == "http://newly-discovered:9000"
def test_peer_info_dataclass_defaults():
p = PeerInfo(id="x", name="y", url="http://z")
assert p.role == ""
assert p.tier == 2
assert p.status == "unknown"
assert p.agent_card == {}
# ---------------------------------------------------------------------------
# Phase 30.3 — install_plugin
# ---------------------------------------------------------------------------
import io
import tarfile
from molecule_agent.client import _safe_extract_tar
def _make_tarball(files: dict[str, bytes]) -> bytes:
"""Build a gzipped tarball in memory from a {name: content} dict."""
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w:gz") as tf:
for name, content in files.items():
info = tarfile.TarInfo(name=name)
info.size = len(content)
info.mode = 0o644
tf.addfile(info, io.BytesIO(content))
return buf.getvalue()
class _StreamingResp:
"""requests-shaped response with .content + .iter_content + context-manager.
install_plugin switched from streaming reads to .content (we hold the
full <=100MiB tarball in memory before extract — see client.py comment),
but we keep iter_content available for any future test that wants to
exercise a streaming path.
"""
def __init__(self, status: int, body: bytes):
self.status_code = status
self._body = body
self.content = body # used by .content readers (install_plugin today)
def __enter__(self): return self
def __exit__(self, *a): return None
def raise_for_status(self):
if self.status_code >= 400:
import requests
raise requests.HTTPError(f"HTTP {self.status_code}")
def iter_content(self, chunk_size=64*1024):
i = 0
while i < len(self._body):
yield self._body[i:i+chunk_size]
i += chunk_size
i += chunk_size
def test_install_plugin_unpacks_into_per_workspace_dir(client: RemoteAgentClient, tmp_path):
client.save_token("t")
tarball = _make_tarball({
"plugin.yaml": b"name: hello\nversion: 1.0.0\n",
"rules.md": b"some rules\n",
"skills/x/SKILL.md": b"---\nname: x\n---\n",
})
# Stub out the streaming GET (used inside `with`)
def fake_get(url, headers=None, params=None, stream=False, timeout=None):
assert "/plugins/hello/download" in url
assert headers["Authorization"] == "Bearer t"
return _StreamingResp(200, tarball)
client._session.get.side_effect = fake_get
# POST install record — also stubbed
client._session.post.return_value = FakeResponse(200, {"status": "installed"})
target = client.install_plugin("hello")
assert target.exists()
assert (target / "plugin.yaml").read_bytes() == b"name: hello\nversion: 1.0.0\n"
assert (target / "skills" / "x" / "SKILL.md").read_text().startswith("---\nname: x\n")
# Atomic-rename means no .staging-* leftover
assert not any(p.name.startswith(".staging-") for p in client.plugins_dir.iterdir())
# Reported the install
post_url = client._session.post.call_args[0][0]
assert post_url.endswith(f"/workspaces/{client.workspace_id}/plugins")
def test_install_plugin_passes_source_query_when_given(client: RemoteAgentClient):
client.save_token("t")
tarball = _make_tarball({"plugin.yaml": b"name: gh\nversion: 0.1.0\n"})
captured = {}
def fake_get(url, headers=None, params=None, stream=False, timeout=None):
captured["url"] = url
captured["params"] = params
return _StreamingResp(200, tarball)
client._session.get.side_effect = fake_get
client._session.post.return_value = FakeResponse(200, {})
client.install_plugin("gh", source="github://acme/my-plugin")
assert captured["params"] == {"source": "github://acme/my-plugin"}
def test_install_plugin_atomic_rollback_on_corrupt_tarball(client: RemoteAgentClient):
client.save_token("t")
# Truncated gzip — tarfile.open will raise
client._session.get.side_effect = lambda *a, **k: _StreamingResp(200, b"not a gzip")
client._session.post.return_value = FakeResponse(200, {})
import pytest as _pytest
with _pytest.raises(Exception):
client.install_plugin("broken")
# No .staging-* dir lingering, no half-installed plugin dir
assert not list(client.plugins_dir.iterdir()) if client.plugins_dir.exists() else True
def test_install_plugin_overwrites_existing(client: RemoteAgentClient):
client.save_token("t")
# Pre-populate an old version
old_dir = client.plugins_dir / "rotateme"
old_dir.mkdir(parents=True)
(old_dir / "old-marker").write_text("old")
new_tarball = _make_tarball({
"plugin.yaml": b"name: rotateme\nversion: 2.0.0\n",
"new-marker": b"new",
})
client._session.get.side_effect = lambda *a, **k: _StreamingResp(200, new_tarball)
client._session.post.return_value = FakeResponse(200, {})
client.install_plugin("rotateme")
assert not (client.plugins_dir / "rotateme" / "old-marker").exists()
assert (client.plugins_dir / "rotateme" / "new-marker").read_text() == "new"
def test_install_plugin_runs_setup_sh_when_present(client: RemoteAgentClient, tmp_path):
client.save_token("t")
# setup.sh that drops a sentinel file we can verify
sentinel = tmp_path / "ran"
setup_script = f"#!/bin/bash\nset -e\ntouch {sentinel}\n".encode()
tarball = _make_tarball({
"plugin.yaml": b"name: withsetup\n",
"setup.sh": setup_script,
})
client._session.get.side_effect = lambda *a, **k: _StreamingResp(200, tarball)
client._session.post.return_value = FakeResponse(200, {})
client.install_plugin("withsetup")
# setup.sh extracted with 0644 perms (tar default), so script execution
# depends on bash interpreting the file contents. The bash invocation
# runs without the +x bit because we call `bash <setup>` not `<setup>`.
assert sentinel.exists(), "setup.sh did not run"
def test_install_plugin_skips_setup_when_disabled(client: RemoteAgentClient, tmp_path):
client.save_token("t")
sentinel = tmp_path / "should-not-exist"
tarball = _make_tarball({
"setup.sh": f"#!/bin/bash\ntouch {sentinel}\n".encode(),
})
client._session.get.side_effect = lambda *a, **k: _StreamingResp(200, tarball)
client._session.post.return_value = FakeResponse(200, {})
client.install_plugin("nosetup", run_setup_sh=False)
assert not sentinel.exists()
def test_install_plugin_skips_platform_report_when_disabled(client: RemoteAgentClient):
client.save_token("t")
tarball = _make_tarball({"plugin.yaml": b"name: silent\n"})
client._session.get.side_effect = lambda *a, **k: _StreamingResp(200, tarball)
client.install_plugin("silent", report_to_platform=False)
# POST never called when report disabled
client._session.post.assert_not_called()
def test_install_plugin_404_raises_with_useful_url(client: RemoteAgentClient):
client.save_token("t")
client._session.get.side_effect = lambda *a, **k: _StreamingResp(404, b"")
import pytest as _pytest
with _pytest.raises(Exception):
client.install_plugin("missing")
# ---------------------------------------------------------------------------
# _safe_extract_tar
# ---------------------------------------------------------------------------
def test_safe_extract_rejects_path_traversal(tmp_path: Path):
"""Tar slip CVE: an entry named '../escape' must be rejected."""
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w") as tf:
info = tarfile.TarInfo(name="../escape.txt")
data = b"oops"
info.size = len(data)
tf.addfile(info, io.BytesIO(data))
buf.seek(0)
with tarfile.open(fileobj=buf, mode="r") as tf:
import pytest as _pytest
with _pytest.raises(ValueError, match="refusing tar entry escaping"):
_safe_extract_tar(tf, tmp_path)
def test_safe_extract_rejects_absolute_paths(tmp_path: Path):
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w") as tf:
info = tarfile.TarInfo(name="/etc/passwd")
data = b"oops"
info.size = len(data)
tf.addfile(info, io.BytesIO(data))
buf.seek(0)
with tarfile.open(fileobj=buf, mode="r") as tf:
import pytest as _pytest
with _pytest.raises(ValueError):
_safe_extract_tar(tf, tmp_path)
def test_safe_extract_skips_symlinks_silently(tmp_path: Path):
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w") as tf:
sym = tarfile.TarInfo(name="link.lnk")
sym.type = tarfile.SYMTYPE
sym.linkname = "/etc/passwd"
tf.addfile(sym)
# Plus a normal file alongside
info = tarfile.TarInfo(name="real.md")
data = b"ok"
info.size = len(data)
tf.addfile(info, io.BytesIO(data))
buf.seek(0)
with tarfile.open(fileobj=buf, mode="r") as tf:
_safe_extract_tar(tf, tmp_path)
assert (tmp_path / "real.md").exists()
assert not (tmp_path / "link.lnk").exists()
def test_safe_extract_logs_warning_for_skipped_symlink(tmp_path: Path, caplog):
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w") as tf:
sym = tarfile.TarInfo(name="link.lnk")
sym.type = tarfile.SYMTYPE
sym.linkname = "/etc/passwd"
tf.addfile(sym)
info = tarfile.TarInfo(name="real.md")
info.size = 2
tf.addfile(info, io.BytesIO(b"ok"))
buf.seek(0)
with tarfile.open(fileobj=buf, mode="r") as tf:
_safe_extract_tar(tf, tmp_path)
assert (tmp_path / "real.md").exists()
assert not (tmp_path / "link.lnk").exists()
# A warning must be emitted so operators know what was dropped.
assert any("link.lnk" in r.message and "/etc/passwd" in r.message for r in caplog.records)
assert any(r.levelname == "WARNING" for r in caplog.records)
# ---------------------------------------------------------------------------
# verify_plugin_sha256 + content integrity
# ---------------------------------------------------------------------------
from molecule_agent.client import _sha256_file, _walk_files, verify_plugin_sha256
def test_sha256_file_computes_correct_hash(tmp_path: Path):
f = tmp_path / "data.txt"
f.write_bytes(b"hello world")
h = _sha256_file(f)
# SHA256("hello world") in lowercase hex
assert h == "b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9"
def test_walk_files_excludes_directories(tmp_path: Path):
(tmp_path / "a.txt").write_bytes(b"a")
(tmp_path / "sub").mkdir()
(tmp_path / "sub" / "b.txt").write_bytes(b"b")
(tmp_path / "sub" / "deep").mkdir()
(tmp_path / "sub" / "deep" / "c.txt").write_bytes(b"c")
rels = sorted(_walk_files(tmp_path))
assert rels == ["a.txt", "sub/b.txt", "sub/deep/c.txt"]
def test_verify_plugin_sha256_returns_true_on_match(tmp_path: Path):
# plugin.yaml is excluded from the manifest (self-referential field),
# so only non-plugin.yaml files contribute to the manifest hash.
(tmp_path / "plugin.yaml").write_text("name: test\nversion: '1.0'\n")
(tmp_path / "rules.md").write_text("# Rules\n")
# Manually compute: plugin.yaml excluded, only rules.md counts
import hashlib, json
file_hashes = [
("rules.md", _sha256_file(tmp_path / "rules.md")),
]
manifest_bytes = json.dumps(sorted(file_hashes), sort_keys=True).encode()
expected = hashlib.sha256(manifest_bytes).hexdigest()
assert verify_plugin_sha256(tmp_path, expected) is True
def test_verify_plugin_sha256_returns_false_on_mismatch(tmp_path: Path):
(tmp_path / "f.txt").write_bytes(b"content")
assert verify_plugin_sha256(tmp_path, "0" * 64) is False
def test_verify_plugin_sha256_rejects_invalid_format():
import pytest as _pytest
with _pytest.raises(ValueError, match="64-character lowercase hex"):
verify_plugin_sha256(Path("/tmp"), "short")
with _pytest.raises(ValueError, match="64-character lowercase hex"):
verify_plugin_sha256(Path("/tmp"), "g" * 64) # 'g' is not hex
with _pytest.raises(ValueError, match="64-character lowercase hex"):
verify_plugin_sha256(Path("/tmp"), 123) # type error
def test_install_plugin_sha256_verified_setup_sh_run(
client: RemoteAgentClient, tmp_path: Path
):
"""When sha256 matches, setup.sh runs normally."""
client.save_token("t")
import hashlib, json
setup_sh = b"#!/bin/bash\ntouch setup-ran\n"
# plugin.yaml is excluded from its own manifest hash (breaks circular dep),
# so convergence is instant: compute the manifest over other files only,
# then write sha256=<that hash> into plugin.yaml.
yaml_no_sha = b"name: withsha\nversion: '1.0'\n"
file_hashes = [
("setup.sh", hashlib.sha256(setup_sh).hexdigest()),
# plugin.yaml excluded from manifest (see verify_plugin_sha256 docstring)
]
manifest_hash = hashlib.sha256(
json.dumps(sorted(file_hashes), sort_keys=True).encode()
).hexdigest()
plugin_yaml_bytes = (
f"name: withsha\nversion: '1.0'\n"
f"sha256: {manifest_hash}\n"
).encode()
tarball = _make_tarball({
"plugin.yaml": plugin_yaml_bytes,
"setup.sh": setup_sh,
})
def fake_get(url, headers=None, params=None, stream=False, timeout=None):
return _StreamingResp(200, tarball)
client._session.get.side_effect = fake_get
client._session.post.return_value = FakeResponse(200, {})
target = client.install_plugin("withsha")
assert (target / "setup-ran").exists(), "setup.sh should have run"
def test_install_plugin_sha256_mismatch_aborts_setup_sh(
client: RemoteAgentClient, tmp_path: Path
):
"""When sha256 does not match, install_plugin raises and setup.sh is NOT run."""
client.save_token("t")
# Plugin.yaml declares sha256 but the actual content differs
mismatched_yaml = (
"name: bad\nversion: '1.0'\n"
"sha256: " + "f" * 64 + "\n"
)
tarball = _make_tarball({
"plugin.yaml": mismatched_yaml.encode(),
"setup.sh": b"#!/bin/bash\ntouch must-not-run\n",
})
def fake_get(url, headers=None, params=None, stream=False, timeout=None):
return _StreamingResp(200, tarball)
client._session.get.side_effect = fake_get
client._session.post.return_value = FakeResponse(200, {})
import pytest as _pytest
with _pytest.raises(ValueError, match="sha256 mismatch"):
client.install_plugin("bad")
# Plugin dir must not exist after failure
assert not (client.plugins_dir / "bad").exists()
def test_install_plugin_missing_sha256_skips_verification(
client: RemoteAgentClient, tmp_path: Path
):
"""When plugin.yaml has no sha256 field, verification is skipped (backward compat)."""
client.save_token("t")
tarball = _make_tarball({
"plugin.yaml": b"name: nosha\nversion: '1.0'\n",
"setup.sh": b"#!/bin/bash\ntouch setup-ran\n",
})
def fake_get(url, headers=None, params=None, stream=False, timeout=None):
return _StreamingResp(200, tarball)
client._session.get.side_effect = fake_get
client._session.post.return_value = FakeResponse(200, {})
target = client.install_plugin("nosha")
assert (target / "setup-ran").exists()
# ---------------------------------------------------------------------------
# validate_manifest — sha256 field
# ---------------------------------------------------------------------------
from molecule_plugin import validate_manifest
def test_validate_manifest_rejects_invalid_sha256(tmp_path: Path):
(tmp_path / "plugin.yaml").write_text("name: test\nsha256: too-short\n")
errors = validate_manifest(tmp_path / "plugin.yaml")
assert any("64" in e for e in errors)
def test_validate_manifest_rejects_non_hex_sha256(tmp_path: Path):
(tmp_path / "plugin.yaml").write_text("name: test\nsha256: " + "g" * 64 + "\n")
errors = validate_manifest(tmp_path / "plugin.yaml")
assert any("hex" in e for e in errors)
def test_validate_manifest_accepts_valid_sha256(tmp_path: Path):
valid_sha = "a" * 64
(tmp_path / "plugin.yaml").write_text(f"name: test\nsha256: {valid_sha}\n")
errors = validate_manifest(tmp_path / "plugin.yaml")
assert not errors
def test_validate_manifest_accepts_absent_sha256(tmp_path: Path):
(tmp_path / "plugin.yaml").write_text("name: test\nversion: '1.0'\n")
errors = validate_manifest(tmp_path / "plugin.yaml")
assert not errors
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w") as tf:
sym = tarfile.TarInfo(name="link.lnk")
sym.type = tarfile.SYMTYPE
sym.linkname = "/etc/passwd"
tf.addfile(sym)
# Plus a normal file alongside
info = tarfile.TarInfo(name="real.md")
data = b"ok"
info.size = len(data)
tf.addfile(info, io.BytesIO(data))
buf.seek(0)
with tarfile.open(fileobj=buf, mode="r") as tf:
_safe_extract_tar(tf, tmp_path)
assert (tmp_path / "real.md").exists()
assert not (tmp_path / "link.lnk").exists()