From ffcf781fa5d07b1216cd1e6c76bec2855df37642 Mon Sep 17 00:00:00 2001 From: core-devops Date: Sat, 23 May 2026 19:43:40 -0700 Subject: [PATCH] fix: download poll-mode inbound attachments --- molecule_agent/client.py | 195 +++++++++++++++++++++++++++++++- molecule_agent/inbound.py | 27 +++++ tests/test_inbound.py | 229 +++++++++++++++++++++++++++++++++++++- 3 files changed, 449 insertions(+), 2 deletions(-) diff --git a/molecule_agent/client.py b/molecule_agent/client.py index f4f0713..948a989 100644 --- a/molecule_agent/client.py +++ b/molecule_agent/client.py @@ -33,6 +33,7 @@ import uuid from dataclasses import dataclass, field from pathlib import Path from typing import TYPE_CHECKING, Any +from urllib.parse import quote import requests @@ -65,6 +66,9 @@ _RETRY_JITTER_FRAC = 0.25 # ±25% jitter around base delay # that concurrent restarts within the same 60-second window produce the # same key, while distinct tasks or distinct minutes produce distinct keys. _IDEMPOTENCY_ROUND_SECONDS = 60 +MAX_INBOUND_ATTACHMENT_BYTES = 100 * 1024 * 1024 +INBOUND_ATTACHMENT_DOWNLOAD_TIMEOUT = 60.0 +_UNSAFE_ATTACHMENT_NAME_CHARS = set('/\\\0') def make_idempotency_key(task_text: str) -> str: @@ -165,6 +169,59 @@ def _safe_extract_tar(tf: tarfile.TarFile, dest: Path) -> None: tf.extract(member, dest) +def _safe_attachment_name(name: str) -> str: + cleaned = "".join("_" if ch in _UNSAFE_ATTACHMENT_NAME_CHARS else ch for ch in name) + cleaned = cleaned.strip().strip(".") or "attachment" + return cleaned[:100] + + +def _parse_platform_pending_uri(uri: str) -> tuple[str, str]: + rest = uri[len("platform-pending:"):] + parts = rest.split("/", 1) + if len(parts) != 2 or not parts[0] or not parts[1]: + raise ValueError(f"invalid platform-pending attachment uri: {uri!r}") + workspace_id, file_id = parts[0], parts[1] + try: + file_id = str(uuid.UUID(file_id)) + except (TypeError, ValueError) as exc: + raise ValueError(f"invalid pending upload file id in uri: {uri!r}") from exc + return workspace_id, file_id + + +def _resolve_workspace_attachment_path(uri: str) -> str | None: + if uri.startswith("workspace:"): + path = uri[len("workspace:"):] + elif uri.startswith("file://"): + path = uri[len("file://"):] + elif uri.startswith("/"): + path = uri + else: + return None + try: + resolved = Path(path).resolve(strict=False) + except (OSError, RuntimeError): + return None + workspace_root = Path("/workspace") + if resolved != workspace_root and workspace_root not in resolved.parents: + return None + return str(resolved) + + +def _content_length(resp: requests.Response) -> int | None: + raw = resp.headers.get("Content-Length") if resp.headers else None + if raw is None: + return None + try: + value = int(raw) + except (TypeError, ValueError): + return None + return value if value >= 0 else None + + +def _ack_marker_path(cache_path: Path) -> Path: + return cache_path.with_name(f"{cache_path.name}.acked") + + def _rmtree_quiet(path: Path) -> None: """rm -rf swallowing missing-file errors. Used for atomic install rollback where we sometimes call this on a non-existent @@ -796,7 +853,11 @@ class RemoteAgentClient: # inbound module references RemoteAgentClient via TYPE_CHECKING. from .inbound import CursorLostError, _parse_activity_row - params: dict[str, str] = {"type": type, "limit": str(int(limit))} + params: dict[str, str] = { + "type": type, + "limit": str(int(limit)), + "include": "peer_info", + } if since_id: params["since_id"] = since_id if peer_id: @@ -841,6 +902,138 @@ class RemoteAgentClient: out.append(msg) return out + # ------------------------------------------------------------------ + # Inbound attachments (poll-mode external workspaces) + # ------------------------------------------------------------------ + + def download_inbound_attachment( + self, + attachment: dict[str, Any], + dest_dir: Path | None = None, + *, + ack: bool = True, + ) -> Path: + """Download one inbound attachment and return the local file path. + + Poll-mode external agents receive attachment metadata by reference in + :attr:`InboundMessage.attachments`. This method fetches the bytes using + the workspace bearer token: + + * ``platform-pending:/`` → pending-upload content, + then optional ack. + * ``workspace:/workspace/...`` / ``file:///workspace/...`` / + ``/workspace/...`` → the platform's chat download endpoint. + + The download is capped at 100 MB and cached by URI under + ``/attachments`` by default. + """ + uri = str(attachment.get("uri") or "") + if not uri: + raise ValueError("attachment is missing uri") + name = _safe_attachment_name(str(attachment.get("name") or "attachment")) + cache_dir = dest_dir or (self._token_dir / "attachments") + url: str + params: dict[str, str] | None = None + ack_url: str | None = None + if uri.startswith("platform-pending:"): + workspace_id, file_id = _parse_platform_pending_uri(uri) + if workspace_id != self.workspace_id: + raise ValueError( + "refusing to fetch attachment for another workspace " + f"({workspace_id!r} != {self.workspace_id!r})" + ) + quoted_ws = quote(workspace_id, safe="") + quoted_file = quote(file_id, safe="") + url = f"{self.platform_url}/workspaces/{quoted_ws}/pending-uploads/{quoted_file}/content" + ack_url = f"{self.platform_url}/workspaces/{quoted_ws}/pending-uploads/{quoted_file}/ack" + else: + path = _resolve_workspace_attachment_path(uri) + if not path: + raise ValueError(f"unsupported attachment uri: {uri!r}") + quoted_ws = quote(self.workspace_id, safe="") + url = f"{self.platform_url}/workspaces/{quoted_ws}/chat/download" + params = {"path": path} + + cache_path = cache_dir / hashlib.sha256(uri.encode("utf-8")).hexdigest()[:24] / name + if cache_path.exists() and cache_path.is_file(): + ack_marker = _ack_marker_path(cache_path) + if ack and ack_url and not ack_marker.exists(): + ack_resp = self._session.post( + ack_url, + headers=self._auth_headers(), + timeout=INBOUND_ATTACHMENT_DOWNLOAD_TIMEOUT, + ) + if ack_resp.status_code == 404: + logger.info( + "pending attachment %s already unavailable on ack; using cached file", + uri, + ) + else: + ack_resp.raise_for_status() + ack_marker.touch() + return cache_path + + resp = self._session.get( + url, + headers=self._auth_headers(), + params=params, + timeout=INBOUND_ATTACHMENT_DOWNLOAD_TIMEOUT, + stream=True, + ) + resp.raise_for_status() + content_length = _content_length(resp) + if content_length is not None and content_length > MAX_INBOUND_ATTACHMENT_BYTES: + raise ValueError( + f"attachment {name!r} is {content_length} bytes; cap is " + f"{MAX_INBOUND_ATTACHMENT_BYTES}" + ) + + cache_path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = cache_path.with_name(f"{cache_path.name}.tmp-{os.getpid()}") + size = 0 + try: + with tmp_path.open("wb") as fh: + for chunk in resp.iter_content(chunk_size=1024 * 1024): + if not chunk: + continue + size += len(chunk) + if size > MAX_INBOUND_ATTACHMENT_BYTES: + raise ValueError( + f"attachment {name!r} exceeds cap " + f"{MAX_INBOUND_ATTACHMENT_BYTES}" + ) + fh.write(chunk) + tmp_path.replace(cache_path) + except Exception: + try: + tmp_path.unlink() + except FileNotFoundError: + pass + raise + + if ack and ack_url: + ack_resp = self._session.post( + ack_url, + headers=self._auth_headers(), + timeout=INBOUND_ATTACHMENT_DOWNLOAD_TIMEOUT, + ) + ack_resp.raise_for_status() + _ack_marker_path(cache_path).touch() + return cache_path + + def download_inbound_attachments( + self, + message: "InboundMessage", + dest_dir: Path | None = None, + *, + ack: bool = True, + ) -> list[Path]: + """Download every attachment on an inbound poll message.""" + return [ + self.download_inbound_attachment(att, dest_dir=dest_dir, ack=ack) + for att in message.attachments + ] + def reply(self, message: "InboundMessage", text: str) -> None: """Reply to an inbound message. diff --git a/molecule_agent/inbound.py b/molecule_agent/inbound.py index dc5d541..5f4d8df 100644 --- a/molecule_agent/inbound.py +++ b/molecule_agent/inbound.py @@ -79,6 +79,7 @@ class InboundMessage: source: InboundSource source_id: str text: str + attachments: list[dict[str, Any]] = field(default_factory=list) raw: dict[str, Any] = field(default_factory=dict) # Enrichment fields — populated from row["data"]["peer_name"], # row["data"]["peer_role"], row["data"]["agent_card_url"]. @@ -139,12 +140,14 @@ def _parse_activity_row(row: dict[str, Any]) -> InboundMessage | None: source = "unknown" text = str(data.get("text") or data.get("message") or "") + attachments = _extract_attachments(row, data) return InboundMessage( activity_id=aid, source=source, source_id=source_id, text=text, + attachments=attachments, raw=row, peer_name=str(data.get("peer_name") or ""), peer_role=str(data.get("peer_role") or ""), @@ -152,6 +155,30 @@ def _parse_activity_row(row: dict[str, Any]) -> InboundMessage | None: ) +def _extract_attachments(row: dict[str, Any], data: dict[str, Any]) -> list[dict[str, Any]]: + """Return the flat ``attachments[]`` projected by platform activity APIs. + + Newer workspace-server builds put attachment metadata at the activity-row + top level when callers request ``include=peer_info``. Some + older or hand-built rows put it under ``data.attachments``. Preserve only + dict entries with a URI; byte fetching remains an explicit client action. + """ + raw = row.get("attachments") + if not isinstance(raw, list): + raw = data.get("attachments") + if not isinstance(raw, list): + return [] + out: list[dict[str, Any]] = [] + for item in raw: + if not isinstance(item, dict): + continue + uri = item.get("uri") + if not isinstance(uri, str) or not uri: + continue + out.append(dict(item)) + return out + + # --------------------------------------------------------------------------- # Handler + delivery protocol # --------------------------------------------------------------------------- diff --git a/tests/test_inbound.py b/tests/test_inbound.py index f139f54..b1e773a 100644 --- a/tests/test_inbound.py +++ b/tests/test_inbound.py @@ -17,6 +17,7 @@ Mocking style matches ``tests/test_remote_agent.py``: a ``FakeResponse`` / from __future__ import annotations import asyncio +import hashlib from pathlib import Path from typing import Any from unittest.mock import MagicMock @@ -41,10 +42,19 @@ from molecule_agent.inbound import _parse_activity_row class FakeResponse: - def __init__(self, status_code: int = 200, json_body: Any = None, text: str = ""): + def __init__( + self, + status_code: int = 200, + json_body: Any = None, + text: str = "", + content: bytes = b"", + chunks: list[bytes] | None = None, + ): self.status_code = status_code self._json = json_body self.text = text + self.content = content + self._chunks = chunks self.headers: dict[str, str] = {} def json(self) -> Any: @@ -54,6 +64,13 @@ class FakeResponse: if self.status_code >= 400: raise requests.HTTPError(f"HTTP {self.status_code}") + def iter_content(self, chunk_size: int = 1): + if self._chunks is not None: + yield from self._chunks + return + for i in range(0, len(self.content), chunk_size): + yield self.content[i:i + chunk_size] + @pytest.fixture def tmp_token_dir(tmp_path: Path) -> Path: @@ -230,6 +247,33 @@ def test_parse_activity_row_enrichment_in_canvas_user_row(): assert msg.agent_card_url == "https://platform.example/registry/discover/user-uuid" +def test_parse_activity_row_preserves_projected_attachments(): + row = { + "id": "act-8", + "data": {"source": "canvas_user", "text": "see image"}, + "attachments": [ + { + "kind": "image", + "uri": "platform-pending:ws-abc-123/11111111-1111-1111-1111-111111111111", + "name": "shape.png", + "mimeType": "image/png", + }, + {"name": "broken.png"}, + "not a dict", + ], + } + msg = _parse_activity_row(row) + assert msg is not None + assert msg.attachments == [ + { + "kind": "image", + "uri": "platform-pending:ws-abc-123/11111111-1111-1111-1111-111111111111", + "name": "shape.png", + "mimeType": "image/png", + } + ] + + # --------------------------------------------------------------------------- # fetch_inbound # --------------------------------------------------------------------------- @@ -252,6 +296,7 @@ def test_fetch_inbound_happy_path(client: RemoteAgentClient): assert call_args.args[0] == "http://platform.test/workspaces/ws-abc-123/activity" assert call_args.kwargs["params"]["type"] == "a2a_receive" assert call_args.kwargs["params"]["limit"] == "100" + assert call_args.kwargs["params"]["include"] == "peer_info" assert "since_id" not in call_args.kwargs["params"] @@ -382,6 +427,188 @@ def test_fetch_inbound_combined_filters(): assert params["before_ts"] == "2026-05-09T12:00:00Z" +def test_fetch_inbound_parses_attachments_from_include_peer_info(client: RemoteAgentClient): + rows = [ + { + "id": "act-with-file", + "data": {"source": "canvas_user", "text": "describe this"}, + "attachments": [ + { + "kind": "image", + "uri": "platform-pending:ws-abc-123/22222222-2222-2222-2222-222222222222", + "name": "shape.png", + "mimeType": "image/png", + } + ], + } + ] + client._session.get.return_value = FakeResponse(200, rows) + + out = client.fetch_inbound() + + assert out[0].attachments[0]["name"] == "shape.png" + assert client._session.get.call_args.kwargs["params"]["include"] == "peer_info" + + +def test_download_inbound_attachment_fetches_pending_upload_and_acks( + client: RemoteAgentClient, tmp_path: Path +): + attachment = { + "uri": "platform-pending:ws-abc-123/33333333-3333-3333-3333-333333333333", + "name": "shape.png", + } + client._session.get.return_value = FakeResponse(200, content=b"png-bytes") + client._session.post.return_value = FakeResponse(204) + + path = client.download_inbound_attachment(attachment, dest_dir=tmp_path) + + assert path.read_bytes() == b"png-bytes" + get_call = client._session.get.call_args + assert get_call.args[0] == ( + "http://platform.test/workspaces/ws-abc-123/pending-uploads/" + "33333333-3333-3333-3333-333333333333/content" + ) + assert get_call.kwargs["stream"] is True + assert get_call.kwargs["headers"]["Authorization"] == "Bearer test-token-secret" + post_call = client._session.post.call_args + assert post_call.args[0].endswith( + "/workspaces/ws-abc-123/pending-uploads/" + "33333333-3333-3333-3333-333333333333/ack" + ) + + +def test_download_inbound_attachment_rejects_cross_workspace_pending_uri( + client: RemoteAgentClient, tmp_path: Path +): + attachment = { + "uri": "platform-pending:other-ws/33333333-3333-3333-3333-333333333333", + "name": "shape.png", + } + + with pytest.raises(ValueError, match="another workspace"): + client.download_inbound_attachment(attachment, dest_dir=tmp_path) + + client._session.get.assert_not_called() + + +def test_download_inbound_attachment_fetches_workspace_uri( + client: RemoteAgentClient, tmp_path: Path +): + attachment = { + "uri": "workspace:/workspace/.molecule/chat-uploads/report.txt", + "name": "../report.txt", + } + client._session.get.return_value = FakeResponse(200, content=b"hello") + + path = client.download_inbound_attachment(attachment, dest_dir=tmp_path) + + assert path.name == "_report.txt" + assert path.read_bytes() == b"hello" + get_call = client._session.get.call_args + assert get_call.args[0] == "http://platform.test/workspaces/ws-abc-123/chat/download" + assert get_call.kwargs["params"] == { + "path": "/workspace/.molecule/chat-uploads/report.txt" + } + + +def test_download_inbound_attachment_cached_pending_still_acks( + client: RemoteAgentClient, tmp_path: Path +): + attachment = { + "uri": "platform-pending:ws-abc-123/44444444-4444-4444-4444-444444444444", + "name": "shape.png", + } + digest = "platform-pending:ws-abc-123/44444444-4444-4444-4444-444444444444" + cache_path = tmp_path / hashlib.sha256(digest.encode("utf-8")).hexdigest()[:24] / "shape.png" + cache_path.parent.mkdir(parents=True) + cache_path.write_bytes(b"cached") + client._session.post.return_value = FakeResponse(204) + + path = client.download_inbound_attachment(attachment, dest_dir=tmp_path) + + assert path == cache_path + client._session.get.assert_not_called() + assert client._session.post.call_args.args[0].endswith( + "/workspaces/ws-abc-123/pending-uploads/" + "44444444-4444-4444-4444-444444444444/ack" + ) + assert path.with_name(f"{path.name}.acked").exists() + + +def test_download_inbound_attachment_cached_pending_skips_ack_when_marked( + client: RemoteAgentClient, tmp_path: Path +): + attachment = { + "uri": "platform-pending:ws-abc-123/55555555-5555-5555-5555-555555555555", + "name": "shape.png", + } + digest = "platform-pending:ws-abc-123/55555555-5555-5555-5555-555555555555" + cache_path = tmp_path / hashlib.sha256(digest.encode("utf-8")).hexdigest()[:24] / "shape.png" + cache_path.parent.mkdir(parents=True) + cache_path.write_bytes(b"cached") + cache_path.with_name("shape.png.acked").write_text("") + + path = client.download_inbound_attachment(attachment, dest_dir=tmp_path) + + assert path == cache_path + client._session.get.assert_not_called() + client._session.post.assert_not_called() + + +def test_download_inbound_attachment_cached_pending_treats_404_ack_as_gone( + client: RemoteAgentClient, tmp_path: Path +): + attachment = { + "uri": "platform-pending:ws-abc-123/66666666-6666-6666-6666-666666666666", + "name": "shape.png", + } + digest = "platform-pending:ws-abc-123/66666666-6666-6666-6666-666666666666" + cache_path = tmp_path / hashlib.sha256(digest.encode("utf-8")).hexdigest()[:24] / "shape.png" + cache_path.parent.mkdir(parents=True) + cache_path.write_bytes(b"cached") + client._session.post.return_value = FakeResponse(404) + + path = client.download_inbound_attachment(attachment, dest_dir=tmp_path) + + assert path == cache_path + assert path.with_name(f"{path.name}.acked").exists() + + +def test_download_inbound_attachment_rejects_large_content_length( + client: RemoteAgentClient, tmp_path: Path +): + attachment = { + "uri": "workspace:/workspace/.molecule/chat-uploads/huge.bin", + "name": "huge.bin", + } + response = FakeResponse(200, content=b"") + response.headers["Content-Length"] = str(100 * 1024 * 1024 + 1) + client._session.get.return_value = response + + with pytest.raises(ValueError, match="cap"): + client.download_inbound_attachment(attachment, dest_dir=tmp_path) + + assert not list(tmp_path.rglob("huge.bin")) + + +def test_download_inbound_attachment_rejects_large_stream( + client: RemoteAgentClient, tmp_path: Path +): + attachment = { + "uri": "workspace:/workspace/.molecule/chat-uploads/huge.bin", + "name": "huge.bin", + } + client._session.get.return_value = FakeResponse( + 200, + chunks=[b"x" * (100 * 1024 * 1024), b"x"], + ) + + with pytest.raises(ValueError, match="exceeds cap"): + client.download_inbound_attachment(attachment, dest_dir=tmp_path) + + assert not list(tmp_path.rglob("huge.bin")) + + # --------------------------------------------------------------------------- # reply() # --------------------------------------------------------------------------- -- 2.52.0