forked from molecule-ai/molecule-core
Fix A — platform/internal/middleware/wsauth_middleware.go (NEW):
WorkspaceAuth() gin middleware enforces per-workspace bearer-token auth on
ALL /workspaces/:id/* sub-routes. Same lazy-bootstrap contract as
secrets.Values: workspaces with no live token are grandfathered through.
Blocks C2, C3, C4, C5, C7, C8, C9, C12, C13 simultaneously.
Fix A — platform/internal/router/router.go:
Reorganised route registration: bare CRUD (/workspaces, /workspaces/:id)
and /a2a remain on root router; all other /workspaces/:id/* sub-routes
moved into wsAuth = r.Group("/workspaces/:id", middleware.WorkspaceAuth(db.DB)).
CORS AllowHeaders updated to include Authorization so browser/agent callers
can send the bearer token cross-origin.
Fix B — workspace-template/heartbeat.py:
_check_delegations(): validate source_id == self.workspace_id before
accepting a delegation result. Attacker-crafted records with a foreign
source_id are silently skipped with a WARNING log (injection attempt).
trigger_msg no longer embeds raw response_preview text; references
delegation_id + status only — removes the prompt-injection vector.
Fix C — workspace-template/skill_loader/loader.py:
load_skill_tools(): before exec_module(), verify script is within
scripts_dir (path traversal guard) and temporarily scrub sensitive env
vars (CLAUDE_CODE_OAUTH_TOKEN, ANTHROPIC_API_KEY, OPENAI_API_KEY,
WORKSPACE_AUTH_TOKEN, GITHUB_TOKEN, GH_TOKEN) from os.environ; restore
in finally block. Defence-in-depth even if /plugins auth gate is bypassed.
Fix D — platform/internal/handlers/socket.go:
HandleConnect(): agent connections (X-Workspace-ID present) validated via
wsauth.HasAnyLiveToken + wsauth.ValidateToken before WebSocket upgrade.
Canvas clients (no X-Workspace-ID) remain unauthenticated.
Fix D — workspace-template/events.py:
PlatformEventSubscriber._connect(): include platform_auth bearer token in
WebSocket upgrade headers alongside X-Workspace-ID.
Fix E — workspace-template/executor_helpers.py:
recall_memories() and commit_memory() now pass platform_auth bearer token
in Authorization header so WorkspaceAuth middleware allows access.
Fix F — workspace-template/a2a_client.py:
send_a2a_message(): timeout=None → httpx.Timeout(connect=30, read=300,
write=30, pool=30). Resolves H2 flagged across 5 consecutive audits.
Tests: 149/149 Python tests pass (test_heartbeat + test_events updated to
assert new source_id validation behaviour and allow Authorization header).
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
440 lines
14 KiB
Python
440 lines
14 KiB
Python
"""Tests for events.py — PlatformEventSubscriber WebSocket handling."""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import sys
|
|
from types import ModuleType
|
|
from unittest.mock import AsyncMock, MagicMock, patch, call
|
|
|
|
import pytest
|
|
|
|
from events import PlatformEventSubscriber, REBUILD_EVENTS
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _make_ws_mock(messages):
|
|
"""Return an async-context-manager mock that yields messages one-by-one.
|
|
|
|
`messages` is a list of raw strings (or exceptions to raise).
|
|
"""
|
|
ws = MagicMock()
|
|
|
|
async def _aiter():
|
|
for item in messages:
|
|
if isinstance(item, BaseException):
|
|
raise item
|
|
yield item
|
|
|
|
ws.__aiter__ = lambda self: _aiter()
|
|
ws.__aenter__ = AsyncMock(return_value=ws)
|
|
ws.__aexit__ = AsyncMock(return_value=False)
|
|
return ws
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# __init__ — URL conversion
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def test_init_http_to_ws():
|
|
"""http:// platform URLs are converted to ws://."""
|
|
sub = PlatformEventSubscriber("http://platform:8080", "ws-1")
|
|
assert sub.ws_url == "ws://platform:8080/ws"
|
|
|
|
|
|
def test_init_https_to_wss():
|
|
"""https:// platform URLs are converted to wss://."""
|
|
sub = PlatformEventSubscriber("https://platform:8080", "ws-1")
|
|
assert sub.ws_url == "wss://platform:8080/ws"
|
|
|
|
|
|
def test_init_stores_attrs():
|
|
"""Constructor stores workspace_id, on_peer_change, initial state."""
|
|
cb = MagicMock()
|
|
sub = PlatformEventSubscriber("http://p:8080", "ws-42", on_peer_change=cb)
|
|
assert sub.workspace_id == "ws-42"
|
|
assert sub.on_peer_change is cb
|
|
assert sub._running is False
|
|
assert sub._reconnect_delay == 1.0
|
|
|
|
|
|
def test_init_on_peer_change_defaults_none():
|
|
"""on_peer_change defaults to None when not supplied."""
|
|
sub = PlatformEventSubscriber("http://p:8080", "ws-1")
|
|
assert sub.on_peer_change is None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# stop()
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def test_stop_sets_running_false():
|
|
"""stop() sets _running to False."""
|
|
sub = PlatformEventSubscriber("http://p:8080", "ws-1")
|
|
sub._running = True
|
|
sub.stop()
|
|
assert sub._running is False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _connect() — websockets ImportError path
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connect_no_websockets_package(monkeypatch):
|
|
"""_connect() disables running and returns when websockets is not installed."""
|
|
sub = PlatformEventSubscriber("http://p:8080", "ws-1")
|
|
sub._running = True
|
|
|
|
# Hide websockets from sys.modules
|
|
original = sys.modules.pop("websockets", None)
|
|
# Also prevent import by making it raise ImportError via builtins
|
|
import builtins
|
|
real_import = builtins.__import__
|
|
|
|
def _no_websockets(name, *args, **kwargs):
|
|
if name == "websockets":
|
|
raise ImportError("No module named 'websockets'")
|
|
return real_import(name, *args, **kwargs)
|
|
|
|
monkeypatch.setattr(builtins, "__import__", _no_websockets)
|
|
try:
|
|
await sub._connect()
|
|
finally:
|
|
if original is not None:
|
|
sys.modules["websockets"] = original
|
|
monkeypatch.setattr(builtins, "__import__", real_import)
|
|
|
|
assert sub._running is False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _connect() — message processing
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connect_rebuild_event_calls_on_peer_change():
|
|
"""REBUILD_EVENTS trigger the on_peer_change callback."""
|
|
peer_events = []
|
|
|
|
async def on_peer_change(event):
|
|
peer_events.append(event)
|
|
|
|
sub = PlatformEventSubscriber("http://p:8080", "ws-1", on_peer_change=on_peer_change)
|
|
sub._running = True
|
|
|
|
event_msg = json.dumps({"event": "WORKSPACE_ONLINE", "workspace_id": "ws-2"})
|
|
ws_mock = _make_ws_mock([event_msg])
|
|
|
|
websockets_mod = MagicMock()
|
|
websockets_mod.connect = MagicMock(return_value=ws_mock)
|
|
|
|
with patch.dict(sys.modules, {"websockets": websockets_mod}):
|
|
await sub._connect()
|
|
|
|
assert len(peer_events) == 1
|
|
assert peer_events[0]["event"] == "WORKSPACE_ONLINE"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connect_all_rebuild_event_types():
|
|
"""Every event type in REBUILD_EVENTS triggers on_peer_change."""
|
|
for event_type in REBUILD_EVENTS:
|
|
received = []
|
|
|
|
async def on_peer_change(event, _et=event_type):
|
|
received.append(event)
|
|
|
|
sub = PlatformEventSubscriber("http://p:8080", "ws-1", on_peer_change=on_peer_change)
|
|
sub._running = True
|
|
|
|
msg = json.dumps({"event": event_type, "workspace_id": "ws-x"})
|
|
ws_mock = _make_ws_mock([msg])
|
|
|
|
websockets_mod = MagicMock()
|
|
websockets_mod.connect = MagicMock(return_value=ws_mock)
|
|
|
|
with patch.dict(sys.modules, {"websockets": websockets_mod}):
|
|
await sub._connect()
|
|
|
|
assert len(received) == 1, f"Expected callback for {event_type}"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connect_ignored_event_no_callback():
|
|
"""Events not in REBUILD_EVENTS do not invoke on_peer_change."""
|
|
called = []
|
|
|
|
async def on_peer_change(event):
|
|
called.append(event)
|
|
|
|
sub = PlatformEventSubscriber("http://p:8080", "ws-1", on_peer_change=on_peer_change)
|
|
sub._running = True
|
|
|
|
msg = json.dumps({"event": "HEARTBEAT", "workspace_id": "ws-2"})
|
|
ws_mock = _make_ws_mock([msg])
|
|
|
|
websockets_mod = MagicMock()
|
|
websockets_mod.connect = MagicMock(return_value=ws_mock)
|
|
|
|
with patch.dict(sys.modules, {"websockets": websockets_mod}):
|
|
await sub._connect()
|
|
|
|
assert called == []
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connect_no_on_peer_change_rebuild_event():
|
|
"""REBUILD_EVENTS are handled without error when on_peer_change is None."""
|
|
sub = PlatformEventSubscriber("http://p:8080", "ws-1", on_peer_change=None)
|
|
sub._running = True
|
|
|
|
msg = json.dumps({"event": "WORKSPACE_ONLINE", "workspace_id": "ws-3"})
|
|
ws_mock = _make_ws_mock([msg])
|
|
|
|
websockets_mod = MagicMock()
|
|
websockets_mod.connect = MagicMock(return_value=ws_mock)
|
|
|
|
with patch.dict(sys.modules, {"websockets": websockets_mod}):
|
|
await sub._connect() # Should not raise
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connect_json_decode_error_continues():
|
|
"""Malformed JSON messages are silently skipped (no crash, no callback)."""
|
|
called = []
|
|
|
|
async def on_peer_change(event):
|
|
called.append(event)
|
|
|
|
sub = PlatformEventSubscriber("http://p:8080", "ws-1", on_peer_change=on_peer_change)
|
|
sub._running = True
|
|
|
|
# Mix bad JSON with a valid message
|
|
good_msg = json.dumps({"event": "WORKSPACE_ONLINE", "workspace_id": "ws-4"})
|
|
ws_mock = _make_ws_mock(["not-valid-json{{{", good_msg])
|
|
|
|
websockets_mod = MagicMock()
|
|
websockets_mod.connect = MagicMock(return_value=ws_mock)
|
|
|
|
with patch.dict(sys.modules, {"websockets": websockets_mod}):
|
|
await sub._connect()
|
|
|
|
# The good message after the bad one should still fire the callback
|
|
assert len(called) == 1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connect_processing_exception_logged(caplog):
|
|
"""Exceptions during event processing are logged as warnings and skipped."""
|
|
async def bad_callback(event):
|
|
raise RuntimeError("callback blew up")
|
|
|
|
sub = PlatformEventSubscriber("http://p:8080", "ws-1", on_peer_change=bad_callback)
|
|
sub._running = True
|
|
|
|
msg = json.dumps({"event": "WORKSPACE_ONLINE", "workspace_id": "ws-5"})
|
|
ws_mock = _make_ws_mock([msg])
|
|
|
|
websockets_mod = MagicMock()
|
|
websockets_mod.connect = MagicMock(return_value=ws_mock)
|
|
|
|
with patch.dict(sys.modules, {"websockets": websockets_mod}):
|
|
with caplog.at_level(logging.WARNING, logger="events"):
|
|
await sub._connect()
|
|
|
|
assert "Error processing event" in caplog.text
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connect_resets_reconnect_delay():
|
|
"""A successful connection resets _reconnect_delay to 1.0."""
|
|
sub = PlatformEventSubscriber("http://p:8080", "ws-1")
|
|
sub._running = True
|
|
sub._reconnect_delay = 16.0 # Simulate previous backoff
|
|
|
|
ws_mock = _make_ws_mock([]) # No messages; connects and exits cleanly
|
|
|
|
websockets_mod = MagicMock()
|
|
websockets_mod.connect = MagicMock(return_value=ws_mock)
|
|
|
|
with patch.dict(sys.modules, {"websockets": websockets_mod}):
|
|
await sub._connect()
|
|
|
|
assert sub._reconnect_delay == 1.0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connect_uses_workspace_id_header():
|
|
"""_connect() passes X-Workspace-ID header to websockets.connect."""
|
|
sub = PlatformEventSubscriber("http://p:8080", "ws-hdr", on_peer_change=None)
|
|
sub._running = True
|
|
|
|
ws_mock = _make_ws_mock([])
|
|
|
|
websockets_mod = MagicMock()
|
|
websockets_mod.connect = MagicMock(return_value=ws_mock)
|
|
|
|
with patch.dict(sys.modules, {"websockets": websockets_mod}):
|
|
await sub._connect()
|
|
|
|
call_kwargs = websockets_mod.connect.call_args[1]
|
|
# Fix D (Cycle 5): headers now include Authorization when platform_auth available.
|
|
# Assert X-Workspace-ID is present; allow optional Authorization header.
|
|
actual_headers = call_kwargs.get("additional_headers", {})
|
|
assert actual_headers.get("X-Workspace-ID") == "ws-hdr"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# start() — reconnect with backoff
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_start_sets_running_true():
|
|
"""start() sets _running=True before entering the loop."""
|
|
sub = PlatformEventSubscriber("http://p:8080", "ws-1")
|
|
|
|
connect_calls = [0]
|
|
|
|
async def fake_connect():
|
|
connect_calls[0] += 1
|
|
sub._running = False # Stop after first connect
|
|
|
|
sub._connect = fake_connect
|
|
await sub.start()
|
|
|
|
assert connect_calls[0] == 1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_start_reconnects_on_exception():
|
|
"""start() reconnects after a connection exception with backoff sleep."""
|
|
sub = PlatformEventSubscriber("http://p:8080", "ws-1")
|
|
|
|
connect_calls = [0]
|
|
sleep_calls = []
|
|
|
|
async def fake_connect():
|
|
connect_calls[0] += 1
|
|
if connect_calls[0] == 1:
|
|
raise ConnectionError("refused")
|
|
sub._running = False
|
|
|
|
async def fake_sleep(secs):
|
|
sleep_calls.append(secs)
|
|
|
|
sub._connect = fake_connect
|
|
|
|
with patch("events.asyncio.sleep", side_effect=fake_sleep):
|
|
await sub.start()
|
|
|
|
assert connect_calls[0] == 2
|
|
assert sleep_calls == [1.0] # initial _reconnect_delay
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_start_backoff_doubles_each_reconnect():
|
|
"""Reconnect delay doubles on each consecutive failure, capped at 30s."""
|
|
sub = PlatformEventSubscriber("http://p:8080", "ws-1")
|
|
|
|
connect_calls = [0]
|
|
sleep_calls = []
|
|
|
|
async def fake_connect():
|
|
connect_calls[0] += 1
|
|
if connect_calls[0] < 4:
|
|
raise ConnectionError("fail")
|
|
sub._running = False
|
|
|
|
async def fake_sleep(secs):
|
|
sleep_calls.append(secs)
|
|
|
|
sub._connect = fake_connect
|
|
|
|
with patch("events.asyncio.sleep", side_effect=fake_sleep):
|
|
await sub.start()
|
|
|
|
# Delays: 1.0, 2.0, 4.0
|
|
assert sleep_calls == [1.0, 2.0, 4.0]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_start_backoff_capped_at_30():
|
|
"""Reconnect delay is capped at 30 seconds."""
|
|
sub = PlatformEventSubscriber("http://p:8080", "ws-1")
|
|
sub._reconnect_delay = 20.0 # Already near the cap
|
|
|
|
connect_calls = [0]
|
|
sleep_calls = []
|
|
|
|
async def fake_connect():
|
|
connect_calls[0] += 1
|
|
if connect_calls[0] < 3:
|
|
raise ConnectionError("fail")
|
|
sub._running = False
|
|
|
|
async def fake_sleep(secs):
|
|
sleep_calls.append(secs)
|
|
|
|
sub._connect = fake_connect
|
|
|
|
with patch("events.asyncio.sleep", side_effect=fake_sleep):
|
|
await sub.start()
|
|
|
|
# 20.0 then min(40.0, 30.0)=30.0
|
|
assert sleep_calls == [20.0, 30.0]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_start_stops_when_running_false_after_exception():
|
|
"""If stop() is called while reconnecting, the loop exits cleanly."""
|
|
sub = PlatformEventSubscriber("http://p:8080", "ws-1")
|
|
|
|
connect_calls = [0]
|
|
|
|
async def fake_connect():
|
|
connect_calls[0] += 1
|
|
# Mark stopped before raising so the 'if not self._running: break' fires
|
|
sub._running = False
|
|
raise ConnectionError("closed")
|
|
|
|
async def fake_sleep(secs):
|
|
pass # Should not be reached
|
|
|
|
sub._connect = fake_connect
|
|
|
|
with patch("events.asyncio.sleep", side_effect=fake_sleep):
|
|
await sub.start()
|
|
|
|
# Connected once, then saw _running=False and broke out
|
|
assert connect_calls[0] == 1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_start_logs_reconnect_warning(caplog):
|
|
"""start() logs a warning message when a reconnect is needed."""
|
|
sub = PlatformEventSubscriber("http://p:8080", "ws-1")
|
|
|
|
connect_calls = [0]
|
|
|
|
async def fake_connect():
|
|
connect_calls[0] += 1
|
|
if connect_calls[0] == 1:
|
|
raise ConnectionError("timed out")
|
|
sub._running = False
|
|
|
|
async def fake_sleep(secs):
|
|
pass
|
|
|
|
sub._connect = fake_connect
|
|
|
|
with patch("events.asyncio.sleep", side_effect=fake_sleep):
|
|
with caplog.at_level(logging.WARNING, logger="events"):
|
|
await sub.start()
|
|
|
|
assert "WebSocket disconnected" in caplog.text
|
|
assert "Reconnecting" in caplog.text
|