diff --git a/workspace/a2a_client.py b/workspace/a2a_client.py index e6569385..4d1c5c7a 100644 --- a/workspace/a2a_client.py +++ b/workspace/a2a_client.py @@ -30,6 +30,23 @@ else: # Cache workspace ID → name mappings (populated by list_peers calls) _peer_names: dict[str, str] = {} +# Cache: peer workspace_id → the source workspace_id whose registry +# returned that peer. Populated by ``a2a_tools.tool_list_peers`` whenever +# it queries a specific workspace's peers — so a later +# ``tool_delegate_task(target)`` can auto-route through the correct +# source workspace without the agent having to specify +# ``source_workspace_id`` explicitly. +# +# Single-workspace mode: dict stays empty, all delegations fall through +# to the module-level WORKSPACE_ID (existing behavior). +# +# Multi-workspace mode: as the agent calls list_peers, this map is +# populated with each peer's source. Subsequent delegate_task calls +# auto-route. If a peer is registered under multiple sources (rare — +# e.g. an org-wide capability) the LAST observed source wins; the agent +# can override by passing ``source_workspace_id`` explicitly. +_peer_to_source: dict[str, str] = {} + # Cache workspace ID → full peer record (id, name, role, status, url, ...). # Populated by tool_list_peers and by the lazy registry lookup in # enrich_peer_metadata. The notification-callback path (channel envelope @@ -49,7 +66,12 @@ _peer_metadata: dict[str, tuple[float, dict | None]] = {} _PEER_METADATA_TTL_SECONDS = 300.0 -def enrich_peer_metadata(peer_id: str, *, now: float | None = None) -> dict | None: +def enrich_peer_metadata( + peer_id: str, + source_workspace_id: str | None = None, + *, + now: float | None = None, +) -> dict | None: """Return cached or freshly-fetched metadata for ``peer_id``. Sync helper — safe to call from the inbox poller's notification @@ -86,10 +108,11 @@ def enrich_peer_metadata(peer_id: str, *, now: float | None = None) -> dict | No # the same as a registry miss, which is the desired UX. return record + src = (source_workspace_id or "").strip() or WORKSPACE_ID url = f"{PLATFORM_URL}/registry/discover/{canon}" try: with httpx.Client(timeout=2.0) as client: - resp = client.get(url, headers={"X-Workspace-ID": WORKSPACE_ID, **auth_headers()}) + resp = client.get(url, headers={"X-Workspace-ID": src, **auth_headers(src)}) except Exception as exc: # noqa: BLE001 logger.debug("enrich_peer_metadata: GET %s failed: %s", url, exc) _peer_metadata[canon] = (current, None) @@ -174,22 +197,30 @@ def _validate_peer_id(peer_id: str) -> str | None: return pid.lower() -async def discover_peer(target_id: str) -> dict | None: +async def discover_peer(target_id: str, source_workspace_id: str | None = None) -> dict | None: """Discover a peer workspace's URL via the platform registry. Validates ``target_id`` is a UUID before constructing the URL — a malformed id can't reach the platform handler now, which both short-circuits an avoidable round-trip AND ensures we never interpolate path-traversal characters into the URL. + + ``source_workspace_id`` selects which registered workspace asks the + question — both the X-Workspace-ID header AND the Authorization + bearer token must come from the same workspace, otherwise the + platform's TenantGuard rejects the request. Defaults to the + module-level WORKSPACE_ID for back-compat with single-workspace + callers. """ safe_id = _validate_peer_id(target_id) if safe_id is None: return None + src = (source_workspace_id or "").strip() or WORKSPACE_ID async with httpx.AsyncClient(timeout=10.0) as client: try: resp = await client.get( f"{PLATFORM_URL}/registry/discover/{safe_id}", - headers={"X-Workspace-ID": WORKSPACE_ID, **auth_headers()}, + headers={"X-Workspace-ID": src, **auth_headers(src)}, ) if resp.status_code == 200: return resp.json() @@ -283,7 +314,7 @@ def _format_a2a_error(exc: BaseException, target_url: str) -> str: return f"{_A2A_ERROR_PREFIX}{detail} [target={target_url}]" -async def send_a2a_message(peer_id: str, message: str) -> str: +async def send_a2a_message(peer_id: str, message: str, source_workspace_id: str | None = None) -> str: """Send an A2A ``message/send`` to a peer workspace via the platform proxy. The target URL is constructed internally as @@ -292,6 +323,12 @@ async def send_a2a_message(peer_id: str, message: str) -> str: in-container and external runtimes — see a2a_tools.tool_delegate_task for the rationale. + ``source_workspace_id`` is the SENDING workspace — drives both the + X-Workspace-ID source-tagging header and the bearer token. Defaults + to the module-level WORKSPACE_ID for back-compat. Multi-workspace + operators pass it explicitly so each registered workspace's peers + are reached via their own auth chain. + Auto-retries up to _DELEGATE_MAX_ATTEMPTS times on transient transport-layer errors (RemoteProtocolError, ConnectError, ReadTimeout, etc.) with exponential-backoff + jitter, capped by @@ -302,6 +339,7 @@ async def send_a2a_message(peer_id: str, message: str) -> str: safe_id = _validate_peer_id(peer_id) if safe_id is None: return f"{_A2A_ERROR_PREFIX}invalid peer_id (expected UUID): {peer_id!r}" + src = (source_workspace_id or "").strip() or WORKSPACE_ID target_url = f"{PLATFORM_URL}/workspaces/{safe_id}/a2a" # Fix F (Cycle 5 / H2 — flagged 5 consecutive audits): timeout=None allowed @@ -322,7 +360,7 @@ async def send_a2a_message(peer_id: str, message: str) -> str: # in the recipient's My Chat tab as user-typed input. resp = await client.post( target_url, - headers=self_source_headers(WORKSPACE_ID), + headers=self_source_headers(src), json={ "jsonrpc": "2.0", "id": str(uuid.uuid4()), @@ -389,7 +427,7 @@ async def send_a2a_message(peer_id: str, message: str) -> str: return _format_a2a_error(last_exc, target_url) -async def get_peers_with_diagnostic() -> tuple[list[dict], str | None]: +async def get_peers_with_diagnostic(source_workspace_id: str | None = None) -> tuple[list[dict], str | None]: """Get this workspace's peers, returning (peers, diagnostic). diagnostic is None when the call succeeded (status 200, even if the list @@ -398,15 +436,22 @@ async def get_peers_with_diagnostic() -> tuple[list[dict], str | None]: diagnostic is a short human-readable string explaining what went wrong so callers can surface it instead of "may be isolated" — see #2397. + ``source_workspace_id`` selects which registered workspace's peers to + enumerate; defaults to the module-level WORKSPACE_ID for + single-workspace back-compat. Multi-workspace operators iterate over + each registered workspace separately so each set of peers is fetched + with the correct auth. + The legacy get_peers() shim below preserves the bare-list contract for non-tool callers. """ - url = f"{PLATFORM_URL}/registry/{WORKSPACE_ID}/peers" + src = (source_workspace_id or "").strip() or WORKSPACE_ID + url = f"{PLATFORM_URL}/registry/{src}/peers" async with httpx.AsyncClient(timeout=10.0) as client: try: resp = await client.get( url, - headers={"X-Workspace-ID": WORKSPACE_ID, **auth_headers()}, + headers={"X-Workspace-ID": src, **auth_headers(src)}, ) except Exception as e: return [], f"Cannot reach platform at {PLATFORM_URL}: {e}" diff --git a/workspace/a2a_mcp_server.py b/workspace/a2a_mcp_server.py index 0c979a18..ea8e7755 100644 --- a/workspace/a2a_mcp_server.py +++ b/workspace/a2a_mcp_server.py @@ -91,16 +91,19 @@ async def handle_tool_call(name: str, arguments: dict) -> str: return await tool_delegate_task( arguments.get("workspace_id", ""), arguments.get("task", ""), + source_workspace_id=arguments.get("source_workspace_id") or None, ) elif name == "delegate_task_async": return await tool_delegate_task_async( arguments.get("workspace_id", ""), arguments.get("task", ""), + source_workspace_id=arguments.get("source_workspace_id") or None, ) elif name == "check_task_status": return await tool_check_task_status( arguments.get("workspace_id", ""), arguments.get("task_id", ""), + source_workspace_id=arguments.get("source_workspace_id") or None, ) elif name == "send_message_to_user": raw_attachments = arguments.get("attachments") @@ -116,7 +119,9 @@ async def handle_tool_call(name: str, arguments: dict) -> str: workspace_id=arguments.get("workspace_id") or None, ) elif name == "list_peers": - return await tool_list_peers() + return await tool_list_peers( + source_workspace_id=arguments.get("source_workspace_id") or None, + ) elif name == "get_workspace_info": return await tool_get_workspace_info() elif name == "commit_memory": diff --git a/workspace/a2a_tools.py b/workspace/a2a_tools.py index e5ce78ec..296bcc72 100644 --- a/workspace/a2a_tools.py +++ b/workspace/a2a_tools.py @@ -16,6 +16,7 @@ from a2a_client import ( WORKSPACE_ID, _A2A_ERROR_PREFIX, _peer_names, + _peer_to_source, discover_peer, get_peers, get_peers_with_diagnostic, @@ -23,6 +24,7 @@ from a2a_client import ( send_a2a_message, ) from builtin_tools.security import _redact_secrets +from platform_auth import list_registered_workspaces # --------------------------------------------------------------------------- @@ -189,16 +191,32 @@ async def report_activity( pass # Best-effort — don't block delegation on activity reporting -async def tool_delegate_task(workspace_id: str, task: str) -> str: - """Delegate a task to another workspace via A2A (synchronous — waits for response).""" +async def tool_delegate_task( + workspace_id: str, + task: str, + source_workspace_id: str | None = None, +) -> str: + """Delegate a task to another workspace via A2A (synchronous — waits for response). + + ``source_workspace_id`` selects which registered workspace this + delegation originates from — drives auth + the X-Workspace-ID source + header so the platform's a2a_proxy logs the correct sender. Single- + workspace operators leave it None and routing falls back to the + module-level WORKSPACE_ID. + """ if not workspace_id or not task: return "Error: workspace_id and task are required" + # Auto-route: if source not specified, look up which registered + # workspace last saw this peer (populated by tool_list_peers). Falls + # back to the legacy WORKSPACE_ID for single-workspace operators. + src = source_workspace_id or _peer_to_source.get(workspace_id) or None + # Discover the target. discover_peer is the access-control gate + # name/status lookup. The peer's reported ``url`` field is NOT used # for routing — see send_a2a_message, which constructs the URL via # the platform's A2A proxy. - peer = await discover_peer(workspace_id) + peer = await discover_peer(workspace_id, source_workspace_id=src) if not peer: return f"Error: workspace {workspace_id} not found or not accessible (check access control)" @@ -214,7 +232,7 @@ async def tool_delegate_task(workspace_id: str, task: str) -> str: # send_a2a_message routes through ${PLATFORM_URL}/workspaces/{id}/a2a # (the platform proxy) so the same code works for in-container and # external (standalone molecule-mcp) callers. - result = await send_a2a_message(workspace_id, task) + result = await send_a2a_message(workspace_id, task, source_workspace_id=src) # Detect delegation failures — wrap them clearly so the calling agent # can decide to retry, use another peer, or handle the task itself. @@ -246,27 +264,41 @@ async def tool_delegate_task(workspace_id: str, task: str) -> str: return result -async def tool_delegate_task_async(workspace_id: str, task: str) -> str: +async def tool_delegate_task_async( + workspace_id: str, + task: str, + source_workspace_id: str | None = None, +) -> str: """Delegate a task via the platform's async delegation API (fire-and-forget). Uses POST /workspaces/:id/delegate which runs the A2A request in the background. Results are tracked in the platform DB and broadcast via WebSocket. Use check_task_status to poll for results. + + ``source_workspace_id`` selects the sending workspace (which one of + this agent's registered workspaces gets logged as the originator); + auto-routes via the peer→source cache when omitted. """ if not workspace_id or not task: return "Error: workspace_id and task are required" - # Idempotency key: SHA-256 of (workspace_id, task) so that a restarted agent - # firing the same delegation gets the same key and the platform returns the - # existing delegation_id instead of creating a duplicate. Fixes #1456. - idem_key = hashlib.sha256(f"{workspace_id}:{task}".encode()).hexdigest()[:32] + src = source_workspace_id or _peer_to_source.get(workspace_id) or WORKSPACE_ID + + # Idempotency key: SHA-256 of (source, target, task) so that a + # restarted agent firing the same delegation gets the same key and + # the platform returns the existing delegation_id instead of + # creating a duplicate. Fixes #1456. Source is in the key so the + # SAME task delegated from two different registered workspaces + # produces two distinct delegations (the right behavior — one per + # tenant audit trail). + idem_key = hashlib.sha256(f"{src}:{workspace_id}:{task}".encode()).hexdigest()[:32] try: async with httpx.AsyncClient(timeout=10.0) as client: resp = await client.post( - f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/delegate", + f"{PLATFORM_URL}/workspaces/{src}/delegate", json={"target_id": workspace_id, "task": task, "idempotency_key": idem_key}, - headers=_auth_headers_for_heartbeat(), + headers=_auth_headers_for_heartbeat(src), ) if resp.status_code == 202: data = resp.json() @@ -282,18 +314,27 @@ async def tool_delegate_task_async(workspace_id: str, task: str) -> str: return f"Error: delegation failed — {e}" -async def tool_check_task_status(workspace_id: str, task_id: str) -> str: +async def tool_check_task_status( + workspace_id: str, + task_id: str, + source_workspace_id: str | None = None, +) -> str: """Check delegations for this workspace via the platform API. Args: - workspace_id: Ignored (kept for backward compat). Checks this workspace's delegations. + workspace_id: Ignored (kept for backward compat). Checks + ``source_workspace_id``'s delegations (the workspace that + FIRED the delegations), not the target's. task_id: Optional delegation_id to filter. If empty, returns all recent delegations. + source_workspace_id: Which registered workspace's delegation log + to query. Defaults to the module-level WORKSPACE_ID. """ + src = source_workspace_id or WORKSPACE_ID try: async with httpx.AsyncClient(timeout=10.0) as client: resp = await client.get( - f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/delegations", - headers=_auth_headers_for_heartbeat(), + f"{PLATFORM_URL}/workspaces/{src}/delegations", + headers=_auth_headers_for_heartbeat(src), ) if resp.status_code != 200: return f"Error: failed to check delegations ({resp.status_code})" @@ -439,25 +480,68 @@ async def tool_send_message_to_user( return f"Error sending message: {e}" -async def tool_list_peers() -> str: - """List all workspaces this agent can communicate with.""" - peers, diagnostic = await get_peers_with_diagnostic() - if not peers: - if diagnostic is not None: - # Non-trivial empty: auth failure / 404 / 5xx / network — surface - # the actual reason so the user/agent doesn't have to guess. #2397. - return f"No peers found. {diagnostic}" +async def tool_list_peers(source_workspace_id: str | None = None) -> str: + """List all workspaces this agent can communicate with. + + Behavior: + - ``source_workspace_id`` set → list peers of that one workspace. + - Unset, single-workspace mode → list peers of WORKSPACE_ID + (the legacy path, unchanged). + - Unset, multi-workspace mode (MOLECULE_WORKSPACES populated) → + aggregate across every registered workspace, prefixing each + peer with its source so the agent / user can see the full peer + surface in one call. + + Side-effect: populates ``_peer_to_source`` so subsequent + ``tool_delegate_task(target)`` auto-routes through the correct + sending workspace without the agent needing ``source_workspace_id``. + """ + sources: list[str] + aggregate = False + if source_workspace_id: + sources = [source_workspace_id] + else: + registered = list_registered_workspaces() + if len(registered) > 1: + sources = registered + aggregate = True + else: + sources = [WORKSPACE_ID] + + all_peers: list[tuple[str, dict]] = [] # (source, peer_record) + diagnostics: list[tuple[str, str]] = [] # (source, diagnostic) + for src in sources: + peers, diagnostic = await get_peers_with_diagnostic(source_workspace_id=src) + if peers: + for p in peers: + all_peers.append((src, p)) + elif diagnostic is not None: + diagnostics.append((src, diagnostic)) + + if not all_peers: + if diagnostics: + joined = "; ".join(f"[{src[:8]}] {d}" for src, d in diagnostics) + return f"No peers found. {joined}" return ( "You have no peers in the platform registry. " "(No parent, no children, no siblings registered.)" ) + lines = [] - for p in peers: + for src, p in all_peers: status = p.get("status", "unknown") role = p.get("role", "") + peer_id = p["id"] # Cache name for use in delegate_task - _peer_names[p["id"]] = p["name"] - lines.append(f"- {p['name']} (ID: {p['id']}, status: {status}, role: {role})") + _peer_names[peer_id] = p["name"] + # Cache the source workspace so tool_delegate_task auto-routes + _peer_to_source[peer_id] = src + if aggregate: + lines.append( + f"- {p['name']} (ID: {peer_id}, status: {status}, role: {role}, via: {src[:8]})" + ) + else: + lines.append(f"- {p['name']} (ID: {peer_id}, status: {status}, role: {role})") return "\n".join(lines) diff --git a/workspace/platform_auth.py b/workspace/platform_auth.py index 17157428..7c3eb215 100644 --- a/workspace/platform_auth.py +++ b/workspace/platform_auth.py @@ -162,6 +162,22 @@ def get_workspace_token(workspace_id: str) -> str | None: return _WORKSPACE_TOKENS.get((workspace_id or "").strip()) +def list_registered_workspaces() -> list[str]: + """Return the workspace IDs currently in the per-workspace registry. + + Empty list when no multi-workspace registration has happened (i.e. + single-workspace operators using the legacy WORKSPACE_ID env path — + those callers should fall back to the module-level WORKSPACE_ID). + + Used by ``a2a_tools.tool_list_peers`` to aggregate peers across all + workspaces an external agent has registered against, so a + multi-workspace operator can see the full peer surface in one call + instead of having to query each workspace separately. + """ + with _WORKSPACE_TOKENS_LOCK: + return list(_WORKSPACE_TOKENS.keys()) + + def auth_headers(workspace_id: str | None = None) -> dict[str, str]: """Return a header dict to merge into httpx calls. Empty if no token is available yet — callers send the request as-is and the platform's @@ -221,7 +237,12 @@ def self_source_headers(workspace_id: str) -> dict[str, str]: correlation ID) only touches one place — and so that any workspace→A2A POST that doesn't use this helper stands out in review as a probable bug.""" - return {**auth_headers(), "X-Workspace-ID": workspace_id} + # Pass workspace_id through to auth_headers so the bearer token + # comes from the per-workspace registry when set — otherwise a + # multi-workspace operator's source-tagged POST authenticates with + # the legacy single token (or none) and the platform rejects with + # 401, or worse silently logs the wrong source. + return {**auth_headers(workspace_id), "X-Workspace-ID": workspace_id} def clear_cache() -> None: diff --git a/workspace/platform_tools/registry.py b/workspace/platform_tools/registry.py index 6da1bb6c..d026b3c5 100644 --- a/workspace/platform_tools/registry.py +++ b/workspace/platform_tools/registry.py @@ -140,6 +140,16 @@ _DELEGATE_TASK = ToolSpec( "type": "string", "description": "Task description to send to the peer.", }, + "source_workspace_id": { + "type": "string", + "description": ( + "Optional. The registered workspace this delegation " + "originates from when the agent is registered to " + "multiple workspaces (MOLECULE_WORKSPACES). Auto-" + "routes via the peer→source cache when omitted; " + "single-workspace operators can ignore it." + ), + }, }, "required": ["workspace_id", "task"], }, @@ -170,6 +180,14 @@ _DELEGATE_TASK_ASYNC = ToolSpec( "type": "string", "description": "Task description to send to the peer.", }, + "source_workspace_id": { + "type": "string", + "description": ( + "Optional. The registered workspace this delegation " + "originates from. Auto-routes via the peer→source " + "cache when omitted." + ), + }, }, "required": ["workspace_id", "task"], }, @@ -201,6 +219,13 @@ _CHECK_TASK_STATUS = ToolSpec( "type": "string", "description": "task_id returned by delegate_task_async.", }, + "source_workspace_id": { + "type": "string", + "description": ( + "Optional. Which registered workspace's delegation " + "log to query. Defaults to this workspace." + ), + }, }, "required": ["workspace_id", "task_id"], }, @@ -217,9 +242,23 @@ _LIST_PEERS = ToolSpec( when_to_use=( "Call this first when you need to delegate but don't know the " "target's ID. Access control is enforced — you only see " - "siblings, parent, and direct children." + "siblings, parent, and direct children. With " + "MOLECULE_WORKSPACES set, peers from every registered workspace " + "are aggregated and tagged with their source." ), - input_schema={"type": "object", "properties": {}}, + input_schema={ + "type": "object", + "properties": { + "source_workspace_id": { + "type": "string", + "description": ( + "Optional. Restrict to peers of this one registered " + "workspace. Omit to aggregate across all workspaces " + "an external agent has registered against." + ), + }, + }, + }, impl=tool_list_peers, section=A2A_SECTION, ) diff --git a/workspace/tests/snapshots/a2a_instructions_mcp.txt b/workspace/tests/snapshots/a2a_instructions_mcp.txt index 8eacdb1c..6bcf471e 100644 --- a/workspace/tests/snapshots/a2a_instructions_mcp.txt +++ b/workspace/tests/snapshots/a2a_instructions_mcp.txt @@ -21,7 +21,7 @@ Use for long-running work where you want to keep doing other things while the pe Statuses: pending/in_progress (peer still working — wait), queued (peer is busy with a prior task — DO NOT retry, the platform stitches the response when it finishes), completed (result available), failed (real error — fall back to a different peer or handle it yourself). ### list_peers -Call this first when you need to delegate but don't know the target's ID. Access control is enforced — you only see siblings, parent, and direct children. +Call this first when you need to delegate but don't know the target's ID. Access control is enforced — you only see siblings, parent, and direct children. With MOLECULE_WORKSPACES set, peers from every registered workspace are aggregated and tagged with their source. ### get_workspace_info Use to introspect your own identity (e.g. before reporting back to the user, or to determine whether you're a tier-0 root that can write GLOBAL memory). diff --git a/workspace/tests/test_a2a_multi_workspace.py b/workspace/tests/test_a2a_multi_workspace.py new file mode 100644 index 00000000..4278ff11 --- /dev/null +++ b/workspace/tests/test_a2a_multi_workspace.py @@ -0,0 +1,425 @@ +"""Tests for cross-workspace A2A delegation + peer aggregation (PR-2 of +the multi-workspace MCP feature). + +PR-1 made the auth registry per-workspace. PR-2 threads +``source_workspace_id`` through the A2A client + tool surface so an +external agent registered against multiple workspaces can: + + - List peers across every registered workspace in one call. + - Delegate from a specific source workspace (or auto-route via the + peer→source cache populated by list_peers). + - The legacy single-workspace path (no MOLECULE_WORKSPACES) is + untouched — falls back to the module-level WORKSPACE_ID exactly as + before. +""" +from __future__ import annotations + +import sys +from pathlib import Path +from unittest.mock import AsyncMock, patch + +import pytest + +_THIS = Path(__file__).resolve() +sys.path.insert(0, str(_THIS.parent.parent)) + + +@pytest.fixture(autouse=True) +def _isolate_env(monkeypatch): + """Ensure WORKSPACE_ID + PLATFORM_URL are predictable across tests + and the per-workspace token registry doesn't leak between cases.""" + monkeypatch.setenv("WORKSPACE_ID", "00000000-0000-0000-0000-000000000001") + monkeypatch.setenv("PLATFORM_URL", "http://test-platform") + + import platform_auth + platform_auth.clear_cache() + + import a2a_client + a2a_client._peer_to_source.clear() + a2a_client._peer_names.clear() + + yield + + platform_auth.clear_cache() + a2a_client._peer_to_source.clear() + a2a_client._peer_names.clear() + + +# --------------------------------------------------------------------------- +# Lower-layer helpers — discover_peer / send_a2a_message / +# get_peers_with_diagnostic — should route via source_workspace_id when +# set, fall back to module-level WORKSPACE_ID otherwise. +# --------------------------------------------------------------------------- + + +class TestDiscoverPeerSourceRouting: + @pytest.mark.asyncio + async def test_routes_through_source_workspace_id_when_set(self, monkeypatch): + """source_workspace_id drives the X-Workspace-ID header AND the + bearer token (via auth_headers(src)).""" + import platform_auth, a2a_client + + platform_auth.register_workspace_token("aaaa1111-aaaa-aaaa-aaaa-aaaaaaaaaaaa", "token-A") + + captured: dict = {} + + class _Resp: + status_code = 200 + def json(self): + return {"id": "bbbb2222-bbbb-bbbb-bbbb-bbbbbbbbbbbb", "name": "peer-of-A"} + + class _Client: + async def __aenter__(self): + return self + async def __aexit__(self, *a): + return None + async def get(self, url, headers): + captured["url"] = url + captured["headers"] = headers + return _Resp() + + monkeypatch.setattr(a2a_client.httpx, "AsyncClient", lambda timeout: _Client()) + + result = await a2a_client.discover_peer( + "bbbb2222-bbbb-bbbb-bbbb-bbbbbbbbbbbb", + source_workspace_id="aaaa1111-aaaa-aaaa-aaaa-aaaaaaaaaaaa", + ) + assert result == {"id": "bbbb2222-bbbb-bbbb-bbbb-bbbbbbbbbbbb", "name": "peer-of-A"} + assert captured["headers"]["X-Workspace-ID"] == "aaaa1111-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + assert captured["headers"]["Authorization"] == "Bearer token-A" + + @pytest.mark.asyncio + async def test_falls_back_to_module_workspace_id(self, monkeypatch): + """No source_workspace_id → uses module-level WORKSPACE_ID.""" + import a2a_client + + captured: dict = {} + + class _Resp: + status_code = 200 + def json(self): + return {"id": "x", "name": "y"} + + class _Client: + async def __aenter__(self): + return self + async def __aexit__(self, *a): + return None + async def get(self, url, headers): + captured["headers"] = headers + return _Resp() + + monkeypatch.setattr(a2a_client.httpx, "AsyncClient", lambda timeout: _Client()) + + await a2a_client.discover_peer("11111111-1111-1111-1111-111111111111") + # Falls back to the env-var WORKSPACE_ID set in _isolate_env. + assert captured["headers"]["X-Workspace-ID"] == "00000000-0000-0000-0000-000000000001" + + @pytest.mark.asyncio + async def test_invalid_target_id_returns_none_without_routing(self, monkeypatch): + """Validation runs before routing — short-circuits without an + outbound HTTP attempt regardless of source.""" + import a2a_client + + called = {"hit": False} + + class _Client: + async def __aenter__(self): + called["hit"] = True + return self + async def __aexit__(self, *a): + return None + async def get(self, *a, **kw): + called["hit"] = True + + monkeypatch.setattr(a2a_client.httpx, "AsyncClient", lambda timeout: _Client()) + + result = await a2a_client.discover_peer("not-a-uuid", source_workspace_id="anything") + assert result is None + assert not called["hit"] + + +class TestSendA2AMessageSourceRouting: + @pytest.mark.asyncio + async def test_self_source_headers_built_from_source_arg(self, monkeypatch): + """The X-Workspace-ID source header must reflect the SENDING + workspace, not the module-level WORKSPACE_ID. Otherwise + cross-workspace delegations land in the wrong tenant's audit log.""" + import platform_auth, a2a_client + + platform_auth.register_workspace_token("cccc3333-cccc-cccc-cccc-cccccccccccc", "token-C") + + captured: dict = {} + + class _Resp: + status_code = 200 + def json(self): + return {"jsonrpc": "2.0", "result": {"parts": [{"text": "PONG"}]}} + + class _Client: + async def __aenter__(self): + return self + async def __aexit__(self, *a): + return None + async def post(self, url, headers, json): + captured["url"] = url + captured["headers"] = headers + return _Resp() + + monkeypatch.setattr(a2a_client.httpx, "AsyncClient", lambda timeout: _Client()) + + result = await a2a_client.send_a2a_message( + "dddd4444-dddd-dddd-dddd-dddddddddddd", + "ping", + source_workspace_id="cccc3333-cccc-cccc-cccc-cccccccccccc", + ) + assert result == "PONG" + assert captured["headers"]["X-Workspace-ID"] == "cccc3333-cccc-cccc-cccc-cccccccccccc" + assert captured["headers"]["Authorization"] == "Bearer token-C" + + +class TestGetPeersSourceRouting: + @pytest.mark.asyncio + async def test_url_and_headers_use_source_workspace_id(self, monkeypatch): + import platform_auth, a2a_client + + platform_auth.register_workspace_token("eeee5555-eeee-eeee-eeee-eeeeeeeeeeee", "token-E") + + captured: dict = {} + + class _Resp: + status_code = 200 + def json(self): + return [{"id": "x", "name": "peer-x", "status": "online"}] + + class _Client: + async def __aenter__(self): + return self + async def __aexit__(self, *a): + return None + async def get(self, url, headers): + captured["url"] = url + captured["headers"] = headers + return _Resp() + + monkeypatch.setattr(a2a_client.httpx, "AsyncClient", lambda timeout: _Client()) + + peers, diag = await a2a_client.get_peers_with_diagnostic( + source_workspace_id="eeee5555-eeee-eeee-eeee-eeeeeeeeeeee", + ) + assert diag is None + assert peers == [{"id": "x", "name": "peer-x", "status": "online"}] + assert "/registry/eeee5555-eeee-eeee-eeee-eeeeeeeeeeee/peers" in captured["url"] + assert captured["headers"]["X-Workspace-ID"] == "eeee5555-eeee-eeee-eeee-eeeeeeeeeeee" + assert captured["headers"]["Authorization"] == "Bearer token-E" + + +# --------------------------------------------------------------------------- +# Tool surface — tool_list_peers aggregation + tool_delegate_task +# auto-routing via the peer→source cache. +# --------------------------------------------------------------------------- + + +class TestToolListPeersAggregation: + @pytest.mark.asyncio + async def test_aggregates_across_registered_workspaces(self, monkeypatch): + """Multi-workspace mode (>1 registered) → list_peers aggregates.""" + import platform_auth, a2a_tools, a2a_client + + ws_a = "aaaa1111-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + ws_b = "bbbb2222-bbbb-bbbb-bbbb-bbbbbbbbbbbb" + platform_auth.register_workspace_token(ws_a, "token-A") + platform_auth.register_workspace_token(ws_b, "token-B") + + async def fake_get_peers(source_workspace_id=None): + if source_workspace_id == ws_a: + return [{"id": "1111aaaa-1111-1111-1111-111111111111", "name": "alice", "status": "online", "role": "ops"}], None + if source_workspace_id == ws_b: + return [{"id": "2222bbbb-2222-2222-2222-222222222222", "name": "bob", "status": "online", "role": "dev"}], None + return [], None + + with patch("a2a_tools.get_peers_with_diagnostic", side_effect=fake_get_peers): + output = await a2a_tools.tool_list_peers() + + assert "alice" in output + assert "bob" in output + assert f"via: {ws_a[:8]}" in output + assert f"via: {ws_b[:8]}" in output + + # Side-effect: peer→source map populated for downstream auto-routing. + assert a2a_client._peer_to_source["1111aaaa-1111-1111-1111-111111111111"] == ws_a + assert a2a_client._peer_to_source["2222bbbb-2222-2222-2222-222222222222"] == ws_b + + @pytest.mark.asyncio + async def test_single_workspace_unchanged(self, monkeypatch): + """Legacy path: no MOLECULE_WORKSPACES → module WORKSPACE_ID, + no `via:` annotation, no aggregation.""" + import a2a_tools, a2a_client + + async def fake_get_peers(source_workspace_id=None): + assert source_workspace_id == a2a_client.WORKSPACE_ID + return [{"id": "1111aaaa-1111-1111-1111-111111111111", "name": "alice", "status": "online", "role": "ops"}], None + + with patch("a2a_tools.get_peers_with_diagnostic", side_effect=fake_get_peers): + output = await a2a_tools.tool_list_peers() + + assert "alice" in output + assert "via:" not in output + + @pytest.mark.asyncio + async def test_explicit_source_workspace_id_overrides(self, monkeypatch): + """Explicit source_workspace_id arg → query that workspace only, + not aggregated.""" + import platform_auth, a2a_tools + + ws_a = "aaaa1111-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + ws_b = "bbbb2222-bbbb-bbbb-bbbb-bbbbbbbbbbbb" + platform_auth.register_workspace_token(ws_a, "token-A") + platform_auth.register_workspace_token(ws_b, "token-B") + + seen = [] + + async def fake_get_peers(source_workspace_id=None): + seen.append(source_workspace_id) + return [{"id": "1111aaaa-1111-1111-1111-111111111111", "name": "alice", "status": "online", "role": "ops"}], None + + with patch("a2a_tools.get_peers_with_diagnostic", side_effect=fake_get_peers): + output = await a2a_tools.tool_list_peers(source_workspace_id=ws_a) + + assert seen == [ws_a] + # Aggregate annotation not applied when scoped to one source. + assert "via:" not in output + + @pytest.mark.asyncio + async def test_aggregated_diagnostic_per_source(self): + """When all workspaces return empty-with-diagnostic, the message + prefixes each diagnostic with its source workspace's short id.""" + import platform_auth, a2a_tools + + ws_a = "aaaa1111-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + ws_b = "bbbb2222-bbbb-bbbb-bbbb-bbbbbbbbbbbb" + platform_auth.register_workspace_token(ws_a, "token-A") + platform_auth.register_workspace_token(ws_b, "token-B") + + async def fake_get_peers(source_workspace_id=None): + if source_workspace_id == ws_a: + return [], "auth failed" + return [], "platform 5xx" + + with patch("a2a_tools.get_peers_with_diagnostic", side_effect=fake_get_peers): + out = await a2a_tools.tool_list_peers() + + assert "[aaaa1111] auth failed" in out + assert "[bbbb2222] platform 5xx" in out + + +class TestToolDelegateTaskAutoRouting: + @pytest.mark.asyncio + async def test_uses_cached_source_when_available(self, monkeypatch): + """When the peer is in the _peer_to_source cache (populated by a + prior list_peers), delegate_task auto-routes through that + source without the agent specifying source_workspace_id.""" + import a2a_tools, a2a_client + + ws_a = "aaaa1111-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + peer_id = "1111aaaa-1111-1111-1111-111111111111" + a2a_client._peer_to_source[peer_id] = ws_a + + seen_discover_src = {} + seen_send_src = {} + + async def fake_discover(target_id, source_workspace_id=None): + seen_discover_src["src"] = source_workspace_id + return {"id": target_id, "name": "alice", "status": "online"} + + async def fake_send(passed_peer_id, message, source_workspace_id=None): + seen_send_src["src"] = source_workspace_id + return "ok" + + with patch("a2a_tools.discover_peer", side_effect=fake_discover), \ + patch("a2a_tools.send_a2a_message", side_effect=fake_send), \ + patch("a2a_tools.report_activity", new=AsyncMock()): + await a2a_tools.tool_delegate_task(peer_id, "do thing") + + assert seen_discover_src["src"] == ws_a + assert seen_send_src["src"] == ws_a + + @pytest.mark.asyncio + async def test_explicit_source_overrides_cache(self): + """Explicit source_workspace_id beats the auto-routing cache.""" + import a2a_tools, a2a_client + + peer_id = "1111aaaa-1111-1111-1111-111111111111" + ws_cached = "aaaa1111-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + ws_explicit = "cccc3333-cccc-cccc-cccc-cccccccccccc" + a2a_client._peer_to_source[peer_id] = ws_cached + + seen = {} + + async def fake_discover(target_id, source_workspace_id=None): + seen["discover"] = source_workspace_id + return {"id": target_id, "name": "alice", "status": "online"} + + async def fake_send(passed_peer_id, message, source_workspace_id=None): + seen["send"] = source_workspace_id + return "ok" + + with patch("a2a_tools.discover_peer", side_effect=fake_discover), \ + patch("a2a_tools.send_a2a_message", side_effect=fake_send), \ + patch("a2a_tools.report_activity", new=AsyncMock()): + await a2a_tools.tool_delegate_task( + peer_id, "do thing", source_workspace_id=ws_explicit, + ) + + assert seen["discover"] == ws_explicit + assert seen["send"] == ws_explicit + + @pytest.mark.asyncio + async def test_no_cache_no_explicit_falls_back_to_module(self): + """Single-workspace operators see no behavior change — when the + peer isn't cached and no source is passed, source_workspace_id + stays None and the lower layer falls back to WORKSPACE_ID.""" + import a2a_tools + + peer_id = "1111aaaa-1111-1111-1111-111111111111" + seen = {} + + async def fake_discover(target_id, source_workspace_id=None): + seen["discover"] = source_workspace_id + return {"id": target_id, "name": "alice", "status": "online"} + + async def fake_send(passed_peer_id, message, source_workspace_id=None): + seen["send"] = source_workspace_id + return "ok" + + with patch("a2a_tools.discover_peer", side_effect=fake_discover), \ + patch("a2a_tools.send_a2a_message", side_effect=fake_send), \ + patch("a2a_tools.report_activity", new=AsyncMock()): + await a2a_tools.tool_delegate_task(peer_id, "do thing") + + assert seen["discover"] is None + assert seen["send"] is None + + +# --------------------------------------------------------------------------- +# platform_auth registry helper exposed to the tool layer. +# --------------------------------------------------------------------------- + + +class TestListRegisteredWorkspaces: + def test_empty_when_no_registrations(self): + import platform_auth + assert platform_auth.list_registered_workspaces() == [] + + def test_returns_registered_ids(self): + import platform_auth + platform_auth.register_workspace_token("ws-1", "tok-1") + platform_auth.register_workspace_token("ws-2", "tok-2") + result = sorted(platform_auth.list_registered_workspaces()) + assert result == ["ws-1", "ws-2"] + + def test_clear_cache_empties_registry(self): + import platform_auth + platform_auth.register_workspace_token("ws-1", "tok-1") + platform_auth.clear_cache() + assert platform_auth.list_registered_workspaces() == [] diff --git a/workspace/tests/test_a2a_tools_impl.py b/workspace/tests/test_a2a_tools_impl.py index 5d994280..5f8bd7bc 100644 --- a/workspace/tests/test_a2a_tools_impl.py +++ b/workspace/tests/test_a2a_tools_impl.py @@ -255,9 +255,10 @@ class TestToolDelegateTask: "status": "online", } captured = {} - async def fake_send(passed_peer_id, message): + async def fake_send(passed_peer_id, message, source_workspace_id=None): captured["peer_id"] = passed_peer_id captured["message"] = message + captured["source"] = source_workspace_id return "ok" with patch("a2a_tools.discover_peer", return_value=peer), \