diff --git a/molecule_runtime/builtin_tools/a2a_tools.py b/molecule_runtime/builtin_tools/a2a_tools.py index ac05290..61c21c1 100644 --- a/molecule_runtime/builtin_tools/a2a_tools.py +++ b/molecule_runtime/builtin_tools/a2a_tools.py @@ -9,15 +9,21 @@ import uuid import httpx +from builtin_tools.validation import WorkspaceIdValidationError, get_validated_workspace_id + PLATFORM_URL = os.environ.get("PLATFORM_URL", "http://platform:8080") -WORKSPACE_ID = os.environ.get("WORKSPACE_ID", "") +WORKSPACE_ID = os.environ.get("WORKSPACE_ID", "") # used only for tracing headers; URLs use validated version async def list_peers() -> list[dict]: """Get this workspace's peers from the platform registry.""" + try: + ws_id = get_validated_workspace_id(caller="a2a_tools.list_peers") + except WorkspaceIdValidationError: + return [] async with httpx.AsyncClient(timeout=10.0) as client: try: - resp = await client.get(f"{PLATFORM_URL}/registry/{WORKSPACE_ID}/peers") + resp = await client.get(f"{PLATFORM_URL}/registry/{ws_id}/peers") if resp.status_code == 200: return resp.json() return [] @@ -27,12 +33,17 @@ async def list_peers() -> list[dict]: async def delegate_task(workspace_id: str, task: str) -> str: """Send a task to a peer workspace via A2A and return the response text.""" + try: + ws_id = get_validated_workspace_id(caller="a2a_tools.delegate_task") + except WorkspaceIdValidationError as e: + return f"Error: {e}" + async with httpx.AsyncClient(timeout=120.0) as client: # Discover target URL try: resp = await client.get( f"{PLATFORM_URL}/registry/discover/{workspace_id}", - headers={"X-Workspace-ID": WORKSPACE_ID}, + headers={"X-Workspace-ID": ws_id}, ) if resp.status_code != 200: return f"Error: cannot reach workspace {workspace_id} (status {resp.status_code})" diff --git a/molecule_runtime/builtin_tools/approval.py b/molecule_runtime/builtin_tools/approval.py index 39c8721..4a87b0c 100644 --- a/molecule_runtime/builtin_tools/approval.py +++ b/molecule_runtime/builtin_tools/approval.py @@ -51,6 +51,7 @@ import httpx from langchain_core.tools import tool from builtin_tools.audit import check_permission, get_workspace_roles, log_event +from builtin_tools.validation import WorkspaceIdValidationError, get_validated_workspace_id logger = logging.getLogger(__name__) @@ -91,10 +92,16 @@ async def _create_approval_request(action: str, reason: str) -> dict: Returns {"approval_id": str} on success or {"error": str} on failure. """ + # --- Workspace ID validation (CWE-20 / CWE-88) ---------------------------- + try: + ws_id = get_validated_workspace_id(caller="approval._create_approval_request") + except WorkspaceIdValidationError as e: + return {"error": str(e)} + async with httpx.AsyncClient(timeout=10.0) as client: try: resp = await client.post( - f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/approvals", + f"{PLATFORM_URL}/workspaces/{ws_id}/approvals", json={"action": action, "reason": reason}, ) if resp.status_code != 201: @@ -156,6 +163,13 @@ async def _wait_websocket(approval_id: str, timeout: float) -> dict: async def _wait_polling(approval_id: str, timeout: float) -> dict: """Legacy polling loop — checks platform REST endpoint every APPROVAL_POLL_INTERVAL seconds.""" + # --- Workspace ID validation (CWE-20 / CWE-88) ---------------------------- + try: + ws_id = get_validated_workspace_id(caller="approval._wait_polling") + except WorkspaceIdValidationError: + # Transient — propagate as timeout so the caller handles it gracefully + raise asyncio.TimeoutError("WORKSPACE_ID validation failed") + elapsed = 0.0 async with httpx.AsyncClient(timeout=10.0) as client: while elapsed < timeout: @@ -163,7 +177,7 @@ async def _wait_polling(approval_id: str, timeout: float) -> dict: elapsed += APPROVAL_POLL_INTERVAL try: resp = await client.get( - f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/approvals", + f"{PLATFORM_URL}/workspaces/{ws_id}/approvals", ) if resp.status_code == 200: for a in resp.json(): diff --git a/molecule_runtime/builtin_tools/delegation.py b/molecule_runtime/builtin_tools/delegation.py index 0d608c2..4e42438 100644 --- a/molecule_runtime/builtin_tools/delegation.py +++ b/molecule_runtime/builtin_tools/delegation.py @@ -19,6 +19,7 @@ import httpx from langchain_core.tools import tool from builtin_tools.audit import check_permission, get_workspace_roles, log_event +from builtin_tools.validation import WorkspaceIdValidationError, get_validated_workspace_id from builtin_tools.telemetry import ( A2A_SOURCE_WORKSPACE, A2A_TARGET_WORKSPACE, @@ -82,10 +83,15 @@ def _on_task_done(task: asyncio.Task): async def _notify_completion(task_id: str, target_workspace_id: str, status: str): """Push notification to platform when delegation completes/fails.""" + try: + ws_id = get_validated_workspace_id(caller="delegation._notify_completion") + except WorkspaceIdValidationError: + logger.debug("Delegation notify skipped: invalid WORKSPACE_ID") + return try: async with httpx.AsyncClient(timeout=10) as client: await client.post( - f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/notify", + f"{PLATFORM_URL}/workspaces/{ws_id}/notify", json={ "type": "delegation_complete", "task_id": task_id, @@ -105,10 +111,15 @@ async def _record_delegation_on_platform(task_id: str, target_workspace_id: str, GET /delegations endpoint now mirrors the same set an agent's local check_delegation_status sees. """ + try: + ws_id = get_validated_workspace_id(caller="delegation._record_delegation_on_platform") + except WorkspaceIdValidationError: + logger.debug("Delegation record skipped: invalid WORKSPACE_ID") + return try: async with httpx.AsyncClient(timeout=10) as client: await client.post( - f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/delegations/record", + f"{PLATFORM_URL}/workspaces/{ws_id}/delegations/record", json={ "target_id": target_workspace_id, "task": task, @@ -125,10 +136,15 @@ async def _update_delegation_on_platform(task_id: str, status: str, error: str = Paired with _record_delegation_on_platform — fires on completion/failure so the platform view stays in sync with the agent's local dict. """ + try: + ws_id = get_validated_workspace_id(caller="delegation._update_delegation_on_platform") + except WorkspaceIdValidationError: + logger.debug("Delegation update skipped: invalid WORKSPACE_ID") + return try: async with httpx.AsyncClient(timeout=10) as client: await client.post( - f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/delegations/{task_id}/update", + f"{PLATFORM_URL}/workspaces/{ws_id}/delegations/{task_id}/update", json={ "status": status, "error": error, diff --git a/molecule_runtime/builtin_tools/memory.py b/molecule_runtime/builtin_tools/memory.py index 5b7d5cd..268494b 100644 --- a/molecule_runtime/builtin_tools/memory.py +++ b/molecule_runtime/builtin_tools/memory.py @@ -33,6 +33,7 @@ from typing import Any from langchain_core.tools import tool from builtin_tools.awareness_client import _normalise_namespace, build_awareness_client +from builtin_tools.validation import WorkspaceIdValidationError, get_validated_workspace_id from builtin_tools.audit import check_permission, get_workspace_roles, log_event from builtin_tools.telemetry import MEMORY_QUERY, MEMORY_SCOPE, WORKSPACE_ID_ATTR, get_tracer @@ -59,6 +60,12 @@ async def commit_memory(content: str, scope: str = "LOCAL", *, namespace: str | if scope not in ("LOCAL", "TEAM", "GLOBAL"): return {"error": "scope must be LOCAL, TEAM, or GLOBAL"} + # --- Workspace ID validation (CWE-20 / CWE-88) ---------------------------- + try: + ws_id = get_validated_workspace_id(caller="memory.commit_memory") + except WorkspaceIdValidationError as e: + return {"success": False, "error": str(e)} + # --- RBAC check ----------------------------------------------------------- roles, custom_perms = get_workspace_roles() if not check_permission("memory.write", roles, custom_perms): @@ -129,7 +136,7 @@ async def commit_memory(content: str, scope: str = "LOCAL", *, namespace: str | async with httpx.AsyncClient(timeout=10.0) as client: try: resp = await client.post( - f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/memories", + f"{PLATFORM_URL}/workspaces/{ws_id}/memories", json={"content": content, "scope": scope, "namespace": _normalise_namespace(namespace)}, headers=_headers, ) @@ -200,6 +207,12 @@ async def search_memory(query: str = "", scope: str = "", *, namespace: str | No if scope and scope not in ("LOCAL", "TEAM", "GLOBAL"): return {"error": "scope must be LOCAL, TEAM, GLOBAL, or empty"} + # --- Workspace ID validation (CWE-20 / CWE-88) ---------------------------- + try: + ws_id = get_validated_workspace_id(caller="memory.search_memory") + except WorkspaceIdValidationError as e: + return {"success": False, "error": str(e)} + # --- RBAC check ----------------------------------------------------------- roles, custom_perms = get_workspace_roles() if not check_permission("memory.read", roles, custom_perms): @@ -292,7 +305,7 @@ async def search_memory(query: str = "", scope: str = "", *, namespace: str | No async with httpx.AsyncClient(timeout=10.0) as client: try: resp = await client.get( - f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/memories", + f"{PLATFORM_URL}/workspaces/{ws_id}/memories", params=params, headers=_headers, ) diff --git a/molecule_runtime/builtin_tools/validation.py b/molecule_runtime/builtin_tools/validation.py new file mode 100644 index 0000000..f8e202b --- /dev/null +++ b/molecule_runtime/builtin_tools/validation.py @@ -0,0 +1,80 @@ +"""Shared input validation helpers for builtin tools. + +Defence-in-depth: validate environment-derived values before they are used +in URLs, headers, or other security-sensitive positions (CWE-20 / CWE-88). +""" + +from __future__ import annotations + +import os +import re +from typing import assert_never + +# Pattern: alphanumeric + hyphen + underscore + dot; no path-traversal chars. +# This deliberately rejects `/`, `\`, `..`, `#`, `?`, `&` which could +# manipulate URL path segments or query strings. +_WORKSPACE_ID_RE = re.compile(r"^[A-Za-z0-9_\-.]{1,256}$") + +# Error message prefix used by callers so callers can surface context. +_WORKSPACE_ID_INVALID_MSG = ( + "WORKSPACE_ID has an invalid format. " + "Expected an alphanumeric identifier (hyphens, underscores, dots allowed); " + "got: {value!r} " + "(path-traversal characters such as / \\ .. or fragment chars such as # ? & are not permitted)" +) + + +class WorkspaceIdValidationError(ValueError): + """Raised when WORKSPACE_ID fails format validation. + + This is intentionally a ValueError subclass so callers that currently + swallow generic Exceptions still get a clear signal. + """ + + pass + + +def _validate_workspace_id(workspace_id: str, *, caller: str = "unknown") -> None: + """Validate WORKSPACE_ID and raise WorkspaceIdValidationError if unsafe. + + Args: + workspace_id: The raw WORKSPACE_ID value to check. + caller: Human-readable name of the calling module/function (for the error). + + Raises: + WorkspaceIdValidationError: If workspace_id is empty or contains unsafe chars. + """ + if not workspace_id: + raise WorkspaceIdValidationError( + f"[{caller}] WORKSPACE_ID is empty — cannot construct platform URLs. " + "Set the WORKSPACE_ID environment variable." + ) + if not _WORKSPACE_ID_RE.match(workspace_id): + raise WorkspaceIdValidationError( + f"[{caller}] " + _WORKSPACE_ID_INVALID_MSG.format(value=workspace_id) + ) + + +# --------------------------------------------------------------------------- +# Lazy validation — call once at module initialisation, then cache the result. +# --------------------------------------------------------------------------- + +_cached_workspace_id: str | None = None +_cached_validated: bool = False + + +def get_validated_workspace_id(*, caller: str = "builtin_tools") -> str: + """Return the validated WORKSPACE_ID, raising on the first bad call. + + Result is cached so repeated calls are cheap. + """ + global _cached_workspace_id, _cached_validated + if _cached_validated: + assert _cached_workspace_id is not None + return _cached_workspace_id + + ws_id = os.environ.get("WORKSPACE_ID", "") + _validate_workspace_id(ws_id, caller=caller) + _cached_workspace_id = ws_id + _cached_validated = True + return ws_id