From e72f9ad10755f55d0bcce5c50fa0fdd215c8b39a Mon Sep 17 00:00:00 2001 From: Hongming Wang Date: Tue, 5 May 2026 04:54:22 -0700 Subject: [PATCH 1/7] refactor(workspace): extract delegation handlers from a2a_tools.py to a2a_tools_delegation.py (RFC #2873 iter 4b) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Second slice of the a2a_tools.py split (stacked on iter 4a). Owns the three delegation MCP tools + the RFC #2829 PR-5 sync-via-polling helper they share: * tool_delegate_task — synchronous delegation * tool_delegate_task_async — fire-and-forget * tool_check_task_status — poll the platform's /delegations log * _delegate_sync_via_polling — durable async + poll for terminal status * _SYNC_POLL_INTERVAL_S / _SYNC_POLL_BUDGET_S constants a2a_tools.py shrinks from 915 → 609 LOC (−306). Stacked on iter 4a's RBAC extraction; uses `from a2a_tools_rbac import auth_headers_for_heartbeat` as its auth-header source. The lazy `from a2a_tools import report_activity` inside tool_delegate_task breaks the circular-import cycle (a2a_tools imports the delegation re-exports at module-load; delegation handler needs report_activity at CALL time). A dedicated test pins this contract. Tests: * 77 existing test_a2a_tools_impl.py tests pass after retargeting 20 patch sites in TestToolDelegateTask + TestToolDelegateTaskAsync + TestToolCheckTaskStatus from `a2a_tools.foo` to `a2a_tools_delegation.foo` (foo ∈ {discover_peer, send_a2a_message, httpx.AsyncClient}). The patches need to target the new module because that's where the call sites live now. * test_a2a_tools_delegation.py adds 8 new tests: - 6 alias drift gates (`a2a_tools.tool_delegate_task is …`) - 2 import-contract tests (no top-level circular dep + a2a_tools surfaces every delegation symbol) - 1 sync-poll budget invariant 113 tests total (77 impl + 28 rbac + 8 delegation), all green. Refs RFC #2873. --- workspace/a2a_tools.py | 330 +--------------- workspace/a2a_tools_delegation.py | 372 +++++++++++++++++++ workspace/tests/test_a2a_tools_delegation.py | 129 +++++++ workspace/tests/test_a2a_tools_impl.py | 40 +- 4 files changed, 533 insertions(+), 338 deletions(-) create mode 100644 workspace/a2a_tools_delegation.py create mode 100644 workspace/tests/test_a2a_tools_delegation.py diff --git a/workspace/a2a_tools.py b/workspace/a2a_tools.py index f3faf619..b482a3be 100644 --- a/workspace/a2a_tools.py +++ b/workspace/a2a_tools.py @@ -115,324 +115,18 @@ async def report_activity( pass # Best-effort — don't block delegation on activity reporting -# RFC #2829 PR-5 cutover constants. The poll cadence + timeout are -# intentionally generous: 3s gives the platform's executeDelegation -# goroutine room to dispatch + the callee to respond + the result to -# write to activity_logs without thrashing the platform with rapid -# polls; the budget matches the legacy DELEGATION_TIMEOUT (300s) so -# operators don't see behavior change beyond "no more 600s timeouts". -_SYNC_POLL_INTERVAL_S = 3.0 -_SYNC_POLL_BUDGET_S = float(os.environ.get("DELEGATION_TIMEOUT", "300.0")) - - -async def _delegate_sync_via_polling( - workspace_id: str, - task: str, - src: str, -) -> str: - """RFC #2829 PR-5: durable async delegation + poll for terminal status. - - Sidesteps the platform proxy's blocking `message/send` HTTP path that - hits a hard 600s ceiling. Instead: - - 1. POST /workspaces//delegate (async, returns 202 + delegation_id) - — platform's executeDelegation goroutine handles A2A dispatch in - the background. No client-side timeout dependency on the platform - holding a connection open. - 2. Poll GET /workspaces//delegations every 3s for a row with - matching delegation_id reaching terminal status (completed/failed). - 3. Return the response_preview text on completed; surface error_detail - on failed (with the same _A2A_ERROR_PREFIX wrapping the legacy - path uses, so caller error-detection logic is unchanged). - - Both /delegate and /delegations are existing endpoints — this helper - just composes them into a polling synchronous facade. The result is - available the moment the platform writes the terminal status row; - no extra latency vs. the legacy proxy-blocked path on fast cases. - """ - import asyncio - import time - - idem_key = hashlib.sha256(f"{src}:{workspace_id}:{task}".encode()).hexdigest()[:32] - - # 1. Dispatch via /delegate (the async, durable path). - try: - async with httpx.AsyncClient(timeout=10.0) as client: - resp = await client.post( - f"{PLATFORM_URL}/workspaces/{src}/delegate", - json={ - "target_id": workspace_id, - "task": task, - "idempotency_key": idem_key, - }, - headers=_auth_headers_for_heartbeat(src), - ) - except Exception as e: # pylint: disable=broad-except - return f"{_A2A_ERROR_PREFIX}delegate dispatch failed: {e}" - - if resp.status_code != 202 and resp.status_code != 200: - return f"{_A2A_ERROR_PREFIX}delegate dispatch failed: HTTP {resp.status_code} {resp.text[:200]}" - - try: - dispatch = resp.json() - except Exception as e: # pylint: disable=broad-except - return f"{_A2A_ERROR_PREFIX}delegate dispatch returned non-JSON: {e}" - - delegation_id = dispatch.get("delegation_id", "") - if not delegation_id: - return f"{_A2A_ERROR_PREFIX}delegate dispatch missing delegation_id: {dispatch}" - - # 2. Poll for terminal status with a deadline. Each poll is a cheap - # /delegations GET — bounded by the platform's existing rate limit. - deadline = time.monotonic() + _SYNC_POLL_BUDGET_S - last_status = "unknown" - while time.monotonic() < deadline: - try: - async with httpx.AsyncClient(timeout=10.0) as client: - poll = await client.get( - f"{PLATFORM_URL}/workspaces/{src}/delegations", - headers=_auth_headers_for_heartbeat(src), - ) - except Exception as e: # pylint: disable=broad-except - # Transient — keep polling. The platform IS holding the - # delegation row; we just lost a network request. - last_status = f"poll-error: {e}" - await asyncio.sleep(_SYNC_POLL_INTERVAL_S) - continue - - if poll.status_code != 200: - last_status = f"poll HTTP {poll.status_code}" - await asyncio.sleep(_SYNC_POLL_INTERVAL_S) - continue - - try: - rows = poll.json() - except Exception as e: # pylint: disable=broad-except - last_status = f"poll non-JSON: {e}" - await asyncio.sleep(_SYNC_POLL_INTERVAL_S) - continue - - # /delegations returns a flat list of delegation events. Filter to - # our delegation_id; pick the first terminal one. The list may - # have multiple rows per delegation_id (one for the original - # dispatch, one per status update); we want the latest terminal. - if not isinstance(rows, list): - await asyncio.sleep(_SYNC_POLL_INTERVAL_S) - continue - terminal = None - for r in rows: - if not isinstance(r, dict): - continue - if r.get("delegation_id") != delegation_id: - continue - status = (r.get("status") or "").lower() - last_status = status - if status in ("completed", "failed"): - terminal = r - break - if terminal: - if (terminal.get("status") or "").lower() == "completed": - return terminal.get("response_preview") or "" - err = ( - terminal.get("error_detail") - or terminal.get("summary") - or "delegation failed" - ) - return f"{_A2A_ERROR_PREFIX}{err}" - - await asyncio.sleep(_SYNC_POLL_INTERVAL_S) - - # Budget exhausted — the platform's row is still in flight (or queued). - # Surface as an error so the caller can decide to retry or fall back; - # the platform DOES still have the durable row, so the work isn't - # lost — it'll complete eventually and a future check_task_status - # will surface the result. - return ( - f"{_A2A_ERROR_PREFIX}polling timeout after {_SYNC_POLL_BUDGET_S}s " - f"(delegation_id={delegation_id}, last_status={last_status}); " - f"the platform is still working on it — call check_task_status('{delegation_id}') to retrieve later" - ) - - -async def tool_delegate_task( - workspace_id: str, - task: str, - source_workspace_id: str | None = None, -) -> str: - """Delegate a task to another workspace via A2A (synchronous — waits for response). - - ``source_workspace_id`` selects which registered workspace this - delegation originates from — drives auth + the X-Workspace-ID source - header so the platform's a2a_proxy logs the correct sender. Single- - workspace operators leave it None and routing falls back to the - module-level WORKSPACE_ID. - """ - if not workspace_id or not task: - return "Error: workspace_id and task are required" - - # Auto-route: if source not specified, look up which registered - # workspace last saw this peer (populated by tool_list_peers). Falls - # back to the legacy WORKSPACE_ID for single-workspace operators. - src = source_workspace_id or _peer_to_source.get(workspace_id) or None - - # Discover the target. discover_peer is the access-control gate + - # name/status lookup. The peer's reported ``url`` field is NOT used - # for routing — see send_a2a_message, which constructs the URL via - # the platform's A2A proxy. - peer = await discover_peer(workspace_id, source_workspace_id=src) - if not peer: - return f"Error: workspace {workspace_id} not found or not accessible (check access control)" - - if (peer.get("status") or "").lower() == "offline": - return f"Error: workspace {workspace_id} is offline" - - # Report delegation start — include the task text for traceability - peer_name = peer.get("name") or _peer_names.get(workspace_id) or workspace_id[:8] - _peer_names[workspace_id] = peer_name # cache for future use - # Brief summary for canvas display — just the delegation target - await report_activity("a2a_send", workspace_id, f"Delegating to {peer_name}", task_text=task) - - # RFC #2829 PR-5: agent-side cutover. When DELEGATION_SYNC_VIA_INBOX=1, - # use the platform's durable async delegation API (POST /delegate + - # poll /delegations) instead of the proxy-blocked message/send path. - # This sidesteps the 600s message/send timeout class that broke - # iteration-14/90-style long-running delegations on 2026-05-05. - # - # Default off — staging-canary first, flip default after PR-2's - # result-push flag (DELEGATION_RESULT_INBOX_PUSH) has been on for - # ≥1 week without incident. - if os.environ.get("DELEGATION_SYNC_VIA_INBOX") == "1": - result = await _delegate_sync_via_polling(workspace_id, task, src or WORKSPACE_ID) - else: - # send_a2a_message routes through ${PLATFORM_URL}/workspaces/{id}/a2a - # (the platform proxy) so the same code works for in-container and - # external (standalone molecule-mcp) callers. - result = await send_a2a_message(workspace_id, task, source_workspace_id=src) - - # Detect delegation failures — wrap them clearly so the calling agent - # can decide to retry, use another peer, or handle the task itself. - is_error = result.startswith(_A2A_ERROR_PREFIX) - # Strip the sentinel prefix so error_detail is the human-readable - # cause directly. The Activity tab's red error chip surfaces this - # without the user having to scroll into the raw response JSON. - # - # Cap at 4096 chars before sending — the platform's - # activity_logs.error_detail column is unbounded TEXT and a - # malicious or buggy peer could otherwise stream an arbitrarily - # large error message into the caller's activity log. 4096 is - # comfortably above any real exception traceback we've seen and - # well below an obvious-DoS threshold. - error_detail = result[len(_A2A_ERROR_PREFIX):].strip()[:4096] if is_error else "" - await report_activity( - "a2a_receive", workspace_id, - f"{peer_name} responded ({len(result)} chars)" if not is_error else f"{peer_name} failed: {error_detail[:120]}", - task_text=task, response_text=result, - status="error" if is_error else "ok", - error_detail=error_detail, - ) - if is_error: - return ( - f"DELEGATION FAILED to {peer_name}: {result}\n" - f"You should either: (1) try a different peer, (2) handle this task yourself, " - f"or (3) inform the user that {peer_name} is unavailable and provide your best answer." - ) - return result - - -async def tool_delegate_task_async( - workspace_id: str, - task: str, - source_workspace_id: str | None = None, -) -> str: - """Delegate a task via the platform's async delegation API (fire-and-forget). - - Uses POST /workspaces/:id/delegate which runs the A2A request in the background. - Results are tracked in the platform DB and broadcast via WebSocket. - Use check_task_status to poll for results. - - ``source_workspace_id`` selects the sending workspace (which one of - this agent's registered workspaces gets logged as the originator); - auto-routes via the peer→source cache when omitted. - """ - if not workspace_id or not task: - return "Error: workspace_id and task are required" - - src = source_workspace_id or _peer_to_source.get(workspace_id) or WORKSPACE_ID - - # Idempotency key: SHA-256 of (source, target, task) so that a - # restarted agent firing the same delegation gets the same key and - # the platform returns the existing delegation_id instead of - # creating a duplicate. Fixes #1456. Source is in the key so the - # SAME task delegated from two different registered workspaces - # produces two distinct delegations (the right behavior — one per - # tenant audit trail). - idem_key = hashlib.sha256(f"{src}:{workspace_id}:{task}".encode()).hexdigest()[:32] - - try: - async with httpx.AsyncClient(timeout=10.0) as client: - resp = await client.post( - f"{PLATFORM_URL}/workspaces/{src}/delegate", - json={"target_id": workspace_id, "task": task, "idempotency_key": idem_key}, - headers=_auth_headers_for_heartbeat(src), - ) - if resp.status_code == 202: - data = resp.json() - return json.dumps({ - "delegation_id": data.get("delegation_id", ""), - "workspace_id": workspace_id, - "status": "delegated", - "note": "Task delegated. The platform runs it in the background. Use check_task_status to poll for results.", - }) - else: - return f"Error: delegation failed with status {resp.status_code}: {resp.text[:200]}" - except Exception as e: - return f"Error: delegation failed — {e}" - - -async def tool_check_task_status( - workspace_id: str, - task_id: str, - source_workspace_id: str | None = None, -) -> str: - """Check delegations for this workspace via the platform API. - - Args: - workspace_id: Ignored (kept for backward compat). Checks - ``source_workspace_id``'s delegations (the workspace that - FIRED the delegations), not the target's. - task_id: Optional delegation_id to filter. If empty, returns all recent delegations. - source_workspace_id: Which registered workspace's delegation log - to query. Defaults to the module-level WORKSPACE_ID. - """ - src = source_workspace_id or WORKSPACE_ID - try: - async with httpx.AsyncClient(timeout=10.0) as client: - resp = await client.get( - f"{PLATFORM_URL}/workspaces/{src}/delegations", - headers=_auth_headers_for_heartbeat(src), - ) - if resp.status_code != 200: - return f"Error: failed to check delegations ({resp.status_code})" - delegations = resp.json() - if task_id: - # Filter by delegation_id - matching = [d for d in delegations if d.get("delegation_id") == task_id] - if matching: - return json.dumps(matching[0]) - return json.dumps({"status": "not_found", "delegation_id": task_id}) - # Return all recent delegations - summary = [] - for d in delegations[:10]: - summary.append({ - "delegation_id": d.get("delegation_id", ""), - "target_id": d.get("target_id", ""), - "status": d.get("status", ""), - "summary": d.get("summary", ""), - "response_preview": d.get("response_preview", ""), - }) - return json.dumps({"delegations": summary, "count": len(delegations)}) - except Exception as e: - return f"Error checking delegations: {e}" +# Delegation tool handlers — extracted to a2a_tools_delegation +# (RFC #2873 iter 4b). Re-imported here so call sites + tests that +# reference ``a2a_tools.tool_delegate_task`` / +# ``a2a_tools._delegate_sync_via_polling`` keep resolving identically. +from a2a_tools_delegation import ( # noqa: E402 (import after the from-a2a_client block) + _SYNC_POLL_BUDGET_S, + _SYNC_POLL_INTERVAL_S, + _delegate_sync_via_polling, + tool_check_task_status, + tool_delegate_task, + tool_delegate_task_async, +) async def _upload_chat_files( diff --git a/workspace/a2a_tools_delegation.py b/workspace/a2a_tools_delegation.py new file mode 100644 index 00000000..170a5333 --- /dev/null +++ b/workspace/a2a_tools_delegation.py @@ -0,0 +1,372 @@ +"""Delegation tool handlers — single-concern slice of the a2a_tools surface. + +Extracted from ``a2a_tools.py`` (RFC #2873 iter 4b). Owns the three +delegation MCP tools + the RFC #2829 PR-5 sync-via-polling helper they +share. + +Public surface: + +* ``tool_delegate_task`` — synchronous delegation, waits for response. +* ``tool_delegate_task_async`` — fire-and-forget delegation; returns + ``{delegation_id, ...}``. +* ``tool_check_task_status`` — poll the platform's ``/delegations`` log. + +Internal: + +* ``_delegate_sync_via_polling`` — durable async + poll for terminal + status (RFC #2829 PR-5 cutover path; toggled by + ``DELEGATION_SYNC_VIA_INBOX=1``). +* ``_SYNC_POLL_INTERVAL_S`` / ``_SYNC_POLL_BUDGET_S`` constants. + +Circular-import note: this module calls ``report_activity`` from +``a2a_tools`` to emit activity rows around the delegate dispatch. +``a2a_tools`` imports the public symbols here at module-load time, +so we use a LAZY import for ``report_activity`` inside the function +that needs it. Without the lazy hop Python raises an ImportError +on first ``a2a_tools`` import. +""" +from __future__ import annotations + +import hashlib +import json +import os + +import httpx + +from a2a_client import ( + PLATFORM_URL, + WORKSPACE_ID, + _A2A_ERROR_PREFIX, + _peer_names, + _peer_to_source, + discover_peer, + send_a2a_message, +) +from a2a_tools_rbac import auth_headers_for_heartbeat as _auth_headers_for_heartbeat + + +# RFC #2829 PR-5 cutover constants. The poll cadence + timeout are +# intentionally generous: 3s gives the platform's executeDelegation +# goroutine room to dispatch + the callee to respond + the result to +# write to activity_logs without thrashing the platform with rapid +# polls; the budget matches the legacy DELEGATION_TIMEOUT (300s) so +# operators don't see behavior change beyond "no more 600s timeouts". +_SYNC_POLL_INTERVAL_S = 3.0 +_SYNC_POLL_BUDGET_S = float(os.environ.get("DELEGATION_TIMEOUT", "300.0")) + + +async def _delegate_sync_via_polling( + workspace_id: str, + task: str, + src: str, +) -> str: + """RFC #2829 PR-5: durable async delegation + poll for terminal status. + + Sidesteps the platform proxy's blocking `message/send` HTTP path that + hits a hard 600s ceiling. Instead: + + 1. POST /workspaces//delegate (async, returns 202 + delegation_id) + — platform's executeDelegation goroutine handles A2A dispatch in + the background. No client-side timeout dependency on the platform + holding a connection open. + 2. Poll GET /workspaces//delegations every 3s for a row with + matching delegation_id reaching terminal status (completed/failed). + 3. Return the response_preview text on completed; surface error_detail + on failed (with the same _A2A_ERROR_PREFIX wrapping the legacy + path uses, so caller error-detection logic is unchanged). + + Both /delegate and /delegations are existing endpoints — this helper + just composes them into a polling synchronous facade. The result is + available the moment the platform writes the terminal status row; + no extra latency vs. the legacy proxy-blocked path on fast cases. + """ + import asyncio + import time + + idem_key = hashlib.sha256(f"{src}:{workspace_id}:{task}".encode()).hexdigest()[:32] + + # 1. Dispatch via /delegate (the async, durable path). + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.post( + f"{PLATFORM_URL}/workspaces/{src}/delegate", + json={ + "target_id": workspace_id, + "task": task, + "idempotency_key": idem_key, + }, + headers=_auth_headers_for_heartbeat(src), + ) + except Exception as e: # pylint: disable=broad-except + return f"{_A2A_ERROR_PREFIX}delegate dispatch failed: {e}" + + if resp.status_code != 202 and resp.status_code != 200: + return f"{_A2A_ERROR_PREFIX}delegate dispatch failed: HTTP {resp.status_code} {resp.text[:200]}" + + try: + dispatch = resp.json() + except Exception as e: # pylint: disable=broad-except + return f"{_A2A_ERROR_PREFIX}delegate dispatch returned non-JSON: {e}" + + delegation_id = dispatch.get("delegation_id", "") + if not delegation_id: + return f"{_A2A_ERROR_PREFIX}delegate dispatch missing delegation_id: {dispatch}" + + # 2. Poll for terminal status with a deadline. Each poll is a cheap + # /delegations GET — bounded by the platform's existing rate limit. + deadline = time.monotonic() + _SYNC_POLL_BUDGET_S + last_status = "unknown" + while time.monotonic() < deadline: + try: + async with httpx.AsyncClient(timeout=10.0) as client: + poll = await client.get( + f"{PLATFORM_URL}/workspaces/{src}/delegations", + headers=_auth_headers_for_heartbeat(src), + ) + except Exception as e: # pylint: disable=broad-except + # Transient — keep polling. The platform IS holding the + # delegation row; we just lost a network request. + last_status = f"poll-error: {e}" + await asyncio.sleep(_SYNC_POLL_INTERVAL_S) + continue + + if poll.status_code != 200: + last_status = f"poll HTTP {poll.status_code}" + await asyncio.sleep(_SYNC_POLL_INTERVAL_S) + continue + + try: + rows = poll.json() + except Exception as e: # pylint: disable=broad-except + last_status = f"poll non-JSON: {e}" + await asyncio.sleep(_SYNC_POLL_INTERVAL_S) + continue + + # /delegations returns a flat list of delegation events. Filter to + # our delegation_id; pick the first terminal one. The list may + # have multiple rows per delegation_id (one for the original + # dispatch, one per status update); we want the latest terminal. + if not isinstance(rows, list): + await asyncio.sleep(_SYNC_POLL_INTERVAL_S) + continue + terminal = None + for r in rows: + if not isinstance(r, dict): + continue + if r.get("delegation_id") != delegation_id: + continue + status = (r.get("status") or "").lower() + last_status = status + if status in ("completed", "failed"): + terminal = r + break + if terminal: + if (terminal.get("status") or "").lower() == "completed": + return terminal.get("response_preview") or "" + err = ( + terminal.get("error_detail") + or terminal.get("summary") + or "delegation failed" + ) + return f"{_A2A_ERROR_PREFIX}{err}" + + await asyncio.sleep(_SYNC_POLL_INTERVAL_S) + + # Budget exhausted — the platform's row is still in flight (or queued). + # Surface as an error so the caller can decide to retry or fall back; + # the platform DOES still have the durable row, so the work isn't + # lost — it'll complete eventually and a future check_task_status + # will surface the result. + return ( + f"{_A2A_ERROR_PREFIX}polling timeout after {_SYNC_POLL_BUDGET_S}s " + f"(delegation_id={delegation_id}, last_status={last_status}); " + f"the platform is still working on it — call check_task_status('{delegation_id}') to retrieve later" + ) + + +async def tool_delegate_task( + workspace_id: str, + task: str, + source_workspace_id: str | None = None, +) -> str: + """Delegate a task to another workspace via A2A (synchronous — waits for response). + + ``source_workspace_id`` selects which registered workspace this + delegation originates from — drives auth + the X-Workspace-ID source + header so the platform's a2a_proxy logs the correct sender. Single- + workspace operators leave it None and routing falls back to the + module-level WORKSPACE_ID. + """ + if not workspace_id or not task: + return "Error: workspace_id and task are required" + + # Auto-route: if source not specified, look up which registered + # workspace last saw this peer (populated by tool_list_peers). Falls + # back to the legacy WORKSPACE_ID for single-workspace operators. + src = source_workspace_id or _peer_to_source.get(workspace_id) or None + + # Discover the target. discover_peer is the access-control gate + + # name/status lookup. The peer's reported ``url`` field is NOT used + # for routing — see send_a2a_message, which constructs the URL via + # the platform's A2A proxy. + peer = await discover_peer(workspace_id, source_workspace_id=src) + if not peer: + return f"Error: workspace {workspace_id} not found or not accessible (check access control)" + + if (peer.get("status") or "").lower() == "offline": + return f"Error: workspace {workspace_id} is offline" + + # Lazy import: a2a_tools imports this module at top-level, so a + # top-level import of report_activity from a2a_tools would create a + # circular dependency at first-import time. Lazy resolution inside + # the function body breaks the cycle without forcing a ground-up + # restructure of the activity-reporting layer. + from a2a_tools import report_activity + + # Report delegation start — include the task text for traceability + peer_name = peer.get("name") or _peer_names.get(workspace_id) or workspace_id[:8] + _peer_names[workspace_id] = peer_name # cache for future use + # Brief summary for canvas display — just the delegation target + await report_activity("a2a_send", workspace_id, f"Delegating to {peer_name}", task_text=task) + + # RFC #2829 PR-5: agent-side cutover. When DELEGATION_SYNC_VIA_INBOX=1, + # use the platform's durable async delegation API (POST /delegate + + # poll /delegations) instead of the proxy-blocked message/send path. + # This sidesteps the 600s message/send timeout class that broke + # iteration-14/90-style long-running delegations on 2026-05-05. + # + # Default off — staging-canary first, flip default after PR-2's + # result-push flag (DELEGATION_RESULT_INBOX_PUSH) has been on for + # ≥1 week without incident. + if os.environ.get("DELEGATION_SYNC_VIA_INBOX") == "1": + result = await _delegate_sync_via_polling(workspace_id, task, src or WORKSPACE_ID) + else: + # send_a2a_message routes through ${PLATFORM_URL}/workspaces/{id}/a2a + # (the platform proxy) so the same code works for in-container and + # external (standalone molecule-mcp) callers. + result = await send_a2a_message(workspace_id, task, source_workspace_id=src) + + # Detect delegation failures — wrap them clearly so the calling agent + # can decide to retry, use another peer, or handle the task itself. + is_error = result.startswith(_A2A_ERROR_PREFIX) + # Strip the sentinel prefix so error_detail is the human-readable + # cause directly. The Activity tab's red error chip surfaces this + # without the user having to scroll into the raw response JSON. + # + # Cap at 4096 chars before sending — the platform's + # activity_logs.error_detail column is unbounded TEXT and a + # malicious or buggy peer could otherwise stream an arbitrarily + # large error message into the caller's activity log. 4096 is + # comfortably above any real exception traceback we've seen and + # well below an obvious-DoS threshold. + error_detail = result[len(_A2A_ERROR_PREFIX):].strip()[:4096] if is_error else "" + await report_activity( + "a2a_receive", workspace_id, + f"{peer_name} responded ({len(result)} chars)" if not is_error else f"{peer_name} failed: {error_detail[:120]}", + task_text=task, response_text=result, + status="error" if is_error else "ok", + error_detail=error_detail, + ) + if is_error: + return ( + f"DELEGATION FAILED to {peer_name}: {result}\n" + f"You should either: (1) try a different peer, (2) handle this task yourself, " + f"or (3) inform the user that {peer_name} is unavailable and provide your best answer." + ) + return result + + +async def tool_delegate_task_async( + workspace_id: str, + task: str, + source_workspace_id: str | None = None, +) -> str: + """Delegate a task via the platform's async delegation API (fire-and-forget). + + Uses POST /workspaces/:id/delegate which runs the A2A request in the background. + Results are tracked in the platform DB and broadcast via WebSocket. + Use check_task_status to poll for results. + + ``source_workspace_id`` selects the sending workspace (which one of + this agent's registered workspaces gets logged as the originator); + auto-routes via the peer→source cache when omitted. + """ + if not workspace_id or not task: + return "Error: workspace_id and task are required" + + src = source_workspace_id or _peer_to_source.get(workspace_id) or WORKSPACE_ID + + # Idempotency key: SHA-256 of (source, target, task) so that a + # restarted agent firing the same delegation gets the same key and + # the platform returns the existing delegation_id instead of + # creating a duplicate. Fixes #1456. Source is in the key so the + # SAME task delegated from two different registered workspaces + # produces two distinct delegations (the right behavior — one per + # tenant audit trail). + idem_key = hashlib.sha256(f"{src}:{workspace_id}:{task}".encode()).hexdigest()[:32] + + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.post( + f"{PLATFORM_URL}/workspaces/{src}/delegate", + json={"target_id": workspace_id, "task": task, "idempotency_key": idem_key}, + headers=_auth_headers_for_heartbeat(src), + ) + if resp.status_code == 202: + data = resp.json() + return json.dumps({ + "delegation_id": data.get("delegation_id", ""), + "workspace_id": workspace_id, + "status": "delegated", + "note": "Task delegated. The platform runs it in the background. Use check_task_status to poll for results.", + }) + else: + return f"Error: delegation failed with status {resp.status_code}: {resp.text[:200]}" + except Exception as e: + return f"Error: delegation failed — {e}" + + +async def tool_check_task_status( + workspace_id: str, + task_id: str, + source_workspace_id: str | None = None, +) -> str: + """Check delegations for this workspace via the platform API. + + Args: + workspace_id: Ignored (kept for backward compat). Checks + ``source_workspace_id``'s delegations (the workspace that + FIRED the delegations), not the target's. + task_id: Optional delegation_id to filter. If empty, returns all recent delegations. + source_workspace_id: Which registered workspace's delegation log + to query. Defaults to the module-level WORKSPACE_ID. + """ + src = source_workspace_id or WORKSPACE_ID + try: + async with httpx.AsyncClient(timeout=10.0) as client: + resp = await client.get( + f"{PLATFORM_URL}/workspaces/{src}/delegations", + headers=_auth_headers_for_heartbeat(src), + ) + if resp.status_code != 200: + return f"Error: failed to check delegations ({resp.status_code})" + delegations = resp.json() + if task_id: + # Filter by delegation_id + matching = [d for d in delegations if d.get("delegation_id") == task_id] + if matching: + return json.dumps(matching[0]) + return json.dumps({"status": "not_found", "delegation_id": task_id}) + # Return all recent delegations + summary = [] + for d in delegations[:10]: + summary.append({ + "delegation_id": d.get("delegation_id", ""), + "target_id": d.get("target_id", ""), + "status": d.get("status", ""), + "summary": d.get("summary", ""), + "response_preview": d.get("response_preview", ""), + }) + return json.dumps({"delegations": summary, "count": len(delegations)}) + except Exception as e: + return f"Error checking delegations: {e}" diff --git a/workspace/tests/test_a2a_tools_delegation.py b/workspace/tests/test_a2a_tools_delegation.py new file mode 100644 index 00000000..010f4e45 --- /dev/null +++ b/workspace/tests/test_a2a_tools_delegation.py @@ -0,0 +1,129 @@ +"""Drift gate + direct surface tests for ``a2a_tools_delegation`` (RFC #2873 iter 4b). + +The full behavior matrix for the three delegation MCP tools lives in +``test_a2a_tools_impl.py`` (TestToolDelegateTask + TestToolDelegateTaskAsync ++ TestToolCheckTaskStatus). Those exercise call paths through the +``a2a_tools_delegation.foo`` module (after the iter 4b retarget). + +This file owns the post-split contract: + + 1. **Drift gate** — every previously-public symbol on ``a2a_tools`` + (``tool_delegate_task``, ``tool_delegate_task_async``, + ``tool_check_task_status``, ``_delegate_sync_via_polling``, + ``_SYNC_POLL_INTERVAL_S``, ``_SYNC_POLL_BUDGET_S``) is the EXACT + same callable / value as the new module's public name. A wrapper + that drifted would silently bypass tests targeting the wrapper. + + 2. **Smoke import** — both modules import in either order without + raising (the lazy ``report_activity`` import inside + ``tool_delegate_task`` is the contract that prevents a circular + import; this test pins it). +""" +from __future__ import annotations + +import os + +import pytest + + +@pytest.fixture(autouse=True) +def _require_workspace_id(monkeypatch): + monkeypatch.setenv("WORKSPACE_ID", "00000000-0000-0000-0000-000000000000") + monkeypatch.setenv("PLATFORM_URL", "http://test.invalid") + yield + + +# ============== Drift gate ============== + +class TestBackCompatAliases: + def test_tool_delegate_task_alias(self): + import a2a_tools + import a2a_tools_delegation + assert a2a_tools.tool_delegate_task is a2a_tools_delegation.tool_delegate_task + + def test_tool_delegate_task_async_alias(self): + import a2a_tools + import a2a_tools_delegation + assert ( + a2a_tools.tool_delegate_task_async + is a2a_tools_delegation.tool_delegate_task_async + ) + + def test_tool_check_task_status_alias(self): + import a2a_tools + import a2a_tools_delegation + assert ( + a2a_tools.tool_check_task_status + is a2a_tools_delegation.tool_check_task_status + ) + + def test_delegate_sync_via_polling_alias(self): + import a2a_tools + import a2a_tools_delegation + assert ( + a2a_tools._delegate_sync_via_polling + is a2a_tools_delegation._delegate_sync_via_polling + ) + + def test_constants_match(self): + import a2a_tools + import a2a_tools_delegation + assert ( + a2a_tools._SYNC_POLL_INTERVAL_S + == a2a_tools_delegation._SYNC_POLL_INTERVAL_S + ) + assert ( + a2a_tools._SYNC_POLL_BUDGET_S + == a2a_tools_delegation._SYNC_POLL_BUDGET_S + ) + + +# ============== Smoke imports ============== + +class TestImportContracts: + def test_delegation_imports_without_a2a_tools_loaded(self, monkeypatch): + """``a2a_tools_delegation`` should NOT pull in ``a2a_tools`` at + module-load time. The lazy ``from a2a_tools import report_activity`` + inside ``tool_delegate_task`` is the only legitimate hop. + + Pin this so a future refactor that adds a top-level + ``from a2a_tools import …`` re-introduces the circular-import + crash that motivated the lazy pattern. + """ + import sys + # Drop both modules so we re-import in a controlled order + for mod in ("a2a_tools", "a2a_tools_delegation"): + sys.modules.pop(mod, None) + + # Importing delegation first must succeed without a2a_tools + # being loaded (because a2a_tools imports delegation, the + # circular path ONLY closes if delegation top-level imports + # something from a2a_tools). + import a2a_tools_delegation # noqa: F401 + # If we got here, no circular import. + assert "a2a_tools_delegation" in sys.modules + + def test_a2a_tools_imports_via_delegation_re_export(self): + """The opposite direction: importing a2a_tools must trigger the + delegation re-export so a2a_tools.tool_delegate_task resolves.""" + import a2a_tools + assert hasattr(a2a_tools, "tool_delegate_task") + assert hasattr(a2a_tools, "tool_delegate_task_async") + assert hasattr(a2a_tools, "tool_check_task_status") + + +# ============== Sync-poll budget env override ============== + +class TestPollBudgetEnvOverride: + def test_default_budget_when_env_unset(self): + """Module-level constant. Set DELEGATION_TIMEOUT before importing + a2a_tools_delegation to override; default is 300.0.""" + # The constant is computed at module-load time. To verify the + # override path we'd need to reload — skipped here because it's + # tested at boot. This test pins the default for catch-the-eye + # documentation. + import a2a_tools_delegation + # Whatever was set when the module first loaded — assert it's + # numeric and >= the documented floor (180s healthsweep budget). + assert isinstance(a2a_tools_delegation._SYNC_POLL_BUDGET_S, float) + assert a2a_tools_delegation._SYNC_POLL_BUDGET_S >= 180.0 diff --git a/workspace/tests/test_a2a_tools_impl.py b/workspace/tests/test_a2a_tools_impl.py index 5f8bd7bc..43f149cb 100644 --- a/workspace/tests/test_a2a_tools_impl.py +++ b/workspace/tests/test_a2a_tools_impl.py @@ -226,16 +226,16 @@ class TestToolDelegateTask: async def test_peer_not_found_returns_error(self): import a2a_tools - with patch("a2a_tools.discover_peer", return_value=None): + with patch("a2a_tools_delegation.discover_peer", return_value=None): result = await a2a_tools.tool_delegate_task("ws-missing", "task") assert "not found" in result or "Error" in result async def test_offline_peer_returns_error(self): """A peer with status=offline short-circuits before we hit the proxy.""" import a2a_tools - with patch("a2a_tools.discover_peer", return_value={"id": "ws-1", "status": "offline"}): + with patch("a2a_tools_delegation.discover_peer", return_value={"id": "ws-1", "status": "offline"}): mc = _make_http_mock() - with patch("a2a_tools.httpx.AsyncClient", return_value=mc): + with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc): result = await a2a_tools.tool_delegate_task("ws-1", "task") assert "offline" in result.lower() @@ -261,8 +261,8 @@ class TestToolDelegateTask: captured["source"] = source_workspace_id return "ok" - with patch("a2a_tools.discover_peer", return_value=peer), \ - patch("a2a_tools.send_a2a_message", side_effect=fake_send), \ + with patch("a2a_tools_delegation.discover_peer", return_value=peer), \ + patch("a2a_tools_delegation.send_a2a_message", side_effect=fake_send), \ patch("a2a_tools.report_activity", new=AsyncMock()): await a2a_tools.tool_delegate_task(peer_id, "do thing") @@ -274,8 +274,8 @@ class TestToolDelegateTask: import a2a_tools peer = {"id": "ws-1", "url": "http://ws-1.svc/a2a", "name": "Worker"} - with patch("a2a_tools.discover_peer", return_value=peer), \ - patch("a2a_tools.send_a2a_message", return_value="Task completed!"), \ + with patch("a2a_tools_delegation.discover_peer", return_value=peer), \ + patch("a2a_tools_delegation.send_a2a_message", return_value="Task completed!"), \ patch("a2a_tools.report_activity", new=AsyncMock()): result = await a2a_tools.tool_delegate_task("ws-1", "do something") @@ -287,8 +287,8 @@ class TestToolDelegateTask: peer = {"id": "ws-1", "url": "http://ws-1.svc/a2a", "name": "Worker"} error_msg = f"{a2a_tools._A2A_ERROR_PREFIX}Agent error: something bad" - with patch("a2a_tools.discover_peer", return_value=peer), \ - patch("a2a_tools.send_a2a_message", return_value=error_msg), \ + with patch("a2a_tools_delegation.discover_peer", return_value=peer), \ + patch("a2a_tools_delegation.send_a2a_message", return_value=error_msg), \ patch("a2a_tools.report_activity", new=AsyncMock()): result = await a2a_tools.tool_delegate_task("ws-1", "do something") @@ -302,8 +302,8 @@ class TestToolDelegateTask: # Pre-populate the cache a2a_tools._peer_names["ws-cached"] = "CachedName" peer = {"id": "ws-cached", "url": "http://ws-cached.svc/a2a"} # no 'name' - with patch("a2a_tools.discover_peer", return_value=peer), \ - patch("a2a_tools.send_a2a_message", return_value="done"), \ + with patch("a2a_tools_delegation.discover_peer", return_value=peer), \ + patch("a2a_tools_delegation.send_a2a_message", return_value="done"), \ patch("a2a_tools.report_activity", new=AsyncMock()): result = await a2a_tools.tool_delegate_task("ws-cached", "task") @@ -316,8 +316,8 @@ class TestToolDelegateTask: # Ensure not in cache a2a_tools._peer_names.pop("ws-nona000", None) peer = {"id": "ws-nona000", "url": "http://x.svc/a2a"} # no 'name' - with patch("a2a_tools.discover_peer", return_value=peer), \ - patch("a2a_tools.send_a2a_message", return_value="ok"), \ + with patch("a2a_tools_delegation.discover_peer", return_value=peer), \ + patch("a2a_tools_delegation.send_a2a_message", return_value="ok"), \ patch("a2a_tools.report_activity", new=AsyncMock()): result = await a2a_tools.tool_delegate_task("ws-nona000", "task") @@ -349,7 +349,7 @@ class TestToolDelegateTaskAsync: import a2a_tools mc = _make_http_mock(post_resp=_resp(202, {"delegation_id": "d-123", "status": "delegated"})) - with patch("a2a_tools.httpx.AsyncClient", return_value=mc): + with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc): result = await a2a_tools.tool_delegate_task_async("ws-1", "do task") data = json.loads(result) @@ -362,7 +362,7 @@ class TestToolDelegateTaskAsync: import a2a_tools mc = _make_http_mock(post_resp=_resp(500, {"error": "internal"})) - with patch("a2a_tools.httpx.AsyncClient", return_value=mc): + with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc): result = await a2a_tools.tool_delegate_task_async("ws-1", "do task") assert "Error" in result @@ -372,7 +372,7 @@ class TestToolDelegateTaskAsync: import a2a_tools mc = _make_http_mock(post_exc=httpx.ConnectError("connection refused")) - with patch("a2a_tools.httpx.AsyncClient", return_value=mc): + with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc): result = await a2a_tools.tool_delegate_task_async("ws-1", "do task") assert "Error" in result or "failed" in result.lower() @@ -393,7 +393,7 @@ class TestToolCheckTaskStatus: {"delegation_id": "d-2", "target_id": "ws-u", "status": "pending", "summary": "waiting"}, ] mc = _make_http_mock(get_resp=_resp(200, delegations)) - with patch("a2a_tools.httpx.AsyncClient", return_value=mc): + with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc): result = await a2a_tools.tool_check_task_status("ws-1", "") data = json.loads(result) @@ -409,7 +409,7 @@ class TestToolCheckTaskStatus: {"delegation_id": "d-2", "status": "pending"}, ] mc = _make_http_mock(get_resp=_resp(200, delegations)) - with patch("a2a_tools.httpx.AsyncClient", return_value=mc): + with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc): result = await a2a_tools.tool_check_task_status("ws-1", "d-1") data = json.loads(result) @@ -421,7 +421,7 @@ class TestToolCheckTaskStatus: import a2a_tools mc = _make_http_mock(get_resp=_resp(200, [])) - with patch("a2a_tools.httpx.AsyncClient", return_value=mc): + with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc): result = await a2a_tools.tool_check_task_status("ws-1", "d-missing") data = json.loads(result) @@ -432,7 +432,7 @@ class TestToolCheckTaskStatus: import a2a_tools mc = _make_http_mock(get_resp=_resp(500, {"error": "db down"})) - with patch("a2a_tools.httpx.AsyncClient", return_value=mc): + with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc): result = await a2a_tools.tool_check_task_status("ws-1", "d-1") assert "Error" in result or "failed" in result.lower() From 2227a14b1e081590fa9f3b35621aaedbfd8ad2ec Mon Sep 17 00:00:00 2001 From: Hongming Wang Date: Tue, 5 May 2026 05:01:04 -0700 Subject: [PATCH 2/7] fix(build): add a2a_tools_delegation to TOP_LEVEL_MODULES drift gate Iter 4b's new module needs the rewrite-list entry. Stacked on iter 4a which already added a2a_tools_rbac. Refs RFC #2873 iter 4b. --- scripts/build_runtime_package.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/build_runtime_package.py b/scripts/build_runtime_package.py index 60963b96..dcb25d05 100755 --- a/scripts/build_runtime_package.py +++ b/scripts/build_runtime_package.py @@ -55,6 +55,7 @@ TOP_LEVEL_MODULES = { "a2a_executor", "a2a_mcp_server", "a2a_tools", + "a2a_tools_delegation", "a2a_tools_rbac", "adapter_base", "agent", From be18b9c8f99e0e75c101d4797a966f10f2fa2e3f Mon Sep 17 00:00:00 2001 From: Hongming Wang Date: Tue, 5 May 2026 09:50:30 -0700 Subject: [PATCH 3/7] fix(tests): retarget remaining a2a_tools delegation patches to a2a_tools_delegation CI caught two test files I missed in the original iter 4b retarget: test_a2a_multi_workspace.py + test_delegation_sync_via_polling.py patch a2a_tools.{discover_peer, send_a2a_message, _delegate_sync_via_polling, httpx.AsyncClient} but those call sites moved to a2a_tools_delegation in this PR. 17 patch sites retargeted; 30 tests now green. Refs RFC #2873 iter 4b. --- workspace/tests/test_a2a_multi_workspace.py | 12 +++++----- .../tests/test_delegation_sync_via_polling.py | 22 +++++++++---------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/workspace/tests/test_a2a_multi_workspace.py b/workspace/tests/test_a2a_multi_workspace.py index 84f929e6..7cee1c10 100644 --- a/workspace/tests/test_a2a_multi_workspace.py +++ b/workspace/tests/test_a2a_multi_workspace.py @@ -339,8 +339,8 @@ class TestToolDelegateTaskAutoRouting: seen_send_src["src"] = source_workspace_id return "ok" - with patch("a2a_tools.discover_peer", side_effect=fake_discover), \ - patch("a2a_tools.send_a2a_message", side_effect=fake_send), \ + with patch("a2a_tools_delegation.discover_peer", side_effect=fake_discover), \ + patch("a2a_tools_delegation.send_a2a_message", side_effect=fake_send), \ patch("a2a_tools.report_activity", new=AsyncMock()): await a2a_tools.tool_delegate_task(peer_id, "do thing") @@ -367,8 +367,8 @@ class TestToolDelegateTaskAutoRouting: seen["send"] = source_workspace_id return "ok" - with patch("a2a_tools.discover_peer", side_effect=fake_discover), \ - patch("a2a_tools.send_a2a_message", side_effect=fake_send), \ + with patch("a2a_tools_delegation.discover_peer", side_effect=fake_discover), \ + patch("a2a_tools_delegation.send_a2a_message", side_effect=fake_send), \ patch("a2a_tools.report_activity", new=AsyncMock()): await a2a_tools.tool_delegate_task( peer_id, "do thing", source_workspace_id=ws_explicit, @@ -395,8 +395,8 @@ class TestToolDelegateTaskAutoRouting: seen["send"] = source_workspace_id return "ok" - with patch("a2a_tools.discover_peer", side_effect=fake_discover), \ - patch("a2a_tools.send_a2a_message", side_effect=fake_send), \ + with patch("a2a_tools_delegation.discover_peer", side_effect=fake_discover), \ + patch("a2a_tools_delegation.send_a2a_message", side_effect=fake_send), \ patch("a2a_tools.report_activity", new=AsyncMock()): await a2a_tools.tool_delegate_task(peer_id, "do thing") diff --git a/workspace/tests/test_delegation_sync_via_polling.py b/workspace/tests/test_delegation_sync_via_polling.py index 4d032f4e..7f6b2918 100644 --- a/workspace/tests/test_delegation_sync_via_polling.py +++ b/workspace/tests/test_delegation_sync_via_polling.py @@ -80,10 +80,10 @@ class TestFlagOffLegacyPath: async def fake_report_activity(*_a, **_kw): return None - with patch("a2a_tools.send_a2a_message", side_effect=fake_send), \ - patch("a2a_tools.discover_peer", side_effect=fake_discover), \ + with patch("a2a_tools_delegation.send_a2a_message", side_effect=fake_send), \ + patch("a2a_tools_delegation.discover_peer", side_effect=fake_discover), \ patch("a2a_tools.report_activity", side_effect=fake_report_activity), \ - patch("a2a_tools._delegate_sync_via_polling", new=AsyncMock()) as poll_mock: + patch("a2a_tools_delegation._delegate_sync_via_polling", new=AsyncMock()) as poll_mock: result = await a2a_tools.tool_delegate_task( "ws-target", "task body", source_workspace_id="ws-self" ) @@ -105,7 +105,7 @@ class TestFlagOnDispatchFailures: import a2a_tools mc = _make_client(post_exc=httpx.ConnectError("network down")) - with patch("a2a_tools.httpx.AsyncClient", return_value=mc): + with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc): res = await a2a_tools._delegate_sync_via_polling( "ws-target", "task", "ws-self" ) @@ -119,7 +119,7 @@ class TestFlagOnDispatchFailures: import a2a_tools mc = _make_client(post_resp=_resp(403, {"error": "forbidden"})) - with patch("a2a_tools.httpx.AsyncClient", return_value=mc): + with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc): res = await a2a_tools._delegate_sync_via_polling( "ws-target", "task", "ws-self" ) @@ -134,7 +134,7 @@ class TestFlagOnDispatchFailures: # 202 Accepted but no delegation_id field — defensive shape check. mc = _make_client(post_resp=_resp(202, {"status": "delegated"})) - with patch("a2a_tools.httpx.AsyncClient", return_value=mc): + with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc): res = await a2a_tools._delegate_sync_via_polling( "ws-target", "task", "ws-self" ) @@ -168,7 +168,7 @@ class TestFlagOnPollingOutcomes: get_resps=[_resp(200, [completed_row])], ) - with patch("a2a_tools.httpx.AsyncClient", return_value=mc): + with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc): res = await a2a_tools._delegate_sync_via_polling( "ws-target", "task", "ws-self" ) @@ -196,7 +196,7 @@ class TestFlagOnPollingOutcomes: get_resps=[_resp(200, [failed_row])], ) - with patch("a2a_tools.httpx.AsyncClient", return_value=mc): + with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc): res = await a2a_tools._delegate_sync_via_polling( "ws-target", "task", "ws-self" ) @@ -234,7 +234,7 @@ class TestFlagOnPollingOutcomes: get_resps=get_seq, ) - with patch("a2a_tools.httpx.AsyncClient", return_value=mc): + with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc): res = await a2a_tools._delegate_sync_via_polling( "ws-target", "task", "ws-self" ) @@ -266,7 +266,7 @@ class TestFlagOnPollingOutcomes: get_resps=get_seq, ) - with patch("a2a_tools.httpx.AsyncClient", return_value=mc): + with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc): res = await a2a_tools._delegate_sync_via_polling( "ws-target", "task", "ws-self" ) @@ -304,7 +304,7 @@ class TestFlagOnPollingOutcomes: get_resps=[first_poll, second_poll], ) - with patch("a2a_tools.httpx.AsyncClient", return_value=mc): + with patch("a2a_tools_delegation.httpx.AsyncClient", return_value=mc): res = await a2a_tools._delegate_sync_via_polling( "ws-target", "task", "ws-self" ) From 7644e82f2fe8ea4285d9100d922f1681b5e9154b Mon Sep 17 00:00:00 2001 From: Hongming Wang Date: Tue, 5 May 2026 10:30:22 -0700 Subject: [PATCH 4/7] feat(saas): default new workspaces to T4 on SaaS, T3 self-hosted MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit User reported every SaaS workspace defaults to T2 (Standard). Three sites quietly disagreed on the default: - canvas CreateWorkspaceDialog (line 126): isSaaS ? 4 : 3 ← only correct one - canvas EmptyState "Create blank": tier: 2 ← hardcoded - workspace.go POST /workspaces: tier = 3 ← not SaaS-aware - org_import.go createWorkspaceTree: tier = 2 (fallback)← not SaaS-aware So a user clicking "+ New Workspace" via the dialog got T4 on SaaS, but a user clicking "Create blank" on the empty canvas got T2, and an agent POSTing /workspaces directly got T3. Same tenant, three different tiers depending on entry point. Fix: 1. WorkspaceHandler.IsSaaS() and DefaultTier() helpers (workspace_dispatchers.go). IsSaaS() := h.cpProv != nil — single source of truth for "are we SaaS" across the file. DefaultTier() returns 4 on SaaS, 3 on self-hosted. SaaS rationale: each workspace runs on its own sibling EC2 so the per-workspace tier boundary is a Docker resource limit on the only container present — no neighbour to protect from. T4 matches the boundary. 2. workspace.go now defaults tier via h.DefaultTier() instead of hardcoded T3. 3. org_import.go fallback (when neither ws.tier nor defaults.tier set) becomes SaaS-aware: T4 on SaaS, T2 on self-hosted (preserve the existing safe-shared-Docker-daemon default for self-hosted org imports). 4. canvas EmptyState "Create blank" stops sending tier:2 in the body and lets the backend pick — single source of truth in the backend. Eliminates the third disagreement. Test plan: - go vet ./... clean - go test ./internal/handlers/ -count 1 — all green (4.3s) - npx tsc --noEmit on canvas — clean - Staging E2E (after deploy): create a fresh workspace via canvas empty-state on hongming.moleculesai.app, confirm tier=4 on the workspace details panel. Co-Authored-By: Claude Opus 4.7 (1M context) --- canvas/src/components/EmptyState.tsx | 13 +++++++--- .../internal/handlers/org_import.go | 12 ++++++++- .../internal/handlers/workspace.go | 18 ++++++------- .../handlers/workspace_dispatchers.go | 26 +++++++++++++++++++ 4 files changed, 55 insertions(+), 14 deletions(-) diff --git a/canvas/src/components/EmptyState.tsx b/canvas/src/components/EmptyState.tsx index 2452ef1a..d54f1709 100644 --- a/canvas/src/components/EmptyState.tsx +++ b/canvas/src/components/EmptyState.tsx @@ -48,16 +48,21 @@ export function EmptyState() { }); // "Create blank" bypasses templates entirely — no preflight, no - // modal, just POST /workspaces with a default name and tier. - // Deliberately NOT routed through useTemplateDeploy because it - // has no `template.id` to deploy against. + // modal, just POST /workspaces with a default name. Deliberately + // NOT routed through useTemplateDeploy because it has no + // `template.id` to deploy against. + // + // tier is omitted so the backend picks a SaaS-aware default + // (T4 on SaaS, T3 on self-hosted — see WorkspaceHandler.DefaultTier). + // The previous hardcoded `tier: 2` shipped every fresh-tenant agent + // at Standard regardless of host, which surprised SaaS users whose + // CreateWorkspaceDialog already defaults to T4. const createBlank = async () => { setBlankCreating(true); setBlankError(null); try { const ws = await api.post<{ id: string }>("/workspaces", { name: "My First Agent", - tier: 2, canvas: firstDeployCoords(), }); handleDeployed(ws.id); diff --git a/workspace-server/internal/handlers/org_import.go b/workspace-server/internal/handlers/org_import.go index 70151e09..94ca0b34 100644 --- a/workspace-server/internal/handlers/org_import.go +++ b/workspace-server/internal/handlers/org_import.go @@ -61,7 +61,17 @@ func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, absX tier = defaults.Tier } if tier == 0 { - tier = 2 + // SaaS-aware fallback. SaaS → T4 (one container per sibling + // EC2, no neighbour to protect from). Self-hosted → T2 + // (safe shared-Docker-daemon default — many workspaces in + // one kernel). Templates that want a different floor + // declare `tier:` in their config.yaml or the org-template's + // `defaults.tier`. + if h.workspace != nil && h.workspace.IsSaaS() { + tier = 4 + } else { + tier = 2 + } } ctxLookup := context.Background() diff --git a/workspace-server/internal/handlers/workspace.go b/workspace-server/internal/handlers/workspace.go index 3b5b4c02..cf210342 100644 --- a/workspace-server/internal/handlers/workspace.go +++ b/workspace-server/internal/handlers/workspace.go @@ -148,15 +148,15 @@ func (h *WorkspaceHandler) Create(c *gin.Context) { id := uuid.New().String() awarenessNamespace := workspaceAwarenessNamespace(id) if payload.Tier == 0 { - // Default to T3 ("Privileged"). T3 gives agents a read_write - // workspace mount + Docker daemon access — the level most - // templates need to do real work. Lower tiers (T1 sandboxed, - // T2 standard) stay available as explicit opt-ins for - // low-trust agents. Matches the Canvas CreateWorkspaceDialog - // default for self-hosted hosts (SaaS defaults to T4 via - // CreateWorkspaceDialog because each SaaS workspace runs on - // its own sibling EC2). - payload.Tier = 3 + // SaaS-aware default. SaaS → T4 (full host access; each + // workspace runs on its own sibling EC2 so the tier boundary + // is a Docker resource limit on the only container present — + // no neighbour to protect from). Self-hosted → T3 (read-write + // workspace mount + Docker daemon access, most templates' + // baseline). Lower tiers (T1 sandboxed, T2 standard) remain + // explicit opt-ins for low-trust agents. Matches the canvas + // CreateWorkspaceDialog defaults so the API and the UI agree. + payload.Tier = h.DefaultTier() } // Detect runtime + default model from template config.yaml when the diff --git a/workspace-server/internal/handlers/workspace_dispatchers.go b/workspace-server/internal/handlers/workspace_dispatchers.go index 23237d00..18ede255 100644 --- a/workspace-server/internal/handlers/workspace_dispatchers.go +++ b/workspace-server/internal/handlers/workspace_dispatchers.go @@ -49,6 +49,32 @@ func (h *WorkspaceHandler) HasProvisioner() bool { return h.cpProv != nil || h.provisioner != nil } +// IsSaaS reports whether the CP (EC2) provisioner is wired. Each SaaS +// workspace runs on its own sibling EC2, so the per-workspace tier +// boundary is a Docker resource limit applied to the only container +// on that EC2 — there's no neighbour to protect from. Self-hosted +// runs many workspaces in one Docker daemon on a single host, so +// the tier-2-by-default safe-neighbour-share posture stays. +// +// Tier defaults across Create / OrgImport / canvas EmptyState branch +// on IsSaaS so SaaS users get T4 (full host access) by default and +// self-hosted users keep the lower-trust caps. +func (h *WorkspaceHandler) IsSaaS() bool { + return h.cpProv != nil +} + +// DefaultTier is the SaaS-aware default tier. T4 on SaaS (single +// container per EC2 — full host access matches the boundary), T3 on +// self-hosted (read-write workspace mount + Docker daemon access, +// most templates' baseline). Callers default to this when the user +// hasn't explicitly picked a tier. +func (h *WorkspaceHandler) DefaultTier() int { + if h.IsSaaS() { + return 4 + } + return 3 +} + // provisionWorkspaceAuto picks the backend (CP for SaaS, local Docker // for self-hosted) and starts provisioning in a goroutine. Returns true // when a backend was kicked off, false when neither is wired. From c79ba05ed5ef848a3b600b12458749bb074a1180 Mon Sep 17 00:00:00 2001 From: Hongming Wang Date: Tue, 5 May 2026 10:46:17 -0700 Subject: [PATCH 5/7] test(pendinguploads): close cycleDone-vs-metric-record race in sweeper tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit TestStartSweeper_RecordsMetricsOnError flaked on every CI rerun under race detection: `error counter delta = 0, want 1`. Root cause is a race between two goroutines, not a bug in the production sweeper. The fake `fakeSweepStorage.Sweep` signals `cycleDone` from inside its deferred return — that happens BEFORE Sweep's return value is received by `sweepOnce`, which is what triggers the metric increment. On slow CI hosts the test goroutine wins the read after `waitForCycle` unblocks and BEFORE StartSweeper's goroutine has called `metrics.PendingUploadsSweepError`, so the asserted delta is 0 even though the metric WILL be 1 a few ms later. Adds a polling assert helper, `waitForMetricDelta`, that closes the race deterministically without timing-based sleeps: - TestStartSweeper_RecordsMetricsOnError uses waitForMetricDelta to wait for the error counter to settle at 1. - TestStartSweeper_RecordsMetricsOnSuccess uses it on the success counters (acked, expired) so the error-stayed-zero assertion reads after StartSweeper has fully processed the cycle. - waitForCycle keeps its current shape but documents the caveat in its comment so future tests don't repeat the assumption. Verified: `go test ./internal/pendinguploads/ -race -count 5` passes all 9 tests across 5 iterations cleanly. Per memory feedback_question_test_when_unexpected.md: the "delta=0, want=1" failure looked like a real production bug at first glance, but instrumented inspection showed the metric DOES increment, just AFTER the test's read. The fix is the test's wait shape, not the sweeper. Unblocks every PR currently broken by this flake (#2898 hit it on two consecutive CI runs; staging-merged PRs from earlier today (#2877/#2881/#2885/#2886) introduced the test). Co-Authored-By: Claude Opus 4.7 (1M context) --- .../internal/pendinguploads/sweeper_test.go | 59 ++++++++++++++++--- 1 file changed, 50 insertions(+), 9 deletions(-) diff --git a/workspace-server/internal/pendinguploads/sweeper_test.go b/workspace-server/internal/pendinguploads/sweeper_test.go index e9cfde08..19ce26da 100644 --- a/workspace-server/internal/pendinguploads/sweeper_test.go +++ b/workspace-server/internal/pendinguploads/sweeper_test.go @@ -65,6 +65,15 @@ func (f *fakeSweepStorage) Sweep(_ context.Context, ackRetention time.Duration) // waitForCycle blocks until at least one Sweep completes, with a deadline. // Tests use this instead of time.Sleep to avoid flakes on slow CI hosts. +// +// CAVEAT: cycleDone fires from inside fakeSweepStorage.Sweep's defer, +// which runs as Sweep returns its result — BEFORE the StartSweeper +// loop has processed the (result, error) tuple and called the +// metric recorders. Tests that assert on metric counters must NOT +// rely on this wait alone; use waitForMetricDelta instead so the +// metric increment race (Sweep returns → cycleDone fires → test +// reads counter → only then does StartSweeper's loop call +// metrics.PendingUploadsSweepError) doesn't produce a flake. func (f *fakeSweepStorage) waitForCycle(t *testing.T, n int, timeout time.Duration) { t.Helper() deadline := time.NewTimer(timeout) @@ -78,6 +87,33 @@ func (f *fakeSweepStorage) waitForCycle(t *testing.T, n int, timeout time.Durati } } +// waitForMetricDelta polls the supplied delta function until it returns +// `want` or the timeout elapses. Use after waitForCycle when the test +// asserts on a metric counter — closes the race between cycleDone +// (signalled inside fakeSweepStorage.Sweep's defer, BEFORE Sweep +// returns to StartSweeper) and the metric recording (which happens in +// StartSweeper's loop AFTER Sweep returns). On a slow CI host the test +// goroutine wins the read before StartSweeper's goroutine writes the +// counter; the polling assert preserves the determinism of "the metric +// MUST be N" without timing-based flakes. +// +// Per memory feedback_question_test_when_unexpected.md: the failure +// mode "delta=0, want=1" looked like a real bug at first glance — +// "metric never incremented" — but instrumented analysis showed the +// metric DID increment, just AFTER the test's read. The fix is the +// test's wait shape, not the production code. +func waitForMetricDelta(t *testing.T, delta func() int64, want int64, timeout time.Duration) { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if delta() == want { + return + } + time.Sleep(5 * time.Millisecond) + } + t.Fatalf("waited %s for metric delta=%d, last seen %d", timeout, want, delta()) +} + func TestStartSweeper_NilStorageDoesNotPanic(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -220,12 +256,13 @@ func TestStartSweeper_RecordsMetricsOnSuccess(t *testing.T) { go pendinguploads.StartSweeper(ctx, store, time.Hour) store.waitForCycle(t, 1, 2*time.Second) - if got := deltaAcked(); got != 3 { - t.Errorf("acked counter delta = %d, want 3", got) - } - if got := deltaExpired(); got != 5 { - t.Errorf("expired counter delta = %d, want 5", got) - } + // Poll for the success counters to settle — closes the cycleDone- + // vs-metric-record race (see waitForMetricDelta comment). + waitForMetricDelta(t, deltaAcked, 3, 2*time.Second) + waitForMetricDelta(t, deltaExpired, 5, 2*time.Second) + // Error counter MUST stay at zero on the success path. Read after + // the success counters have settled — once those are correct, + // StartSweeper has fully processed this cycle's result. if got := deltaError(); got != 0 { t.Errorf("error counter delta = %d, want 0", got) } @@ -244,7 +281,11 @@ func TestStartSweeper_RecordsMetricsOnError(t *testing.T) { go pendinguploads.StartSweeper(ctx, store, time.Hour) store.waitForCycle(t, 1, 2*time.Second) - if got := deltaError(); got != 1 { - t.Errorf("error counter delta = %d, want 1", got) - } + // Poll for the error counter to settle — cycleDone fires inside + // the fake's Sweep defer, BEFORE StartSweeper's loop receives the + // returned error and calls metrics.PendingUploadsSweepError. On + // slow CI hosts a direct deltaError() read here returns 0 even + // though the metric WILL be 1 a few ms later. See + // waitForMetricDelta comment. + waitForMetricDelta(t, deltaError, 1, 2*time.Second) } From a489ee1a7c032c0d81f7faaf42a7a14e6fea28eb Mon Sep 17 00:00:00 2001 From: Hongming Wang Date: Tue, 5 May 2026 10:47:32 -0700 Subject: [PATCH 6/7] fix(canvas/chat): instant-scroll to bottom on first mount MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reported: "right now when chat box opens it opens in the middle, but it should be at the end of conversation." Root cause: ChatTab.tsx:548 fires `bottomRef.scrollIntoView({ behavior: "smooth" })` on every messages-update. On initial mount with N messages already loaded, the smooth-scroll triggers a ~300ms animation that any concurrent React re-render (agent push landing, theme toggle, sidepanel resize) interrupts mid-flight, leaving the user stuck somewhere in the middle of the conversation. Fix: track first-mount via hasInitialScrollRef. Use behavior:"instant" for the initial jump (deterministic, no animation interruption), then smooth for subsequent appends (the new-message-landing visual stays). Refs flipped on first messages.length > 0 transition, so: - Initial open of chat tab: instant jump to bottom ✓ - New agent message arrives: smooth scroll into view ✓ - Workspace switch (ChatTab remounts): fresh hasInitialScrollRef, gets instant again ✓ - loadOlder prepend: anchor-restore path unchanged, still pins user's reading position ✓ Test plan: - pnpm test --run ChatTab.lazyHistory.test.tsx → 8 pass (existing lazy-history tests untouched) - npx tsc --noEmit clean - Manual on hongming.moleculesai.app: open a busy chat (mac laptop, ~50 messages), confirm view lands at the latest bubble, not mid- scroll. Switch to another workspace + back → instant again. Co-Authored-By: Claude Opus 4.7 (1M context) --- canvas/src/components/tabs/ChatTab.tsx | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/canvas/src/components/tabs/ChatTab.tsx b/canvas/src/components/tabs/ChatTab.tsx index 7da17b72..2d6ae908 100644 --- a/canvas/src/components/tabs/ChatTab.tsx +++ b/canvas/src/components/tabs/ChatTab.tsx @@ -286,6 +286,14 @@ function MyChatPanel({ workspaceId, data }: Props) { const [error, setError] = useState(null); const [confirmRestart, setConfirmRestart] = useState(false); const bottomRef = useRef(null); + // First-mount scroll-to-bottom needs `behavior: "instant"` — long + // conversations smooth-animate for ~300ms which any concurrent + // re-render can interrupt, leaving the user stuck mid-conversation + // when the chat tab opens. Subsequent appends (new agent messages) + // keep `smooth` for the visual "landing" feel. Flipped the first + // time messages.length goes positive, so a workspace switch (which + // remounts ChatTab) gets a fresh instant jump too. + const hasInitialScrollRef = useRef(false); // Lazy-load older history on scroll-up. // - containerRef = the scrollable messages viewport // - topRef = sentinel above the messages list; IO observes it @@ -545,6 +553,15 @@ function MyChatPanel({ workspaceId, data }: Props) { scrollAnchorRef.current = null; return; } + // Instant on first arrival of messages — smooth-scroll on a long + // conversation gets interrupted by concurrent renders and leaves + // the user stuck in the middle. After the first jump, subsequent + // appends animate as before. + if (!hasInitialScrollRef.current && messages.length > 0) { + hasInitialScrollRef.current = true; + bottomRef.current?.scrollIntoView({ behavior: "instant" as ScrollBehavior }); + return; + } bottomRef.current?.scrollIntoView({ behavior: "smooth" }); }, [messages]); From 9991057ad19db33b1dbbb004ed1044545be9b282 Mon Sep 17 00:00:00 2001 From: Hongming Wang Date: Tue, 5 May 2026 11:10:13 -0700 Subject: [PATCH 7/7] =?UTF-8?q?feat(poll-upload):=20phase=205a=20=E2=80=94?= =?UTF-8?q?=20atomic=20batch=20insert=20+=20acked-index=20+=20mime=20harde?= =?UTF-8?q?ning?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Resolves four of six findings from the retrospective code review of Phases 1–4 (poll-mode chat upload). Bundled because every change is in the platform's pending_uploads layer or the multi-file handler that reads it. Findings resolved: 1. Important — Sweep query lacked an index for the acked-retention OR-arm. The Phase 1 partial indexes are both `WHERE acked_at IS NULL`, so the `(acked_at IS NOT NULL AND acked_at < retention)` half of the WHERE clause seq-scanned the table on every cycle. Add a complementary partial index on `acked_at WHERE acked_at IS NOT NULL` so both arms of the disjunction are index-covered. Disjoint from the existing two indexes (no row matches both predicates), so write amplification is bounded to ~one index entry per terminal-state row. 2. Important — uploadPollMode partial-failure left orphans. The previous per-file Put loop committed rows 1..K-1 and then errored on row K with no compensation, so a client retry would double-insert the survivors. Refactor the handler into three explicit phases (pre-validate + read-into-memory, single atomic PutBatch, per-file activity row) and add Storage.PutBatch with all-or-nothing transaction semantics. 3. FYI — pendinguploads.StartSweeperWithInterval was exported only for tests. Move it to lower-case startSweeperWithInterval and expose the test seam through pendinguploads/export_test.go (Go convention; the shim file is stripped from the production binary at build time). 4. Nit — multipart Content-Type was passed verbatim into pending_uploads rows and re-served on /content. Add safeMimetype which strips parameters, rejects CR/LF/control bytes, and coerces malformed shapes to application/octet-stream. The eventual GET /content response can no longer be header-split via a crafted Content-Type on the multipart. Comprehensive tests: - 10 PutBatch unit tests (sqlmock): happy path, empty input, all four pre-validation rejection paths, BeginTx error, per-row error + Rollback (no Commit), first-row error, Commit error. - 4 new PutBatch integration tests (real Postgres): all-rows-commit happy path with COUNT(*) verification, atomic-rollback no-leak via a NUL-byte filename that lib/pq rejects mid-batch, oversize short-circuit no-Tx, idx_pending_uploads_acked existence + partial predicate via pg_indexes (planner-shape-independent). - 3 new chat_files_poll tests: atomic rollback on second-file oversize, atomic rollback on PutBatch error, mimetype CRLF/NUL/parameter sanitization (8 sub-cases). The two remaining review findings (inbox_uploads.fetch_and_stage blocks the poll loop synchronously; two httpx Clients per row) are Python-side and ship in Phase 5b once this lands on staging. Test-only export pattern via export_test.go, atomic pre-validation discipline (validate before Tx), and behavior-based (not name-based) test assertions follow the standing project conventions. --- .../internal/handlers/chat_files.go | 158 +++++++++---- .../internal/handlers/chat_files_poll_test.go | 154 ++++++++++++ .../pending_uploads_integration_test.go | 178 ++++++++++++++ .../internal/handlers/pending_uploads_test.go | 8 + .../internal/pendinguploads/export_test.go | 17 ++ .../internal/pendinguploads/storage.go | 78 +++++++ .../internal/pendinguploads/storage_test.go | 220 ++++++++++++++++++ .../internal/pendinguploads/sweeper.go | 6 +- .../internal/pendinguploads/sweeper_test.go | 5 +- ...00000_pending_uploads_acked_index.down.sql | 2 + ...5200000_pending_uploads_acked_index.up.sql | 30 +++ 11 files changed, 806 insertions(+), 50 deletions(-) create mode 100644 workspace-server/internal/pendinguploads/export_test.go create mode 100644 workspace-server/migrations/20260505200000_pending_uploads_acked_index.down.sql create mode 100644 workspace-server/migrations/20260505200000_pending_uploads_acked_index.up.sql diff --git a/workspace-server/internal/handlers/chat_files.go b/workspace-server/internal/handlers/chat_files.go index ccfa0d4c..f5e980bf 100644 --- a/workspace-server/internal/handlers/chat_files.go +++ b/workspace-server/internal/handlers/chat_files.go @@ -600,14 +600,21 @@ func (h *ChatFilesHandler) uploadPollMode(c *gin.Context, ctx context.Context, w return } - out := make([]uploadedFile, 0, len(headers)) + // Phase 1: pre-validate + read every part BEFORE any DB write. + // A multi-file upload must commit all-or-nothing; a per-file + // failure halfway through used to leave rows 1..K-1 in the table + // while the client got a 500 and retried the whole batch — duplicate + // rows, orphan activity rows. Validating up-front + atomic PutBatch + // closes that gap. + type prepped struct { + Sanitized string + Mimetype string + Content []byte + Original string // original (unsanitized) filename for error messages + } + prepReady := make([]prepped, 0, len(headers)) + items := make([]pendinguploads.PutItem, 0, len(headers)) for _, fh := range headers { - // Read full content. Per-file cap enforced post-read so an - // oversized file fails with a clean 413 rather than a torn - // stream. The +1 byte ReadAll trick that the Python side - // uses isn't easy through multipart.FileHeader; instead we - // rely on the multipart layer's ContentLength header and - // short-circuit before opening the part. if fh.Size > pendinguploads.MaxFileBytes { log.Printf("chat_files uploadPollMode: per-file cap exceeded for %s: %s (%d bytes)", workspaceID, fh.Filename, fh.Size) @@ -621,45 +628,67 @@ func (h *ChatFilesHandler) uploadPollMode(c *gin.Context, ctx context.Context, w } content, err := readMultipartFile(fh) if err != nil { - log.Printf("chat_files uploadPollMode: read part failed for %s/%s: %v", workspaceID, fh.Filename, err) + log.Printf("chat_files uploadPollMode: read part failed for %s/%s: %v", + workspaceID, fh.Filename, err) c.JSON(http.StatusBadRequest, gin.H{"error": "could not read file part"}) return } - - sanitized := SanitizeFilename(fh.Filename) - mimetype := fh.Header.Get("Content-Type") - - fileID, err := h.pendingUploads.Put(ctx, wsUUID, content, sanitized, mimetype) - if err != nil { - if errors.Is(err, pendinguploads.ErrTooLarge) { - // Belt + suspenders: the size check above already - // caught this, but Storage.Put re-validates so a - // malformed FileHeader can't slip through. 413 with - // the same shape so the client sees one error class. - c.JSON(http.StatusRequestEntityTooLarge, gin.H{ - "error": "file exceeds per-file cap", - "filename": fh.Filename, - "size": len(content), - "max": pendinguploads.MaxFileBytes, - }) - return - } - log.Printf("chat_files uploadPollMode: storage.Put failed for %s/%s: %v", - workspaceID, sanitized, err) - c.JSON(http.StatusInternalServerError, gin.H{"error": "could not stage file"}) + // Belt-and-braces post-read cap (multipart.FileHeader.Size can lie + // on some clients that don't set Content-Length per part). + if len(content) > pendinguploads.MaxFileBytes { + log.Printf("chat_files uploadPollMode: per-file cap exceeded post-read for %s: %s (%d bytes)", + workspaceID, fh.Filename, len(content)) + c.JSON(http.StatusRequestEntityTooLarge, gin.H{ + "error": "file exceeds per-file cap", + "filename": fh.Filename, + "size": len(content), + "max": pendinguploads.MaxFileBytes, + }) return } + sanitized := SanitizeFilename(fh.Filename) + mimetype := safeMimetype(fh.Header.Get("Content-Type")) + prepReady = append(prepReady, prepped{ + Sanitized: sanitized, Mimetype: mimetype, Content: content, Original: fh.Filename, + }) + items = append(items, pendinguploads.PutItem{ + Content: content, Filename: sanitized, Mimetype: mimetype, + }) + } - // Activity row so the workspace's inbox poller picks this up - // on its next cycle. activity_type=a2a_receive (NOT a new - // type) so the existing poll filter - // `?type=a2a_receive` catches it without poll-side changes; - // method=chat_upload_receive is the discriminator the - // workspace's adapter (Phase 2) uses to route to the upload - // fetcher instead of the agent's message handler. Same - // shape as A2A's tasks/send vs message/send method split. + // Phase 2: atomic batch insert. On failure no rows commit. + fileIDs, err := h.pendingUploads.PutBatch(ctx, wsUUID, items) + if err != nil { + if errors.Is(err, pendinguploads.ErrTooLarge) { + // Belt + suspenders: pre-validation above already caught + // this; surface a clean 413 if a malformed FileHeader + // somehow slipped through. + c.JSON(http.StatusRequestEntityTooLarge, gin.H{ + "error": "one or more files exceed per-file cap", + "max": pendinguploads.MaxFileBytes, + }) + return + } + log.Printf("chat_files uploadPollMode: storage.PutBatch failed for %s: %v", + workspaceID, err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "could not stage files"}) + return + } + + // Phase 3: write per-file activity rows and build the response. Activity + // rows are written individually (not part of the same Tx as PutBatch) + // because LogActivity is shared across many handlers and threading the + // Tx through would be a bigger refactor. The trade-off: if an activity + // write fails after the PutBatch commits, the pending_uploads rows + // orphan until the 24h TTL — significantly better than the previous + // "every multi-file upload could orphan" behavior, and the workspace's + // fetcher handles soft-404 cleanly when activity rows reference a row + // the platform later expired. + out := make([]uploadedFile, 0, len(prepReady)) + for i, p := range prepReady { + fileID := fileIDs[i] uri := fmt.Sprintf("platform-pending:%s/%s", workspaceID, fileID) - summary := "chat_upload_receive: " + sanitized + summary := "chat_upload_receive: " + p.Sanitized method := "chat_upload_receive" LogActivity(ctx, h.broadcaster, ActivityParams{ WorkspaceID: workspaceID, @@ -669,28 +698,65 @@ func (h *ChatFilesHandler) uploadPollMode(c *gin.Context, ctx context.Context, w Summary: &summary, RequestBody: map[string]interface{}{ "file_id": fileID.String(), - "name": sanitized, - "mimeType": mimetype, - "size": len(content), + "name": p.Sanitized, + "mimeType": p.Mimetype, + "size": len(p.Content), "uri": uri, }, Status: "ok", }) log.Printf("chat_files uploadPollMode: staged %s/%s (file_id=%s size=%d mimetype=%q)", - workspaceID, sanitized, fileID, len(content), mimetype) + workspaceID, p.Sanitized, fileID, len(p.Content), p.Mimetype) out = append(out, uploadedFile{ URI: uri, - Name: sanitized, - Mimetype: mimetype, - Size: int64(len(content)), + Name: p.Sanitized, + Mimetype: p.Mimetype, + Size: int64(len(p.Content)), }) } c.JSON(http.StatusOK, gin.H{"files": out}) } +// safeMimetype validates a multipart-supplied Content-Type header and +// returns a sanitized value safe to store + serve back unmodified. +// +// The platform's GET /content handler reflects the stored mimetype as +// the response Content-Type. An attacker-controlled header that +// embedded CR/LF could split the response (header injection); a value +// containing semicolons could carry an unexpected charset parameter +// that confuses a downstream renderer. Strip CR/LF/control chars + +// keep only the type/subtype prefix; reject anything that doesn't +// match a basic `type/subtype` regex by falling back to the safe +// default (application/octet-stream — the workspace-side handler does +// the same fallback). +func safeMimetype(raw string) string { + const fallback = "application/octet-stream" + // Trim parameters (`text/html; charset=utf-8` → `text/html`). + if i := strings.IndexByte(raw, ';'); i >= 0 { + raw = raw[:i] + } + raw = strings.TrimSpace(raw) + if raw == "" { + return "" + } + // Reject if any control char or whitespace is present (header + // injection defense). RFC 7231 mimetype grammar forbids whitespace. + for _, r := range raw { + if r < 0x21 || r > 0x7e { + return fallback + } + } + // Require exactly one slash separating type and subtype. + parts := strings.Split(raw, "/") + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + return fallback + } + return raw +} + // readMultipartFile reads a multipart part fully into memory. Wraps // the open + io.ReadAll + close idiom so the call site stays clean, // and so a future change (chunked reads / hashing) has one place to diff --git a/workspace-server/internal/handlers/chat_files_poll_test.go b/workspace-server/internal/handlers/chat_files_poll_test.go index b9aeb5d6..aa5bab34 100644 --- a/workspace-server/internal/handlers/chat_files_poll_test.go +++ b/workspace-server/internal/handlers/chat_files_poll_test.go @@ -67,6 +67,46 @@ func (s *inMemStorage) Put(_ context.Context, ws uuid.UUID, content []byte, file return id, nil } +// PutBatch mirrors the production atomic-batch contract: any per-item +// failure leaves the in-memory state unchanged, simulating Tx rollback. +// Pre-validation matches PostgresStorage.PutBatch; oversized items +// return ErrTooLarge before any row is added. +func (s *inMemStorage) PutBatch(_ context.Context, ws uuid.UUID, items []pendinguploads.PutItem) ([]uuid.UUID, error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.putErr != nil { + return nil, s.putErr + } + // Pre-validate so an oversized item rejects the whole batch before + // any state mutation — matches the Tx-rollback semantics. + for _, it := range items { + if len(it.Content) > pendinguploads.MaxFileBytes { + return nil, pendinguploads.ErrTooLarge + } + } + ids := make([]uuid.UUID, 0, len(items)) + stagedRows := make(map[uuid.UUID]pendinguploads.Record, len(items)) + stagedPuts := make([]putCall, 0, len(items)) + for _, it := range items { + id := uuid.New() + stagedRows[id] = pendinguploads.Record{ + FileID: id, WorkspaceID: ws, Content: it.Content, + Filename: it.Filename, Mimetype: it.Mimetype, + SizeBytes: int64(len(it.Content)), CreatedAt: time.Now(), + ExpiresAt: time.Now().Add(24 * time.Hour), + } + stagedPuts = append(stagedPuts, putCall{ + WorkspaceID: ws, Filename: it.Filename, Mimetype: it.Mimetype, Size: len(it.Content), + }) + ids = append(ids, id) + } + for id, r := range stagedRows { + s.rows[id] = r + } + s.puts = append(s.puts, stagedPuts...) + return ids, nil +} + func (s *inMemStorage) Get(context.Context, uuid.UUID) (pendinguploads.Record, error) { return pendinguploads.Record{}, pendinguploads.ErrNotFound } @@ -557,6 +597,120 @@ func TestPollUpload_SanitizesFilenameInResponse(t *testing.T) { } } +// TestPollUpload_AtomicRollbackOnSecondFileTooLarge pins the +// transactional contract introduced in phase 5: when one file in a +// multi-file batch fails pre-validation (oversize), NONE of the files +// in the batch land in storage. Previously a per-file Put loop would +// stage rows 1..K-1 before failing on row K, leaving orphan +// pending_uploads + activity rows the client would re-create on retry. +// +// Pinned via inMemStorage's PutBatch (which mirrors PostgresStorage's +// Tx-rollback behavior on a per-item validation failure) — but the +// real atomicity guarantee is the integration test in +// pending_uploads_integration_test.go. +func TestPollUpload_AtomicRollbackOnSecondFileTooLarge(t *testing.T) { + mock := setupTestDB(t) + setupTestRedis(t) + + wsID := "aaaaaaaa-3333-3333-4444-555555555555" + expectPollDeliveryMode(mock, wsID, "poll") + + store := newInMemStorage() + h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil)). + WithPendingUploads(store, nil) + + // Two files: first OK, second over the per-file cap. Pre-validation + // in uploadPollMode catches it BEFORE any Put — store.puts must + // stay empty. (If the test ever sees len=1, the regression is + // "first file slipped through into storage on a partial-failure + // batch.") + tooBig := bytes.Repeat([]byte{0x42}, pendinguploads.MaxFileBytes+1) + body, ct := pollUploadFixture(t, map[string][]byte{ + "ok.txt": []byte("small"), + "huge.bin": tooBig, + }) + c, w := makeUploadRequest(t, wsID, body, ct) + h.Upload(c) + + if w.Code != http.StatusRequestEntityTooLarge { + t.Errorf("status=%d body=%s, want 413", w.Code, w.Body.String()) + } + if len(store.puts) != 0 { + t.Errorf("expected zero Puts on rollback, got %d: %+v", len(store.puts), store.puts) + } +} + +// TestPollUpload_AtomicRollbackOnPutBatchError validates that an in- +// flight PutBatch failure (e.g. simulated DB error) leaves zero rows +// — same guarantee as the pre-validation path, but exercises the +// "Tx-Rollback after BEGIN" branch via the fake. +func TestPollUpload_AtomicRollbackOnPutBatchError(t *testing.T) { + mock := setupTestDB(t) + setupTestRedis(t) + + wsID := "bbbbbbbb-3333-3333-4444-555555555555" + expectPollDeliveryMode(mock, wsID, "poll") + + store := newInMemStorage() + store.putErr = errors.New("db down mid-batch") + h := NewChatFilesHandler(NewTemplatesHandler(t.TempDir(), nil)). + WithPendingUploads(store, nil) + + body, ct := pollUploadFixture(t, map[string][]byte{ + "a.txt": []byte("aaa"), + "b.txt": []byte("bbb"), + "c.txt": []byte("ccc"), + }) + c, w := makeUploadRequest(t, wsID, body, ct) + h.Upload(c) + + if w.Code != http.StatusInternalServerError { + t.Errorf("status=%d, want 500", w.Code) + } + if len(store.puts) != 0 { + t.Errorf("expected zero Puts after PutBatch error, got %d", len(store.puts)) + } +} + +// TestPollUpload_MimetypeWithCRLFInjectionStripped pins the safeMimetype +// hardening: a multipart-supplied Content-Type header with CR/LF is +// rewritten to application/octet-stream so the eventual /content +// response can't be header-split on the wire. +func TestPollUpload_MimetypeWithCRLFInjectionStripped(t *testing.T) { + got := safeMimetype("text/html\r\nX-Injected: pwn") + if got != "application/octet-stream" { + t.Errorf("CRLF mimetype not stripped, got %q", got) + } + got = safeMimetype("image/png\x00") + if got != "application/octet-stream" { + t.Errorf("NUL byte mimetype not stripped, got %q", got) + } + got = safeMimetype("text/plain; charset=utf-8") + if got != "text/plain" { + t.Errorf("parameter not stripped, got %q", got) + } + got = safeMimetype("application/pdf") + if got != "application/pdf" { + t.Errorf("clean mime modified, got %q", got) + } + got = safeMimetype("") + if got != "" { + t.Errorf("empty input should pass through, got %q", got) + } + got = safeMimetype("notamime") + if got != "application/octet-stream" { + t.Errorf("non-type/subtype not coerced, got %q", got) + } + got = safeMimetype("/empty-type") + if got != "application/octet-stream" { + t.Errorf("missing type half not coerced, got %q", got) + } + got = safeMimetype("type/") + if got != "application/octet-stream" { + t.Errorf("missing subtype half not coerced, got %q", got) + } +} + // TestPollUpload_ActivityRowDiscriminator pins the // activity_type / method shape that the workspace inbox poller depends // on. The poller filters `GET /workspaces/:id/activity?type=a2a_receive` diff --git a/workspace-server/internal/handlers/pending_uploads_integration_test.go b/workspace-server/internal/handlers/pending_uploads_integration_test.go index bec9011c..61c64f86 100644 --- a/workspace-server/internal/handlers/pending_uploads_integration_test.go +++ b/workspace-server/internal/handlers/pending_uploads_integration_test.go @@ -44,6 +44,7 @@ import ( "context" "database/sql" "os" + "strings" "testing" "time" @@ -273,6 +274,183 @@ func TestIntegration_PendingUploads_PutEnforcesSizeCap(t *testing.T) { } } +// TestIntegration_PendingUploads_PutBatch_HappyPath_AllRowsCommit pins the +// "all rows commit" leg of the PutBatch atomicity contract against a real +// Postgres. sqlmock can't catch a regression where the Go-side Tx machinery +// silently no-ops the inserts (e.g., wrong driver options on BeginTx); only +// COUNT(*) on the real table can. +func TestIntegration_PendingUploads_PutBatch_HappyPath_AllRowsCommit(t *testing.T) { + conn := integrationDB_PendingUploads(t) + store := pendinguploads.NewPostgres(conn) + ctx := context.Background() + + wsID := uuid.New() + + // Pre-existing row so the COUNT(*) baseline is non-zero — proves + // PutBatch adds rows incrementally rather than overwriting. + if _, err := store.Put(ctx, wsID, []byte("seed"), "seed.txt", "text/plain"); err != nil { + t.Fatalf("seed Put: %v", err) + } + + items := []pendinguploads.PutItem{ + {Content: []byte("alpha"), Filename: "alpha.txt", Mimetype: "text/plain"}, + {Content: []byte("beta"), Filename: "beta.bin", Mimetype: "application/octet-stream"}, + {Content: []byte("gamma"), Filename: "gamma.pdf", Mimetype: "application/pdf"}, + } + ids, err := store.PutBatch(ctx, wsID, items) + if err != nil { + t.Fatalf("PutBatch: %v", err) + } + if len(ids) != len(items) { + t.Fatalf("ids length %d, want %d", len(ids), len(items)) + } + + // Each returned id round-trips through Get with the right content. + for i, id := range ids { + rec, err := store.Get(ctx, id) + if err != nil { + t.Fatalf("Get item %d (%s): %v", i, id, err) + } + if string(rec.Content) != string(items[i].Content) { + t.Errorf("item %d content = %q, want %q", i, rec.Content, items[i].Content) + } + if rec.Filename != items[i].Filename { + t.Errorf("item %d filename = %q, want %q", i, rec.Filename, items[i].Filename) + } + } + + var n int + if err := conn.QueryRowContext(ctx, `SELECT COUNT(*) FROM pending_uploads WHERE workspace_id = $1`, wsID).Scan(&n); err != nil { + t.Fatalf("count: %v", err) + } + if n != 4 { + t.Errorf("workspace row count = %d, want 4 (1 seed + 3 batch)", n) + } +} + +// TestIntegration_PendingUploads_PutBatch_AtomicRollback_NoLeakOnFailure +// proves the all-or-nothing contract end-to-end against real Postgres MVCC. +// +// Strategy: build a 3-item batch where item index 1 carries a filename with +// an embedded NUL byte. lib/pq rejects NULs in TEXT columns at the protocol +// layer (`pq: invalid byte sequence for encoding "UTF8": 0x00`), which +// triggers the per-row INSERT error path in PutBatch. The first item's +// INSERT…RETURNING already wrote a row to the Tx's snapshot, so a buggy +// rollback would leave that row visible after PutBatch returns. +// +// Postgrest semantics: ROLLBACK is the only way a real DB can guarantee the +// "no leak" contract; a unit test with sqlmock can prove the Go function +// CALLED Rollback, but only this integration test proves Postgres actually +// HONORED it. +func TestIntegration_PendingUploads_PutBatch_AtomicRollback_NoLeakOnFailure(t *testing.T) { + conn := integrationDB_PendingUploads(t) + store := pendinguploads.NewPostgres(conn) + ctx := context.Background() + + wsID := uuid.New() + + // Baseline COUNT(*) for this workspace — must remain 0 after a failed batch. + var before int + if err := conn.QueryRowContext(ctx, `SELECT COUNT(*) FROM pending_uploads WHERE workspace_id = $1`, wsID).Scan(&before); err != nil { + t.Fatalf("baseline count: %v", err) + } + if before != 0 { + t.Fatalf("workspace not isolated: baseline = %d, want 0", before) + } + + // Item 1 has a NUL byte in the filename — Go-side pre-validation + // (which only checks empty/length) lets it through, so the INSERT + // reaches lib/pq, which rejects it at the protocol level. That's the + // canonical "DB-side error mid-batch" we want to exercise. + items := []pendinguploads.PutItem{ + {Content: []byte("ok"), Filename: "ok.txt", Mimetype: "text/plain"}, + {Content: []byte("bad"), Filename: "bad\x00name.txt", Mimetype: "text/plain"}, + {Content: []byte("never"), Filename: "never.txt", Mimetype: "text/plain"}, + } + _, err := store.PutBatch(ctx, wsID, items) + if err == nil { + t.Fatalf("expected error from NUL-byte filename, got nil") + } + + // THE assertion this whole test exists for: even though item 0's + // INSERT…RETURNING succeeded inside the Tx, the rollback unwound + // it — zero rows for this workspace, not one (let alone three). + var after int + if err := conn.QueryRowContext(ctx, `SELECT COUNT(*) FROM pending_uploads WHERE workspace_id = $1`, wsID).Scan(&after); err != nil { + t.Fatalf("post-failure count: %v", err) + } + if after != 0 { + t.Errorf("Tx rollback leaked rows: workspace count = %d, want 0", after) + } +} + +// TestIntegration_PendingUploads_PutBatch_Oversize_NoTxOpened verifies the +// pre-validation short-circuit: an oversized item rejects with ErrTooLarge +// BEFORE any Tx opens, so the table is untouched. The unit test (sqlmock +// with zero expectations) catches the Go-side path; this test sanity-checks +// no real DB I/O happens by confirming COUNT(*) doesn't move. +func TestIntegration_PendingUploads_PutBatch_Oversize_NoTxOpened(t *testing.T) { + conn := integrationDB_PendingUploads(t) + store := pendinguploads.NewPostgres(conn) + ctx := context.Background() + + wsID := uuid.New() + tooBig := make([]byte, pendinguploads.MaxFileBytes+1) + _, err := store.PutBatch(ctx, wsID, []pendinguploads.PutItem{ + {Content: []byte("ok"), Filename: "ok.txt"}, + {Content: tooBig, Filename: "too-big.bin"}, + }) + if err != pendinguploads.ErrTooLarge { + t.Fatalf("expected ErrTooLarge, got %v", err) + } + var n int + if err := conn.QueryRowContext(ctx, `SELECT COUNT(*) FROM pending_uploads WHERE workspace_id = $1`, wsID).Scan(&n); err != nil { + t.Fatalf("count: %v", err) + } + if n != 0 { + t.Errorf("pre-validation did NOT short-circuit: count = %d, want 0", n) + } +} + +// TestIntegration_PendingUploads_AckedIndexExists verifies the Phase 5a +// migration (20260505200000_pending_uploads_acked_index.up.sql) actually +// created idx_pending_uploads_acked with the right partial-index predicate. +// +// Why pg_indexes and not EXPLAIN: the planner prefers Seq Scan on tiny +// tables regardless of available indexes — a plan-shape check would be +// flaky under real test loads. The contract we care about is "the index +// exists with the predicate we wrote in the migration"; pg_indexes is +// the canonical source for that, robust to row count and planner version. +func TestIntegration_PendingUploads_AckedIndexExists(t *testing.T) { + conn := integrationDB_PendingUploads(t) + ctx := context.Background() + + var indexdef string + err := conn.QueryRowContext(ctx, ` + SELECT indexdef FROM pg_indexes + WHERE schemaname = 'public' + AND tablename = 'pending_uploads' + AND indexname = 'idx_pending_uploads_acked' + `).Scan(&indexdef) + if err == sql.ErrNoRows { + t.Fatal("idx_pending_uploads_acked is missing — migration 20260505200000 not applied") + } + if err != nil { + t.Fatalf("pg_indexes query: %v", err) + } + + // Pin the partial-index predicate. Without "WHERE acked_at IS NOT NULL" + // we'd be indexing the entire table (defeats the point — most rows are + // unacked), and the existing idx_pending_uploads_unacked already covers + // the inverse predicate. + if !strings.Contains(indexdef, "(acked_at)") { + t.Errorf("index missing acked_at column: %s", indexdef) + } + if !strings.Contains(indexdef, "WHERE (acked_at IS NOT NULL)") { + t.Errorf("index missing partial predicate: %s", indexdef) + } +} + func TestIntegration_PendingUploads_GetIgnoresExpiredAndAcked(t *testing.T) { conn := integrationDB_PendingUploads(t) store := pendinguploads.NewPostgres(conn) diff --git a/workspace-server/internal/handlers/pending_uploads_test.go b/workspace-server/internal/handlers/pending_uploads_test.go index e4b11a09..778e8170 100644 --- a/workspace-server/internal/handlers/pending_uploads_test.go +++ b/workspace-server/internal/handlers/pending_uploads_test.go @@ -77,6 +77,14 @@ func (f *fakeStorage) Sweep(_ context.Context, _ time.Duration) (pendinguploads. return pendinguploads.SweepResult{}, nil } +// PutBatch is required by the Storage interface; the upload handler +// tests live in chat_files_poll_test.go and use a separate fake +// (inMemStorage). Stubbed here because the Get/Ack tests don't drive +// PutBatch, but the interface must be satisfied. +func (f *fakeStorage) PutBatch(_ context.Context, _ uuid.UUID, _ []pendinguploads.PutItem) ([]uuid.UUID, error) { + return nil, nil +} + func newRouter(handler *handlers.PendingUploadsHandler) *gin.Engine { gin.SetMode(gin.TestMode) r := gin.New() diff --git a/workspace-server/internal/pendinguploads/export_test.go b/workspace-server/internal/pendinguploads/export_test.go new file mode 100644 index 00000000..c758b629 --- /dev/null +++ b/workspace-server/internal/pendinguploads/export_test.go @@ -0,0 +1,17 @@ +package pendinguploads + +import ( + "context" + "time" +) + +// StartSweeperWithIntervalForTest exposes startSweeperWithInterval to +// the external test package. The production code uses StartSweeper +// (which pins the canonical SweepInterval); tests pin a short interval +// to exercise the ticker-driven cycle without burning real wall-clock +// time. The Go convention `export_test.go` keeps this seam OUT of the +// production binary — files ending in _test.go are stripped at build +// time, so this re-export only exists during `go test`. +func StartSweeperWithIntervalForTest(ctx context.Context, storage Storage, ackRetention, interval time.Duration) { + startSweeperWithInterval(ctx, storage, ackRetention, interval) +} diff --git a/workspace-server/internal/pendinguploads/storage.go b/workspace-server/internal/pendinguploads/storage.go index 8bf63b1e..c4bcaf92 100644 --- a/workspace-server/internal/pendinguploads/storage.go +++ b/workspace-server/internal/pendinguploads/storage.go @@ -85,6 +85,15 @@ type SweepResult struct { // Total returns the sum of Acked + Expired — convenient for log lines. func (r SweepResult) Total() int { return r.Acked + r.Expired } +// PutItem is one file in a PutBatch call. Same per-field rules as Put — +// empty content, missing filename, or content > MaxFileBytes is rejected +// up-front so a bad item in the batch doesn't poison the transaction. +type PutItem struct { + Content []byte + Filename string + Mimetype string +} + // Storage is the platform-side persistence boundary for poll-mode chat // uploads. The Postgres implementation backs all callers today; an S3- // backed implementation can drop in once RFC #2789 lands by making @@ -99,6 +108,17 @@ type Storage interface { // content > MaxFileBytes return errors before any DB write. Put(ctx context.Context, workspaceID uuid.UUID, content []byte, filename, mimetype string) (uuid.UUID, error) + // PutBatch inserts N uploads atomically — either all rows commit or + // none do. Returns assigned file_ids in input order on success; + // returns an error and does NOT insert any row on failure. + // + // Use this from multi-file upload handlers so a per-row failure on + // row K doesn't leave rows 1..K-1 orphaned in the table (a client + // retry would then double-insert them on success). All-or-nothing + // semantics match the multipart request the canvas sends — either + // the whole batch succeeds or the user re-uploads. + PutBatch(ctx context.Context, workspaceID uuid.UUID, items []PutItem) ([]uuid.UUID, error) + // Get returns the full row including content. Returns ErrNotFound // when the row is absent, acked, or past expires_at. Caller should // not differentiate the three cases in the response — from the @@ -174,6 +194,64 @@ func (p *PostgresStorage) Put(ctx context.Context, workspaceID uuid.UUID, conten return fileID, nil } +// PutBatch inserts every item atomically inside a single Tx. On any +// per-item validation or per-row INSERT error the Tx is rolled back and +// the caller sees the error without any rows committed — no partial +// orphans for a multi-file upload that fails mid-batch. +// +// Validation runs BEFORE BEGIN so a bad input shape (empty content, +// over-cap size) doesn't even open a Tx. Once we're in the Tx, the only +// failures expected are DB-side (broken connection, statement timeout) +// — those abort cleanly via Rollback. +func (p *PostgresStorage) PutBatch(ctx context.Context, workspaceID uuid.UUID, items []PutItem) ([]uuid.UUID, error) { + if len(items) == 0 { + return nil, nil + } + for i, it := range items { + if len(it.Content) == 0 { + return nil, fmt.Errorf("pendinguploads: item %d: empty content", i) + } + if len(it.Content) > MaxFileBytes { + return nil, ErrTooLarge + } + if it.Filename == "" { + return nil, fmt.Errorf("pendinguploads: item %d: empty filename", i) + } + if len(it.Filename) > 100 { + return nil, fmt.Errorf("pendinguploads: item %d: filename exceeds 100 chars", i) + } + } + + tx, err := p.db.BeginTx(ctx, nil) + if err != nil { + return nil, fmt.Errorf("pendinguploads: begin tx: %w", err) + } + // Defer-rollback is safe even after a successful Commit — the second + // Rollback is a no-op (database/sql tracks tx state). + defer func() { + _ = tx.Rollback() + }() + + out := make([]uuid.UUID, 0, len(items)) + for i, it := range items { + var fid uuid.UUID + err := tx.QueryRowContext(ctx, ` + INSERT INTO pending_uploads (workspace_id, content, size_bytes, filename, mimetype) + VALUES ($1, $2, $3, $4, $5) + RETURNING file_id + `, workspaceID, it.Content, int64(len(it.Content)), it.Filename, it.Mimetype).Scan(&fid) + if err != nil { + return nil, fmt.Errorf("pendinguploads: batch insert item %d: %w", i, err) + } + out = append(out, fid) + } + + if err := tx.Commit(); err != nil { + return nil, fmt.Errorf("pendinguploads: commit batch: %w", err) + } + return out, nil +} + func (p *PostgresStorage) Get(ctx context.Context, fileID uuid.UUID) (Record, error) { // The expires_at + acked_at filter in the WHERE clause means a // caller sees ErrNotFound for absent / acked / expired without diff --git a/workspace-server/internal/pendinguploads/storage_test.go b/workspace-server/internal/pendinguploads/storage_test.go index e4db87f8..c6793c10 100644 --- a/workspace-server/internal/pendinguploads/storage_test.go +++ b/workspace-server/internal/pendinguploads/storage_test.go @@ -511,3 +511,223 @@ func TestSweepResult_TotalSumsCounts(t *testing.T) { t.Errorf("zero Total = %d, want 0", z.Total()) } } + +// ----- PutBatch ------------------------------------------------------------- +// +// PutBatch is the multi-file atomic insert path used by uploadPollMode in +// chat_files.go. The contract that callers rely on: +// +// - Either ALL rows commit, or NONE do — a per-row INSERT failure must +// leave the table unchanged (no orphaned rows from a half-applied batch). +// - Per-item validation runs BEFORE the Tx opens so a bad input shape +// never wastes a BEGIN round-trip. +// - Returned []uuid.UUID is in input order — handler maps response back +// to the multipart Files[i]. +// +// sqlmock's ExpectBegin / ExpectQuery / ExpectCommit / ExpectRollback let us +// pin the exact tx-lifecycle shape; if a future refactor swaps Begin for +// BeginTx-with-options, the test fails until we re-pin. + +func TestPutBatch_HappyPath_AllCommitInOrder(t *testing.T) { + db, mock := newMockDB(t) + store := pendinguploads.NewPostgres(db) + + wsID := uuid.New() + id1, id2, id3 := uuid.New(), uuid.New(), uuid.New() + + mock.ExpectBegin() + mock.ExpectQuery(insertSQL). + WithArgs(wsID, []byte("aaa"), int64(3), "a.txt", "text/plain"). + WillReturnRows(sqlmock.NewRows([]string{"file_id"}).AddRow(id1)) + mock.ExpectQuery(insertSQL). + WithArgs(wsID, []byte("bbbb"), int64(4), "b.bin", "application/octet-stream"). + WillReturnRows(sqlmock.NewRows([]string{"file_id"}).AddRow(id2)) + mock.ExpectQuery(insertSQL). + WithArgs(wsID, []byte("ccccc"), int64(5), "c.pdf", "application/pdf"). + WillReturnRows(sqlmock.NewRows([]string{"file_id"}).AddRow(id3)) + mock.ExpectCommit() + // Rollback after Commit is a no-op in database/sql; sqlmock allows it + // when ExpectCommit was already matched, so we don't need to expect it. + + got, err := store.PutBatch(context.Background(), wsID, []pendinguploads.PutItem{ + {Content: []byte("aaa"), Filename: "a.txt", Mimetype: "text/plain"}, + {Content: []byte("bbbb"), Filename: "b.bin", Mimetype: "application/octet-stream"}, + {Content: []byte("ccccc"), Filename: "c.pdf", Mimetype: "application/pdf"}, + }) + if err != nil { + t.Fatalf("PutBatch: %v", err) + } + if len(got) != 3 || got[0] != id1 || got[1] != id2 || got[2] != id3 { + t.Errorf("ids out of order or missing: got %v want [%s %s %s]", got, id1, id2, id3) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("expectations: %v", err) + } +} + +func TestPutBatch_EmptyItems_NoTxNoError(t *testing.T) { + db, _ := newMockDB(t) // zero expectations — must NOT round-trip + store := pendinguploads.NewPostgres(db) + + got, err := store.PutBatch(context.Background(), uuid.New(), nil) + if err != nil { + t.Fatalf("expected nil error on empty batch, got %v", err) + } + if got != nil { + t.Errorf("expected nil ids on empty batch, got %v", got) + } +} + +func TestPutBatch_RejectsEmptyContent_NoTx(t *testing.T) { + db, _ := newMockDB(t) + store := pendinguploads.NewPostgres(db) + + _, err := store.PutBatch(context.Background(), uuid.New(), []pendinguploads.PutItem{ + {Content: []byte("ok"), Filename: "a.txt"}, + {Content: nil, Filename: "b.txt"}, + }) + if err == nil || !strings.Contains(err.Error(), "item 1") || !strings.Contains(err.Error(), "empty content") { + t.Fatalf("expected item-1 empty-content error, got %v", err) + } +} + +func TestPutBatch_RejectsOversize_ReturnsErrTooLarge(t *testing.T) { + db, _ := newMockDB(t) + store := pendinguploads.NewPostgres(db) + + too := make([]byte, pendinguploads.MaxFileBytes+1) + _, err := store.PutBatch(context.Background(), uuid.New(), []pendinguploads.PutItem{ + {Content: []byte("ok"), Filename: "small.txt"}, + {Content: too, Filename: "huge.bin"}, + }) + if !errors.Is(err, pendinguploads.ErrTooLarge) { + t.Fatalf("expected ErrTooLarge, got %v", err) + } +} + +func TestPutBatch_RejectsEmptyFilename_NoTx(t *testing.T) { + db, _ := newMockDB(t) + store := pendinguploads.NewPostgres(db) + + _, err := store.PutBatch(context.Background(), uuid.New(), []pendinguploads.PutItem{ + {Content: []byte("hi"), Filename: ""}, + }) + if err == nil || !strings.Contains(err.Error(), "item 0") || !strings.Contains(err.Error(), "empty filename") { + t.Fatalf("expected item-0 empty-filename error, got %v", err) + } +} + +func TestPutBatch_RejectsLongFilename_NoTx(t *testing.T) { + db, _ := newMockDB(t) + store := pendinguploads.NewPostgres(db) + + long := strings.Repeat("z", 101) + _, err := store.PutBatch(context.Background(), uuid.New(), []pendinguploads.PutItem{ + {Content: []byte("hi"), Filename: "ok.txt"}, + {Content: []byte("hi"), Filename: long}, + }) + if err == nil || !strings.Contains(err.Error(), "item 1") || !strings.Contains(err.Error(), "exceeds 100 chars") { + t.Fatalf("expected item-1 too-long-filename error, got %v", err) + } +} + +func TestPutBatch_BeginTxError_Wrapped(t *testing.T) { + db, mock := newMockDB(t) + store := pendinguploads.NewPostgres(db) + + mock.ExpectBegin().WillReturnError(errors.New("conn refused")) + + _, err := store.PutBatch(context.Background(), uuid.New(), []pendinguploads.PutItem{ + {Content: []byte("hi"), Filename: "a.txt"}, + }) + if err == nil || !strings.Contains(err.Error(), "begin tx") { + t.Fatalf("expected wrapped begin-tx error, got %v", err) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("expectations: %v", err) + } +} + +func TestPutBatch_RollsBackOnPerRowError_NoCommit(t *testing.T) { + // First INSERT succeeds, second errors. PutBatch MUST NOT issue + // Commit; the deferred Rollback unwinds row 1 so neither row commits. + // This is the contract that prevents orphan rows on a failed batch. + db, mock := newMockDB(t) + store := pendinguploads.NewPostgres(db) + + wsID := uuid.New() + id1 := uuid.New() + + mock.ExpectBegin() + mock.ExpectQuery(insertSQL). + WithArgs(wsID, []byte("aaa"), int64(3), "a.txt", ""). + WillReturnRows(sqlmock.NewRows([]string{"file_id"}).AddRow(id1)) + mock.ExpectQuery(insertSQL). + WithArgs(wsID, []byte("bb"), int64(2), "b.txt", ""). + WillReturnError(errors.New("statement timeout")) + // Critical: Rollback expected, NOT Commit. If a future refactor + // accidentally swallows the per-row error and Commits anyway, this + // test fails because the unmet ExpectCommit-vs-Rollback shape diverges. + mock.ExpectRollback() + + _, err := store.PutBatch(context.Background(), wsID, []pendinguploads.PutItem{ + {Content: []byte("aaa"), Filename: "a.txt"}, + {Content: []byte("bb"), Filename: "b.txt"}, + }) + if err == nil || !strings.Contains(err.Error(), "batch insert item 1") { + t.Fatalf("expected wrapped per-row insert error, got %v", err) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("expectations (must rollback, no commit): %v", err) + } +} + +func TestPutBatch_RollsBackOnFirstRowError(t *testing.T) { + // Edge case: very first INSERT fails. No rows ever staged — but the + // Tx still needs to roll back to release the snapshot. + db, mock := newMockDB(t) + store := pendinguploads.NewPostgres(db) + + wsID := uuid.New() + mock.ExpectBegin() + mock.ExpectQuery(insertSQL). + WithArgs(wsID, []byte("oops"), int64(4), "a.txt", ""). + WillReturnError(errors.New("constraint violation")) + mock.ExpectRollback() + + _, err := store.PutBatch(context.Background(), wsID, []pendinguploads.PutItem{ + {Content: []byte("oops"), Filename: "a.txt"}, + }) + if err == nil || !strings.Contains(err.Error(), "batch insert item 0") { + t.Fatalf("expected wrapped item-0 insert error, got %v", err) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("expectations: %v", err) + } +} + +func TestPutBatch_CommitError_Wrapped(t *testing.T) { + // Commit fails after every INSERT succeeded. Postgres has already + // rolled back the Tx by this point; we surface the error so the + // handler returns 500 and the client retries. + db, mock := newMockDB(t) + store := pendinguploads.NewPostgres(db) + + wsID := uuid.New() + id1 := uuid.New() + mock.ExpectBegin() + mock.ExpectQuery(insertSQL). + WithArgs(wsID, []byte("hi"), int64(2), "a.txt", ""). + WillReturnRows(sqlmock.NewRows([]string{"file_id"}).AddRow(id1)) + mock.ExpectCommit().WillReturnError(errors.New("commit broken")) + + _, err := store.PutBatch(context.Background(), wsID, []pendinguploads.PutItem{ + {Content: []byte("hi"), Filename: "a.txt"}, + }) + if err == nil || !strings.Contains(err.Error(), "commit batch") { + t.Fatalf("expected wrapped commit error, got %v", err) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("expectations: %v", err) + } +} diff --git a/workspace-server/internal/pendinguploads/sweeper.go b/workspace-server/internal/pendinguploads/sweeper.go index 84a56dab..b29a87ad 100644 --- a/workspace-server/internal/pendinguploads/sweeper.go +++ b/workspace-server/internal/pendinguploads/sweeper.go @@ -66,13 +66,13 @@ const sweepDeadline = 30 * time.Second // to exercise the ticker-driven sweep path without burning real wall- // clock time. func StartSweeper(ctx context.Context, storage Storage, ackRetention time.Duration) { - StartSweeperWithInterval(ctx, storage, ackRetention, SweepInterval) + startSweeperWithInterval(ctx, storage, ackRetention, SweepInterval) } -// StartSweeperWithInterval is the test-friendly variant of StartSweeper +// startSweeperWithInterval is the test-friendly variant of StartSweeper // — same loop, but the cadence is caller-specified. Production code // should use StartSweeper to keep the SweepInterval constant pinned. -func StartSweeperWithInterval(ctx context.Context, storage Storage, ackRetention, interval time.Duration) { +func startSweeperWithInterval(ctx context.Context, storage Storage, ackRetention, interval time.Duration) { if storage == nil { log.Println("pendinguploads sweeper: storage is nil — sweeper disabled") return diff --git a/workspace-server/internal/pendinguploads/sweeper_test.go b/workspace-server/internal/pendinguploads/sweeper_test.go index e9cfde08..1174b87d 100644 --- a/workspace-server/internal/pendinguploads/sweeper_test.go +++ b/workspace-server/internal/pendinguploads/sweeper_test.go @@ -44,6 +44,9 @@ func (f *fakeSweepStorage) MarkFetched(_ context.Context, _ uuid.UUID) error { func (f *fakeSweepStorage) Ack(_ context.Context, _ uuid.UUID) error { return errors.New("not used") } +func (f *fakeSweepStorage) PutBatch(_ context.Context, _ uuid.UUID, _ []pendinguploads.PutItem) ([]uuid.UUID, error) { + return nil, errors.New("not used") +} func (f *fakeSweepStorage) Sweep(_ context.Context, ackRetention time.Duration) (pendinguploads.SweepResult, error) { idx := int(f.calls.Load()) f.calls.Add(1) @@ -144,7 +147,7 @@ func TestStartSweeperWithInterval_TickerFiresAdditionalCycles(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - go pendinguploads.StartSweeperWithInterval(ctx, store, time.Hour, 30*time.Millisecond) + go pendinguploads.StartSweeperWithIntervalForTest(ctx, store, time.Hour, 30*time.Millisecond) // Immediate cycle + at least one tick-driven cycle. store.waitForCycle(t, 2, 2*time.Second) diff --git a/workspace-server/migrations/20260505200000_pending_uploads_acked_index.down.sql b/workspace-server/migrations/20260505200000_pending_uploads_acked_index.down.sql new file mode 100644 index 00000000..2d84b00d --- /dev/null +++ b/workspace-server/migrations/20260505200000_pending_uploads_acked_index.down.sql @@ -0,0 +1,2 @@ +-- Reversal of 20260505200000_pending_uploads_acked_index.up.sql. +DROP INDEX IF EXISTS idx_pending_uploads_acked; diff --git a/workspace-server/migrations/20260505200000_pending_uploads_acked_index.up.sql b/workspace-server/migrations/20260505200000_pending_uploads_acked_index.up.sql new file mode 100644 index 00000000..f2beced2 --- /dev/null +++ b/workspace-server/migrations/20260505200000_pending_uploads_acked_index.up.sql @@ -0,0 +1,30 @@ +-- 20260505200000_pending_uploads_acked_index.up.sql +-- +-- Adds the missing partial index for the acked-retention arm of the +-- pendinguploads.Sweep query. The Phase 1 migration created two +-- partial indexes both gated on `acked_at IS NULL` (workspace-fetch +-- hot path + expires_at sweep arm); the third query path — +-- `WHERE acked_at IS NOT NULL AND acked_at < now() - interval` — was +-- left to a seq scan. +-- +-- For a high-traffic deployment that's a real cost: the table +-- accumulates one row per chat-attached file; the sweeper runs every +-- 5 minutes and DELETEs rows past the 1-hour ack retention. A seq +-- scan over 100K-1M acked rows holds an AccessShare lock for seconds +-- on every cycle. Partial-indexing the inverse predicate reduces +-- this to a btree range scan and lets the DELETE complete in +-- low-millisecond range. +-- +-- WHERE acked_at IS NOT NULL is intentionally inverse of the other +-- two indexes — they cover the unacked working set; this covers the +-- terminal-state set the sweeper visits. Disjoint subsets, so the +-- two indexes don't overlap. +-- +-- Caught in self-review on the parent RFC's Phase 4 PR; filed as +-- a follow-up rather than a Phase 1 fix because the cost only +-- materializes at a row count we don't expect to hit before the +-- sweeper has had a chance to keep up. + +CREATE INDEX IF NOT EXISTS idx_pending_uploads_acked + ON pending_uploads (acked_at) + WHERE acked_at IS NOT NULL;