diff --git a/molecule_runtime/heartbeat.py b/molecule_runtime/heartbeat.py index 194d52e..c08f944 100644 --- a/molecule_runtime/heartbeat.py +++ b/molecule_runtime/heartbeat.py @@ -17,7 +17,7 @@ from pathlib import Path import httpx -from molecule_runtime.platform_auth import auth_headers +from molecule_runtime.platform_auth import auth_headers, refresh_from_disk logger = logging.getLogger(__name__) @@ -85,18 +85,34 @@ class HeartbeatLoop: while True: # 1. Send heartbeat (Phase 30.1: include auth header if token known) try: - await client.post( + hb_payload = { + "workspace_id": self.workspace_id, + "error_rate": self.error_rate, + "sample_error": self.sample_error, + "active_tasks": self.active_tasks, + "current_task": self.current_task, + "uptime_seconds": int(time.time() - self.start_time), + } + resp = await client.post( f"{self.platform_url}/registry/heartbeat", - json={ - "workspace_id": self.workspace_id, - "error_rate": self.error_rate, - "sample_error": self.sample_error, - "active_tasks": self.active_tasks, - "current_task": self.current_task, - "uptime_seconds": int(time.time() - self.start_time), - }, + json=hb_payload, headers=auth_headers(), ) + # #1877: auto-restart rotates the workspace token AFTER + # container start, so the first heartbeat after a restart + # can race the token write and send the stale cached + # value → 401. Re-read /configs/.auth_token and retry ONCE + # to break the 401 loop without needing another restart. + if resp.status_code == 401: + if refresh_from_disk() is not None: + logger.info( + "Heartbeat: got 401, refreshed token from disk, retrying" + ) + resp = await client.post( + f"{self.platform_url}/registry/heartbeat", + json=hb_payload, + headers=auth_headers(), + ) self.error_count = 0 self.request_count = 0 self._consecutive_failures = 0 @@ -135,10 +151,12 @@ class HeartbeatLoop: async def _check_delegations(self, client: httpx.AsyncClient): """Check for completed delegations and store results for the agent.""" try: - resp = await client.get( - f"{self.platform_url}/workspaces/{self.workspace_id}/delegations", - headers=auth_headers(), - ) + url = f"{self.platform_url}/workspaces/{self.workspace_id}/delegations" + resp = await client.get(url, headers=auth_headers()) + # #1877: refresh token on 401 and retry ONCE — same post-restart + # token-rotation race as the heartbeat path above. + if resp.status_code == 401 and refresh_from_disk() is not None: + resp = await client.get(url, headers=auth_headers()) if resp.status_code != 200: return diff --git a/molecule_runtime/platform_auth.py b/molecule_runtime/platform_auth.py index 36bcfe7..34c1abf 100644 --- a/molecule_runtime/platform_auth.py +++ b/molecule_runtime/platform_auth.py @@ -171,3 +171,22 @@ def clear_cache() -> None: files between cases.""" global _cached_token _cached_token = None + + +def refresh_from_disk() -> str | None: + """Force-reload the token from ``/configs/.auth_token``, bypassing the + in-memory cache. Used by callers (e.g. heartbeat loop) that got a 401 + from the platform and suspect the on-disk token was rotated after boot. + + Returns the fresh token on success, ``None`` if the file is missing or + unreadable. Updates the in-memory cache as a side-effect so subsequent + :func:`auth_headers` calls pick up the new value. + + Context (#1877): on auto-restart, the platform revokes the old token + and writes a new ``.auth_token`` AFTER ``ContainerStart``, so the + runtime's first heartbeat can race the token write and send the stale + cached value. Re-reading from disk on 401 breaks the loop without + needing another full container restart. + """ + clear_cache() + return get_token() diff --git a/pyproject.toml b/pyproject.toml index c904023..ae5f3be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta" [project] name = "molecule-ai-workspace-runtime" -version = "0.1.8" +version = "0.1.9" description = "Molecule AI workspace runtime — shared infrastructure for all agent adapters" requires-python = ">=3.11" diff --git a/tests/test_token_refresh_1877.py b/tests/test_token_refresh_1877.py new file mode 100644 index 0000000..6c8e3c0 --- /dev/null +++ b/tests/test_token_refresh_1877.py @@ -0,0 +1,151 @@ +"""Tests for #1877 fix — runtime re-reads /configs/.auth_token on 401. + +Covers two surfaces: + +1. ``platform_auth.refresh_from_disk()`` — pure helper that clears the + in-memory cache and re-reads the file. +2. The HeartbeatLoop 401-then-retry pattern (verified by replaying it + against an httpx MockTransport). +""" + +from __future__ import annotations + +import os +from pathlib import Path +from typing import Any + +# WORKSPACE_ID must be set BEFORE importing platform_auth — the module +# validates the env var at import time. +os.environ.setdefault("WORKSPACE_ID", "00000000-0000-0000-0000-000000000001") + +import httpx +import pytest + +import molecule_runtime.platform_auth as pa +from molecule_runtime.platform_auth import ( + auth_headers, + clear_cache, + get_token, + refresh_from_disk, + save_token, +) + + +# ---------- platform_auth.refresh_from_disk ---------- + + +def test_refresh_picks_up_rotated_token(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("CONFIGS_DIR", str(tmp_path)) + clear_cache() + + save_token("token-v1") + assert get_token() == "token-v1" + + # Simulate platform rotating the token on disk while runtime had it cached + (tmp_path / ".auth_token").write_text("token-v2") + assert auth_headers().get("Authorization") == "Bearer token-v1" # cache stale + + fresh = refresh_from_disk() + assert fresh == "token-v2" + assert auth_headers().get("Authorization") == "Bearer token-v2" + + +def test_refresh_returns_none_when_file_missing(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("CONFIGS_DIR", str(tmp_path)) + clear_cache() + assert refresh_from_disk() is None + assert "Authorization" not in auth_headers() + + +def test_refresh_clears_stale_cache_when_file_disappears( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +): + monkeypatch.setenv("CONFIGS_DIR", str(tmp_path)) + clear_cache() + save_token("token-v1") + assert get_token() == "token-v1" + + (tmp_path / ".auth_token").unlink() + assert refresh_from_disk() is None + + +def test_refresh_is_idempotent(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("CONFIGS_DIR", str(tmp_path)) + clear_cache() + (tmp_path / ".auth_token").write_text("stable-token") + + a = refresh_from_disk() + b = refresh_from_disk() + assert a == b == "stable-token" + + +# ---------- 401 retry pattern (replayed manually against MockTransport) ---------- + + +def test_401_retry_pattern_uses_refreshed_token( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +): + """Models the #1877 fix path: 401 -> refresh_from_disk -> retry succeeds. + + Uses httpx sync Client + MockTransport so the test doesn't require + pytest-asyncio in CI (the production code is async, but the retry + *logic* — refresh-on-401 — is identical sync or async). + """ + monkeypatch.setenv("CONFIGS_DIR", str(tmp_path)) + clear_cache() + + save_token("token-v1") + (tmp_path / ".auth_token").write_text("token-v2") + pa._cached_token = "token-v1" # explicit stale cache + + calls: list[dict[str, Any]] = [] + + def handler(request: httpx.Request) -> httpx.Response: + calls.append({"auth": request.headers.get("authorization", "")}) + if "token-v1" in request.headers.get("authorization", ""): + return httpx.Response(401, json={"error": "invalid token"}) + return httpx.Response(200, json={}) + + with httpx.Client(transport=httpx.MockTransport(handler), timeout=5.0) as client: + payload = {"workspace_id": "ws-test", "active_tasks": 0} + url = "http://platform:8080/registry/heartbeat" + + # Mirror exactly what heartbeat.py now does: + resp = client.post(url, json=payload, headers=auth_headers()) + if resp.status_code == 401 and refresh_from_disk() is not None: + resp = client.post(url, json=payload, headers=auth_headers()) + + assert resp.status_code == 200 + assert len(calls) == 2 + assert calls[0]["auth"] == "Bearer token-v1" # stale, rejected + assert calls[1]["auth"] == "Bearer token-v2" # fresh, accepted + + +def test_401_retry_no_loop_when_disk_token_also_stale( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +): + """If both cached AND disk tokens are stale, the retry uses the same value + as the original — and the loop must NOT retry forever. The production code + only retries ONCE.""" + monkeypatch.setenv("CONFIGS_DIR", str(tmp_path)) + clear_cache() + + save_token("token-everywhere-stale") # disk + cache match, both invalid + + calls: list[str] = [] + + def handler(request: httpx.Request) -> httpx.Response: + calls.append(request.headers.get("authorization", "")) + return httpx.Response(401, json={"error": "invalid token"}) + + with httpx.Client(transport=httpx.MockTransport(handler), timeout=5.0) as client: + payload = {"workspace_id": "ws-test"} + url = "http://platform:8080/registry/heartbeat" + + resp = client.post(url, json=payload, headers=auth_headers()) + if resp.status_code == 401 and refresh_from_disk() is not None: + resp = client.post(url, json=payload, headers=auth_headers()) + + # Both attempts 401, no third call — bounded retry budget + assert resp.status_code == 401 + assert len(calls) == 2