molecule-ai-workspace-runtime/molecule_runtime/platform_auth.py
rabbitblood 050c2412b3 fix(heartbeat): refresh on-disk auth token on 401 + retry once (#1877)
## Problem

Auto-restart rotates the workspace's auth token in two non-atomic steps:
  1. Platform issues new token via wsauth.IssueToken
  2. Provisioner writes the new token to /configs/.auth_token AFTER
     ContainerStart returns

Between steps 1 and 2, the new container has booted and the runtime has
already loaded the OLD cached value of .auth_token (or no value if the
file was empty during boot). The runtime's first /registry/heartbeat
call sends the stale token, gets 401, but the loop never re-reads the
on-disk token — so subsequent heartbeats also send the stale value.

Each 401 means the platform never sees the workspace as alive →
status stays 'provisioning' → scheduler won't dispatch → workspace
looks dead from every angle even though the container is actually
running.

The existing code comment in workspace_provision.go acknowledges this:
"the workspace will get 401 on its first heartbeat and can recover on
the next restart." That recovery only worked because workspaces used
to crash for unrelated reasons and get restarted. After PR #1861
(provisioner empty-volume auto-recover) removed those crashes,
workspaces get stuck in the 401 loop with no exit.

## Fix

Two-part runtime-side fix in molecule-ai-workspace-runtime:

1. **platform_auth.refresh_from_disk()** — new helper that clears the
   in-memory cache and re-reads /configs/.auth_token. Returns the
   fresh value (or None if missing). Updates the cache as a side effect.

2. **HeartbeatLoop._loop()** — on 401 from /registry/heartbeat, calls
   refresh_from_disk() and retries the request ONCE with the new token.
   Same pattern in _check_delegations(). Bounded retry budget — if the
   on-disk token is also stale (bug elsewhere), no infinite loop.

## Tests

6/6 new tests in tests/test_token_refresh_1877.py:

  - refresh_picks_up_rotated_token              — happy path
  - refresh_returns_none_when_file_missing      — defensive
  - refresh_clears_stale_cache_when_file_disappears
  - refresh_is_idempotent
  - 401_retry_pattern_uses_refreshed_token      — the production fix path
  - 401_retry_no_loop_when_disk_token_also_stale — bounded retry budget

All pass locally on Python 3.13 + pytest 9.

## Why this fix and not the alternatives

- **Alternative B (platform writes token before ContainerStart):**
  Right architecturally but invasive — needs provisioner refactor to
  prep volumes before docker run.
- **Alternative C (skip rotation on auto-restart):** Breaks the
  multi-instance-safety invariant the existing code calls out
  (revoke prevents stale tokens from sister deployments).
- **This fix (A):** 3-line core change + helper. Self-healing for any
  timing edge case, not just the post-restart one. Costs nothing in
  the happy path (only triggers on 401).

## Version

Bumped to 0.1.9. Once published to PyPI + workspace template image
rebuilt, deployed workspaces auto-recover from token-rotation races
without operator intervention.

Closes #1877.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-23 13:26:36 -07:00

193 lines
7.2 KiB
Python

"""Workspace auth-token store (Phase 30.1).
Single source of truth for this workspace's authentication token. The
token is issued by the platform on the first successful
``POST /registry/register`` call and travels with every subsequent
heartbeat / update-card / (later) secrets-pull / A2A request.
The token is persisted to ``<configs>/.auth_token`` so it survives
restarts — we only expect to receive it once from the platform, since
``/registry/register`` no-ops token issuance for workspaces that already
have one on file.
Storage:
${CONFIGS_DIR}/.auth_token # 0600, one line, no trailing newline
Callers interact with three functions:
:func:`get_token` — returns the cached token or None
:func:`save_token` — persists a freshly-issued token
:func:`auth_headers`— builds the Authorization header dict for httpx
"""
from __future__ import annotations
import logging
import os
import re
from pathlib import Path
logger = logging.getLogger(__name__)
# Valid workspace ID: lowercase alphanumeric + hyphens (UUIDs and org-generated IDs).
# Rejects /, \, .., #, ?, &, newlines — all chars that could break URL paths
# or HTTP header values. This is the single validation gate for WORKSPACE_ID.
_WORKSPACE_ID_RE = re.compile(r"^[a-z0-9][a-z0-9\-]{0,127}$")
# Cached result — validated once per process startup, not on every call.
_validated_workspace_id: str | None = None
def validate_workspace_id(workspace_id: str) -> str:
"""Validate *workspace_id* and return it.
Raises ValueError if the ID is empty, contains unsafe characters, or
does not match the expected format. This function is the single validation
gate — call it once at startup and reuse the result.
Fixes issue #14 (CWE-20): prevents URL/header injection when WORKSPACE_ID
is used in platform API URLs and ``X-Workspace-ID`` headers.
"""
global _validated_workspace_id
if _validated_workspace_id is not None:
return _validated_workspace_id # pragma: no cover — cached fast path
if not workspace_id:
raise ValueError("WORKSPACE_ID is empty — set the WORKSPACE_ID env var")
# Strip and check again after strip
workspace_id = workspace_id.strip()
if not _WORKSPACE_ID_RE.match(workspace_id):
raise ValueError(
f"WORKSPACE_ID contains invalid characters: {workspace_id!r}. "
"Only lowercase letters, digits, and hyphens are allowed. "
"Ensure WORKSPACE_ID is a valid UUID or alphanumeric ID."
)
_validated_workspace_id = workspace_id
return workspace_id
# In-process cache so we don't hit disk on every heartbeat. The heartbeat
# loop fires on a short interval and reading a tiny file 10x per minute
# is wasteful. The file is the durable copy; this var is the hot path.
_cached_token: str | None = None
# Validated WORKSPACE_ID — read once at import time so every caller gets the
# same validated value without re-checking. Raises on bad input.
WORKSPACE_ID: str = validate_workspace_id(os.environ.get("WORKSPACE_ID", ""))
def get_workspace_id() -> str:
"""Return the validated workspace ID.
Cached result from module-level WORKSPACE_ID constant. Call this instead
of reading WORKSPACE_ID directly — it guarantees the ID passed validation.
"""
return WORKSPACE_ID
def _token_file() -> Path:
"""Path to the on-disk token file. Respects CONFIGS_DIR, falls back
to /configs for the default container layout."""
return Path(os.environ.get("CONFIGS_DIR", "/configs")) / ".auth_token"
def get_token() -> str | None:
"""Return the cached token, reading it from disk on first call."""
global _cached_token
if _cached_token is not None:
return _cached_token
path = _token_file()
if not path.exists():
return None
try:
tok = path.read_text().strip()
except OSError as exc:
logger.warning("platform_auth: failed to read %s: %s", path, exc)
return None
if not tok:
return None
_cached_token = tok
return tok
def save_token(token: str) -> None:
"""Persist a newly-issued token. Creates the file with 0600 mode atomically.
Uses ``os.open(O_CREAT, 0o600)`` so the file is never world-readable,
even transiently. The previous ``write_text()`` + ``chmod()`` approach
had a TOCTOU window where a concurrent reader could access the token
between the two syscalls (M4 — flagged in security audit cycle 10).
Idempotent — if an identical token is already on disk we skip the
write so we don't churn the file's mtime or trigger spurious
filesystem watchers."""
global _cached_token
token = token.strip()
if not token:
raise ValueError("platform_auth: refusing to save empty token")
if get_token() == token:
return
path = _token_file()
path.parent.mkdir(parents=True, exist_ok=True)
# O_CREAT | O_WRONLY | O_TRUNC with mode=0o600 atomically creates (or
# truncates) the file with restricted permissions in a single syscall,
# eliminating the TOCTOU window.
fd = os.open(str(path), os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
try:
os.write(fd, token.encode())
finally:
os.close(fd)
_cached_token = token
def auth_headers() -> dict[str, str]:
"""Return a header dict to merge into every outbound platform call.
Two headers, both optional:
- ``Authorization: Bearer <token>`` — the workspace-scoped auth
token issued on first /registry/register. Empty if not yet
issued; the platform grandfathers pre-token workspaces through.
- ``X-Molecule-Org-Id: <uuid>`` — the SaaS cross-org routing tag
the tenant platform's TenantGuard requires on every non-
allowlisted route. Read from the ``MOLECULE_ORG_ID`` env var
that the control plane exports into workspace user-data.
Unset on self-hosted / dev deployments where TenantGuard is a
no-op, so omitting the header keeps those paths working.
"""
headers: dict[str, str] = {}
tok = get_token()
if tok:
headers["Authorization"] = f"Bearer {tok}"
org_id = os.environ.get("MOLECULE_ORG_ID", "").strip()
if org_id:
headers["X-Molecule-Org-Id"] = org_id
return headers
def clear_cache() -> None:
"""Reset the in-memory cache. Used by tests that write fresh token
files between cases."""
global _cached_token
_cached_token = None
def refresh_from_disk() -> str | None:
"""Force-reload the token from ``/configs/.auth_token``, bypassing the
in-memory cache. Used by callers (e.g. heartbeat loop) that got a 401
from the platform and suspect the on-disk token was rotated after boot.
Returns the fresh token on success, ``None`` if the file is missing or
unreadable. Updates the in-memory cache as a side-effect so subsequent
:func:`auth_headers` calls pick up the new value.
Context (#1877): on auto-restart, the platform revokes the old token
and writes a new ``.auth_token`` AFTER ``ContainerStart``, so the
runtime's first heartbeat can race the token write and send the stale
cached value. Re-reading from disk on 401 breaks the loop without
needing another full container restart.
"""
clear_cache()
return get_token()